1/*-
2 * SPDX-License-Identifier: BSD-2-Clause
3 *
4 * Copyright (c) 2001 Dima Dorfman.
5 * All rights reserved.
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 * 2. Redistributions in binary form must reproduce the above copyright
13 *    notice, this list of conditions and the following disclaimer in the
14 *    documentation and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
20 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
22 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
25 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26 * SUCH DAMAGE.
27 */
28
29/*
30 * This is the traditional Berkeley MP library implemented in terms of
31 * the OpenSSL BIGNUM library.  It was written to replace libgmp, and
32 * is meant to be as compatible with the latter as feasible.
33 *
34 * There seems to be a lack of documentation for the Berkeley MP
35 * interface.  All I could find was libgmp documentation (which didn't
36 * talk about the semantics of the functions) and an old SunOS 4.1
37 * manual page from 1989.  The latter wasn't very detailed, either,
38 * but at least described what the function's arguments were.  In
39 * general the interface seems to be archaic, somewhat poorly
40 * designed, and poorly, if at all, documented.  It is considered
41 * harmful.
42 *
43 * Miscellaneous notes on this implementation:
44 *
45 *  - The SunOS manual page mentioned above indicates that if an error
46 *  occurs, the library should "produce messages and core images."
47 *  Given that most of the functions don't have return values (and
48 *  thus no sane way of alerting the caller to an error), this seems
49 *  reasonable.  The MPERR and MPERRX macros call warn and warnx,
50 *  respectively, then abort().
51 *
52 *  - All the functions which take an argument to be "filled in"
53 *  assume that the argument has been initialized by one of the *tom()
54 *  routines before being passed to it.  I never saw this documented
55 *  anywhere, but this seems to be consistent with the way this
56 *  library is used.
57 *
58 *  - msqrt() is the only routine which had to be implemented which
59 *  doesn't have a close counterpart in the OpenSSL BIGNUM library.
60 *  It was implemented by hand using Newton's recursive formula.
61 *  Doing it this way, although more error-prone, has the positive
62 *  sideaffect of testing a lot of other functions; if msqrt()
63 *  produces the correct results, most of the other routines will as
64 *  well.
65 *
66 *  - Internal-use-only routines (i.e., those defined here statically
67 *  and not in mp.h) have an underscore prepended to their name (this
68 *  is more for aesthetical reasons than technical).  All such
69 *  routines take an extra argument, 'msg', that denotes what they
70 *  should call themselves in an error message.  This is so a user
71 *  doesn't get an error message from a function they didn't call.
72 */
73
74#include <sys/cdefs.h>
75#include <ctype.h>
76#include <err.h>
77#include <errno.h>
78#include <stdio.h>
79#include <stdlib.h>
80#include <string.h>
81
82#include <openssl/crypto.h>
83#include <openssl/err.h>
84
85#include "mp.h"
86
87#define MPERR(s)	do { warn s; abort(); } while (0)
88#define MPERRX(s)	do { warnx s; abort(); } while (0)
89#define BN_ERRCHECK(msg, expr) do {		\
90	if (!(expr)) _bnerr(msg);		\
91} while (0)
92
93static void _bnerr(const char *);
94static MINT *_dtom(const char *, const char *);
95static MINT *_itom(const char *, short);
96static void _madd(const char *, const MINT *, const MINT *, MINT *);
97static int _mcmpa(const char *, const MINT *, const MINT *);
98static void _mdiv(const char *, const MINT *, const MINT *, MINT *, MINT *,
99		BN_CTX *);
100static void _mfree(const char *, MINT *);
101static void _moveb(const char *, const BIGNUM *, MINT *);
102static void _movem(const char *, const MINT *, MINT *);
103static void _msub(const char *, const MINT *, const MINT *, MINT *);
104static char *_mtod(const char *, const MINT *);
105static char *_mtox(const char *, const MINT *);
106static void _mult(const char *, const MINT *, const MINT *, MINT *, BN_CTX *);
107static void _sdiv(const char *, const MINT *, short, MINT *, short *, BN_CTX *);
108static MINT *_xtom(const char *, const char *);
109
110/*
111 * Report an error from one of the BN_* functions using MPERRX.
112 */
113static void
114_bnerr(const char *msg)
115{
116
117	ERR_load_crypto_strings();
118	MPERRX(("%s: %s", msg, ERR_reason_error_string(ERR_get_error())));
119}
120
121/*
122 * Convert a decimal string to an MINT.
123 */
124static MINT *
125_dtom(const char *msg, const char *s)
126{
127	MINT *mp;
128
129	mp = malloc(sizeof(*mp));
130	if (mp == NULL)
131		MPERR(("%s", msg));
132	mp->bn = BN_new();
133	if (mp->bn == NULL)
134		_bnerr(msg);
135	BN_ERRCHECK(msg, BN_dec2bn(&mp->bn, s));
136	return (mp);
137}
138
139/*
140 * Compute the greatest common divisor of mp1 and mp2; result goes in rmp.
141 */
142void
143mp_gcd(const MINT *mp1, const MINT *mp2, MINT *rmp)
144{
145	BIGNUM *b;
146	BN_CTX *c;
147
148	b = NULL;
149	c = BN_CTX_new();
150	if (c != NULL)
151		b = BN_new();
152	if (c == NULL || b == NULL)
153		_bnerr("gcd");
154	BN_ERRCHECK("gcd", BN_gcd(b, mp1->bn, mp2->bn, c));
155	_moveb("gcd", b, rmp);
156	BN_free(b);
157	BN_CTX_free(c);
158}
159
160/*
161 * Make an MINT out of a short integer.  Return value must be mfree()'d.
162 */
163static MINT *
164_itom(const char *msg, short n)
165{
166	MINT *mp;
167	char *s;
168
169	asprintf(&s, "%x", n);
170	if (s == NULL)
171		MPERR(("%s", msg));
172	mp = _xtom(msg, s);
173	free(s);
174	return (mp);
175}
176
177MINT *
178mp_itom(short n)
179{
180
181	return (_itom("itom", n));
182}
183
184/*
185 * Compute rmp=mp1+mp2.
186 */
187static void
188_madd(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
189{
190	BIGNUM *b;
191
192	b = BN_new();
193	if (b == NULL)
194		_bnerr(msg);
195	BN_ERRCHECK(msg, BN_add(b, mp1->bn, mp2->bn));
196	_moveb(msg, b, rmp);
197	BN_free(b);
198}
199
200void
201mp_madd(const MINT *mp1, const MINT *mp2, MINT *rmp)
202{
203
204	_madd("madd", mp1, mp2, rmp);
205}
206
207/*
208 * Return -1, 0, or 1 if mp1<mp2, mp1==mp2, or mp1>mp2, respectivley.
209 */
210int
211mp_mcmp(const MINT *mp1, const MINT *mp2)
212{
213
214	return (BN_cmp(mp1->bn, mp2->bn));
215}
216
217/*
218 * Same as mcmp but compares absolute values.
219 */
220static int
221_mcmpa(const char *msg __unused, const MINT *mp1, const MINT *mp2)
222{
223
224	return (BN_ucmp(mp1->bn, mp2->bn));
225}
226
227/*
228 * Compute qmp=nmp/dmp and rmp=nmp%dmp.
229 */
230static void
231_mdiv(const char *msg, const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp,
232    BN_CTX *c)
233{
234	BIGNUM *q, *r;
235
236	q = NULL;
237	r = BN_new();
238	if (r != NULL)
239		q = BN_new();
240	if (r == NULL || q == NULL)
241		_bnerr(msg);
242	BN_ERRCHECK(msg, BN_div(q, r, nmp->bn, dmp->bn, c));
243	_moveb(msg, q, qmp);
244	_moveb(msg, r, rmp);
245	BN_free(q);
246	BN_free(r);
247}
248
249void
250mp_mdiv(const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
251{
252	BN_CTX *c;
253
254	c = BN_CTX_new();
255	if (c == NULL)
256		_bnerr("mdiv");
257	_mdiv("mdiv", nmp, dmp, qmp, rmp, c);
258	BN_CTX_free(c);
259}
260
261/*
262 * Free memory associated with an MINT.
263 */
264static void
265_mfree(const char *msg __unused, MINT *mp)
266{
267
268	BN_clear(mp->bn);
269	BN_free(mp->bn);
270	free(mp);
271}
272
273void
274mp_mfree(MINT *mp)
275{
276
277	_mfree("mfree", mp);
278}
279
280/*
281 * Read an integer from standard input and stick the result in mp.
282 * The input is treated to be in base 10.  This must be the silliest
283 * API in existence; why can't the program read in a string and call
284 * xtom()?  (Or if base 10 is desires, perhaps dtom() could be
285 * exported.)
286 */
287void
288mp_min(MINT *mp)
289{
290	MINT *rmp;
291	char *line, *nline;
292	size_t linelen;
293
294	line = fgetln(stdin, &linelen);
295	if (line == NULL)
296		MPERR(("min"));
297	nline = malloc(linelen + 1);
298	if (nline == NULL)
299		MPERR(("min"));
300	memcpy(nline, line, linelen);
301	nline[linelen] = '\0';
302	rmp = _dtom("min", nline);
303	_movem("min", rmp, mp);
304	_mfree("min", rmp);
305	free(nline);
306}
307
308/*
309 * Print the value of mp to standard output in base 10.  See blurb
310 * above min() for why this is so useless.
311 */
312void
313mp_mout(const MINT *mp)
314{
315	char *s;
316
317	s = _mtod("mout", mp);
318	printf("%s", s);
319	free(s);
320}
321
322/*
323 * Set the value of tmp to the value of smp (i.e., tmp=smp).
324 */
325void
326mp_move(const MINT *smp, MINT *tmp)
327{
328
329	_movem("move", smp, tmp);
330}
331
332
333/*
334 * Internal routine to set the value of tmp to that of sbp.
335 */
336static void
337_moveb(const char *msg, const BIGNUM *sbp, MINT *tmp)
338{
339
340	BN_ERRCHECK(msg, BN_copy(tmp->bn, sbp));
341}
342
343/*
344 * Internal routine to set the value of tmp to that of smp.
345 */
346static void
347_movem(const char *msg, const MINT *smp, MINT *tmp)
348{
349
350	BN_ERRCHECK(msg, BN_copy(tmp->bn, smp->bn));
351}
352
353/*
354 * Compute the square root of nmp and put the result in xmp.  The
355 * remainder goes in rmp.  Should satisfy: rmp=nmp-(xmp*xmp).
356 *
357 * Note that the OpenSSL BIGNUM library does not have a square root
358 * function, so this had to be implemented by hand using Newton's
359 * recursive formula:
360 *
361 *		x = (x + (n / x)) / 2
362 *
363 * where x is the square root of the positive number n.  In the
364 * beginning, x should be a reasonable guess, but the value 1,
365 * although suboptimal, works, too; this is that is used below.
366 */
367void
368mp_msqrt(const MINT *nmp, MINT *xmp, MINT *rmp)
369{
370	BN_CTX *c;
371	MINT *tolerance;
372	MINT *ox, *x;
373	MINT *z1, *z2, *z3;
374	short i;
375
376	c = BN_CTX_new();
377	if (c == NULL)
378		_bnerr("msqrt");
379	tolerance = _itom("msqrt", 1);
380	x = _itom("msqrt", 1);
381	ox = _itom("msqrt", 0);
382	z1 = _itom("msqrt", 0);
383	z2 = _itom("msqrt", 0);
384	z3 = _itom("msqrt", 0);
385	do {
386		_movem("msqrt", x, ox);
387		_mdiv("msqrt", nmp, x, z1, z2, c);
388		_madd("msqrt", x, z1, z2);
389		_sdiv("msqrt", z2, 2, x, &i, c);
390		_msub("msqrt", ox, x, z3);
391	} while (_mcmpa("msqrt", z3, tolerance) == 1);
392	_movem("msqrt", x, xmp);
393	_mult("msqrt", x, x, z1, c);
394	_msub("msqrt", nmp, z1, z2);
395	_movem("msqrt", z2, rmp);
396	_mfree("msqrt", tolerance);
397	_mfree("msqrt", ox);
398	_mfree("msqrt", x);
399	_mfree("msqrt", z1);
400	_mfree("msqrt", z2);
401	_mfree("msqrt", z3);
402	BN_CTX_free(c);
403}
404
405/*
406 * Compute rmp=mp1-mp2.
407 */
408static void
409_msub(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
410{
411	BIGNUM *b;
412
413	b = BN_new();
414	if (b == NULL)
415		_bnerr(msg);
416	BN_ERRCHECK(msg, BN_sub(b, mp1->bn, mp2->bn));
417	_moveb(msg, b, rmp);
418	BN_free(b);
419}
420
421void
422mp_msub(const MINT *mp1, const MINT *mp2, MINT *rmp)
423{
424
425	_msub("msub", mp1, mp2, rmp);
426}
427
428/*
429 * Return a decimal representation of mp.  Return value must be
430 * free()'d.
431 */
432static char *
433_mtod(const char *msg, const MINT *mp)
434{
435	char *s, *s2;
436
437	s = BN_bn2dec(mp->bn);
438	if (s == NULL)
439		_bnerr(msg);
440	asprintf(&s2, "%s", s);
441	if (s2 == NULL)
442		MPERR(("%s", msg));
443	OPENSSL_free(s);
444	return (s2);
445}
446
447/*
448 * Return a hexadecimal representation of mp.  Return value must be
449 * free()'d.
450 */
451static char *
452_mtox(const char *msg, const MINT *mp)
453{
454	char *p, *s, *s2;
455	int len;
456
457	s = BN_bn2hex(mp->bn);
458	if (s == NULL)
459		_bnerr(msg);
460	asprintf(&s2, "%s", s);
461	if (s2 == NULL)
462		MPERR(("%s", msg));
463	OPENSSL_free(s);
464
465	/*
466	 * This is a kludge for libgmp compatibility.  The latter's
467	 * implementation of this function returns lower-case letters,
468	 * but BN_bn2hex returns upper-case.  Some programs (e.g.,
469	 * newkey(1)) are sensitive to this.  Although it's probably
470	 * their fault, it's nice to be compatible.
471	 */
472	len = strlen(s2);
473	for (p = s2; p < s2 + len; p++)
474		*p = tolower(*p);
475
476	return (s2);
477}
478
479char *
480mp_mtox(const MINT *mp)
481{
482
483	return (_mtox("mtox", mp));
484}
485
486/*
487 * Compute rmp=mp1*mp2.
488 */
489static void
490_mult(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp, BN_CTX *c)
491{
492	BIGNUM *b;
493
494	b = BN_new();
495	if (b == NULL)
496		_bnerr(msg);
497	BN_ERRCHECK(msg, BN_mul(b, mp1->bn, mp2->bn, c));
498	_moveb(msg, b, rmp);
499	BN_free(b);
500}
501
502void
503mp_mult(const MINT *mp1, const MINT *mp2, MINT *rmp)
504{
505	BN_CTX *c;
506
507	c = BN_CTX_new();
508	if (c == NULL)
509		_bnerr("mult");
510	_mult("mult", mp1, mp2, rmp, c);
511	BN_CTX_free(c);
512}
513
514/*
515 * Compute rmp=(bmp^emp)mod mmp.  (Note that here and above rpow() '^'
516 * means 'raise to power', not 'bitwise XOR'.)
517 */
518void
519mp_pow(const MINT *bmp, const MINT *emp, const MINT *mmp, MINT *rmp)
520{
521	BIGNUM *b;
522	BN_CTX *c;
523
524	b = NULL;
525	c = BN_CTX_new();
526	if (c != NULL)
527		b = BN_new();
528	if (c == NULL || b == NULL)
529		_bnerr("pow");
530	BN_ERRCHECK("pow", BN_mod_exp(b, bmp->bn, emp->bn, mmp->bn, c));
531	_moveb("pow", b, rmp);
532	BN_free(b);
533	BN_CTX_free(c);
534}
535
536/*
537 * Compute rmp=bmp^e.  (See note above pow().)
538 */
539void
540mp_rpow(const MINT *bmp, short e, MINT *rmp)
541{
542	MINT *emp;
543	BIGNUM *b;
544	BN_CTX *c;
545
546	b = NULL;
547	c = BN_CTX_new();
548	if (c != NULL)
549		b = BN_new();
550	if (c == NULL || b == NULL)
551		_bnerr("rpow");
552	emp = _itom("rpow", e);
553	BN_ERRCHECK("rpow", BN_exp(b, bmp->bn, emp->bn, c));
554	_moveb("rpow", b, rmp);
555	_mfree("rpow", emp);
556	BN_free(b);
557	BN_CTX_free(c);
558}
559
560/*
561 * Compute qmp=nmp/d and ro=nmp%d.
562 */
563static void
564_sdiv(const char *msg, const MINT *nmp, short d, MINT *qmp, short *ro,
565    BN_CTX *c)
566{
567	MINT *dmp, *rmp;
568	BIGNUM *q, *r;
569	char *s;
570
571	r = NULL;
572	q = BN_new();
573	if (q != NULL)
574		r = BN_new();
575	if (q == NULL || r == NULL)
576		_bnerr(msg);
577	dmp = _itom(msg, d);
578	rmp = _itom(msg, 0);
579	BN_ERRCHECK(msg, BN_div(q, r, nmp->bn, dmp->bn, c));
580	_moveb(msg, q, qmp);
581	_moveb(msg, r, rmp);
582	s = _mtox(msg, rmp);
583	errno = 0;
584	*ro = strtol(s, NULL, 16);
585	if (errno != 0)
586		MPERR(("%s underflow or overflow", msg));
587	free(s);
588	_mfree(msg, dmp);
589	_mfree(msg, rmp);
590	BN_free(r);
591	BN_free(q);
592}
593
594void
595mp_sdiv(const MINT *nmp, short d, MINT *qmp, short *ro)
596{
597	BN_CTX *c;
598
599	c = BN_CTX_new();
600	if (c == NULL)
601		_bnerr("sdiv");
602	_sdiv("sdiv", nmp, d, qmp, ro, c);
603	BN_CTX_free(c);
604}
605
606/*
607 * Convert a hexadecimal string to an MINT.
608 */
609static MINT *
610_xtom(const char *msg, const char *s)
611{
612	MINT *mp;
613
614	mp = malloc(sizeof(*mp));
615	if (mp == NULL)
616		MPERR(("%s", msg));
617	mp->bn = BN_new();
618	if (mp->bn == NULL)
619		_bnerr(msg);
620	BN_ERRCHECK(msg, BN_hex2bn(&mp->bn, s));
621	return (mp);
622}
623
624MINT *
625mp_xtom(const char *s)
626{
627
628	return (_xtom("xtom", s));
629}
630