1/* Program for computing integer expressions using the GNU Multiple Precision
2   Arithmetic Library.
3
4Copyright 1997, 1999, 2000, 2001, 2002, 2005 Free Software Foundation, Inc.
5
6This program is free software; you can redistribute it and/or modify it under
7the terms of the GNU General Public License as published by the Free Software
8Foundation; either version 3 of the License, or (at your option) any later
9version.
10
11This program is distributed in the hope that it will be useful, but WITHOUT ANY
12WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
13PARTICULAR PURPOSE.  See the GNU General Public License for more details.
14
15You should have received a copy of the GNU General Public License along with
16this program.  If not, see http://www.gnu.org/licenses/.  */
17
18
19/* This expressions evaluator works by building an expression tree (using a
20   recursive descent parser) which is then evaluated.  The expression tree is
21   useful since we want to optimize certain expressions (like a^b % c).
22
23   Usage: pexpr [options] expr ...
24   (Assuming you called the executable `pexpr' of course.)
25
26   Command line options:
27
28   -b        print output in binary
29   -o        print output in octal
30   -d        print output in decimal (the default)
31   -x        print output in hexadecimal
32   -b<NUM>   print output in base NUM
33   -t        print timing information
34   -html     output html
35   -wml      output wml
36   -split    split long lines each 80th digit
37*/
38
39/* Define LIMIT_RESOURCE_USAGE if you want to make sure the program doesn't
40   use up extensive resources (cpu, memory).  Useful for the GMP demo on the
41   GMP web site, since we cannot load the server too much.  */
42
43#include "pexpr-config.h"
44
45#include <string.h>
46#include <stdio.h>
47#include <stdlib.h>
48#include <setjmp.h>
49#include <signal.h>
50#include <ctype.h>
51
52#include <time.h>
53#include <sys/types.h>
54#include <sys/time.h>
55#if HAVE_SYS_RESOURCE_H
56#include <sys/resource.h>
57#endif
58
59#include "gmp.h"
60
61/* SunOS 4 and HPUX 9 don't define a canonical SIGSTKSZ, use a default. */
62#ifndef SIGSTKSZ
63#define SIGSTKSZ  4096
64#endif
65
66
67#define TIME(t,func)							\
68  do { int __t0, __tmp;							\
69    __t0 = cputime ();							\
70    {func;}								\
71    __tmp = cputime () - __t0;						\
72    (t) = __tmp;							\
73  } while (0)
74
75/* GMP version 1.x compatibility.  */
76#if ! (__GNU_MP_VERSION >= 2)
77typedef MP_INT __mpz_struct;
78typedef __mpz_struct mpz_t[1];
79typedef __mpz_struct *mpz_ptr;
80#define mpz_fdiv_q	mpz_div
81#define mpz_fdiv_r	mpz_mod
82#define mpz_tdiv_q_2exp	mpz_div_2exp
83#define mpz_sgn(Z) ((Z)->size < 0 ? -1 : (Z)->size > 0)
84#endif
85
86/* GMP version 2.0 compatibility.  */
87#if ! (__GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1)
88#define mpz_swap(a,b) \
89  do { __mpz_struct __t; __t = *a; *a = *b; *b = __t;} while (0)
90#endif
91
92jmp_buf errjmpbuf;
93
94enum op_t {NOP, LIT, NEG, NOT, PLUS, MINUS, MULT, DIV, MOD, REM, INVMOD, POW,
95	   AND, IOR, XOR, SLL, SRA, POPCNT, HAMDIST, GCD, LCM, SQRT, ROOT, FAC,
96	   LOG, LOG2, FERMAT, MERSENNE, FIBONACCI, RANDOM, NEXTPRIME, BINOM,
97	   TIMING};
98
99/* Type for the expression tree.  */
100struct expr
101{
102  enum op_t op;
103  union
104  {
105    struct {struct expr *lhs, *rhs;} ops;
106    mpz_t val;
107  } operands;
108};
109
110typedef struct expr *expr_t;
111
112void cleanup_and_exit __GMP_PROTO ((int));
113
114char *skipspace __GMP_PROTO ((char *));
115void makeexp __GMP_PROTO ((expr_t *, enum op_t, expr_t, expr_t));
116void free_expr __GMP_PROTO ((expr_t));
117char *expr __GMP_PROTO ((char *, expr_t *));
118char *term __GMP_PROTO ((char *, expr_t *));
119char *power __GMP_PROTO ((char *, expr_t *));
120char *factor __GMP_PROTO ((char *, expr_t *));
121int match __GMP_PROTO ((char *, char *));
122int matchp __GMP_PROTO ((char *, char *));
123int cputime __GMP_PROTO ((void));
124
125void mpz_eval_expr __GMP_PROTO ((mpz_ptr, expr_t));
126void mpz_eval_mod_expr __GMP_PROTO ((mpz_ptr, expr_t, mpz_ptr));
127
128char *error;
129int flag_print = 1;
130int print_timing = 0;
131int flag_html = 0;
132int flag_wml = 0;
133int flag_splitup_output = 0;
134char *newline = "";
135gmp_randstate_t rstate;
136
137
138
139/* cputime() returns user CPU time measured in milliseconds.  */
140#if ! HAVE_CPUTIME
141#if HAVE_GETRUSAGE
142int
143cputime (void)
144{
145  struct rusage rus;
146
147  getrusage (0, &rus);
148  return rus.ru_utime.tv_sec * 1000 + rus.ru_utime.tv_usec / 1000;
149}
150#else
151#if HAVE_CLOCK
152int
153cputime (void)
154{
155  if (CLOCKS_PER_SEC < 100000)
156    return clock () * 1000 / CLOCKS_PER_SEC;
157  return clock () / (CLOCKS_PER_SEC / 1000);
158}
159#else
160int
161cputime (void)
162{
163  return 0;
164}
165#endif
166#endif
167#endif
168
169
170int
171stack_downwards_helper (char *xp)
172{
173  char  y;
174  return &y < xp;
175}
176int
177stack_downwards_p (void)
178{
179  char  x;
180  return stack_downwards_helper (&x);
181}
182
183
184void
185setup_error_handler (void)
186{
187#if HAVE_SIGACTION
188  struct sigaction act;
189  act.sa_handler = cleanup_and_exit;
190  sigemptyset (&(act.sa_mask));
191#define SIGNAL(sig)  sigaction (sig, &act, NULL)
192#else
193  struct { int sa_flags; } act;
194#define SIGNAL(sig)  signal (sig, cleanup_and_exit)
195#endif
196  act.sa_flags = 0;
197
198  /* Set up a stack for signal handling.  A typical cause of error is stack
199     overflow, and in such situation a signal can not be delivered on the
200     overflown stack.  */
201#if HAVE_SIGALTSTACK
202  {
203    /* AIX uses stack_t, MacOS uses struct sigaltstack, various other
204       systems have both. */
205#if HAVE_STACK_T
206    stack_t s;
207#else
208    struct sigaltstack s;
209#endif
210    s.ss_sp = malloc (SIGSTKSZ);
211    s.ss_size = SIGSTKSZ;
212    s.ss_flags = 0;
213    if (sigaltstack (&s, NULL) != 0)
214      perror("sigaltstack");
215    act.sa_flags = SA_ONSTACK;
216  }
217#else
218#if HAVE_SIGSTACK
219  {
220    struct sigstack s;
221    s.ss_sp = malloc (SIGSTKSZ);
222    if (stack_downwards_p ())
223      s.ss_sp += SIGSTKSZ;
224    s.ss_onstack = 0;
225    if (sigstack (&s, NULL) != 0)
226      perror("sigstack");
227    act.sa_flags = SA_ONSTACK;
228  }
229#else
230#endif
231#endif
232
233#ifdef LIMIT_RESOURCE_USAGE
234  {
235    struct rlimit limit;
236
237    limit.rlim_cur = limit.rlim_max = 0;
238    setrlimit (RLIMIT_CORE, &limit);
239
240    limit.rlim_cur = 3;
241    limit.rlim_max = 4;
242    setrlimit (RLIMIT_CPU, &limit);
243
244    limit.rlim_cur = limit.rlim_max = 16 * 1024 * 1024;
245    setrlimit (RLIMIT_DATA, &limit);
246
247    getrlimit (RLIMIT_STACK, &limit);
248    limit.rlim_cur = 4 * 1024 * 1024;
249    setrlimit (RLIMIT_STACK, &limit);
250
251    SIGNAL (SIGXCPU);
252  }
253#endif /* LIMIT_RESOURCE_USAGE */
254
255  SIGNAL (SIGILL);
256  SIGNAL (SIGSEGV);
257#ifdef SIGBUS /* not in mingw */
258  SIGNAL (SIGBUS);
259#endif
260  SIGNAL (SIGFPE);
261  SIGNAL (SIGABRT);
262}
263
264int
265main (int argc, char **argv)
266{
267  struct expr *e;
268  int i;
269  mpz_t r;
270  int errcode = 0;
271  char *str;
272  int base = 10;
273
274  setup_error_handler ();
275
276  gmp_randinit (rstate, GMP_RAND_ALG_LC, 128);
277
278  {
279#if HAVE_GETTIMEOFDAY
280    struct timeval tv;
281    gettimeofday (&tv, NULL);
282    gmp_randseed_ui (rstate, tv.tv_sec + tv.tv_usec);
283#else
284    time_t t;
285    time (&t);
286    gmp_randseed_ui (rstate, t);
287#endif
288  }
289
290  mpz_init (r);
291
292  while (argc > 1 && argv[1][0] == '-')
293    {
294      char *arg = argv[1];
295
296      if (arg[1] >= '0' && arg[1] <= '9')
297	break;
298
299      if (arg[1] == 't')
300	print_timing = 1;
301      else if (arg[1] == 'b' && arg[2] >= '0' && arg[2] <= '9')
302	{
303	  base = atoi (arg + 2);
304	  if (base < 2 || base > 62)
305	    {
306	      fprintf (stderr, "error: invalid output base\n");
307	      exit (-1);
308	    }
309	}
310      else if (arg[1] == 'b' && arg[2] == 0)
311	base = 2;
312      else if (arg[1] == 'x' && arg[2] == 0)
313	base = 16;
314      else if (arg[1] == 'X' && arg[2] == 0)
315	base = -16;
316      else if (arg[1] == 'o' && arg[2] == 0)
317	base = 8;
318      else if (arg[1] == 'd' && arg[2] == 0)
319	base = 10;
320      else if (arg[1] == 'v' && arg[2] == 0)
321	{
322	  printf ("pexpr linked to gmp %s\n", __gmp_version);
323	}
324      else if (strcmp (arg, "-html") == 0)
325	{
326	  flag_html = 1;
327	  newline = "<br>";
328	}
329      else if (strcmp (arg, "-wml") == 0)
330	{
331	  flag_wml = 1;
332	  newline = "<br/>";
333	}
334      else if (strcmp (arg, "-split") == 0)
335	{
336	  flag_splitup_output = 1;
337	}
338      else if (strcmp (arg, "-noprint") == 0)
339	{
340	  flag_print = 0;
341	}
342      else
343	{
344	  fprintf (stderr, "error: unknown option `%s'\n", arg);
345	  exit (-1);
346	}
347      argv++;
348      argc--;
349    }
350
351  for (i = 1; i < argc; i++)
352    {
353      int s;
354      int jmpval;
355
356      /* Set up error handler for parsing expression.  */
357      jmpval = setjmp (errjmpbuf);
358      if (jmpval != 0)
359	{
360	  fprintf (stderr, "error: %s%s\n", error, newline);
361	  fprintf (stderr, "       %s%s\n", argv[i], newline);
362	  if (! flag_html)
363	    {
364	      /* ??? Dunno how to align expression position with arrow in
365		 HTML ??? */
366	      fprintf (stderr, "       ");
367	      for (s = jmpval - (long) argv[i]; --s >= 0; )
368		putc (' ', stderr);
369	      fprintf (stderr, "^\n");
370	    }
371
372	  errcode |= 1;
373	  continue;
374	}
375
376      str = expr (argv[i], &e);
377
378      if (str[0] != 0)
379	{
380	  fprintf (stderr,
381		   "error: garbage where end of expression expected%s\n",
382		   newline);
383	  fprintf (stderr, "       %s%s\n", argv[i], newline);
384	  if (! flag_html)
385	    {
386	      /* ??? Dunno how to align expression position with arrow in
387		 HTML ??? */
388	      fprintf (stderr, "        ");
389	      for (s = str - argv[i]; --s; )
390		putc (' ', stderr);
391	      fprintf (stderr, "^\n");
392	    }
393
394	  errcode |= 1;
395	  free_expr (e);
396	  continue;
397	}
398
399      /* Set up error handler for evaluating expression.  */
400      if (setjmp (errjmpbuf))
401	{
402	  fprintf (stderr, "error: %s%s\n", error, newline);
403	  fprintf (stderr, "       %s%s\n", argv[i], newline);
404	  if (! flag_html)
405	    {
406	      /* ??? Dunno how to align expression position with arrow in
407		 HTML ??? */
408	      fprintf (stderr, "       ");
409	      for (s = str - argv[i]; --s >= 0; )
410		putc (' ', stderr);
411	      fprintf (stderr, "^\n");
412	    }
413
414	  errcode |= 2;
415	  continue;
416	}
417
418      if (print_timing)
419	{
420	  int t;
421	  TIME (t, mpz_eval_expr (r, e));
422	  printf ("computation took %d ms%s\n", t, newline);
423	}
424      else
425	mpz_eval_expr (r, e);
426
427      if (flag_print)
428	{
429	  size_t out_len;
430	  char *tmp, *s;
431
432	  out_len = mpz_sizeinbase (r, base >= 0 ? base : -base) + 2;
433#ifdef LIMIT_RESOURCE_USAGE
434	  if (out_len > 100000)
435	    {
436	      printf ("result is about %ld digits, not printing it%s\n",
437		      (long) out_len - 3, newline);
438	      exit (-2);
439	    }
440#endif
441	  tmp = malloc (out_len);
442
443	  if (print_timing)
444	    {
445	      int t;
446	      printf ("output conversion ");
447	      TIME (t, mpz_get_str (tmp, base, r));
448	      printf ("took %d ms%s\n", t, newline);
449	    }
450	  else
451	    mpz_get_str (tmp, base, r);
452
453	  out_len = strlen (tmp);
454	  if (flag_splitup_output)
455	    {
456	      for (s = tmp; out_len > 80; s += 80)
457		{
458		  fwrite (s, 1, 80, stdout);
459		  printf ("%s\n", newline);
460		  out_len -= 80;
461		}
462
463	      fwrite (s, 1, out_len, stdout);
464	    }
465	  else
466	    {
467	      fwrite (tmp, 1, out_len, stdout);
468	    }
469
470	  free (tmp);
471	  printf ("%s\n", newline);
472	}
473      else
474	{
475	  printf ("result is approximately %ld digits%s\n",
476		  (long) mpz_sizeinbase (r, base >= 0 ? base : -base),
477		  newline);
478	}
479
480      free_expr (e);
481    }
482
483  exit (errcode);
484}
485
486char *
487expr (char *str, expr_t *e)
488{
489  expr_t e2;
490
491  str = skipspace (str);
492  if (str[0] == '+')
493    {
494      str = term (str + 1, e);
495    }
496  else if (str[0] == '-')
497    {
498      str = term (str + 1, e);
499      makeexp (e, NEG, *e, NULL);
500    }
501  else if (str[0] == '~')
502    {
503      str = term (str + 1, e);
504      makeexp (e, NOT, *e, NULL);
505    }
506  else
507    {
508      str = term (str, e);
509    }
510
511  for (;;)
512    {
513      str = skipspace (str);
514      switch (str[0])
515	{
516	case 'p':
517	  if (match ("plus", str))
518	    {
519	      str = term (str + 4, &e2);
520	      makeexp (e, PLUS, *e, e2);
521	    }
522	  else
523	    return str;
524	  break;
525	case 'm':
526	  if (match ("minus", str))
527	    {
528	      str = term (str + 5, &e2);
529	      makeexp (e, MINUS, *e, e2);
530	    }
531	  else
532	    return str;
533	  break;
534	case '+':
535	  str = term (str + 1, &e2);
536	  makeexp (e, PLUS, *e, e2);
537	  break;
538	case '-':
539	  str = term (str + 1, &e2);
540	  makeexp (e, MINUS, *e, e2);
541	  break;
542	default:
543	  return str;
544	}
545    }
546}
547
548char *
549term (char *str, expr_t *e)
550{
551  expr_t e2;
552
553  str = power (str, e);
554  for (;;)
555    {
556      str = skipspace (str);
557      switch (str[0])
558	{
559	case 'm':
560	  if (match ("mul", str))
561	    {
562	      str = power (str + 3, &e2);
563	      makeexp (e, MULT, *e, e2);
564	      break;
565	    }
566	  if (match ("mod", str))
567	    {
568	      str = power (str + 3, &e2);
569	      makeexp (e, MOD, *e, e2);
570	      break;
571	    }
572	  return str;
573	case 'd':
574	  if (match ("div", str))
575	    {
576	      str = power (str + 3, &e2);
577	      makeexp (e, DIV, *e, e2);
578	      break;
579	    }
580	  return str;
581	case 'r':
582	  if (match ("rem", str))
583	    {
584	      str = power (str + 3, &e2);
585	      makeexp (e, REM, *e, e2);
586	      break;
587	    }
588	  return str;
589	case 'i':
590	  if (match ("invmod", str))
591	    {
592	      str = power (str + 6, &e2);
593	      makeexp (e, REM, *e, e2);
594	      break;
595	    }
596	  return str;
597	case 't':
598	  if (match ("times", str))
599	    {
600	      str = power (str + 5, &e2);
601	      makeexp (e, MULT, *e, e2);
602	      break;
603	    }
604	  if (match ("thru", str))
605	    {
606	      str = power (str + 4, &e2);
607	      makeexp (e, DIV, *e, e2);
608	      break;
609	    }
610	  if (match ("through", str))
611	    {
612	      str = power (str + 7, &e2);
613	      makeexp (e, DIV, *e, e2);
614	      break;
615	    }
616	  return str;
617	case '*':
618	  str = power (str + 1, &e2);
619	  makeexp (e, MULT, *e, e2);
620	  break;
621	case '/':
622	  str = power (str + 1, &e2);
623	  makeexp (e, DIV, *e, e2);
624	  break;
625	case '%':
626	  str = power (str + 1, &e2);
627	  makeexp (e, MOD, *e, e2);
628	  break;
629	default:
630	  return str;
631	}
632    }
633}
634
635char *
636power (char *str, expr_t *e)
637{
638  expr_t e2;
639
640  str = factor (str, e);
641  while (str[0] == '!')
642    {
643      str++;
644      makeexp (e, FAC, *e, NULL);
645    }
646  str = skipspace (str);
647  if (str[0] == '^')
648    {
649      str = power (str + 1, &e2);
650      makeexp (e, POW, *e, e2);
651    }
652  return str;
653}
654
655int
656match (char *s, char *str)
657{
658  char *ostr = str;
659  int i;
660
661  for (i = 0; s[i] != 0; i++)
662    {
663      if (str[i] != s[i])
664	return 0;
665    }
666  str = skipspace (str + i);
667  return str - ostr;
668}
669
670int
671matchp (char *s, char *str)
672{
673  char *ostr = str;
674  int i;
675
676  for (i = 0; s[i] != 0; i++)
677    {
678      if (str[i] != s[i])
679	return 0;
680    }
681  str = skipspace (str + i);
682  if (str[0] == '(')
683    return str - ostr + 1;
684  return 0;
685}
686
687struct functions
688{
689  char *spelling;
690  enum op_t op;
691  int arity; /* 1 or 2 means real arity; 0 means arbitrary.  */
692};
693
694struct functions fns[] =
695{
696  {"sqrt", SQRT, 1},
697#if __GNU_MP_VERSION >= 2
698  {"root", ROOT, 2},
699  {"popc", POPCNT, 1},
700  {"hamdist", HAMDIST, 2},
701#endif
702  {"gcd", GCD, 0},
703#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
704  {"lcm", LCM, 0},
705#endif
706  {"and", AND, 0},
707  {"ior", IOR, 0},
708#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
709  {"xor", XOR, 0},
710#endif
711  {"plus", PLUS, 0},
712  {"pow", POW, 2},
713  {"minus", MINUS, 2},
714  {"mul", MULT, 0},
715  {"div", DIV, 2},
716  {"mod", MOD, 2},
717  {"rem", REM, 2},
718#if __GNU_MP_VERSION >= 2
719  {"invmod", INVMOD, 2},
720#endif
721  {"log", LOG, 2},
722  {"log2", LOG2, 1},
723  {"F", FERMAT, 1},
724  {"M", MERSENNE, 1},
725  {"fib", FIBONACCI, 1},
726  {"Fib", FIBONACCI, 1},
727  {"random", RANDOM, 1},
728  {"nextprime", NEXTPRIME, 1},
729  {"binom", BINOM, 2},
730  {"binomial", BINOM, 2},
731  {"fac", FAC, 1},
732  {"fact", FAC, 1},
733  {"factorial", FAC, 1},
734  {"time", TIMING, 1},
735  {"", NOP, 0}
736};
737
738char *
739factor (char *str, expr_t *e)
740{
741  expr_t e1, e2;
742
743  str = skipspace (str);
744
745  if (isalpha (str[0]))
746    {
747      int i;
748      int cnt;
749
750      for (i = 0; fns[i].op != NOP; i++)
751	{
752	  if (fns[i].arity == 1)
753	    {
754	      cnt = matchp (fns[i].spelling, str);
755	      if (cnt != 0)
756		{
757		  str = expr (str + cnt, &e1);
758		  str = skipspace (str);
759		  if (str[0] != ')')
760		    {
761		      error = "expected `)'";
762		      longjmp (errjmpbuf, (int) (long) str);
763		    }
764		  makeexp (e, fns[i].op, e1, NULL);
765		  return str + 1;
766		}
767	    }
768	}
769
770      for (i = 0; fns[i].op != NOP; i++)
771	{
772	  if (fns[i].arity != 1)
773	    {
774	      cnt = matchp (fns[i].spelling, str);
775	      if (cnt != 0)
776		{
777		  str = expr (str + cnt, &e1);
778		  str = skipspace (str);
779
780		  if (str[0] != ',')
781		    {
782		      error = "expected `,' and another operand";
783		      longjmp (errjmpbuf, (int) (long) str);
784		    }
785
786		  str = skipspace (str + 1);
787		  str = expr (str, &e2);
788		  str = skipspace (str);
789
790		  if (fns[i].arity == 0)
791		    {
792		      while (str[0] == ',')
793			{
794			  makeexp (&e1, fns[i].op, e1, e2);
795			  str = skipspace (str + 1);
796			  str = expr (str, &e2);
797			  str = skipspace (str);
798			}
799		    }
800
801		  if (str[0] != ')')
802		    {
803		      error = "expected `)'";
804		      longjmp (errjmpbuf, (int) (long) str);
805		    }
806
807		  makeexp (e, fns[i].op, e1, e2);
808		  return str + 1;
809		}
810	    }
811	}
812    }
813
814  if (str[0] == '(')
815    {
816      str = expr (str + 1, e);
817      str = skipspace (str);
818      if (str[0] != ')')
819	{
820	  error = "expected `)'";
821	  longjmp (errjmpbuf, (int) (long) str);
822	}
823      str++;
824    }
825  else if (str[0] >= '0' && str[0] <= '9')
826    {
827      expr_t res;
828      char *s, *sc;
829
830      res = malloc (sizeof (struct expr));
831      res -> op = LIT;
832      mpz_init (res->operands.val);
833
834      s = str;
835      while (isalnum (str[0]))
836	str++;
837      sc = malloc (str - s + 1);
838      memcpy (sc, s, str - s);
839      sc[str - s] = 0;
840
841      mpz_set_str (res->operands.val, sc, 0);
842      *e = res;
843      free (sc);
844    }
845  else
846    {
847      error = "operand expected";
848      longjmp (errjmpbuf, (int) (long) str);
849    }
850  return str;
851}
852
853char *
854skipspace (char *str)
855{
856  while (str[0] == ' ')
857    str++;
858  return str;
859}
860
861/* Make a new expression with operation OP and right hand side
862   RHS and left hand side lhs.  Put the result in R.  */
863void
864makeexp (expr_t *r, enum op_t op, expr_t lhs, expr_t rhs)
865{
866  expr_t res;
867  res = malloc (sizeof (struct expr));
868  res -> op = op;
869  res -> operands.ops.lhs = lhs;
870  res -> operands.ops.rhs = rhs;
871  *r = res;
872  return;
873}
874
875/* Free the memory used by expression E.  */
876void
877free_expr (expr_t e)
878{
879  if (e->op != LIT)
880    {
881      free_expr (e->operands.ops.lhs);
882      if (e->operands.ops.rhs != NULL)
883	free_expr (e->operands.ops.rhs);
884    }
885  else
886    {
887      mpz_clear (e->operands.val);
888    }
889}
890
891/* Evaluate the expression E and put the result in R.  */
892void
893mpz_eval_expr (mpz_ptr r, expr_t e)
894{
895  mpz_t lhs, rhs;
896
897  switch (e->op)
898    {
899    case LIT:
900      mpz_set (r, e->operands.val);
901      return;
902    case PLUS:
903      mpz_init (lhs); mpz_init (rhs);
904      mpz_eval_expr (lhs, e->operands.ops.lhs);
905      mpz_eval_expr (rhs, e->operands.ops.rhs);
906      mpz_add (r, lhs, rhs);
907      mpz_clear (lhs); mpz_clear (rhs);
908      return;
909    case MINUS:
910      mpz_init (lhs); mpz_init (rhs);
911      mpz_eval_expr (lhs, e->operands.ops.lhs);
912      mpz_eval_expr (rhs, e->operands.ops.rhs);
913      mpz_sub (r, lhs, rhs);
914      mpz_clear (lhs); mpz_clear (rhs);
915      return;
916    case MULT:
917      mpz_init (lhs); mpz_init (rhs);
918      mpz_eval_expr (lhs, e->operands.ops.lhs);
919      mpz_eval_expr (rhs, e->operands.ops.rhs);
920      mpz_mul (r, lhs, rhs);
921      mpz_clear (lhs); mpz_clear (rhs);
922      return;
923    case DIV:
924      mpz_init (lhs); mpz_init (rhs);
925      mpz_eval_expr (lhs, e->operands.ops.lhs);
926      mpz_eval_expr (rhs, e->operands.ops.rhs);
927      mpz_fdiv_q (r, lhs, rhs);
928      mpz_clear (lhs); mpz_clear (rhs);
929      return;
930    case MOD:
931      mpz_init (rhs);
932      mpz_eval_expr (rhs, e->operands.ops.rhs);
933      mpz_abs (rhs, rhs);
934      mpz_eval_mod_expr (r, e->operands.ops.lhs, rhs);
935      mpz_clear (rhs);
936      return;
937    case REM:
938      /* Check if lhs operand is POW expression and optimize for that case.  */
939      if (e->operands.ops.lhs->op == POW)
940	{
941	  mpz_t powlhs, powrhs;
942	  mpz_init (powlhs);
943	  mpz_init (powrhs);
944	  mpz_init (rhs);
945	  mpz_eval_expr (powlhs, e->operands.ops.lhs->operands.ops.lhs);
946	  mpz_eval_expr (powrhs, e->operands.ops.lhs->operands.ops.rhs);
947	  mpz_eval_expr (rhs, e->operands.ops.rhs);
948	  mpz_powm (r, powlhs, powrhs, rhs);
949	  if (mpz_cmp_si (rhs, 0L) < 0)
950	    mpz_neg (r, r);
951	  mpz_clear (powlhs);
952	  mpz_clear (powrhs);
953	  mpz_clear (rhs);
954	  return;
955	}
956
957      mpz_init (lhs); mpz_init (rhs);
958      mpz_eval_expr (lhs, e->operands.ops.lhs);
959      mpz_eval_expr (rhs, e->operands.ops.rhs);
960      mpz_fdiv_r (r, lhs, rhs);
961      mpz_clear (lhs); mpz_clear (rhs);
962      return;
963#if __GNU_MP_VERSION >= 2
964    case INVMOD:
965      mpz_init (lhs); mpz_init (rhs);
966      mpz_eval_expr (lhs, e->operands.ops.lhs);
967      mpz_eval_expr (rhs, e->operands.ops.rhs);
968      mpz_invert (r, lhs, rhs);
969      mpz_clear (lhs); mpz_clear (rhs);
970      return;
971#endif
972    case POW:
973      mpz_init (lhs); mpz_init (rhs);
974      mpz_eval_expr (lhs, e->operands.ops.lhs);
975      if (mpz_cmpabs_ui (lhs, 1) <= 0)
976	{
977	  /* For 0^rhs and 1^rhs, we just need to verify that
978	     rhs is well-defined.  For (-1)^rhs we need to
979	     determine (rhs mod 2).  For simplicity, compute
980	     (rhs mod 2) for all three cases.  */
981	  expr_t two, et;
982	  two = malloc (sizeof (struct expr));
983	  two -> op = LIT;
984	  mpz_init_set_ui (two->operands.val, 2L);
985	  makeexp (&et, MOD, e->operands.ops.rhs, two);
986	  e->operands.ops.rhs = et;
987	}
988
989      mpz_eval_expr (rhs, e->operands.ops.rhs);
990      if (mpz_cmp_si (rhs, 0L) == 0)
991	/* x^0 is 1 */
992	mpz_set_ui (r, 1L);
993      else if (mpz_cmp_si (lhs, 0L) == 0)
994	/* 0^y (where y != 0) is 0 */
995	mpz_set_ui (r, 0L);
996      else if (mpz_cmp_ui (lhs, 1L) == 0)
997	/* 1^y is 1 */
998	mpz_set_ui (r, 1L);
999      else if (mpz_cmp_si (lhs, -1L) == 0)
1000	/* (-1)^y just depends on whether y is even or odd */
1001	mpz_set_si (r, (mpz_get_ui (rhs) & 1) ? -1L : 1L);
1002      else if (mpz_cmp_si (rhs, 0L) < 0)
1003	/* x^(-n) is 0 */
1004	mpz_set_ui (r, 0L);
1005      else
1006	{
1007	  unsigned long int cnt;
1008	  unsigned long int y;
1009	  /* error if exponent does not fit into an unsigned long int.  */
1010	  if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1011	    goto pow_err;
1012
1013	  y = mpz_get_ui (rhs);
1014	  /* x^y == (x/(2^c))^y * 2^(c*y) */
1015#if __GNU_MP_VERSION >= 2
1016	  cnt = mpz_scan1 (lhs, 0);
1017#else
1018	  cnt = 0;
1019#endif
1020	  if (cnt != 0)
1021	    {
1022	      if (y * cnt / cnt != y)
1023		goto pow_err;
1024	      mpz_tdiv_q_2exp (lhs, lhs, cnt);
1025	      mpz_pow_ui (r, lhs, y);
1026	      mpz_mul_2exp (r, r, y * cnt);
1027	    }
1028	  else
1029	    mpz_pow_ui (r, lhs, y);
1030	}
1031      mpz_clear (lhs); mpz_clear (rhs);
1032      return;
1033    pow_err:
1034      error = "result of `pow' operator too large";
1035      mpz_clear (lhs); mpz_clear (rhs);
1036      longjmp (errjmpbuf, 1);
1037    case GCD:
1038      mpz_init (lhs); mpz_init (rhs);
1039      mpz_eval_expr (lhs, e->operands.ops.lhs);
1040      mpz_eval_expr (rhs, e->operands.ops.rhs);
1041      mpz_gcd (r, lhs, rhs);
1042      mpz_clear (lhs); mpz_clear (rhs);
1043      return;
1044#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1045    case LCM:
1046      mpz_init (lhs); mpz_init (rhs);
1047      mpz_eval_expr (lhs, e->operands.ops.lhs);
1048      mpz_eval_expr (rhs, e->operands.ops.rhs);
1049      mpz_lcm (r, lhs, rhs);
1050      mpz_clear (lhs); mpz_clear (rhs);
1051      return;
1052#endif
1053    case AND:
1054      mpz_init (lhs); mpz_init (rhs);
1055      mpz_eval_expr (lhs, e->operands.ops.lhs);
1056      mpz_eval_expr (rhs, e->operands.ops.rhs);
1057      mpz_and (r, lhs, rhs);
1058      mpz_clear (lhs); mpz_clear (rhs);
1059      return;
1060    case IOR:
1061      mpz_init (lhs); mpz_init (rhs);
1062      mpz_eval_expr (lhs, e->operands.ops.lhs);
1063      mpz_eval_expr (rhs, e->operands.ops.rhs);
1064      mpz_ior (r, lhs, rhs);
1065      mpz_clear (lhs); mpz_clear (rhs);
1066      return;
1067#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1068    case XOR:
1069      mpz_init (lhs); mpz_init (rhs);
1070      mpz_eval_expr (lhs, e->operands.ops.lhs);
1071      mpz_eval_expr (rhs, e->operands.ops.rhs);
1072      mpz_xor (r, lhs, rhs);
1073      mpz_clear (lhs); mpz_clear (rhs);
1074      return;
1075#endif
1076    case NEG:
1077      mpz_eval_expr (r, e->operands.ops.lhs);
1078      mpz_neg (r, r);
1079      return;
1080    case NOT:
1081      mpz_eval_expr (r, e->operands.ops.lhs);
1082      mpz_com (r, r);
1083      return;
1084    case SQRT:
1085      mpz_init (lhs);
1086      mpz_eval_expr (lhs, e->operands.ops.lhs);
1087      if (mpz_sgn (lhs) < 0)
1088	{
1089	  error = "cannot take square root of negative numbers";
1090	  mpz_clear (lhs);
1091	  longjmp (errjmpbuf, 1);
1092	}
1093      mpz_sqrt (r, lhs);
1094      return;
1095#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1096    case ROOT:
1097      mpz_init (lhs); mpz_init (rhs);
1098      mpz_eval_expr (lhs, e->operands.ops.lhs);
1099      mpz_eval_expr (rhs, e->operands.ops.rhs);
1100      if (mpz_sgn (rhs) <= 0)
1101	{
1102	  error = "cannot take non-positive root orders";
1103	  mpz_clear (lhs); mpz_clear (rhs);
1104	  longjmp (errjmpbuf, 1);
1105	}
1106      if (mpz_sgn (lhs) < 0 && (mpz_get_ui (rhs) & 1) == 0)
1107	{
1108	  error = "cannot take even root orders of negative numbers";
1109	  mpz_clear (lhs); mpz_clear (rhs);
1110	  longjmp (errjmpbuf, 1);
1111	}
1112
1113      {
1114	unsigned long int nth = mpz_get_ui (rhs);
1115	if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1116	  {
1117	    /* If we are asked to take an awfully large root order, cheat and
1118	       ask for the largest order we can pass to mpz_root.  This saves
1119	       some error prone special cases.  */
1120	    nth = ~(unsigned long int) 0;
1121	  }
1122	mpz_root (r, lhs, nth);
1123      }
1124      mpz_clear (lhs); mpz_clear (rhs);
1125      return;
1126#endif
1127    case FAC:
1128      mpz_eval_expr (r, e->operands.ops.lhs);
1129      if (mpz_size (r) > 1)
1130	{
1131	  error = "result of `!' operator too large";
1132	  longjmp (errjmpbuf, 1);
1133	}
1134      mpz_fac_ui (r, mpz_get_ui (r));
1135      return;
1136#if __GNU_MP_VERSION >= 2
1137    case POPCNT:
1138      mpz_eval_expr (r, e->operands.ops.lhs);
1139      { long int cnt;
1140	cnt = mpz_popcount (r);
1141	mpz_set_si (r, cnt);
1142      }
1143      return;
1144    case HAMDIST:
1145      { long int cnt;
1146	mpz_init (lhs); mpz_init (rhs);
1147	mpz_eval_expr (lhs, e->operands.ops.lhs);
1148	mpz_eval_expr (rhs, e->operands.ops.rhs);
1149	cnt = mpz_hamdist (lhs, rhs);
1150	mpz_clear (lhs); mpz_clear (rhs);
1151	mpz_set_si (r, cnt);
1152      }
1153      return;
1154#endif
1155    case LOG2:
1156      mpz_eval_expr (r, e->operands.ops.lhs);
1157      { unsigned long int cnt;
1158	if (mpz_sgn (r) <= 0)
1159	  {
1160	    error = "logarithm of non-positive number";
1161	    longjmp (errjmpbuf, 1);
1162	  }
1163	cnt = mpz_sizeinbase (r, 2);
1164	mpz_set_ui (r, cnt - 1);
1165      }
1166      return;
1167    case LOG:
1168      { unsigned long int cnt;
1169	mpz_init (lhs); mpz_init (rhs);
1170	mpz_eval_expr (lhs, e->operands.ops.lhs);
1171	mpz_eval_expr (rhs, e->operands.ops.rhs);
1172	if (mpz_sgn (lhs) <= 0)
1173	  {
1174	    error = "logarithm of non-positive number";
1175	    mpz_clear (lhs); mpz_clear (rhs);
1176	    longjmp (errjmpbuf, 1);
1177	  }
1178	if (mpz_cmp_ui (rhs, 256) >= 0)
1179	  {
1180	    error = "logarithm base too large";
1181	    mpz_clear (lhs); mpz_clear (rhs);
1182	    longjmp (errjmpbuf, 1);
1183	  }
1184	cnt = mpz_sizeinbase (lhs, mpz_get_ui (rhs));
1185	mpz_set_ui (r, cnt - 1);
1186	mpz_clear (lhs); mpz_clear (rhs);
1187      }
1188      return;
1189    case FERMAT:
1190      {
1191	unsigned long int t;
1192	mpz_init (lhs);
1193	mpz_eval_expr (lhs, e->operands.ops.lhs);
1194	t = (unsigned long int) 1 << mpz_get_ui (lhs);
1195	if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0 || t == 0)
1196	  {
1197	    error = "too large Mersenne number index";
1198	    mpz_clear (lhs);
1199	    longjmp (errjmpbuf, 1);
1200	  }
1201	mpz_set_ui (r, 1);
1202	mpz_mul_2exp (r, r, t);
1203	mpz_add_ui (r, r, 1);
1204	mpz_clear (lhs);
1205      }
1206      return;
1207    case MERSENNE:
1208      mpz_init (lhs);
1209      mpz_eval_expr (lhs, e->operands.ops.lhs);
1210      if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0)
1211	{
1212	  error = "too large Mersenne number index";
1213	  mpz_clear (lhs);
1214	  longjmp (errjmpbuf, 1);
1215	}
1216      mpz_set_ui (r, 1);
1217      mpz_mul_2exp (r, r, mpz_get_ui (lhs));
1218      mpz_sub_ui (r, r, 1);
1219      mpz_clear (lhs);
1220      return;
1221    case FIBONACCI:
1222      { mpz_t t;
1223	unsigned long int n, i;
1224	mpz_init (lhs);
1225	mpz_eval_expr (lhs, e->operands.ops.lhs);
1226	if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
1227	  {
1228	    error = "Fibonacci index out of range";
1229	    mpz_clear (lhs);
1230	    longjmp (errjmpbuf, 1);
1231	  }
1232	n = mpz_get_ui (lhs);
1233	mpz_clear (lhs);
1234
1235#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1236	mpz_fib_ui (r, n);
1237#else
1238	mpz_init_set_ui (t, 1);
1239	mpz_set_ui (r, 1);
1240
1241	if (n <= 2)
1242	  mpz_set_ui (r, 1);
1243	else
1244	  {
1245	    for (i = 3; i <= n; i++)
1246	      {
1247		mpz_add (t, t, r);
1248		mpz_swap (t, r);
1249	      }
1250	  }
1251	mpz_clear (t);
1252#endif
1253      }
1254      return;
1255    case RANDOM:
1256      {
1257	unsigned long int n;
1258	mpz_init (lhs);
1259	mpz_eval_expr (lhs, e->operands.ops.lhs);
1260	if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
1261	  {
1262	    error = "random number size out of range";
1263	    mpz_clear (lhs);
1264	    longjmp (errjmpbuf, 1);
1265	  }
1266	n = mpz_get_ui (lhs);
1267	mpz_clear (lhs);
1268	mpz_urandomb (r, rstate, n);
1269      }
1270      return;
1271    case NEXTPRIME:
1272      {
1273	mpz_eval_expr (r, e->operands.ops.lhs);
1274	mpz_nextprime (r, r);
1275      }
1276      return;
1277    case BINOM:
1278      mpz_init (lhs); mpz_init (rhs);
1279      mpz_eval_expr (lhs, e->operands.ops.lhs);
1280      mpz_eval_expr (rhs, e->operands.ops.rhs);
1281      {
1282	unsigned long int k;
1283	if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1284	  {
1285	    error = "k too large in (n over k) expression";
1286	    mpz_clear (lhs); mpz_clear (rhs);
1287	    longjmp (errjmpbuf, 1);
1288	  }
1289	k = mpz_get_ui (rhs);
1290	mpz_bin_ui (r, lhs, k);
1291      }
1292      mpz_clear (lhs); mpz_clear (rhs);
1293      return;
1294    case TIMING:
1295      {
1296	int t0;
1297	t0 = cputime ();
1298	mpz_eval_expr (r, e->operands.ops.lhs);
1299	printf ("time: %d\n", cputime () - t0);
1300      }
1301      return;
1302    default:
1303      abort ();
1304    }
1305}
1306
1307/* Evaluate the expression E modulo MOD and put the result in R.  */
1308void
1309mpz_eval_mod_expr (mpz_ptr r, expr_t e, mpz_ptr mod)
1310{
1311  mpz_t lhs, rhs;
1312
1313  switch (e->op)
1314    {
1315      case POW:
1316	mpz_init (lhs); mpz_init (rhs);
1317	mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1318	mpz_eval_expr (rhs, e->operands.ops.rhs);
1319	mpz_powm (r, lhs, rhs, mod);
1320	mpz_clear (lhs); mpz_clear (rhs);
1321	return;
1322      case PLUS:
1323	mpz_init (lhs); mpz_init (rhs);
1324	mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1325	mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1326	mpz_add (r, lhs, rhs);
1327	if (mpz_cmp_si (r, 0L) < 0)
1328	  mpz_add (r, r, mod);
1329	else if (mpz_cmp (r, mod) >= 0)
1330	  mpz_sub (r, r, mod);
1331	mpz_clear (lhs); mpz_clear (rhs);
1332	return;
1333      case MINUS:
1334	mpz_init (lhs); mpz_init (rhs);
1335	mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1336	mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1337	mpz_sub (r, lhs, rhs);
1338	if (mpz_cmp_si (r, 0L) < 0)
1339	  mpz_add (r, r, mod);
1340	else if (mpz_cmp (r, mod) >= 0)
1341	  mpz_sub (r, r, mod);
1342	mpz_clear (lhs); mpz_clear (rhs);
1343	return;
1344      case MULT:
1345	mpz_init (lhs); mpz_init (rhs);
1346	mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1347	mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1348	mpz_mul (r, lhs, rhs);
1349	mpz_mod (r, r, mod);
1350	mpz_clear (lhs); mpz_clear (rhs);
1351	return;
1352      default:
1353	mpz_init (lhs);
1354	mpz_eval_expr (lhs, e);
1355	mpz_mod (r, lhs, mod);
1356	mpz_clear (lhs);
1357	return;
1358    }
1359}
1360
1361void
1362cleanup_and_exit (int sig)
1363{
1364  switch (sig) {
1365#ifdef LIMIT_RESOURCE_USAGE
1366  case SIGXCPU:
1367    printf ("expression took too long to evaluate%s\n", newline);
1368    break;
1369#endif
1370  case SIGFPE:
1371    printf ("divide by zero%s\n", newline);
1372    break;
1373  default:
1374    printf ("expression required too much memory to evaluate%s\n", newline);
1375    break;
1376  }
1377  exit (-2);
1378}
1379