1/* Tests matrix22_mul. 2 3Copyright 2008 Free Software Foundation, Inc. 4 5This file is part of the GNU MP Library. 6 7The GNU MP Library is free software; you can redistribute it and/or modify 8it under the terms of the GNU Lesser General Public License as published by 9the Free Software Foundation; either version 3 of the License, or (at your 10option) any later version. 11 12The GNU MP Library is distributed in the hope that it will be useful, but 13WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 14or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public 15License for more details. 16 17You should have received a copy of the GNU Lesser General Public License 18along with the GNU MP Library. If not, see http://www.gnu.org/licenses/. */ 19 20#include <stdio.h> 21#include <stdlib.h> 22 23#include "gmp.h" 24#include "gmp-impl.h" 25#include "tests.h" 26 27struct matrix { 28 mp_size_t alloc; 29 mp_size_t n; 30 mp_ptr e00, e01, e10, e11; 31}; 32 33static void 34matrix_init (struct matrix *M, mp_size_t n) 35{ 36 mp_ptr p = refmpn_malloc_limbs (4*(n+1)); 37 M->e00 = p; p += n+1; 38 M->e01 = p; p += n+1; 39 M->e10 = p; p += n+1; 40 M->e11 = p; 41 M->alloc = n + 1; 42 M->n = 0; 43} 44 45static void 46matrix_clear (struct matrix *M) 47{ 48 refmpn_free_limbs (M->e00); 49} 50 51static void 52matrix_copy (struct matrix *R, const struct matrix *M) 53{ 54 R->n = M->n; 55 MPN_COPY (R->e00, M->e00, M->n); 56 MPN_COPY (R->e01, M->e01, M->n); 57 MPN_COPY (R->e10, M->e10, M->n); 58 MPN_COPY (R->e11, M->e11, M->n); 59} 60 61/* Used with same size, so no need for normalization. */ 62static int 63matrix_equal_p (const struct matrix *A, const struct matrix *B) 64{ 65 return (A->n == B->n 66 && mpn_cmp (A->e00, B->e00, A->n) == 0 67 && mpn_cmp (A->e01, B->e01, A->n) == 0 68 && mpn_cmp (A->e10, B->e10, A->n) == 0 69 && mpn_cmp (A->e11, B->e11, A->n) == 0); 70} 71 72static void 73matrix_random(struct matrix *M, mp_size_t n, gmp_randstate_ptr rands) 74{ 75 M->n = n; 76 mpn_random (M->e00, n); 77 mpn_random (M->e01, n); 78 mpn_random (M->e10, n); 79 mpn_random (M->e11, n); 80} 81 82#define MUL(rp, ap, an, bp, bn) do { \ 83 if (an > bn) \ 84 mpn_mul (rp, ap, an, bp, bn); \ 85 else \ 86 mpn_mul (rp, bp, bn, ap, an); \ 87 } while(0) 88 89static void 90ref_matrix22_mul (struct matrix *R, 91 const struct matrix *A, 92 const struct matrix *B, mp_ptr tp) 93{ 94 mp_size_t an, bn, n; 95 mp_ptr r00, r01, r10, r11, a00, a01, a10, a11, b00, b01, b10, b11; 96 97 if (A->n >= B->n) 98 { 99 r00 = R->e00; a00 = A->e00; b00 = B->e00; 100 r01 = R->e01; a01 = A->e01; b01 = B->e01; 101 r10 = R->e10; a10 = A->e10; b10 = B->e10; 102 r11 = R->e11; a11 = A->e11; b11 = B->e11; 103 an = A->n, bn = B->n; 104 } 105 else 106 { 107 /* Transpose */ 108 r00 = R->e00; a00 = B->e00; b00 = A->e00; 109 r01 = R->e10; a01 = B->e10; b01 = A->e10; 110 r10 = R->e01; a10 = B->e01; b10 = A->e01; 111 r11 = R->e11; a11 = B->e11; b11 = A->e11; 112 an = B->n, bn = A->n; 113 } 114 n = an + bn; 115 R->n = n + 1; 116 117 mpn_mul (r00, a00, an, b00, bn); 118 mpn_mul (tp, a01, an, b10, bn); 119 r00[n] = mpn_add_n (r00, r00, tp, n); 120 121 mpn_mul (r01, a00, an, b01, bn); 122 mpn_mul (tp, a01, an, b11, bn); 123 r01[n] = mpn_add_n (r01, r01, tp, n); 124 125 mpn_mul (r10, a10, an, b00, bn); 126 mpn_mul (tp, a11, an, b10, bn); 127 r10[n] = mpn_add_n (r10, r10, tp, n); 128 129 mpn_mul (r11, a10, an, b01, bn); 130 mpn_mul (tp, a11, an, b11, bn); 131 r11[n] = mpn_add_n (r11, r11, tp, n); 132} 133 134static void 135one_test (const struct matrix *A, const struct matrix *B, int i) 136{ 137 struct matrix R; 138 struct matrix P; 139 mp_ptr tp; 140 141 matrix_init (&R, A->n + B->n + 1); 142 matrix_init (&P, A->n + B->n + 1); 143 144 tp = refmpn_malloc_limbs (mpn_matrix22_mul_itch (A->n, B->n)); 145 146 ref_matrix22_mul (&R, A, B, tp); 147 matrix_copy (&P, A); 148 mpn_matrix22_mul (P.e00, P.e01, P.e10, P.e11, A->n, 149 B->e00, B->e01, B->e10, B->e11, B->n, tp); 150 P.n = A->n + B->n + 1; 151 if (!matrix_equal_p (&R, &P)) 152 { 153 fprintf (stderr, "ERROR in test %d\n", i); 154 gmp_fprintf (stderr, "A = (%Nx, %Nx\n %Nx, %Nx)\n" 155 "B = (%Nx, %Nx\n %Nx, %Nx)\n" 156 "R = (%Nx, %Nx (expected)\n %Nx, %Nx)\n" 157 "P = (%Nx, %Nx (incorrect)\n %Nx, %Nx)\n", 158 A->e00, A->n, A->e01, A->n, A->e10, A->n, A->e11, A->n, 159 B->e00, B->n, B->e01, B->n, B->e10, B->n, B->e11, B->n, 160 R.e00, R.n, R.e01, R.n, R.e10, R.n, R.e11, R.n, 161 P.e00, P.n, P.e01, P.n, P.e10, P.n, P.e11, P.n); 162 abort(); 163 } 164 refmpn_free_limbs (tp); 165 matrix_clear (&R); 166 matrix_clear (&P); 167} 168 169#define MAX_SIZE (2+2*MATRIX22_STRASSEN_THRESHOLD) 170 171int 172main (int argc, char **argv) 173{ 174 struct matrix A; 175 struct matrix B; 176 177 gmp_randstate_ptr rands; 178 mpz_t bs; 179 int i; 180 181 tests_start (); 182 rands = RANDS; 183 184 matrix_init (&A, MAX_SIZE); 185 matrix_init (&B, MAX_SIZE); 186 mpz_init (bs); 187 188 for (i = 0; i < 1000; i++) 189 { 190 mp_size_t an, bn; 191 mpz_urandomb (bs, rands, 32); 192 an = 1 + mpz_get_ui (bs) % MAX_SIZE; 193 mpz_urandomb (bs, rands, 32); 194 bn = 1 + mpz_get_ui (bs) % MAX_SIZE; 195 196 matrix_random (&A, an, rands); 197 matrix_random (&B, bn, rands); 198 199 one_test (&A, &B, i); 200 } 201 mpz_clear (bs); 202 matrix_clear (&A); 203 matrix_clear (&B); 204 205 tests_end (); 206 return 0; 207} 208