1/* Tests matrix22_mul.
2
3Copyright 2008 Free Software Foundation, Inc.
4
5This file is part of the GNU MP Library.
6
7The GNU MP Library is free software; you can redistribute it and/or modify
8it under the terms of the GNU Lesser General Public License as published by
9the Free Software Foundation; either version 3 of the License, or (at your
10option) any later version.
11
12The GNU MP Library is distributed in the hope that it will be useful, but
13WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
15License for more details.
16
17You should have received a copy of the GNU Lesser General Public License
18along with the GNU MP Library.  If not, see http://www.gnu.org/licenses/.  */
19
20#include <stdio.h>
21#include <stdlib.h>
22
23#include "gmp.h"
24#include "gmp-impl.h"
25#include "tests.h"
26
27struct matrix {
28  mp_size_t alloc;
29  mp_size_t n;
30  mp_ptr e00, e01, e10, e11;
31};
32
33static void
34matrix_init (struct matrix *M, mp_size_t n)
35{
36  mp_ptr p = refmpn_malloc_limbs (4*(n+1));
37  M->e00 = p; p += n+1;
38  M->e01 = p; p += n+1;
39  M->e10 = p; p += n+1;
40  M->e11 = p;
41  M->alloc = n + 1;
42  M->n = 0;
43}
44
45static void
46matrix_clear (struct matrix *M)
47{
48  refmpn_free_limbs (M->e00);
49}
50
51static void
52matrix_copy (struct matrix *R, const struct matrix *M)
53{
54  R->n = M->n;
55  MPN_COPY (R->e00, M->e00, M->n);
56  MPN_COPY (R->e01, M->e01, M->n);
57  MPN_COPY (R->e10, M->e10, M->n);
58  MPN_COPY (R->e11, M->e11, M->n);
59}
60
61/* Used with same size, so no need for normalization. */
62static int
63matrix_equal_p (const struct matrix *A, const struct matrix *B)
64{
65  return (A->n == B->n
66	  && mpn_cmp (A->e00, B->e00, A->n) == 0
67	  && mpn_cmp (A->e01, B->e01, A->n) == 0
68	  && mpn_cmp (A->e10, B->e10, A->n) == 0
69	  && mpn_cmp (A->e11, B->e11, A->n) == 0);
70}
71
72static void
73matrix_random(struct matrix *M, mp_size_t n, gmp_randstate_ptr rands)
74{
75  M->n = n;
76  mpn_random (M->e00, n);
77  mpn_random (M->e01, n);
78  mpn_random (M->e10, n);
79  mpn_random (M->e11, n);
80}
81
82#define MUL(rp, ap, an, bp, bn) do { \
83    if (an > bn)		     \
84      mpn_mul (rp, ap, an, bp, bn);  \
85    else			     \
86      mpn_mul (rp, bp, bn, ap, an);  \
87  } while(0)
88
89static void
90ref_matrix22_mul (struct matrix *R,
91		  const struct matrix *A,
92		  const struct matrix *B, mp_ptr tp)
93{
94  mp_size_t an, bn, n;
95  mp_ptr r00, r01, r10, r11, a00, a01, a10, a11, b00, b01, b10, b11;
96
97  if (A->n >= B->n)
98    {
99      r00 = R->e00; a00 = A->e00; b00 = B->e00;
100      r01 = R->e01; a01 = A->e01; b01 = B->e01;
101      r10 = R->e10; a10 = A->e10; b10 = B->e10;
102      r11 = R->e11; a11 = A->e11; b11 = B->e11;
103      an = A->n, bn = B->n;
104    }
105  else
106    {
107      /* Transpose */
108      r00 = R->e00; a00 = B->e00; b00 = A->e00;
109      r01 = R->e10; a01 = B->e10; b01 = A->e10;
110      r10 = R->e01; a10 = B->e01; b10 = A->e01;
111      r11 = R->e11; a11 = B->e11; b11 = A->e11;
112      an = B->n, bn = A->n;
113    }
114  n = an + bn;
115  R->n = n + 1;
116
117  mpn_mul (r00, a00, an, b00, bn);
118  mpn_mul (tp, a01, an, b10, bn);
119  r00[n] = mpn_add_n (r00, r00, tp, n);
120
121  mpn_mul (r01, a00, an, b01, bn);
122  mpn_mul (tp, a01, an, b11, bn);
123  r01[n] = mpn_add_n (r01, r01, tp, n);
124
125  mpn_mul (r10, a10, an, b00, bn);
126  mpn_mul (tp, a11, an, b10, bn);
127  r10[n] = mpn_add_n (r10, r10, tp, n);
128
129  mpn_mul (r11, a10, an, b01, bn);
130  mpn_mul (tp, a11, an, b11, bn);
131  r11[n] = mpn_add_n (r11, r11, tp, n);
132}
133
134static void
135one_test (const struct matrix *A, const struct matrix *B, int i)
136{
137  struct matrix R;
138  struct matrix P;
139  mp_ptr tp;
140
141  matrix_init (&R, A->n + B->n + 1);
142  matrix_init (&P, A->n + B->n + 1);
143
144  tp = refmpn_malloc_limbs (mpn_matrix22_mul_itch (A->n, B->n));
145
146  ref_matrix22_mul (&R, A, B, tp);
147  matrix_copy (&P, A);
148  mpn_matrix22_mul (P.e00, P.e01, P.e10, P.e11, A->n,
149		    B->e00, B->e01, B->e10, B->e11, B->n, tp);
150  P.n = A->n + B->n + 1;
151  if (!matrix_equal_p (&R, &P))
152    {
153      fprintf (stderr, "ERROR in test %d\n", i);
154      gmp_fprintf (stderr, "A = (%Nx, %Nx\n      %Nx, %Nx)\n"
155		   "B = (%Nx, %Nx\n      %Nx, %Nx)\n"
156		   "R = (%Nx, %Nx (expected)\n      %Nx, %Nx)\n"
157		   "P = (%Nx, %Nx (incorrect)\n      %Nx, %Nx)\n",
158		   A->e00, A->n, A->e01, A->n, A->e10, A->n, A->e11, A->n,
159		   B->e00, B->n, B->e01, B->n, B->e10, B->n, B->e11, B->n,
160		   R.e00, R.n, R.e01, R.n, R.e10, R.n, R.e11, R.n,
161		   P.e00, P.n, P.e01, P.n, P.e10, P.n, P.e11, P.n);
162      abort();
163    }
164  refmpn_free_limbs (tp);
165  matrix_clear (&R);
166  matrix_clear (&P);
167}
168
169#define MAX_SIZE (2+2*MATRIX22_STRASSEN_THRESHOLD)
170
171int
172main (int argc, char **argv)
173{
174  struct matrix A;
175  struct matrix B;
176
177  gmp_randstate_ptr rands;
178  mpz_t bs;
179  int i;
180
181  tests_start ();
182  rands = RANDS;
183
184  matrix_init (&A, MAX_SIZE);
185  matrix_init (&B, MAX_SIZE);
186  mpz_init (bs);
187
188  for (i = 0; i < 1000; i++)
189    {
190      mp_size_t an, bn;
191      mpz_urandomb (bs, rands, 32);
192      an = 1 + mpz_get_ui (bs) % MAX_SIZE;
193      mpz_urandomb (bs, rands, 32);
194      bn = 1 + mpz_get_ui (bs) % MAX_SIZE;
195
196      matrix_random (&A, an, rands);
197      matrix_random (&B, bn, rands);
198
199      one_test (&A, &B, i);
200    }
201  mpz_clear (bs);
202  matrix_clear (&A);
203  matrix_clear (&B);
204
205  tests_end ();
206  return 0;
207}
208