1/* Test mpz_perfect_square_p.
2
3Copyright 2000, 2001, 2002 Free Software Foundation, Inc.
4
5This file is part of the GNU MP Library.
6
7The GNU MP Library is free software; you can redistribute it and/or modify
8it under the terms of the GNU Lesser General Public License as published by
9the Free Software Foundation; either version 3 of the License, or (at your
10option) any later version.
11
12The GNU MP Library is distributed in the hope that it will be useful, but
13WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
15License for more details.
16
17You should have received a copy of the GNU Lesser General Public License
18along with the GNU MP Library.  If not, see http://www.gnu.org/licenses/.  */
19
20#include <stdio.h>
21#include <stdlib.h>
22
23#include "gmp.h"
24#include "gmp-impl.h"
25#include "tests.h"
26
27#include "mpn/perfsqr.h"
28
29
30/* check_modulo() exercises mpz_perfect_square_p on squares which cover each
31   possible quadratic residue to each divisor used within
32   mpn_perfect_square_p, ensuring those residues aren't incorrectly claimed
33   to be non-residues.
34
35   Each divisor is taken separately.  It's arranged that n is congruent to 0
36   modulo the other divisors, 0 of course being a quadratic residue to any
37   modulus.
38
39   The values "(j*others)^2" cover all quadratic residues mod divisor[i],
40   but in no particular order.  j is run from 1<=j<=divisor[i] so that zero
41   is excluded.  A literal n==0 doesn't reach the residue tests.  */
42
43void
44check_modulo (void)
45{
46  static const unsigned long  divisor[] = PERFSQR_DIVISORS;
47  unsigned long  i, j;
48
49  mpz_t  alldiv, others, n;
50
51  mpz_init (alldiv);
52  mpz_init (others);
53  mpz_init (n);
54
55  /* product of all divisors */
56  mpz_set_ui (alldiv, 1L);
57  for (i = 0; i < numberof (divisor); i++)
58    mpz_mul_ui (alldiv, alldiv, divisor[i]);
59
60  for (i = 0; i < numberof (divisor); i++)
61    {
62      /* product of all divisors except i */
63      mpz_set_ui (others, 1L);
64      for (j = 0; j < numberof (divisor); j++)
65        if (i != j)
66          mpz_mul_ui (others, others, divisor[j]);
67
68      for (j = 1; j <= divisor[i]; j++)
69        {
70          /* square */
71          mpz_mul_ui (n, others, j);
72          mpz_mul (n, n, n);
73          if (! mpz_perfect_square_p (n))
74            {
75              printf ("mpz_perfect_square_p got 0, want 1\n");
76              mpz_trace ("  n", n);
77              abort ();
78            }
79        }
80    }
81
82  mpz_clear (alldiv);
83  mpz_clear (others);
84  mpz_clear (n);
85}
86
87
88/* Exercise mpz_perfect_square_p compared to what mpz_sqrt says. */
89void
90check_sqrt (int reps)
91{
92  mpz_t x2, x2t, x;
93  mp_size_t x2n;
94  int res;
95  int i;
96  /* int cnt = 0; */
97  gmp_randstate_ptr rands = RANDS;
98  mpz_t bs;
99
100  mpz_init (bs);
101
102  mpz_init (x2);
103  mpz_init (x);
104  mpz_init (x2t);
105
106  for (i = 0; i < reps; i++)
107    {
108      mpz_urandomb (bs, rands, 9);
109      x2n = mpz_get_ui (bs);
110      mpz_rrandomb (x2, rands, x2n);
111      /* mpz_out_str (stdout, -16, x2); puts (""); */
112
113      res = mpz_perfect_square_p (x2);
114      mpz_sqrt (x, x2);
115      mpz_mul (x2t, x, x);
116
117      if (res != (mpz_cmp (x2, x2t) == 0))
118        {
119          printf    ("mpz_perfect_square_p and mpz_sqrt differ\n");
120          mpz_trace ("   x  ", x);
121          mpz_trace ("   x2 ", x2);
122          mpz_trace ("   x2t", x2t);
123          printf    ("   mpz_perfect_square_p %d\n", res);
124          printf    ("   mpz_sqrt             %d\n", mpz_cmp (x2, x2t) == 0);
125          abort ();
126        }
127
128      /* cnt += res != 0; */
129    }
130  /* printf ("%d/%d perfect squares\n", cnt, reps); */
131
132  mpz_clear (bs);
133  mpz_clear (x2);
134  mpz_clear (x);
135  mpz_clear (x2t);
136}
137
138
139int
140main (int argc, char **argv)
141{
142  int reps = 200000;
143
144  tests_start ();
145  mp_trace_base = -16;
146
147  if (argc == 2)
148     reps = atoi (argv[1]);
149
150  check_modulo ();
151  check_sqrt (reps);
152
153  tests_end ();
154  exit (0);
155}
156