1/*  $OpenBSD: sntrup761.c,v 1.1 2021/05/28 18:01:39 tobhe Exp $ */
2
3/*
4 * Public Domain, Authors:
5 * - Daniel J. Bernstein
6 * - Chitchanok Chuengsatiansup
7 * - Tanja Lange
8 * - Christine van Vredendaal
9 */
10
11#include <string.h>
12#include "crypto_api.h"
13
14#define int8 crypto_int8
15#define uint8 crypto_uint8
16#define int16 crypto_int16
17#define uint16 crypto_uint16
18#define int32 crypto_int32
19#define uint32 crypto_uint32
20#define int64 crypto_int64
21#define uint64 crypto_uint64
22
23/* from supercop-20201130/crypto_sort/int32/portable4/int32_minmax.inc */
24#define int32_MINMAX(a,b) \
25do { \
26  int64_t ab = (int64_t)b ^ (int64_t)a; \
27  int64_t c = (int64_t)b - (int64_t)a; \
28  c ^= ab & (c ^ b); \
29  c >>= 31; \
30  c &= ab; \
31  a ^= c; \
32  b ^= c; \
33} while(0)
34
35/* from supercop-20201130/crypto_sort/int32/portable4/sort.c */
36
37
38static void crypto_sort_int32(void *array,long long n)
39{
40  long long top,p,q,r,i,j;
41  int32 *x = array;
42
43  if (n < 2) return;
44  top = 1;
45  while (top < n - top) top += top;
46
47  for (p = top;p >= 1;p >>= 1) {
48    i = 0;
49    while (i + 2 * p <= n) {
50      for (j = i;j < i + p;++j)
51        int32_MINMAX(x[j],x[j+p]);
52      i += 2 * p;
53    }
54    for (j = i;j < n - p;++j)
55      int32_MINMAX(x[j],x[j+p]);
56
57    i = 0;
58    j = 0;
59    for (q = top;q > p;q >>= 1) {
60      if (j != i) for (;;) {
61        if (j == n - q) goto done;
62        int32 a = x[j + p];
63        for (r = q;r > p;r >>= 1)
64          int32_MINMAX(a,x[j + r]);
65        x[j + p] = a;
66        ++j;
67        if (j == i + p) {
68          i += 2 * p;
69          break;
70        }
71      }
72      while (i + p <= n - q) {
73        for (j = i;j < i + p;++j) {
74          int32 a = x[j + p];
75          for (r = q;r > p;r >>= 1)
76            int32_MINMAX(a,x[j+r]);
77          x[j + p] = a;
78        }
79        i += 2 * p;
80      }
81      /* now i + p > n - q */
82      j = i;
83      while (j < n - q) {
84        int32 a = x[j + p];
85        for (r = q;r > p;r >>= 1)
86          int32_MINMAX(a,x[j+r]);
87        x[j + p] = a;
88        ++j;
89      }
90
91      done: ;
92    }
93  }
94}
95
96/* from supercop-20201130/crypto_sort/uint32/useint32/sort.c */
97
98/* can save time by vectorizing xor loops */
99/* can save time by integrating xor loops with int32_sort */
100
101static void crypto_sort_uint32(void *array,long long n)
102{
103  crypto_uint32 *x = array;
104  long long j;
105  for (j = 0;j < n;++j) x[j] ^= 0x80000000;
106  crypto_sort_int32(array,n);
107  for (j = 0;j < n;++j) x[j] ^= 0x80000000;
108}
109
110/* from supercop-20201130/crypto_kem/sntrup761/ref/uint32.c */
111
112/*
113CPU division instruction typically takes time depending on x.
114This software is designed to take time independent of x.
115Time still varies depending on m; user must ensure that m is constant.
116Time also varies on CPUs where multiplication is variable-time.
117There could be more CPU issues.
118There could also be compiler issues.
119*/
120
121static void uint32_divmod_uint14(uint32 *q,uint16 *r,uint32 x,uint16 m)
122{
123  uint32 v = 0x80000000;
124  uint32 qpart;
125  uint32 mask;
126
127  v /= m;
128
129  /* caller guarantees m > 0 */
130  /* caller guarantees m < 16384 */
131  /* vm <= 2^31 <= vm+m-1 */
132  /* xvm <= 2^31 x <= xvm+x(m-1) */
133
134  *q = 0;
135
136  qpart = (x*(uint64)v)>>31;
137  /* 2^31 qpart <= xv <= 2^31 qpart + 2^31-1 */
138  /* 2^31 qpart m <= xvm <= 2^31 qpart m + (2^31-1)m */
139  /* 2^31 qpart m <= 2^31 x <= 2^31 qpart m + (2^31-1)m + x(m-1) */
140  /* 0 <= 2^31 newx <= (2^31-1)m + x(m-1) */
141  /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
142  /* 0 <= newx <= (1-1/2^31)(2^14-1) + (2^32-1)((2^14-1)-1)/2^31 */
143
144  x -= qpart*m; *q += qpart;
145  /* x <= 49146 */
146
147  qpart = (x*(uint64)v)>>31;
148  /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
149  /* 0 <= newx <= m + 49146(2^14-1)/2^31 */
150  /* 0 <= newx <= m + 0.4 */
151  /* 0 <= newx <= m */
152
153  x -= qpart*m; *q += qpart;
154  /* x <= m */
155
156  x -= m; *q += 1;
157  mask = -(x>>31);
158  x += mask&(uint32)m; *q += mask;
159  /* x < m */
160
161  *r = x;
162}
163
164
165static uint16 uint32_mod_uint14(uint32 x,uint16 m)
166{
167  uint32 q;
168  uint16 r;
169  uint32_divmod_uint14(&q,&r,x,m);
170  return r;
171}
172
173/* from supercop-20201130/crypto_kem/sntrup761/ref/int32.c */
174
175static void int32_divmod_uint14(int32 *q,uint16 *r,int32 x,uint16 m)
176{
177  uint32 uq,uq2;
178  uint16 ur,ur2;
179  uint32 mask;
180
181  uint32_divmod_uint14(&uq,&ur,0x80000000+(uint32)x,m);
182  uint32_divmod_uint14(&uq2,&ur2,0x80000000,m);
183  ur -= ur2; uq -= uq2;
184  mask = -(uint32)(ur>>15);
185  ur += mask&m; uq += mask;
186  *r = ur; *q = uq;
187}
188
189
190static uint16 int32_mod_uint14(int32 x,uint16 m)
191{
192  int32 q;
193  uint16 r;
194  int32_divmod_uint14(&q,&r,x,m);
195  return r;
196}
197
198/* from supercop-20201130/crypto_kem/sntrup761/ref/paramsmenu.h */
199/* pick one of these three: */
200#define SIZE761
201#undef SIZE653
202#undef SIZE857
203
204/* pick one of these two: */
205#define SNTRUP /* Streamlined NTRU Prime */
206#undef LPR /* NTRU LPRime */
207
208/* from supercop-20201130/crypto_kem/sntrup761/ref/params.h */
209#ifndef params_H
210#define params_H
211
212/* menu of parameter choices: */
213
214
215/* what the menu means: */
216
217#if defined(SIZE761)
218#define p 761
219#define q 4591
220#define Rounded_bytes 1007
221#ifndef LPR
222#define Rq_bytes 1158
223#define w 286
224#else
225#define w 250
226#define tau0 2156
227#define tau1 114
228#define tau2 2007
229#define tau3 287
230#endif
231
232#elif defined(SIZE653)
233#define p 653
234#define q 4621
235#define Rounded_bytes 865
236#ifndef LPR
237#define Rq_bytes 994
238#define w 288
239#else
240#define w 252
241#define tau0 2175
242#define tau1 113
243#define tau2 2031
244#define tau3 290
245#endif
246
247#elif defined(SIZE857)
248#define p 857
249#define q 5167
250#define Rounded_bytes 1152
251#ifndef LPR
252#define Rq_bytes 1322
253#define w 322
254#else
255#define w 281
256#define tau0 2433
257#define tau1 101
258#define tau2 2265
259#define tau3 324
260#endif
261
262#else
263#error "no parameter set defined"
264#endif
265
266#ifdef LPR
267#define I 256
268#endif
269
270#endif
271
272/* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.h */
273#ifndef Decode_H
274#define Decode_H
275
276
277/* Decode(R,s,M,len) */
278/* assumes 0 < M[i] < 16384 */
279/* produces 0 <= R[i] < M[i] */
280
281#endif
282
283/* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.c */
284
285static void Decode(uint16 *out,const unsigned char *S,const uint16 *M,long long len)
286{
287  if (len == 1) {
288    if (M[0] == 1)
289      *out = 0;
290    else if (M[0] <= 256)
291      *out = uint32_mod_uint14(S[0],M[0]);
292    else
293      *out = uint32_mod_uint14(S[0]+(((uint16)S[1])<<8),M[0]);
294  }
295  if (len > 1) {
296    uint16 R2[(len+1)/2];
297    uint16 M2[(len+1)/2];
298    uint16 bottomr[len/2];
299    uint32 bottomt[len/2];
300    long long i;
301    for (i = 0;i < len-1;i += 2) {
302      uint32 m = M[i]*(uint32) M[i+1];
303      if (m > 256*16383) {
304        bottomt[i/2] = 256*256;
305        bottomr[i/2] = S[0]+256*S[1];
306        S += 2;
307        M2[i/2] = (((m+255)>>8)+255)>>8;
308      } else if (m >= 16384) {
309        bottomt[i/2] = 256;
310        bottomr[i/2] = S[0];
311        S += 1;
312        M2[i/2] = (m+255)>>8;
313      } else {
314        bottomt[i/2] = 1;
315        bottomr[i/2] = 0;
316        M2[i/2] = m;
317      }
318    }
319    if (i < len)
320      M2[i/2] = M[i];
321    Decode(R2,S,M2,(len+1)/2);
322    for (i = 0;i < len-1;i += 2) {
323      uint32 r = bottomr[i/2];
324      uint32 r1;
325      uint16 r0;
326      r += bottomt[i/2]*R2[i/2];
327      uint32_divmod_uint14(&r1,&r0,r,M[i]);
328      r1 = uint32_mod_uint14(r1,M[i+1]); /* only needed for invalid inputs */
329      *out++ = r0;
330      *out++ = r1;
331    }
332    if (i < len)
333      *out++ = R2[i/2];
334  }
335}
336
337/* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.h */
338#ifndef Encode_H
339#define Encode_H
340
341
342/* Encode(s,R,M,len) */
343/* assumes 0 <= R[i] < M[i] < 16384 */
344
345#endif
346
347/* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.c */
348
349/* 0 <= R[i] < M[i] < 16384 */
350static void Encode(unsigned char *out,const uint16 *R,const uint16 *M,long long len)
351{
352  if (len == 1) {
353    uint16 r = R[0];
354    uint16 m = M[0];
355    while (m > 1) {
356      *out++ = r;
357      r >>= 8;
358      m = (m+255)>>8;
359    }
360  }
361  if (len > 1) {
362    uint16 R2[(len+1)/2];
363    uint16 M2[(len+1)/2];
364    long long i;
365    for (i = 0;i < len-1;i += 2) {
366      uint32 m0 = M[i];
367      uint32 r = R[i]+R[i+1]*m0;
368      uint32 m = M[i+1]*m0;
369      while (m >= 16384) {
370        *out++ = r;
371        r >>= 8;
372        m = (m+255)>>8;
373      }
374      R2[i/2] = r;
375      M2[i/2] = m;
376    }
377    if (i < len) {
378      R2[i/2] = R[i];
379      M2[i/2] = M[i];
380    }
381    Encode(out,R2,M2,(len+1)/2);
382  }
383}
384
385/* from supercop-20201130/crypto_kem/sntrup761/ref/kem.c */
386
387#ifdef LPR
388#endif
389
390
391/* ----- masks */
392
393#ifndef LPR
394
395/* return -1 if x!=0; else return 0 */
396static int int16_nonzero_mask(int16 x)
397{
398  uint16 u = x; /* 0, else 1...65535 */
399  uint32 v = u; /* 0, else 1...65535 */
400  v = -v; /* 0, else 2^32-65535...2^32-1 */
401  v >>= 31; /* 0, else 1 */
402  return -v; /* 0, else -1 */
403}
404
405#endif
406
407/* return -1 if x<0; otherwise return 0 */
408static int int16_negative_mask(int16 x)
409{
410  uint16 u = x;
411  u >>= 15;
412  return -(int) u;
413  /* alternative with gcc -fwrapv: */
414  /* x>>15 compiles to CPU's arithmetic right shift */
415}
416
417/* ----- arithmetic mod 3 */
418
419typedef int8 small;
420
421/* F3 is always represented as -1,0,1 */
422/* so ZZ_fromF3 is a no-op */
423
424/* x must not be close to top int16 */
425static small F3_freeze(int16 x)
426{
427  return int32_mod_uint14(x+1,3)-1;
428}
429
430/* ----- arithmetic mod q */
431
432#define q12 ((q-1)/2)
433typedef int16 Fq;
434/* always represented as -q12...q12 */
435/* so ZZ_fromFq is a no-op */
436
437/* x must not be close to top int32 */
438static Fq Fq_freeze(int32 x)
439{
440  return int32_mod_uint14(x+q12,q)-q12;
441}
442
443#ifndef LPR
444
445static Fq Fq_recip(Fq a1)
446{
447  int i = 1;
448  Fq ai = a1;
449
450  while (i < q-2) {
451    ai = Fq_freeze(a1*(int32)ai);
452    i += 1;
453  }
454  return ai;
455}
456
457#endif
458
459/* ----- Top and Right */
460
461#ifdef LPR
462#define tau 16
463
464static int8 Top(Fq C)
465{
466  return (tau1*(int32)(C+tau0)+16384)>>15;
467}
468
469static Fq Right(int8 T)
470{
471  return Fq_freeze(tau3*(int32)T-tau2);
472}
473#endif
474
475/* ----- small polynomials */
476
477#ifndef LPR
478
479/* 0 if Weightw_is(r), else -1 */
480static int Weightw_mask(small *r)
481{
482  int weight = 0;
483  int i;
484
485  for (i = 0;i < p;++i) weight += r[i]&1;
486  return int16_nonzero_mask(weight-w);
487}
488
489/* R3_fromR(R_fromRq(r)) */
490static void R3_fromRq(small *out,const Fq *r)
491{
492  int i;
493  for (i = 0;i < p;++i) out[i] = F3_freeze(r[i]);
494}
495
496/* h = f*g in the ring R3 */
497static void R3_mult(small *h,const small *f,const small *g)
498{
499  small fg[p+p-1];
500  small result;
501  int i,j;
502
503  for (i = 0;i < p;++i) {
504    result = 0;
505    for (j = 0;j <= i;++j) result = F3_freeze(result+f[j]*g[i-j]);
506    fg[i] = result;
507  }
508  for (i = p;i < p+p-1;++i) {
509    result = 0;
510    for (j = i-p+1;j < p;++j) result = F3_freeze(result+f[j]*g[i-j]);
511    fg[i] = result;
512  }
513
514  for (i = p+p-2;i >= p;--i) {
515    fg[i-p] = F3_freeze(fg[i-p]+fg[i]);
516    fg[i-p+1] = F3_freeze(fg[i-p+1]+fg[i]);
517  }
518
519  for (i = 0;i < p;++i) h[i] = fg[i];
520}
521
522/* returns 0 if recip succeeded; else -1 */
523static int R3_recip(small *out,const small *in)
524{
525  small f[p+1],g[p+1],v[p+1],r[p+1];
526  int i,loop,delta;
527  int sign,swap,t;
528
529  for (i = 0;i < p+1;++i) v[i] = 0;
530  for (i = 0;i < p+1;++i) r[i] = 0;
531  r[0] = 1;
532  for (i = 0;i < p;++i) f[i] = 0;
533  f[0] = 1; f[p-1] = f[p] = -1;
534  for (i = 0;i < p;++i) g[p-1-i] = in[i];
535  g[p] = 0;
536
537  delta = 1;
538
539  for (loop = 0;loop < 2*p-1;++loop) {
540    for (i = p;i > 0;--i) v[i] = v[i-1];
541    v[0] = 0;
542
543    sign = -g[0]*f[0];
544    swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
545    delta ^= swap&(delta^-delta);
546    delta += 1;
547
548    for (i = 0;i < p+1;++i) {
549      t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
550      t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
551    }
552
553    for (i = 0;i < p+1;++i) g[i] = F3_freeze(g[i]+sign*f[i]);
554    for (i = 0;i < p+1;++i) r[i] = F3_freeze(r[i]+sign*v[i]);
555
556    for (i = 0;i < p;++i) g[i] = g[i+1];
557    g[p] = 0;
558  }
559
560  sign = f[0];
561  for (i = 0;i < p;++i) out[i] = sign*v[p-1-i];
562
563  return int16_nonzero_mask(delta);
564}
565
566#endif
567
568/* ----- polynomials mod q */
569
570/* h = f*g in the ring Rq */
571static void Rq_mult_small(Fq *h,const Fq *f,const small *g)
572{
573  Fq fg[p+p-1];
574  Fq result;
575  int i,j;
576
577  for (i = 0;i < p;++i) {
578    result = 0;
579    for (j = 0;j <= i;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
580    fg[i] = result;
581  }
582  for (i = p;i < p+p-1;++i) {
583    result = 0;
584    for (j = i-p+1;j < p;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
585    fg[i] = result;
586  }
587
588  for (i = p+p-2;i >= p;--i) {
589    fg[i-p] = Fq_freeze(fg[i-p]+fg[i]);
590    fg[i-p+1] = Fq_freeze(fg[i-p+1]+fg[i]);
591  }
592
593  for (i = 0;i < p;++i) h[i] = fg[i];
594}
595
596#ifndef LPR
597
598/* h = 3f in Rq */
599static void Rq_mult3(Fq *h,const Fq *f)
600{
601  int i;
602
603  for (i = 0;i < p;++i) h[i] = Fq_freeze(3*f[i]);
604}
605
606/* out = 1/(3*in) in Rq */
607/* returns 0 if recip succeeded; else -1 */
608static int Rq_recip3(Fq *out,const small *in)
609{
610  Fq f[p+1],g[p+1],v[p+1],r[p+1];
611  int i,loop,delta;
612  int swap,t;
613  int32 f0,g0;
614  Fq scale;
615
616  for (i = 0;i < p+1;++i) v[i] = 0;
617  for (i = 0;i < p+1;++i) r[i] = 0;
618  r[0] = Fq_recip(3);
619  for (i = 0;i < p;++i) f[i] = 0;
620  f[0] = 1; f[p-1] = f[p] = -1;
621  for (i = 0;i < p;++i) g[p-1-i] = in[i];
622  g[p] = 0;
623
624  delta = 1;
625
626  for (loop = 0;loop < 2*p-1;++loop) {
627    for (i = p;i > 0;--i) v[i] = v[i-1];
628    v[0] = 0;
629
630    swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
631    delta ^= swap&(delta^-delta);
632    delta += 1;
633
634    for (i = 0;i < p+1;++i) {
635      t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
636      t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
637    }
638
639    f0 = f[0];
640    g0 = g[0];
641    for (i = 0;i < p+1;++i) g[i] = Fq_freeze(f0*g[i]-g0*f[i]);
642    for (i = 0;i < p+1;++i) r[i] = Fq_freeze(f0*r[i]-g0*v[i]);
643
644    for (i = 0;i < p;++i) g[i] = g[i+1];
645    g[p] = 0;
646  }
647
648  scale = Fq_recip(f[0]);
649  for (i = 0;i < p;++i) out[i] = Fq_freeze(scale*(int32)v[p-1-i]);
650
651  return int16_nonzero_mask(delta);
652}
653
654#endif
655
656/* ----- rounded polynomials mod q */
657
658static void Round(Fq *out,const Fq *a)
659{
660  int i;
661  for (i = 0;i < p;++i) out[i] = a[i]-F3_freeze(a[i]);
662}
663
664/* ----- sorting to generate short polynomial */
665
666static void Short_fromlist(small *out,const uint32 *in)
667{
668  uint32 L[p];
669  int i;
670
671  for (i = 0;i < w;++i) L[i] = in[i]&(uint32)-2;
672  for (i = w;i < p;++i) L[i] = (in[i]&(uint32)-3)|1;
673  crypto_sort_uint32(L,p);
674  for (i = 0;i < p;++i) out[i] = (L[i]&3)-1;
675}
676
677/* ----- underlying hash function */
678
679#define Hash_bytes 32
680
681/* e.g., b = 0 means out = Hash0(in) */
682static void Hash_prefix(unsigned char *out,int b,const unsigned char *in,int inlen)
683{
684  unsigned char x[inlen+1];
685  unsigned char h[64];
686  int i;
687
688  x[0] = b;
689  for (i = 0;i < inlen;++i) x[i+1] = in[i];
690  crypto_hash_sha512(h,x,inlen+1);
691  for (i = 0;i < 32;++i) out[i] = h[i];
692}
693
694/* ----- higher-level randomness */
695
696static uint32 urandom32(void)
697{
698  unsigned char c[4];
699  uint32 out[4];
700
701  randombytes(c,4);
702  out[0] = (uint32)c[0];
703  out[1] = ((uint32)c[1])<<8;
704  out[2] = ((uint32)c[2])<<16;
705  out[3] = ((uint32)c[3])<<24;
706  return out[0]+out[1]+out[2]+out[3];
707}
708
709static void Short_random(small *out)
710{
711  uint32 L[p];
712  int i;
713
714  for (i = 0;i < p;++i) L[i] = urandom32();
715  Short_fromlist(out,L);
716}
717
718#ifndef LPR
719
720static void Small_random(small *out)
721{
722  int i;
723
724  for (i = 0;i < p;++i) out[i] = (((urandom32()&0x3fffffff)*3)>>30)-1;
725}
726
727#endif
728
729/* ----- Streamlined NTRU Prime Core */
730
731#ifndef LPR
732
733/* h,(f,ginv) = KeyGen() */
734static void KeyGen(Fq *h,small *f,small *ginv)
735{
736  small g[p];
737  Fq finv[p];
738
739  for (;;) {
740    Small_random(g);
741    if (R3_recip(ginv,g) == 0) break;
742  }
743  Short_random(f);
744  Rq_recip3(finv,f); /* always works */
745  Rq_mult_small(h,finv,g);
746}
747
748/* c = Encrypt(r,h) */
749static void Encrypt(Fq *c,const small *r,const Fq *h)
750{
751  Fq hr[p];
752
753  Rq_mult_small(hr,h,r);
754  Round(c,hr);
755}
756
757/* r = Decrypt(c,(f,ginv)) */
758static void Decrypt(small *r,const Fq *c,const small *f,const small *ginv)
759{
760  Fq cf[p];
761  Fq cf3[p];
762  small e[p];
763  small ev[p];
764  int mask;
765  int i;
766
767  Rq_mult_small(cf,c,f);
768  Rq_mult3(cf3,cf);
769  R3_fromRq(e,cf3);
770  R3_mult(ev,e,ginv);
771
772  mask = Weightw_mask(ev); /* 0 if weight w, else -1 */
773  for (i = 0;i < w;++i) r[i] = ((ev[i]^1)&~mask)^1;
774  for (i = w;i < p;++i) r[i] = ev[i]&~mask;
775}
776
777#endif
778
779/* ----- NTRU LPRime Core */
780
781#ifdef LPR
782
783/* (G,A),a = KeyGen(G); leaves G unchanged */
784static void KeyGen(Fq *A,small *a,const Fq *G)
785{
786  Fq aG[p];
787
788  Short_random(a);
789  Rq_mult_small(aG,G,a);
790  Round(A,aG);
791}
792
793/* B,T = Encrypt(r,(G,A),b) */
794static void Encrypt(Fq *B,int8 *T,const int8 *r,const Fq *G,const Fq *A,const small *b)
795{
796  Fq bG[p];
797  Fq bA[p];
798  int i;
799
800  Rq_mult_small(bG,G,b);
801  Round(B,bG);
802  Rq_mult_small(bA,A,b);
803  for (i = 0;i < I;++i) T[i] = Top(Fq_freeze(bA[i]+r[i]*q12));
804}
805
806/* r = Decrypt((B,T),a) */
807static void Decrypt(int8 *r,const Fq *B,const int8 *T,const small *a)
808{
809  Fq aB[p];
810  int i;
811
812  Rq_mult_small(aB,B,a);
813  for (i = 0;i < I;++i)
814    r[i] = -int16_negative_mask(Fq_freeze(Right(T[i])-aB[i]+4*w+1));
815}
816
817#endif
818
819/* ----- encoding I-bit inputs */
820
821#ifdef LPR
822
823#define Inputs_bytes (I/8)
824typedef int8 Inputs[I]; /* passed by reference */
825
826static void Inputs_encode(unsigned char *s,const Inputs r)
827{
828  int i;
829  for (i = 0;i < Inputs_bytes;++i) s[i] = 0;
830  for (i = 0;i < I;++i) s[i>>3] |= r[i]<<(i&7);
831}
832
833#endif
834
835/* ----- Expand */
836
837#ifdef LPR
838
839static const unsigned char aes_nonce[16] = {0};
840
841static void Expand(uint32 *L,const unsigned char *k)
842{
843  int i;
844  crypto_stream_aes256ctr((unsigned char *) L,4*p,aes_nonce,k);
845  for (i = 0;i < p;++i) {
846    uint32 L0 = ((unsigned char *) L)[4*i];
847    uint32 L1 = ((unsigned char *) L)[4*i+1];
848    uint32 L2 = ((unsigned char *) L)[4*i+2];
849    uint32 L3 = ((unsigned char *) L)[4*i+3];
850    L[i] = L0+(L1<<8)+(L2<<16)+(L3<<24);
851  }
852}
853
854#endif
855
856/* ----- Seeds */
857
858#ifdef LPR
859
860#define Seeds_bytes 32
861
862static void Seeds_random(unsigned char *s)
863{
864  randombytes(s,Seeds_bytes);
865}
866
867#endif
868
869/* ----- Generator, HashShort */
870
871#ifdef LPR
872
873/* G = Generator(k) */
874static void Generator(Fq *G,const unsigned char *k)
875{
876  uint32 L[p];
877  int i;
878
879  Expand(L,k);
880  for (i = 0;i < p;++i) G[i] = uint32_mod_uint14(L[i],q)-q12;
881}
882
883/* out = HashShort(r) */
884static void HashShort(small *out,const Inputs r)
885{
886  unsigned char s[Inputs_bytes];
887  unsigned char h[Hash_bytes];
888  uint32 L[p];
889
890  Inputs_encode(s,r);
891  Hash_prefix(h,5,s,sizeof s);
892  Expand(L,h);
893  Short_fromlist(out,L);
894}
895
896#endif
897
898/* ----- NTRU LPRime Expand */
899
900#ifdef LPR
901
902/* (S,A),a = XKeyGen() */
903static void XKeyGen(unsigned char *S,Fq *A,small *a)
904{
905  Fq G[p];
906
907  Seeds_random(S);
908  Generator(G,S);
909  KeyGen(A,a,G);
910}
911
912/* B,T = XEncrypt(r,(S,A)) */
913static void XEncrypt(Fq *B,int8 *T,const int8 *r,const unsigned char *S,const Fq *A)
914{
915  Fq G[p];
916  small b[p];
917
918  Generator(G,S);
919  HashShort(b,r);
920  Encrypt(B,T,r,G,A,b);
921}
922
923#define XDecrypt Decrypt
924
925#endif
926
927/* ----- encoding small polynomials (including short polynomials) */
928
929#define Small_bytes ((p+3)/4)
930
931/* these are the only functions that rely on p mod 4 = 1 */
932
933static void Small_encode(unsigned char *s,const small *f)
934{
935  small x;
936  int i;
937
938  for (i = 0;i < p/4;++i) {
939    x = *f++ + 1;
940    x += (*f++ + 1)<<2;
941    x += (*f++ + 1)<<4;
942    x += (*f++ + 1)<<6;
943    *s++ = x;
944  }
945  x = *f++ + 1;
946  *s++ = x;
947}
948
949static void Small_decode(small *f,const unsigned char *s)
950{
951  unsigned char x;
952  int i;
953
954  for (i = 0;i < p/4;++i) {
955    x = *s++;
956    *f++ = ((small)(x&3))-1; x >>= 2;
957    *f++ = ((small)(x&3))-1; x >>= 2;
958    *f++ = ((small)(x&3))-1; x >>= 2;
959    *f++ = ((small)(x&3))-1;
960  }
961  x = *s++;
962  *f++ = ((small)(x&3))-1;
963}
964
965/* ----- encoding general polynomials */
966
967#ifndef LPR
968
969static void Rq_encode(unsigned char *s,const Fq *r)
970{
971  uint16 R[p],M[p];
972  int i;
973
974  for (i = 0;i < p;++i) R[i] = r[i]+q12;
975  for (i = 0;i < p;++i) M[i] = q;
976  Encode(s,R,M,p);
977}
978
979static void Rq_decode(Fq *r,const unsigned char *s)
980{
981  uint16 R[p],M[p];
982  int i;
983
984  for (i = 0;i < p;++i) M[i] = q;
985  Decode(R,s,M,p);
986  for (i = 0;i < p;++i) r[i] = ((Fq)R[i])-q12;
987}
988
989#endif
990
991/* ----- encoding rounded polynomials */
992
993static void Rounded_encode(unsigned char *s,const Fq *r)
994{
995  uint16 R[p],M[p];
996  int i;
997
998  for (i = 0;i < p;++i) R[i] = ((r[i]+q12)*10923)>>15;
999  for (i = 0;i < p;++i) M[i] = (q+2)/3;
1000  Encode(s,R,M,p);
1001}
1002
1003static void Rounded_decode(Fq *r,const unsigned char *s)
1004{
1005  uint16 R[p],M[p];
1006  int i;
1007
1008  for (i = 0;i < p;++i) M[i] = (q+2)/3;
1009  Decode(R,s,M,p);
1010  for (i = 0;i < p;++i) r[i] = R[i]*3-q12;
1011}
1012
1013/* ----- encoding top polynomials */
1014
1015#ifdef LPR
1016
1017#define Top_bytes (I/2)
1018
1019static void Top_encode(unsigned char *s,const int8 *T)
1020{
1021  int i;
1022  for (i = 0;i < Top_bytes;++i)
1023    s[i] = T[2*i]+(T[2*i+1]<<4);
1024}
1025
1026static void Top_decode(int8 *T,const unsigned char *s)
1027{
1028  int i;
1029  for (i = 0;i < Top_bytes;++i) {
1030    T[2*i] = s[i]&15;
1031    T[2*i+1] = s[i]>>4;
1032  }
1033}
1034
1035#endif
1036
1037/* ----- Streamlined NTRU Prime Core plus encoding */
1038
1039#ifndef LPR
1040
1041typedef small Inputs[p]; /* passed by reference */
1042#define Inputs_random Short_random
1043#define Inputs_encode Small_encode
1044#define Inputs_bytes Small_bytes
1045
1046#define Ciphertexts_bytes Rounded_bytes
1047#define SecretKeys_bytes (2*Small_bytes)
1048#define PublicKeys_bytes Rq_bytes
1049
1050/* pk,sk = ZKeyGen() */
1051static void ZKeyGen(unsigned char *pk,unsigned char *sk)
1052{
1053  Fq h[p];
1054  small f[p],v[p];
1055
1056  KeyGen(h,f,v);
1057  Rq_encode(pk,h);
1058  Small_encode(sk,f); sk += Small_bytes;
1059  Small_encode(sk,v);
1060}
1061
1062/* C = ZEncrypt(r,pk) */
1063static void ZEncrypt(unsigned char *C,const Inputs r,const unsigned char *pk)
1064{
1065  Fq h[p];
1066  Fq c[p];
1067  Rq_decode(h,pk);
1068  Encrypt(c,r,h);
1069  Rounded_encode(C,c);
1070}
1071
1072/* r = ZDecrypt(C,sk) */
1073static void ZDecrypt(Inputs r,const unsigned char *C,const unsigned char *sk)
1074{
1075  small f[p],v[p];
1076  Fq c[p];
1077
1078  Small_decode(f,sk); sk += Small_bytes;
1079  Small_decode(v,sk);
1080  Rounded_decode(c,C);
1081  Decrypt(r,c,f,v);
1082}
1083
1084#endif
1085
1086/* ----- NTRU LPRime Expand plus encoding */
1087
1088#ifdef LPR
1089
1090#define Ciphertexts_bytes (Rounded_bytes+Top_bytes)
1091#define SecretKeys_bytes Small_bytes
1092#define PublicKeys_bytes (Seeds_bytes+Rounded_bytes)
1093
1094static void Inputs_random(Inputs r)
1095{
1096  unsigned char s[Inputs_bytes];
1097  int i;
1098
1099  randombytes(s,sizeof s);
1100  for (i = 0;i < I;++i) r[i] = 1&(s[i>>3]>>(i&7));
1101}
1102
1103/* pk,sk = ZKeyGen() */
1104static void ZKeyGen(unsigned char *pk,unsigned char *sk)
1105{
1106  Fq A[p];
1107  small a[p];
1108
1109  XKeyGen(pk,A,a); pk += Seeds_bytes;
1110  Rounded_encode(pk,A);
1111  Small_encode(sk,a);
1112}
1113
1114/* c = ZEncrypt(r,pk) */
1115static void ZEncrypt(unsigned char *c,const Inputs r,const unsigned char *pk)
1116{
1117  Fq A[p];
1118  Fq B[p];
1119  int8 T[I];
1120
1121  Rounded_decode(A,pk+Seeds_bytes);
1122  XEncrypt(B,T,r,pk,A);
1123  Rounded_encode(c,B); c += Rounded_bytes;
1124  Top_encode(c,T);
1125}
1126
1127/* r = ZDecrypt(C,sk) */
1128static void ZDecrypt(Inputs r,const unsigned char *c,const unsigned char *sk)
1129{
1130  small a[p];
1131  Fq B[p];
1132  int8 T[I];
1133
1134  Small_decode(a,sk);
1135  Rounded_decode(B,c);
1136  Top_decode(T,c+Rounded_bytes);
1137  XDecrypt(r,B,T,a);
1138}
1139
1140#endif
1141
1142/* ----- confirmation hash */
1143
1144#define Confirm_bytes 32
1145
1146/* h = HashConfirm(r,pk,cache); cache is Hash4(pk) */
1147static void HashConfirm(unsigned char *h,const unsigned char *r,const unsigned char *pk,const unsigned char *cache)
1148{
1149#ifndef LPR
1150  unsigned char x[Hash_bytes*2];
1151  int i;
1152
1153  Hash_prefix(x,3,r,Inputs_bytes);
1154  for (i = 0;i < Hash_bytes;++i) x[Hash_bytes+i] = cache[i];
1155#else
1156  unsigned char x[Inputs_bytes+Hash_bytes];
1157  int i;
1158
1159  for (i = 0;i < Inputs_bytes;++i) x[i] = r[i];
1160  for (i = 0;i < Hash_bytes;++i) x[Inputs_bytes+i] = cache[i];
1161#endif
1162  Hash_prefix(h,2,x,sizeof x);
1163}
1164
1165/* ----- session-key hash */
1166
1167/* k = HashSession(b,y,z) */
1168static void HashSession(unsigned char *k,int b,const unsigned char *y,const unsigned char *z)
1169{
1170#ifndef LPR
1171  unsigned char x[Hash_bytes+Ciphertexts_bytes+Confirm_bytes];
1172  int i;
1173
1174  Hash_prefix(x,3,y,Inputs_bytes);
1175  for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Hash_bytes+i] = z[i];
1176#else
1177  unsigned char x[Inputs_bytes+Ciphertexts_bytes+Confirm_bytes];
1178  int i;
1179
1180  for (i = 0;i < Inputs_bytes;++i) x[i] = y[i];
1181  for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Inputs_bytes+i] = z[i];
1182#endif
1183  Hash_prefix(k,b,x,sizeof x);
1184}
1185
1186/* ----- Streamlined NTRU Prime and NTRU LPRime */
1187
1188/* pk,sk = KEM_KeyGen() */
1189static void KEM_KeyGen(unsigned char *pk,unsigned char *sk)
1190{
1191  int i;
1192
1193  ZKeyGen(pk,sk); sk += SecretKeys_bytes;
1194  for (i = 0;i < PublicKeys_bytes;++i) *sk++ = pk[i];
1195  randombytes(sk,Inputs_bytes); sk += Inputs_bytes;
1196  Hash_prefix(sk,4,pk,PublicKeys_bytes);
1197}
1198
1199/* c,r_enc = Hide(r,pk,cache); cache is Hash4(pk) */
1200static void Hide(unsigned char *c,unsigned char *r_enc,const Inputs r,const unsigned char *pk,const unsigned char *cache)
1201{
1202  Inputs_encode(r_enc,r);
1203  ZEncrypt(c,r,pk); c += Ciphertexts_bytes;
1204  HashConfirm(c,r_enc,pk,cache);
1205}
1206
1207/* c,k = Encap(pk) */
1208static void Encap(unsigned char *c,unsigned char *k,const unsigned char *pk)
1209{
1210  Inputs r;
1211  unsigned char r_enc[Inputs_bytes];
1212  unsigned char cache[Hash_bytes];
1213
1214  Hash_prefix(cache,4,pk,PublicKeys_bytes);
1215  Inputs_random(r);
1216  Hide(c,r_enc,r,pk,cache);
1217  HashSession(k,1,r_enc,c);
1218}
1219
1220/* 0 if matching ciphertext+confirm, else -1 */
1221static int Ciphertexts_diff_mask(const unsigned char *c,const unsigned char *c2)
1222{
1223  uint16 differentbits = 0;
1224  int len = Ciphertexts_bytes+Confirm_bytes;
1225
1226  while (len-- > 0) differentbits |= (*c++)^(*c2++);
1227  return (1&((differentbits-1)>>8))-1;
1228}
1229
1230/* k = Decap(c,sk) */
1231static void Decap(unsigned char *k,const unsigned char *c,const unsigned char *sk)
1232{
1233  const unsigned char *pk = sk + SecretKeys_bytes;
1234  const unsigned char *rho = pk + PublicKeys_bytes;
1235  const unsigned char *cache = rho + Inputs_bytes;
1236  Inputs r;
1237  unsigned char r_enc[Inputs_bytes];
1238  unsigned char cnew[Ciphertexts_bytes+Confirm_bytes];
1239  int mask;
1240  int i;
1241
1242  ZDecrypt(r,c,sk);
1243  Hide(cnew,r_enc,r,pk,cache);
1244  mask = Ciphertexts_diff_mask(c,cnew);
1245  for (i = 0;i < Inputs_bytes;++i) r_enc[i] ^= mask&(r_enc[i]^rho[i]);
1246  HashSession(k,1+mask,r_enc,c);
1247}
1248
1249/* ----- crypto_kem API */
1250
1251
1252int crypto_kem_sntrup761_keypair(unsigned char *pk,unsigned char *sk)
1253{
1254  KEM_KeyGen(pk,sk);
1255  return 0;
1256}
1257
1258int crypto_kem_sntrup761_enc(unsigned char *c,unsigned char *k,const unsigned char *pk)
1259{
1260  Encap(c,k,pk);
1261  return 0;
1262}
1263
1264int crypto_kem_sntrup761_dec(unsigned char *k,const unsigned char *c,const unsigned char *sk)
1265{
1266  Decap(k,c,sk);
1267  return 0;
1268}
1269
1270