1/*
2 * Copyright (c) 2016 Thomas Pornin <pornin@bolet.org>
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining
5 * a copy of this software and associated documentation files (the
6 * "Software"), to deal in the Software without restriction, including
7 * without limitation the rights to use, copy, modify, merge, publish,
8 * distribute, sublicense, and/or sell copies of the Software, and to
9 * permit persons to whom the Software is furnished to do so, subject to
10 * the following conditions:
11 *
12 * The above copyright notice and this permission notice shall be
13 * included in all copies or substantial portions of the Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
17 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
18 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
19 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
20 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
21 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include <stdio.h>
26#include <stdlib.h>
27#include <string.h>
28#include <stdarg.h>
29#include <time.h>
30
31#include <gmp.h>
32
33#include "bearssl.h"
34#include "inner.h"
35
36/*
37 * Pointers to implementations.
38 */
39typedef struct {
40	uint32_t word_size;
41	void (*zero)(uint32_t *x, uint32_t bit_len);
42	void (*decode)(uint32_t *x, const void *src, size_t len);
43	uint32_t (*decode_mod)(uint32_t *x,
44		const void *src, size_t len, const uint32_t *m);
45	void (*reduce)(uint32_t *x, const uint32_t *a, const uint32_t *m);
46	void (*decode_reduce)(uint32_t *x,
47		const void *src, size_t len, const uint32_t *m);
48	void (*encode)(void *dst, size_t len, const uint32_t *x);
49	uint32_t (*add)(uint32_t *a, const uint32_t *b, uint32_t ctl);
50	uint32_t (*sub)(uint32_t *a, const uint32_t *b, uint32_t ctl);
51	uint32_t (*ninv)(uint32_t x);
52	void (*montymul)(uint32_t *d, const uint32_t *x, const uint32_t *y,
53		const uint32_t *m, uint32_t m0i);
54	void (*to_monty)(uint32_t *x, const uint32_t *m);
55	void (*from_monty)(uint32_t *x, const uint32_t *m, uint32_t m0i);
56	void (*modpow)(uint32_t *x, const unsigned char *e, size_t elen,
57		const uint32_t *m, uint32_t m0i, uint32_t *t1, uint32_t *t2);
58} int_impl;
59
60static const int_impl i31_impl = {
61	31,
62	&br_i31_zero,
63	&br_i31_decode,
64	&br_i31_decode_mod,
65	&br_i31_reduce,
66	&br_i31_decode_reduce,
67	&br_i31_encode,
68	&br_i31_add,
69	&br_i31_sub,
70	&br_i31_ninv31,
71	&br_i31_montymul,
72	&br_i31_to_monty,
73	&br_i31_from_monty,
74	&br_i31_modpow
75};
76static const int_impl i32_impl = {
77	32,
78	&br_i32_zero,
79	&br_i32_decode,
80	&br_i32_decode_mod,
81	&br_i32_reduce,
82	&br_i32_decode_reduce,
83	&br_i32_encode,
84	&br_i32_add,
85	&br_i32_sub,
86	&br_i32_ninv32,
87	&br_i32_montymul,
88	&br_i32_to_monty,
89	&br_i32_from_monty,
90	&br_i32_modpow
91};
92
93static const int_impl *impl;
94
95static gmp_randstate_t RNG;
96
97/*
98 * Get a random prime of length 'size' bits. This function also guarantees
99 * that x-1 is not a multiple of 65537.
100 */
101static void
102rand_prime(mpz_t x, int size)
103{
104	for (;;) {
105		mpz_urandomb(x, RNG, size - 1);
106		mpz_setbit(x, 0);
107		mpz_setbit(x, size - 1);
108		if (mpz_probab_prime_p(x, 50)) {
109			mpz_sub_ui(x, x, 1);
110			if (mpz_divisible_ui_p(x, 65537)) {
111				continue;
112			}
113			mpz_add_ui(x, x, 1);
114			return;
115		}
116	}
117}
118
119/*
120 * Print out a GMP integer (for debug).
121 */
122static void
123print_z(mpz_t z)
124{
125	unsigned char zb[1000];
126	size_t zlen, k;
127
128	mpz_export(zb, &zlen, 1, 1, 0, 0, z);
129	if (zlen == 0) {
130		printf(" 00");
131		return;
132	}
133	if ((zlen & 3) != 0) {
134		k = 4 - (zlen & 3);
135		memmove(zb + k, zb, zlen);
136		memset(zb, 0, k);
137		zlen += k;
138	}
139	for (k = 0; k < zlen; k += 4) {
140		printf(" %02X%02X%02X%02X",
141			zb[k], zb[k + 1], zb[k + 2], zb[k + 3]);
142	}
143}
144
145/*
146 * Print out an i31 or i32 integer (for debug).
147 */
148static void
149print_u(uint32_t *x)
150{
151	size_t k;
152
153	if (x[0] == 0) {
154		printf(" 00000000 (0, 0)");
155		return;
156	}
157	for (k = (x[0] + 31) >> 5; k > 0; k --) {
158		printf(" %08lX", (unsigned long)x[k]);
159	}
160	printf(" (%u, %u)", (unsigned)(x[0] >> 5), (unsigned)(x[0] & 31));
161}
162
163/*
164 * Check that an i31/i32 number and a GMP number are equal.
165 */
166static void
167check_eqz(uint32_t *x, mpz_t z)
168{
169	unsigned char xb[1000];
170	unsigned char zb[1000];
171	size_t xlen, zlen;
172	int good;
173
174	xlen = ((x[0] + 31) & ~(uint32_t)31) >> 3;
175	impl->encode(xb, xlen, x);
176	mpz_export(zb, &zlen, 1, 1, 0, 0, z);
177	good = 1;
178	if (xlen < zlen) {
179		good = 0;
180	} else if (xlen > zlen) {
181		size_t u;
182
183		for (u = xlen; u > zlen; u --) {
184			if (xb[xlen - u] != 0) {
185				good = 0;
186				break;
187			}
188		}
189	}
190	good = good && memcmp(xb + xlen - zlen, zb, zlen) == 0;
191	if (!good) {
192		size_t u;
193
194		printf("Mismatch:\n");
195		printf("  x = ");
196		print_u(x);
197		printf("\n");
198		printf("  ex = ");
199		for (u = 0; u < xlen; u ++) {
200			printf("%02X", xb[u]);
201		}
202		printf("\n");
203		printf("  z = ");
204		print_z(z);
205		printf("\n");
206		exit(EXIT_FAILURE);
207	}
208}
209
210/* obsolete
211static void
212mp_to_br(uint32_t *mx, uint32_t x_bitlen, mpz_t x)
213{
214	uint32_t x_ebitlen;
215	size_t xlen;
216
217	if (mpz_sizeinbase(x, 2) > x_bitlen) {
218		abort();
219	}
220	x_ebitlen = ((x_bitlen / 31) << 5) + (x_bitlen % 31);
221	br_i31_zero(mx, x_ebitlen);
222	mpz_export(mx + 1, &xlen, -1, sizeof *mx, 0, 1, x);
223}
224*/
225
226static void
227test_modint(void)
228{
229	int i, j, k;
230	mpz_t p, a, b, v, t1;
231
232	printf("Test modular integers: ");
233	fflush(stdout);
234
235	gmp_randinit_mt(RNG);
236	mpz_init(p);
237	mpz_init(a);
238	mpz_init(b);
239	mpz_init(v);
240	mpz_init(t1);
241	mpz_set_ui(t1, (unsigned long)time(NULL));
242	gmp_randseed(RNG, t1);
243	for (k = 2; k <= 128; k ++) {
244		for (i = 0; i < 10; i ++) {
245			unsigned char ep[100], ea[100], eb[100], ev[100];
246			size_t plen, alen, blen, vlen;
247			uint32_t mp[40], ma[40], mb[40], mv[60], mx[100];
248			uint32_t mt1[40], mt2[40], mt3[40];
249			uint32_t ctl;
250			uint32_t mp0i;
251
252			rand_prime(p, k);
253			mpz_urandomm(a, RNG, p);
254			mpz_urandomm(b, RNG, p);
255			mpz_urandomb(v, RNG, k + 60);
256			if (mpz_sgn(b) == 0) {
257				mpz_set_ui(b, 1);
258			}
259			mpz_export(ep, &plen, 1, 1, 0, 0, p);
260			mpz_export(ea, &alen, 1, 1, 0, 0, a);
261			mpz_export(eb, &blen, 1, 1, 0, 0, b);
262			mpz_export(ev, &vlen, 1, 1, 0, 0, v);
263
264			impl->decode(mp, ep, plen);
265			if (impl->decode_mod(ma, ea, alen, mp) != 1) {
266				printf("Decode error\n");
267				printf("  ea = ");
268				print_z(a);
269				printf("\n");
270				printf("  p = ");
271				print_u(mp);
272				printf("\n");
273				exit(EXIT_FAILURE);
274			}
275			mp0i = impl->ninv(mp[1]);
276			if (impl->decode_mod(mb, eb, blen, mp) != 1) {
277				printf("Decode error\n");
278				printf("  eb = ");
279				print_z(b);
280				printf("\n");
281				printf("  p = ");
282				print_u(mp);
283				printf("\n");
284				exit(EXIT_FAILURE);
285			}
286			impl->decode(mv, ev, vlen);
287			check_eqz(mp, p);
288			check_eqz(ma, a);
289			check_eqz(mb, b);
290			check_eqz(mv, v);
291
292			impl->decode_mod(ma, ea, alen, mp);
293			impl->decode_mod(mb, eb, blen, mp);
294			ctl = impl->add(ma, mb, 1);
295			ctl |= impl->sub(ma, mp, 0) ^ (uint32_t)1;
296			impl->sub(ma, mp, ctl);
297			mpz_add(t1, a, b);
298			mpz_mod(t1, t1, p);
299			check_eqz(ma, t1);
300
301			impl->decode_mod(ma, ea, alen, mp);
302			impl->decode_mod(mb, eb, blen, mp);
303			impl->add(ma, mp, impl->sub(ma, mb, 1));
304			mpz_sub(t1, a, b);
305			mpz_mod(t1, t1, p);
306			check_eqz(ma, t1);
307
308			impl->decode_reduce(ma, ev, vlen, mp);
309			mpz_mod(t1, v, p);
310			check_eqz(ma, t1);
311
312			impl->decode(mv, ev, vlen);
313			impl->reduce(ma, mv, mp);
314			mpz_mod(t1, v, p);
315			check_eqz(ma, t1);
316
317			impl->decode_mod(ma, ea, alen, mp);
318			impl->to_monty(ma, mp);
319			mpz_mul_2exp(t1, a, ((k + impl->word_size - 1)
320				/ impl->word_size) * impl->word_size);
321			mpz_mod(t1, t1, p);
322			check_eqz(ma, t1);
323			impl->from_monty(ma, mp, mp0i);
324			check_eqz(ma, a);
325
326			impl->decode_mod(ma, ea, alen, mp);
327			impl->decode_mod(mb, eb, blen, mp);
328			impl->to_monty(ma, mp);
329			impl->montymul(mt1, ma, mb, mp, mp0i);
330			mpz_mul(t1, a, b);
331			mpz_mod(t1, t1, p);
332			check_eqz(mt1, t1);
333
334			impl->decode_mod(ma, ea, alen, mp);
335			impl->modpow(ma, ev, vlen, mp, mp0i, mt1, mt2);
336			mpz_powm(t1, a, v, p);
337			check_eqz(ma, t1);
338
339			/*
340			br_modint_decode(ma, mp, ea, alen);
341			br_modint_decode(mb, mp, eb, blen);
342			if (!br_modint_div(ma, mb, mp, mt1, mt2, mt3)) {
343				fprintf(stderr, "division failed\n");
344				exit(EXIT_FAILURE);
345			}
346			mpz_sub_ui(t1, p, 2);
347			mpz_powm(t1, b, t1, p);
348			mpz_mul(t1, a, t1);
349			mpz_mod(t1, t1, p);
350			check_eqz(ma, t1);
351
352			br_modint_decode(ma, mp, ea, alen);
353			br_modint_decode(mb, mp, eb, blen);
354			for (j = 0; j <= (2 * k + 5); j ++) {
355				br_int_add(mx, j, ma, mb);
356				mpz_add(t1, a, b);
357				mpz_tdiv_r_2exp(t1, t1, j);
358				check_eqz(mx, t1);
359
360				br_int_mul(mx, j, ma, mb);
361				mpz_mul(t1, a, b);
362				mpz_tdiv_r_2exp(t1, t1, j);
363				check_eqz(mx, t1);
364			}
365			*/
366		}
367		printf(".");
368		fflush(stdout);
369	}
370	mpz_clear(p);
371	mpz_clear(a);
372	mpz_clear(b);
373	mpz_clear(v);
374	mpz_clear(t1);
375
376	printf(" done.\n");
377	fflush(stdout);
378}
379
380#if 0
381static void
382test_RSA_core(void)
383{
384	int i, j, k;
385	mpz_t n, e, d, p, q, dp, dq, iq, t1, t2, phi;
386
387	printf("Test RSA core: ");
388	fflush(stdout);
389
390	gmp_randinit_mt(RNG);
391	mpz_init(n);
392	mpz_init(e);
393	mpz_init(d);
394	mpz_init(p);
395	mpz_init(q);
396	mpz_init(dp);
397	mpz_init(dq);
398	mpz_init(iq);
399	mpz_init(t1);
400	mpz_init(t2);
401	mpz_init(phi);
402	mpz_set_ui(t1, (unsigned long)time(NULL));
403	gmp_randseed(RNG, t1);
404
405	/*
406	 * To test corner cases, we want to try RSA keys such that the
407	 * lengths of both factors can be arbitrary modulo 2^32. Factors
408	 * p and q need not be of the same length; p can be greater than
409	 * q and q can be greater than p.
410	 *
411	 * To keep computation time reasonable, we use p and q factors of
412	 * less than 128 bits; this is way too small for secure RSA,
413	 * but enough to exercise all code paths (since we work only with
414	 * 32-bit words).
415	 */
416	for (i = 64; i <= 96; i ++) {
417		rand_prime(p, i);
418		for (j = i - 33; j <= i + 33; j ++) {
419			uint32_t mp[40], mq[40], mdp[40], mdq[40], miq[40];
420
421			/*
422			 * Generate a RSA key pair, with p of length i bits,
423			 * and q of length j bits.
424			 */
425			do {
426				rand_prime(q, j);
427			} while (mpz_cmp(p, q) == 0);
428			mpz_mul(n, p, q);
429			mpz_set_ui(e, 65537);
430			mpz_sub_ui(t1, p, 1);
431			mpz_sub_ui(t2, q, 1);
432			mpz_mul(phi, t1, t2);
433			mpz_invert(d, e, phi);
434			mpz_mod(dp, d, t1);
435			mpz_mod(dq, d, t2);
436			mpz_invert(iq, q, p);
437
438			/*
439			 * Convert the key pair elements to BearSSL arrays.
440			 */
441			mp_to_br(mp, mpz_sizeinbase(p, 2), p);
442			mp_to_br(mq, mpz_sizeinbase(q, 2), q);
443			mp_to_br(mdp, mpz_sizeinbase(dp, 2), dp);
444			mp_to_br(mdq, mpz_sizeinbase(dq, 2), dq);
445			mp_to_br(miq, mp[0], iq);
446
447			/*
448			 * Compute and check ten public/private operations.
449			 */
450			for (k = 0; k < 10; k ++) {
451				uint32_t mx[40];
452
453				mpz_urandomm(t1, RNG, n);
454				mpz_powm(t2, t1, e, n);
455				mp_to_br(mx, mpz_sizeinbase(n, 2), t2);
456				br_rsa_private_core(mx, mp, mq, mdp, mdq, miq);
457				check_eqz(mx, t1);
458			}
459		}
460		printf(".");
461		fflush(stdout);
462	}
463
464	printf(" done.\n");
465	fflush(stdout);
466}
467#endif
468
469int
470main(void)
471{
472	printf("===== i32 ======\n");
473	impl = &i32_impl;
474	test_modint();
475	printf("===== i31 ======\n");
476	impl = &i31_impl;
477	test_modint();
478	/*
479	test_RSA_core();
480	*/
481	return 0;
482}
483