mpasbn.c revision 330897
1/*-
2 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3 *
4 * Copyright (c) 2001 Dima Dorfman.
5 * All rights reserved.
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 * 2. Redistributions in binary form must reproduce the above copyright
13 *    notice, this list of conditions and the following disclaimer in the
14 *    documentation and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
20 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
22 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
25 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26 * SUCH DAMAGE.
27 */
28
29/*
30 * This is the traditional Berkeley MP library implemented in terms of
31 * the OpenSSL BIGNUM library.  It was written to replace libgmp, and
32 * is meant to be as compatible with the latter as feasible.
33 *
34 * There seems to be a lack of documentation for the Berkeley MP
35 * interface.  All I could find was libgmp documentation (which didn't
36 * talk about the semantics of the functions) and an old SunOS 4.1
37 * manual page from 1989.  The latter wasn't very detailed, either,
38 * but at least described what the function's arguments were.  In
39 * general the interface seems to be archaic, somewhat poorly
40 * designed, and poorly, if at all, documented.  It is considered
41 * harmful.
42 *
43 * Miscellaneous notes on this implementation:
44 *
45 *  - The SunOS manual page mentioned above indicates that if an error
46 *  occurs, the library should "produce messages and core images."
47 *  Given that most of the functions don't have return values (and
48 *  thus no sane way of alerting the caller to an error), this seems
49 *  reasonable.  The MPERR and MPERRX macros call warn and warnx,
50 *  respectively, then abort().
51 *
52 *  - All the functions which take an argument to be "filled in"
53 *  assume that the argument has been initialized by one of the *tom()
54 *  routines before being passed to it.  I never saw this documented
55 *  anywhere, but this seems to be consistent with the way this
56 *  library is used.
57 *
58 *  - msqrt() is the only routine which had to be implemented which
59 *  doesn't have a close counterpart in the OpenSSL BIGNUM library.
60 *  It was implemented by hand using Newton's recursive formula.
61 *  Doing it this way, although more error-prone, has the positive
62 *  sideaffect of testing a lot of other functions; if msqrt()
63 *  produces the correct results, most of the other routines will as
64 *  well.
65 *
66 *  - Internal-use-only routines (i.e., those defined here statically
67 *  and not in mp.h) have an underscore prepended to their name (this
68 *  is more for aesthetical reasons than technical).  All such
69 *  routines take an extra argument, 'msg', that denotes what they
70 *  should call themselves in an error message.  This is so a user
71 *  doesn't get an error message from a function they didn't call.
72 */
73
74#include <sys/cdefs.h>
75__FBSDID("$FreeBSD: stable/11/lib/libmp/mpasbn.c 330897 2018-03-14 03:19:51Z eadler $");
76
77#include <ctype.h>
78#include <err.h>
79#include <errno.h>
80#include <stdio.h>
81#include <stdlib.h>
82#include <string.h>
83
84#include <openssl/crypto.h>
85#include <openssl/err.h>
86
87#include "mp.h"
88
89#define MPERR(s)	do { warn s; abort(); } while (0)
90#define MPERRX(s)	do { warnx s; abort(); } while (0)
91#define BN_ERRCHECK(msg, expr) do {		\
92	if (!(expr)) _bnerr(msg);		\
93} while (0)
94
95static void _bnerr(const char *);
96static MINT *_dtom(const char *, const char *);
97static MINT *_itom(const char *, short);
98static void _madd(const char *, const MINT *, const MINT *, MINT *);
99static int _mcmpa(const char *, const MINT *, const MINT *);
100static void _mdiv(const char *, const MINT *, const MINT *, MINT *, MINT *,
101		BN_CTX *);
102static void _mfree(const char *, MINT *);
103static void _moveb(const char *, const BIGNUM *, MINT *);
104static void _movem(const char *, const MINT *, MINT *);
105static void _msub(const char *, const MINT *, const MINT *, MINT *);
106static char *_mtod(const char *, const MINT *);
107static char *_mtox(const char *, const MINT *);
108static void _mult(const char *, const MINT *, const MINT *, MINT *, BN_CTX *);
109static void _sdiv(const char *, const MINT *, short, MINT *, short *, BN_CTX *);
110static MINT *_xtom(const char *, const char *);
111
112/*
113 * Report an error from one of the BN_* functions using MPERRX.
114 */
115static void
116_bnerr(const char *msg)
117{
118
119	ERR_load_crypto_strings();
120	MPERRX(("%s: %s", msg, ERR_reason_error_string(ERR_get_error())));
121}
122
123/*
124 * Convert a decimal string to an MINT.
125 */
126static MINT *
127_dtom(const char *msg, const char *s)
128{
129	MINT *mp;
130
131	mp = malloc(sizeof(*mp));
132	if (mp == NULL)
133		MPERR(("%s", msg));
134	mp->bn = BN_new();
135	if (mp->bn == NULL)
136		_bnerr(msg);
137	BN_ERRCHECK(msg, BN_dec2bn(&mp->bn, s));
138	return (mp);
139}
140
141/*
142 * Compute the greatest common divisor of mp1 and mp2; result goes in rmp.
143 */
144void
145mp_gcd(const MINT *mp1, const MINT *mp2, MINT *rmp)
146{
147	BIGNUM b;
148	BN_CTX *c;
149
150	c = BN_CTX_new();
151	if (c == NULL)
152		_bnerr("gcd");
153	BN_init(&b);
154	BN_ERRCHECK("gcd", BN_gcd(&b, mp1->bn, mp2->bn, c));
155	_moveb("gcd", &b, rmp);
156	BN_free(&b);
157	BN_CTX_free(c);
158}
159
160/*
161 * Make an MINT out of a short integer.  Return value must be mfree()'d.
162 */
163static MINT *
164_itom(const char *msg, short n)
165{
166	MINT *mp;
167	char *s;
168
169	asprintf(&s, "%x", n);
170	if (s == NULL)
171		MPERR(("%s", msg));
172	mp = _xtom(msg, s);
173	free(s);
174	return (mp);
175}
176
177MINT *
178mp_itom(short n)
179{
180
181	return (_itom("itom", n));
182}
183
184/*
185 * Compute rmp=mp1+mp2.
186 */
187static void
188_madd(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
189{
190	BIGNUM b;
191
192	BN_init(&b);
193	BN_ERRCHECK(msg, BN_add(&b, mp1->bn, mp2->bn));
194	_moveb(msg, &b, rmp);
195	BN_free(&b);
196}
197
198void
199mp_madd(const MINT *mp1, const MINT *mp2, MINT *rmp)
200{
201
202	_madd("madd", mp1, mp2, rmp);
203}
204
205/*
206 * Return -1, 0, or 1 if mp1<mp2, mp1==mp2, or mp1>mp2, respectivley.
207 */
208int
209mp_mcmp(const MINT *mp1, const MINT *mp2)
210{
211
212	return (BN_cmp(mp1->bn, mp2->bn));
213}
214
215/*
216 * Same as mcmp but compares absolute values.
217 */
218static int
219_mcmpa(const char *msg __unused, const MINT *mp1, const MINT *mp2)
220{
221
222	return (BN_ucmp(mp1->bn, mp2->bn));
223}
224
225/*
226 * Compute qmp=nmp/dmp and rmp=nmp%dmp.
227 */
228static void
229_mdiv(const char *msg, const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp,
230    BN_CTX *c)
231{
232	BIGNUM q, r;
233
234	BN_init(&r);
235	BN_init(&q);
236	BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, c));
237	_moveb(msg, &q, qmp);
238	_moveb(msg, &r, rmp);
239	BN_free(&q);
240	BN_free(&r);
241}
242
243void
244mp_mdiv(const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
245{
246	BN_CTX *c;
247
248	c = BN_CTX_new();
249	if (c == NULL)
250		_bnerr("mdiv");
251	_mdiv("mdiv", nmp, dmp, qmp, rmp, c);
252	BN_CTX_free(c);
253}
254
255/*
256 * Free memory associated with an MINT.
257 */
258static void
259_mfree(const char *msg __unused, MINT *mp)
260{
261
262	BN_clear(mp->bn);
263	BN_free(mp->bn);
264	free(mp);
265}
266
267void
268mp_mfree(MINT *mp)
269{
270
271	_mfree("mfree", mp);
272}
273
274/*
275 * Read an integer from standard input and stick the result in mp.
276 * The input is treated to be in base 10.  This must be the silliest
277 * API in existence; why can't the program read in a string and call
278 * xtom()?  (Or if base 10 is desires, perhaps dtom() could be
279 * exported.)
280 */
281void
282mp_min(MINT *mp)
283{
284	MINT *rmp;
285	char *line, *nline;
286	size_t linelen;
287
288	line = fgetln(stdin, &linelen);
289	if (line == NULL)
290		MPERR(("min"));
291	nline = malloc(linelen + 1);
292	if (nline == NULL)
293		MPERR(("min"));
294	memcpy(nline, line, linelen);
295	nline[linelen] = '\0';
296	rmp = _dtom("min", nline);
297	_movem("min", rmp, mp);
298	_mfree("min", rmp);
299	free(nline);
300}
301
302/*
303 * Print the value of mp to standard output in base 10.  See blurb
304 * above min() for why this is so useless.
305 */
306void
307mp_mout(const MINT *mp)
308{
309	char *s;
310
311	s = _mtod("mout", mp);
312	printf("%s", s);
313	free(s);
314}
315
316/*
317 * Set the value of tmp to the value of smp (i.e., tmp=smp).
318 */
319void
320mp_move(const MINT *smp, MINT *tmp)
321{
322
323	_movem("move", smp, tmp);
324}
325
326
327/*
328 * Internal routine to set the value of tmp to that of sbp.
329 */
330static void
331_moveb(const char *msg, const BIGNUM *sbp, MINT *tmp)
332{
333
334	BN_ERRCHECK(msg, BN_copy(tmp->bn, sbp));
335}
336
337/*
338 * Internal routine to set the value of tmp to that of smp.
339 */
340static void
341_movem(const char *msg, const MINT *smp, MINT *tmp)
342{
343
344	BN_ERRCHECK(msg, BN_copy(tmp->bn, smp->bn));
345}
346
347/*
348 * Compute the square root of nmp and put the result in xmp.  The
349 * remainder goes in rmp.  Should satisfy: rmp=nmp-(xmp*xmp).
350 *
351 * Note that the OpenSSL BIGNUM library does not have a square root
352 * function, so this had to be implemented by hand using Newton's
353 * recursive formula:
354 *
355 *		x = (x + (n / x)) / 2
356 *
357 * where x is the square root of the positive number n.  In the
358 * beginning, x should be a reasonable guess, but the value 1,
359 * although suboptimal, works, too; this is that is used below.
360 */
361void
362mp_msqrt(const MINT *nmp, MINT *xmp, MINT *rmp)
363{
364	BN_CTX *c;
365	MINT *tolerance;
366	MINT *ox, *x;
367	MINT *z1, *z2, *z3;
368	short i;
369
370	c = BN_CTX_new();
371	if (c == NULL)
372		_bnerr("msqrt");
373	tolerance = _itom("msqrt", 1);
374	x = _itom("msqrt", 1);
375	ox = _itom("msqrt", 0);
376	z1 = _itom("msqrt", 0);
377	z2 = _itom("msqrt", 0);
378	z3 = _itom("msqrt", 0);
379	do {
380		_movem("msqrt", x, ox);
381		_mdiv("msqrt", nmp, x, z1, z2, c);
382		_madd("msqrt", x, z1, z2);
383		_sdiv("msqrt", z2, 2, x, &i, c);
384		_msub("msqrt", ox, x, z3);
385	} while (_mcmpa("msqrt", z3, tolerance) == 1);
386	_movem("msqrt", x, xmp);
387	_mult("msqrt", x, x, z1, c);
388	_msub("msqrt", nmp, z1, z2);
389	_movem("msqrt", z2, rmp);
390	_mfree("msqrt", tolerance);
391	_mfree("msqrt", ox);
392	_mfree("msqrt", x);
393	_mfree("msqrt", z1);
394	_mfree("msqrt", z2);
395	_mfree("msqrt", z3);
396	BN_CTX_free(c);
397}
398
399/*
400 * Compute rmp=mp1-mp2.
401 */
402static void
403_msub(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
404{
405	BIGNUM b;
406
407	BN_init(&b);
408	BN_ERRCHECK(msg, BN_sub(&b, mp1->bn, mp2->bn));
409	_moveb(msg, &b, rmp);
410	BN_free(&b);
411}
412
413void
414mp_msub(const MINT *mp1, const MINT *mp2, MINT *rmp)
415{
416
417	_msub("msub", mp1, mp2, rmp);
418}
419
420/*
421 * Return a decimal representation of mp.  Return value must be
422 * free()'d.
423 */
424static char *
425_mtod(const char *msg, const MINT *mp)
426{
427	char *s, *s2;
428
429	s = BN_bn2dec(mp->bn);
430	if (s == NULL)
431		_bnerr(msg);
432	asprintf(&s2, "%s", s);
433	if (s2 == NULL)
434		MPERR(("%s", msg));
435	OPENSSL_free(s);
436	return (s2);
437}
438
439/*
440 * Return a hexadecimal representation of mp.  Return value must be
441 * free()'d.
442 */
443static char *
444_mtox(const char *msg, const MINT *mp)
445{
446	char *p, *s, *s2;
447	int len;
448
449	s = BN_bn2hex(mp->bn);
450	if (s == NULL)
451		_bnerr(msg);
452	asprintf(&s2, "%s", s);
453	if (s2 == NULL)
454		MPERR(("%s", msg));
455	OPENSSL_free(s);
456
457	/*
458	 * This is a kludge for libgmp compatibility.  The latter's
459	 * implementation of this function returns lower-case letters,
460	 * but BN_bn2hex returns upper-case.  Some programs (e.g.,
461	 * newkey(1)) are sensitive to this.  Although it's probably
462	 * their fault, it's nice to be compatible.
463	 */
464	len = strlen(s2);
465	for (p = s2; p < s2 + len; p++)
466		*p = tolower(*p);
467
468	return (s2);
469}
470
471char *
472mp_mtox(const MINT *mp)
473{
474
475	return (_mtox("mtox", mp));
476}
477
478/*
479 * Compute rmp=mp1*mp2.
480 */
481static void
482_mult(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp, BN_CTX *c)
483{
484	BIGNUM b;
485
486	BN_init(&b);
487	BN_ERRCHECK(msg, BN_mul(&b, mp1->bn, mp2->bn, c));
488	_moveb(msg, &b, rmp);
489	BN_free(&b);
490}
491
492void
493mp_mult(const MINT *mp1, const MINT *mp2, MINT *rmp)
494{
495	BN_CTX *c;
496
497	c = BN_CTX_new();
498	if (c == NULL)
499		_bnerr("mult");
500	_mult("mult", mp1, mp2, rmp, c);
501	BN_CTX_free(c);
502}
503
504/*
505 * Compute rmp=(bmp^emp)mod mmp.  (Note that here and above rpow() '^'
506 * means 'raise to power', not 'bitwise XOR'.)
507 */
508void
509mp_pow(const MINT *bmp, const MINT *emp, const MINT *mmp, MINT *rmp)
510{
511	BIGNUM b;
512	BN_CTX *c;
513
514	c = BN_CTX_new();
515	if (c == NULL)
516		_bnerr("pow");
517	BN_init(&b);
518	BN_ERRCHECK("pow", BN_mod_exp(&b, bmp->bn, emp->bn, mmp->bn, c));
519	_moveb("pow", &b, rmp);
520	BN_free(&b);
521	BN_CTX_free(c);
522}
523
524/*
525 * Compute rmp=bmp^e.  (See note above pow().)
526 */
527void
528mp_rpow(const MINT *bmp, short e, MINT *rmp)
529{
530	MINT *emp;
531	BIGNUM b;
532	BN_CTX *c;
533
534	c = BN_CTX_new();
535	if (c == NULL)
536		_bnerr("rpow");
537	BN_init(&b);
538	emp = _itom("rpow", e);
539	BN_ERRCHECK("rpow", BN_exp(&b, bmp->bn, emp->bn, c));
540	_moveb("rpow", &b, rmp);
541	_mfree("rpow", emp);
542	BN_free(&b);
543	BN_CTX_free(c);
544}
545
546/*
547 * Compute qmp=nmp/d and ro=nmp%d.
548 */
549static void
550_sdiv(const char *msg, const MINT *nmp, short d, MINT *qmp, short *ro,
551    BN_CTX *c)
552{
553	MINT *dmp, *rmp;
554	BIGNUM q, r;
555	char *s;
556
557	BN_init(&q);
558	BN_init(&r);
559	dmp = _itom(msg, d);
560	rmp = _itom(msg, 0);
561	BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, c));
562	_moveb(msg, &q, qmp);
563	_moveb(msg, &r, rmp);
564	s = _mtox(msg, rmp);
565	errno = 0;
566	*ro = strtol(s, NULL, 16);
567	if (errno != 0)
568		MPERR(("%s underflow or overflow", msg));
569	free(s);
570	_mfree(msg, dmp);
571	_mfree(msg, rmp);
572	BN_free(&r);
573	BN_free(&q);
574}
575
576void
577mp_sdiv(const MINT *nmp, short d, MINT *qmp, short *ro)
578{
579	BN_CTX *c;
580
581	c = BN_CTX_new();
582	if (c == NULL)
583		_bnerr("sdiv");
584	_sdiv("sdiv", nmp, d, qmp, ro, c);
585	BN_CTX_free(c);
586}
587
588/*
589 * Convert a hexadecimal string to an MINT.
590 */
591static MINT *
592_xtom(const char *msg, const char *s)
593{
594	MINT *mp;
595
596	mp = malloc(sizeof(*mp));
597	if (mp == NULL)
598		MPERR(("%s", msg));
599	mp->bn = BN_new();
600	if (mp->bn == NULL)
601		_bnerr(msg);
602	BN_ERRCHECK(msg, BN_hex2bn(&mp->bn, s));
603	return (mp);
604}
605
606MINT *
607mp_xtom(const char *s)
608{
609
610	return (_xtom("xtom", s));
611}
612