1/*  $OpenBSD: sntrup761.c,v 1.6 2023/01/11 02:13:52 djm Exp $ */
2
3/*
4 * Public Domain, Authors:
5 * - Daniel J. Bernstein
6 * - Chitchanok Chuengsatiansup
7 * - Tanja Lange
8 * - Christine van Vredendaal
9 */
10
11#include "includes.h"
12
13#ifdef USE_SNTRUP761X25519
14
15#include <string.h>
16#include "crypto_api.h"
17
18#define int8 crypto_int8
19#define uint8 crypto_uint8
20#define int16 crypto_int16
21#define uint16 crypto_uint16
22#define int32 crypto_int32
23#define uint32 crypto_uint32
24#define int64 crypto_int64
25#define uint64 crypto_uint64
26
27/* from supercop-20201130/crypto_sort/int32/portable4/int32_minmax.inc */
28#define int32_MINMAX(a,b) \
29do { \
30  int64_t ab = (int64_t)b ^ (int64_t)a; \
31  int64_t c = (int64_t)b - (int64_t)a; \
32  c ^= ab & (c ^ b); \
33  c >>= 31; \
34  c &= ab; \
35  a ^= c; \
36  b ^= c; \
37} while(0)
38
39/* from supercop-20201130/crypto_sort/int32/portable4/sort.c */
40
41
42static void crypto_sort_int32(void *array,long long n)
43{
44  long long top,p,q,r,i,j;
45  int32 *x = array;
46
47  if (n < 2) return;
48  top = 1;
49  while (top < n - top) top += top;
50
51  for (p = top;p >= 1;p >>= 1) {
52    i = 0;
53    while (i + 2 * p <= n) {
54      for (j = i;j < i + p;++j)
55        int32_MINMAX(x[j],x[j+p]);
56      i += 2 * p;
57    }
58    for (j = i;j < n - p;++j)
59      int32_MINMAX(x[j],x[j+p]);
60
61    i = 0;
62    j = 0;
63    for (q = top;q > p;q >>= 1) {
64      if (j != i) for (;;) {
65        if (j == n - q) goto done;
66        int32 a = x[j + p];
67        for (r = q;r > p;r >>= 1)
68          int32_MINMAX(a,x[j + r]);
69        x[j + p] = a;
70        ++j;
71        if (j == i + p) {
72          i += 2 * p;
73          break;
74        }
75      }
76      while (i + p <= n - q) {
77        for (j = i;j < i + p;++j) {
78          int32 a = x[j + p];
79          for (r = q;r > p;r >>= 1)
80            int32_MINMAX(a,x[j+r]);
81          x[j + p] = a;
82        }
83        i += 2 * p;
84      }
85      /* now i + p > n - q */
86      j = i;
87      while (j < n - q) {
88        int32 a = x[j + p];
89        for (r = q;r > p;r >>= 1)
90          int32_MINMAX(a,x[j+r]);
91        x[j + p] = a;
92        ++j;
93      }
94
95      done: ;
96    }
97  }
98}
99
100/* from supercop-20201130/crypto_sort/uint32/useint32/sort.c */
101
102/* can save time by vectorizing xor loops */
103/* can save time by integrating xor loops with int32_sort */
104
105static void crypto_sort_uint32(void *array,long long n)
106{
107  crypto_uint32 *x = array;
108  long long j;
109  for (j = 0;j < n;++j) x[j] ^= 0x80000000;
110  crypto_sort_int32(array,n);
111  for (j = 0;j < n;++j) x[j] ^= 0x80000000;
112}
113
114/* from supercop-20201130/crypto_kem/sntrup761/ref/uint32.c */
115
116/*
117CPU division instruction typically takes time depending on x.
118This software is designed to take time independent of x.
119Time still varies depending on m; user must ensure that m is constant.
120Time also varies on CPUs where multiplication is variable-time.
121There could be more CPU issues.
122There could also be compiler issues.
123*/
124
125static void uint32_divmod_uint14(uint32 *q,uint16 *r,uint32 x,uint16 m)
126{
127  uint32 v = 0x80000000;
128  uint32 qpart;
129  uint32 mask;
130
131  v /= m;
132
133  /* caller guarantees m > 0 */
134  /* caller guarantees m < 16384 */
135  /* vm <= 2^31 <= vm+m-1 */
136  /* xvm <= 2^31 x <= xvm+x(m-1) */
137
138  *q = 0;
139
140  qpart = (x*(uint64)v)>>31;
141  /* 2^31 qpart <= xv <= 2^31 qpart + 2^31-1 */
142  /* 2^31 qpart m <= xvm <= 2^31 qpart m + (2^31-1)m */
143  /* 2^31 qpart m <= 2^31 x <= 2^31 qpart m + (2^31-1)m + x(m-1) */
144  /* 0 <= 2^31 newx <= (2^31-1)m + x(m-1) */
145  /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
146  /* 0 <= newx <= (1-1/2^31)(2^14-1) + (2^32-1)((2^14-1)-1)/2^31 */
147
148  x -= qpart*m; *q += qpart;
149  /* x <= 49146 */
150
151  qpart = (x*(uint64)v)>>31;
152  /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
153  /* 0 <= newx <= m + 49146(2^14-1)/2^31 */
154  /* 0 <= newx <= m + 0.4 */
155  /* 0 <= newx <= m */
156
157  x -= qpart*m; *q += qpart;
158  /* x <= m */
159
160  x -= m; *q += 1;
161  mask = -(x>>31);
162  x += mask&(uint32)m; *q += mask;
163  /* x < m */
164
165  *r = x;
166}
167
168
169static uint16 uint32_mod_uint14(uint32 x,uint16 m)
170{
171  uint32 q;
172  uint16 r;
173  uint32_divmod_uint14(&q,&r,x,m);
174  return r;
175}
176
177/* from supercop-20201130/crypto_kem/sntrup761/ref/int32.c */
178
179static void int32_divmod_uint14(int32 *q,uint16 *r,int32 x,uint16 m)
180{
181  uint32 uq,uq2;
182  uint16 ur,ur2;
183  uint32 mask;
184
185  uint32_divmod_uint14(&uq,&ur,0x80000000+(uint32)x,m);
186  uint32_divmod_uint14(&uq2,&ur2,0x80000000,m);
187  ur -= ur2; uq -= uq2;
188  mask = -(uint32)(ur>>15);
189  ur += mask&m; uq += mask;
190  *r = ur; *q = uq;
191}
192
193
194static uint16 int32_mod_uint14(int32 x,uint16 m)
195{
196  int32 q;
197  uint16 r;
198  int32_divmod_uint14(&q,&r,x,m);
199  return r;
200}
201
202/* from supercop-20201130/crypto_kem/sntrup761/ref/paramsmenu.h */
203/* pick one of these three: */
204#define SIZE761
205#undef SIZE653
206#undef SIZE857
207
208/* pick one of these two: */
209#define SNTRUP /* Streamlined NTRU Prime */
210#undef LPR /* NTRU LPRime */
211
212/* from supercop-20201130/crypto_kem/sntrup761/ref/params.h */
213#ifndef params_H
214#define params_H
215
216/* menu of parameter choices: */
217
218
219/* what the menu means: */
220
221#if defined(SIZE761)
222#define p 761
223#define q 4591
224#define Rounded_bytes 1007
225#ifndef LPR
226#define Rq_bytes 1158
227#define w 286
228#else
229#define w 250
230#define tau0 2156
231#define tau1 114
232#define tau2 2007
233#define tau3 287
234#endif
235
236#elif defined(SIZE653)
237#define p 653
238#define q 4621
239#define Rounded_bytes 865
240#ifndef LPR
241#define Rq_bytes 994
242#define w 288
243#else
244#define w 252
245#define tau0 2175
246#define tau1 113
247#define tau2 2031
248#define tau3 290
249#endif
250
251#elif defined(SIZE857)
252#define p 857
253#define q 5167
254#define Rounded_bytes 1152
255#ifndef LPR
256#define Rq_bytes 1322
257#define w 322
258#else
259#define w 281
260#define tau0 2433
261#define tau1 101
262#define tau2 2265
263#define tau3 324
264#endif
265
266#else
267#error "no parameter set defined"
268#endif
269
270#ifdef LPR
271#define I 256
272#endif
273
274#endif
275
276/* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.h */
277#ifndef Decode_H
278#define Decode_H
279
280
281/* Decode(R,s,M,len) */
282/* assumes 0 < M[i] < 16384 */
283/* produces 0 <= R[i] < M[i] */
284
285#endif
286
287/* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.c */
288
289static void Decode(uint16 *out,const unsigned char *S,const uint16 *M,long long len)
290{
291  if (len == 1) {
292    if (M[0] == 1)
293      *out = 0;
294    else if (M[0] <= 256)
295      *out = uint32_mod_uint14(S[0],M[0]);
296    else
297      *out = uint32_mod_uint14(S[0]+(((uint16)S[1])<<8),M[0]);
298  }
299  if (len > 1) {
300    uint16 R2[(len+1)/2];
301    uint16 M2[(len+1)/2];
302    uint16 bottomr[len/2];
303    uint32 bottomt[len/2];
304    long long i;
305    for (i = 0;i < len-1;i += 2) {
306      uint32 m = M[i]*(uint32) M[i+1];
307      if (m > 256*16383) {
308        bottomt[i/2] = 256*256;
309        bottomr[i/2] = S[0]+256*S[1];
310        S += 2;
311        M2[i/2] = (((m+255)>>8)+255)>>8;
312      } else if (m >= 16384) {
313        bottomt[i/2] = 256;
314        bottomr[i/2] = S[0];
315        S += 1;
316        M2[i/2] = (m+255)>>8;
317      } else {
318        bottomt[i/2] = 1;
319        bottomr[i/2] = 0;
320        M2[i/2] = m;
321      }
322    }
323    if (i < len)
324      M2[i/2] = M[i];
325    Decode(R2,S,M2,(len+1)/2);
326    for (i = 0;i < len-1;i += 2) {
327      uint32 r = bottomr[i/2];
328      uint32 r1;
329      uint16 r0;
330      r += bottomt[i/2]*R2[i/2];
331      uint32_divmod_uint14(&r1,&r0,r,M[i]);
332      r1 = uint32_mod_uint14(r1,M[i+1]); /* only needed for invalid inputs */
333      *out++ = r0;
334      *out++ = r1;
335    }
336    if (i < len)
337      *out++ = R2[i/2];
338  }
339}
340
341/* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.h */
342#ifndef Encode_H
343#define Encode_H
344
345
346/* Encode(s,R,M,len) */
347/* assumes 0 <= R[i] < M[i] < 16384 */
348
349#endif
350
351/* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.c */
352
353/* 0 <= R[i] < M[i] < 16384 */
354static void Encode(unsigned char *out,const uint16 *R,const uint16 *M,long long len)
355{
356  if (len == 1) {
357    uint16 r = R[0];
358    uint16 m = M[0];
359    while (m > 1) {
360      *out++ = r;
361      r >>= 8;
362      m = (m+255)>>8;
363    }
364  }
365  if (len > 1) {
366    uint16 R2[(len+1)/2];
367    uint16 M2[(len+1)/2];
368    long long i;
369    for (i = 0;i < len-1;i += 2) {
370      uint32 m0 = M[i];
371      uint32 r = R[i]+R[i+1]*m0;
372      uint32 m = M[i+1]*m0;
373      while (m >= 16384) {
374        *out++ = r;
375        r >>= 8;
376        m = (m+255)>>8;
377      }
378      R2[i/2] = r;
379      M2[i/2] = m;
380    }
381    if (i < len) {
382      R2[i/2] = R[i];
383      M2[i/2] = M[i];
384    }
385    Encode(out,R2,M2,(len+1)/2);
386  }
387}
388
389/* from supercop-20201130/crypto_kem/sntrup761/ref/kem.c */
390
391#ifdef LPR
392#endif
393
394
395/* ----- masks */
396
397#ifndef LPR
398
399/* return -1 if x!=0; else return 0 */
400static int int16_nonzero_mask(int16 x)
401{
402  uint16 u = x; /* 0, else 1...65535 */
403  uint32 v = u; /* 0, else 1...65535 */
404  v = -v; /* 0, else 2^32-65535...2^32-1 */
405  v >>= 31; /* 0, else 1 */
406  return -v; /* 0, else -1 */
407}
408
409#endif
410
411/* return -1 if x<0; otherwise return 0 */
412static int int16_negative_mask(int16 x)
413{
414  uint16 u = x;
415  u >>= 15;
416  return -(int) u;
417  /* alternative with gcc -fwrapv: */
418  /* x>>15 compiles to CPU's arithmetic right shift */
419}
420
421/* ----- arithmetic mod 3 */
422
423typedef int8 small;
424
425/* F3 is always represented as -1,0,1 */
426/* so ZZ_fromF3 is a no-op */
427
428/* x must not be close to top int16 */
429static small F3_freeze(int16 x)
430{
431  return int32_mod_uint14(x+1,3)-1;
432}
433
434/* ----- arithmetic mod q */
435
436#define q12 ((q-1)/2)
437typedef int16 Fq;
438/* always represented as -q12...q12 */
439/* so ZZ_fromFq is a no-op */
440
441/* x must not be close to top int32 */
442static Fq Fq_freeze(int32 x)
443{
444  return int32_mod_uint14(x+q12,q)-q12;
445}
446
447#ifndef LPR
448
449static Fq Fq_recip(Fq a1)
450{
451  int i = 1;
452  Fq ai = a1;
453
454  while (i < q-2) {
455    ai = Fq_freeze(a1*(int32)ai);
456    i += 1;
457  }
458  return ai;
459}
460
461#endif
462
463/* ----- Top and Right */
464
465#ifdef LPR
466#define tau 16
467
468static int8 Top(Fq C)
469{
470  return (tau1*(int32)(C+tau0)+16384)>>15;
471}
472
473static Fq Right(int8 T)
474{
475  return Fq_freeze(tau3*(int32)T-tau2);
476}
477#endif
478
479/* ----- small polynomials */
480
481#ifndef LPR
482
483/* 0 if Weightw_is(r), else -1 */
484static int Weightw_mask(small *r)
485{
486  int weight = 0;
487  int i;
488
489  for (i = 0;i < p;++i) weight += r[i]&1;
490  return int16_nonzero_mask(weight-w);
491}
492
493/* R3_fromR(R_fromRq(r)) */
494static void R3_fromRq(small *out,const Fq *r)
495{
496  int i;
497  for (i = 0;i < p;++i) out[i] = F3_freeze(r[i]);
498}
499
500/* h = f*g in the ring R3 */
501static void R3_mult(small *h,const small *f,const small *g)
502{
503  small fg[p+p-1];
504  small result;
505  int i,j;
506
507  for (i = 0;i < p;++i) {
508    result = 0;
509    for (j = 0;j <= i;++j) result = F3_freeze(result+f[j]*g[i-j]);
510    fg[i] = result;
511  }
512  for (i = p;i < p+p-1;++i) {
513    result = 0;
514    for (j = i-p+1;j < p;++j) result = F3_freeze(result+f[j]*g[i-j]);
515    fg[i] = result;
516  }
517
518  for (i = p+p-2;i >= p;--i) {
519    fg[i-p] = F3_freeze(fg[i-p]+fg[i]);
520    fg[i-p+1] = F3_freeze(fg[i-p+1]+fg[i]);
521  }
522
523  for (i = 0;i < p;++i) h[i] = fg[i];
524}
525
526/* returns 0 if recip succeeded; else -1 */
527static int R3_recip(small *out,const small *in)
528{
529  small f[p+1],g[p+1],v[p+1],r[p+1];
530  int i,loop,delta;
531  int sign,swap,t;
532
533  for (i = 0;i < p+1;++i) v[i] = 0;
534  for (i = 0;i < p+1;++i) r[i] = 0;
535  r[0] = 1;
536  for (i = 0;i < p;++i) f[i] = 0;
537  f[0] = 1; f[p-1] = f[p] = -1;
538  for (i = 0;i < p;++i) g[p-1-i] = in[i];
539  g[p] = 0;
540
541  delta = 1;
542
543  for (loop = 0;loop < 2*p-1;++loop) {
544    for (i = p;i > 0;--i) v[i] = v[i-1];
545    v[0] = 0;
546
547    sign = -g[0]*f[0];
548    swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
549    delta ^= swap&(delta^-delta);
550    delta += 1;
551
552    for (i = 0;i < p+1;++i) {
553      t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
554      t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
555    }
556
557    for (i = 0;i < p+1;++i) g[i] = F3_freeze(g[i]+sign*f[i]);
558    for (i = 0;i < p+1;++i) r[i] = F3_freeze(r[i]+sign*v[i]);
559
560    for (i = 0;i < p;++i) g[i] = g[i+1];
561    g[p] = 0;
562  }
563
564  sign = f[0];
565  for (i = 0;i < p;++i) out[i] = sign*v[p-1-i];
566
567  return int16_nonzero_mask(delta);
568}
569
570#endif
571
572/* ----- polynomials mod q */
573
574/* h = f*g in the ring Rq */
575static void Rq_mult_small(Fq *h,const Fq *f,const small *g)
576{
577  Fq fg[p+p-1];
578  Fq result;
579  int i,j;
580
581  for (i = 0;i < p;++i) {
582    result = 0;
583    for (j = 0;j <= i;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
584    fg[i] = result;
585  }
586  for (i = p;i < p+p-1;++i) {
587    result = 0;
588    for (j = i-p+1;j < p;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
589    fg[i] = result;
590  }
591
592  for (i = p+p-2;i >= p;--i) {
593    fg[i-p] = Fq_freeze(fg[i-p]+fg[i]);
594    fg[i-p+1] = Fq_freeze(fg[i-p+1]+fg[i]);
595  }
596
597  for (i = 0;i < p;++i) h[i] = fg[i];
598}
599
600#ifndef LPR
601
602/* h = 3f in Rq */
603static void Rq_mult3(Fq *h,const Fq *f)
604{
605  int i;
606
607  for (i = 0;i < p;++i) h[i] = Fq_freeze(3*f[i]);
608}
609
610/* out = 1/(3*in) in Rq */
611/* returns 0 if recip succeeded; else -1 */
612static int Rq_recip3(Fq *out,const small *in)
613{
614  Fq f[p+1],g[p+1],v[p+1],r[p+1];
615  int i,loop,delta;
616  int swap,t;
617  int32 f0,g0;
618  Fq scale;
619
620  for (i = 0;i < p+1;++i) v[i] = 0;
621  for (i = 0;i < p+1;++i) r[i] = 0;
622  r[0] = Fq_recip(3);
623  for (i = 0;i < p;++i) f[i] = 0;
624  f[0] = 1; f[p-1] = f[p] = -1;
625  for (i = 0;i < p;++i) g[p-1-i] = in[i];
626  g[p] = 0;
627
628  delta = 1;
629
630  for (loop = 0;loop < 2*p-1;++loop) {
631    for (i = p;i > 0;--i) v[i] = v[i-1];
632    v[0] = 0;
633
634    swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
635    delta ^= swap&(delta^-delta);
636    delta += 1;
637
638    for (i = 0;i < p+1;++i) {
639      t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
640      t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
641    }
642
643    f0 = f[0];
644    g0 = g[0];
645    for (i = 0;i < p+1;++i) g[i] = Fq_freeze(f0*g[i]-g0*f[i]);
646    for (i = 0;i < p+1;++i) r[i] = Fq_freeze(f0*r[i]-g0*v[i]);
647
648    for (i = 0;i < p;++i) g[i] = g[i+1];
649    g[p] = 0;
650  }
651
652  scale = Fq_recip(f[0]);
653  for (i = 0;i < p;++i) out[i] = Fq_freeze(scale*(int32)v[p-1-i]);
654
655  return int16_nonzero_mask(delta);
656}
657
658#endif
659
660/* ----- rounded polynomials mod q */
661
662static void Round(Fq *out,const Fq *a)
663{
664  int i;
665  for (i = 0;i < p;++i) out[i] = a[i]-F3_freeze(a[i]);
666}
667
668/* ----- sorting to generate short polynomial */
669
670static void Short_fromlist(small *out,const uint32 *in)
671{
672  uint32 L[p];
673  int i;
674
675  for (i = 0;i < w;++i) L[i] = in[i]&(uint32)-2;
676  for (i = w;i < p;++i) L[i] = (in[i]&(uint32)-3)|1;
677  crypto_sort_uint32(L,p);
678  for (i = 0;i < p;++i) out[i] = (L[i]&3)-1;
679}
680
681/* ----- underlying hash function */
682
683#define Hash_bytes 32
684
685/* e.g., b = 0 means out = Hash0(in) */
686static void Hash_prefix(unsigned char *out,int b,const unsigned char *in,int inlen)
687{
688  unsigned char x[inlen+1];
689  unsigned char h[64];
690  int i;
691
692  x[0] = b;
693  for (i = 0;i < inlen;++i) x[i+1] = in[i];
694  crypto_hash_sha512(h,x,inlen+1);
695  for (i = 0;i < 32;++i) out[i] = h[i];
696}
697
698/* ----- higher-level randomness */
699
700static uint32 urandom32(void)
701{
702  unsigned char c[4];
703  uint32 out[4];
704
705  randombytes(c,4);
706  out[0] = (uint32)c[0];
707  out[1] = ((uint32)c[1])<<8;
708  out[2] = ((uint32)c[2])<<16;
709  out[3] = ((uint32)c[3])<<24;
710  return out[0]+out[1]+out[2]+out[3];
711}
712
713static void Short_random(small *out)
714{
715  uint32 L[p];
716  int i;
717
718  for (i = 0;i < p;++i) L[i] = urandom32();
719  Short_fromlist(out,L);
720}
721
722#ifndef LPR
723
724static void Small_random(small *out)
725{
726  int i;
727
728  for (i = 0;i < p;++i) out[i] = (((urandom32()&0x3fffffff)*3)>>30)-1;
729}
730
731#endif
732
733/* ----- Streamlined NTRU Prime Core */
734
735#ifndef LPR
736
737/* h,(f,ginv) = KeyGen() */
738static void KeyGen(Fq *h,small *f,small *ginv)
739{
740  small g[p];
741  Fq finv[p];
742
743  for (;;) {
744    Small_random(g);
745    if (R3_recip(ginv,g) == 0) break;
746  }
747  Short_random(f);
748  Rq_recip3(finv,f); /* always works */
749  Rq_mult_small(h,finv,g);
750}
751
752/* c = Encrypt(r,h) */
753static void Encrypt(Fq *c,const small *r,const Fq *h)
754{
755  Fq hr[p];
756
757  Rq_mult_small(hr,h,r);
758  Round(c,hr);
759}
760
761/* r = Decrypt(c,(f,ginv)) */
762static void Decrypt(small *r,const Fq *c,const small *f,const small *ginv)
763{
764  Fq cf[p];
765  Fq cf3[p];
766  small e[p];
767  small ev[p];
768  int mask;
769  int i;
770
771  Rq_mult_small(cf,c,f);
772  Rq_mult3(cf3,cf);
773  R3_fromRq(e,cf3);
774  R3_mult(ev,e,ginv);
775
776  mask = Weightw_mask(ev); /* 0 if weight w, else -1 */
777  for (i = 0;i < w;++i) r[i] = ((ev[i]^1)&~mask)^1;
778  for (i = w;i < p;++i) r[i] = ev[i]&~mask;
779}
780
781#endif
782
783/* ----- NTRU LPRime Core */
784
785#ifdef LPR
786
787/* (G,A),a = KeyGen(G); leaves G unchanged */
788static void KeyGen(Fq *A,small *a,const Fq *G)
789{
790  Fq aG[p];
791
792  Short_random(a);
793  Rq_mult_small(aG,G,a);
794  Round(A,aG);
795}
796
797/* B,T = Encrypt(r,(G,A),b) */
798static void Encrypt(Fq *B,int8 *T,const int8 *r,const Fq *G,const Fq *A,const small *b)
799{
800  Fq bG[p];
801  Fq bA[p];
802  int i;
803
804  Rq_mult_small(bG,G,b);
805  Round(B,bG);
806  Rq_mult_small(bA,A,b);
807  for (i = 0;i < I;++i) T[i] = Top(Fq_freeze(bA[i]+r[i]*q12));
808}
809
810/* r = Decrypt((B,T),a) */
811static void Decrypt(int8 *r,const Fq *B,const int8 *T,const small *a)
812{
813  Fq aB[p];
814  int i;
815
816  Rq_mult_small(aB,B,a);
817  for (i = 0;i < I;++i)
818    r[i] = -int16_negative_mask(Fq_freeze(Right(T[i])-aB[i]+4*w+1));
819}
820
821#endif
822
823/* ----- encoding I-bit inputs */
824
825#ifdef LPR
826
827#define Inputs_bytes (I/8)
828typedef int8 Inputs[I]; /* passed by reference */
829
830static void Inputs_encode(unsigned char *s,const Inputs r)
831{
832  int i;
833  for (i = 0;i < Inputs_bytes;++i) s[i] = 0;
834  for (i = 0;i < I;++i) s[i>>3] |= r[i]<<(i&7);
835}
836
837#endif
838
839/* ----- Expand */
840
841#ifdef LPR
842
843static const unsigned char aes_nonce[16] = {0};
844
845static void Expand(uint32 *L,const unsigned char *k)
846{
847  int i;
848  crypto_stream_aes256ctr((unsigned char *) L,4*p,aes_nonce,k);
849  for (i = 0;i < p;++i) {
850    uint32 L0 = ((unsigned char *) L)[4*i];
851    uint32 L1 = ((unsigned char *) L)[4*i+1];
852    uint32 L2 = ((unsigned char *) L)[4*i+2];
853    uint32 L3 = ((unsigned char *) L)[4*i+3];
854    L[i] = L0+(L1<<8)+(L2<<16)+(L3<<24);
855  }
856}
857
858#endif
859
860/* ----- Seeds */
861
862#ifdef LPR
863
864#define Seeds_bytes 32
865
866static void Seeds_random(unsigned char *s)
867{
868  randombytes(s,Seeds_bytes);
869}
870
871#endif
872
873/* ----- Generator, HashShort */
874
875#ifdef LPR
876
877/* G = Generator(k) */
878static void Generator(Fq *G,const unsigned char *k)
879{
880  uint32 L[p];
881  int i;
882
883  Expand(L,k);
884  for (i = 0;i < p;++i) G[i] = uint32_mod_uint14(L[i],q)-q12;
885}
886
887/* out = HashShort(r) */
888static void HashShort(small *out,const Inputs r)
889{
890  unsigned char s[Inputs_bytes];
891  unsigned char h[Hash_bytes];
892  uint32 L[p];
893
894  Inputs_encode(s,r);
895  Hash_prefix(h,5,s,sizeof s);
896  Expand(L,h);
897  Short_fromlist(out,L);
898}
899
900#endif
901
902/* ----- NTRU LPRime Expand */
903
904#ifdef LPR
905
906/* (S,A),a = XKeyGen() */
907static void XKeyGen(unsigned char *S,Fq *A,small *a)
908{
909  Fq G[p];
910
911  Seeds_random(S);
912  Generator(G,S);
913  KeyGen(A,a,G);
914}
915
916/* B,T = XEncrypt(r,(S,A)) */
917static void XEncrypt(Fq *B,int8 *T,const int8 *r,const unsigned char *S,const Fq *A)
918{
919  Fq G[p];
920  small b[p];
921
922  Generator(G,S);
923  HashShort(b,r);
924  Encrypt(B,T,r,G,A,b);
925}
926
927#define XDecrypt Decrypt
928
929#endif
930
931/* ----- encoding small polynomials (including short polynomials) */
932
933#define Small_bytes ((p+3)/4)
934
935/* these are the only functions that rely on p mod 4 = 1 */
936
937static void Small_encode(unsigned char *s,const small *f)
938{
939  small x;
940  int i;
941
942  for (i = 0;i < p/4;++i) {
943    x = *f++ + 1;
944    x += (*f++ + 1)<<2;
945    x += (*f++ + 1)<<4;
946    x += (*f++ + 1)<<6;
947    *s++ = x;
948  }
949  x = *f++ + 1;
950  *s++ = x;
951}
952
953static void Small_decode(small *f,const unsigned char *s)
954{
955  unsigned char x;
956  int i;
957
958  for (i = 0;i < p/4;++i) {
959    x = *s++;
960    *f++ = ((small)(x&3))-1; x >>= 2;
961    *f++ = ((small)(x&3))-1; x >>= 2;
962    *f++ = ((small)(x&3))-1; x >>= 2;
963    *f++ = ((small)(x&3))-1;
964  }
965  x = *s++;
966  *f++ = ((small)(x&3))-1;
967}
968
969/* ----- encoding general polynomials */
970
971#ifndef LPR
972
973static void Rq_encode(unsigned char *s,const Fq *r)
974{
975  uint16 R[p],M[p];
976  int i;
977
978  for (i = 0;i < p;++i) R[i] = r[i]+q12;
979  for (i = 0;i < p;++i) M[i] = q;
980  Encode(s,R,M,p);
981}
982
983static void Rq_decode(Fq *r,const unsigned char *s)
984{
985  uint16 R[p],M[p];
986  int i;
987
988  for (i = 0;i < p;++i) M[i] = q;
989  Decode(R,s,M,p);
990  for (i = 0;i < p;++i) r[i] = ((Fq)R[i])-q12;
991}
992
993#endif
994
995/* ----- encoding rounded polynomials */
996
997static void Rounded_encode(unsigned char *s,const Fq *r)
998{
999  uint16 R[p],M[p];
1000  int i;
1001
1002  for (i = 0;i < p;++i) R[i] = ((r[i]+q12)*10923)>>15;
1003  for (i = 0;i < p;++i) M[i] = (q+2)/3;
1004  Encode(s,R,M,p);
1005}
1006
1007static void Rounded_decode(Fq *r,const unsigned char *s)
1008{
1009  uint16 R[p],M[p];
1010  int i;
1011
1012  for (i = 0;i < p;++i) M[i] = (q+2)/3;
1013  Decode(R,s,M,p);
1014  for (i = 0;i < p;++i) r[i] = R[i]*3-q12;
1015}
1016
1017/* ----- encoding top polynomials */
1018
1019#ifdef LPR
1020
1021#define Top_bytes (I/2)
1022
1023static void Top_encode(unsigned char *s,const int8 *T)
1024{
1025  int i;
1026  for (i = 0;i < Top_bytes;++i)
1027    s[i] = T[2*i]+(T[2*i+1]<<4);
1028}
1029
1030static void Top_decode(int8 *T,const unsigned char *s)
1031{
1032  int i;
1033  for (i = 0;i < Top_bytes;++i) {
1034    T[2*i] = s[i]&15;
1035    T[2*i+1] = s[i]>>4;
1036  }
1037}
1038
1039#endif
1040
1041/* ----- Streamlined NTRU Prime Core plus encoding */
1042
1043#ifndef LPR
1044
1045typedef small Inputs[p]; /* passed by reference */
1046#define Inputs_random Short_random
1047#define Inputs_encode Small_encode
1048#define Inputs_bytes Small_bytes
1049
1050#define Ciphertexts_bytes Rounded_bytes
1051#define SecretKeys_bytes (2*Small_bytes)
1052#define PublicKeys_bytes Rq_bytes
1053
1054/* pk,sk = ZKeyGen() */
1055static void ZKeyGen(unsigned char *pk,unsigned char *sk)
1056{
1057  Fq h[p];
1058  small f[p],v[p];
1059
1060  KeyGen(h,f,v);
1061  Rq_encode(pk,h);
1062  Small_encode(sk,f); sk += Small_bytes;
1063  Small_encode(sk,v);
1064}
1065
1066/* C = ZEncrypt(r,pk) */
1067static void ZEncrypt(unsigned char *C,const Inputs r,const unsigned char *pk)
1068{
1069  Fq h[p];
1070  Fq c[p];
1071  Rq_decode(h,pk);
1072  Encrypt(c,r,h);
1073  Rounded_encode(C,c);
1074}
1075
1076/* r = ZDecrypt(C,sk) */
1077static void ZDecrypt(Inputs r,const unsigned char *C,const unsigned char *sk)
1078{
1079  small f[p],v[p];
1080  Fq c[p];
1081
1082  Small_decode(f,sk); sk += Small_bytes;
1083  Small_decode(v,sk);
1084  Rounded_decode(c,C);
1085  Decrypt(r,c,f,v);
1086}
1087
1088#endif
1089
1090/* ----- NTRU LPRime Expand plus encoding */
1091
1092#ifdef LPR
1093
1094#define Ciphertexts_bytes (Rounded_bytes+Top_bytes)
1095#define SecretKeys_bytes Small_bytes
1096#define PublicKeys_bytes (Seeds_bytes+Rounded_bytes)
1097
1098static void Inputs_random(Inputs r)
1099{
1100  unsigned char s[Inputs_bytes];
1101  int i;
1102
1103  randombytes(s,sizeof s);
1104  for (i = 0;i < I;++i) r[i] = 1&(s[i>>3]>>(i&7));
1105}
1106
1107/* pk,sk = ZKeyGen() */
1108static void ZKeyGen(unsigned char *pk,unsigned char *sk)
1109{
1110  Fq A[p];
1111  small a[p];
1112
1113  XKeyGen(pk,A,a); pk += Seeds_bytes;
1114  Rounded_encode(pk,A);
1115  Small_encode(sk,a);
1116}
1117
1118/* c = ZEncrypt(r,pk) */
1119static void ZEncrypt(unsigned char *c,const Inputs r,const unsigned char *pk)
1120{
1121  Fq A[p];
1122  Fq B[p];
1123  int8 T[I];
1124
1125  Rounded_decode(A,pk+Seeds_bytes);
1126  XEncrypt(B,T,r,pk,A);
1127  Rounded_encode(c,B); c += Rounded_bytes;
1128  Top_encode(c,T);
1129}
1130
1131/* r = ZDecrypt(C,sk) */
1132static void ZDecrypt(Inputs r,const unsigned char *c,const unsigned char *sk)
1133{
1134  small a[p];
1135  Fq B[p];
1136  int8 T[I];
1137
1138  Small_decode(a,sk);
1139  Rounded_decode(B,c);
1140  Top_decode(T,c+Rounded_bytes);
1141  XDecrypt(r,B,T,a);
1142}
1143
1144#endif
1145
1146/* ----- confirmation hash */
1147
1148#define Confirm_bytes 32
1149
1150/* h = HashConfirm(r,pk,cache); cache is Hash4(pk) */
1151static void HashConfirm(unsigned char *h,const unsigned char *r,const unsigned char *pk,const unsigned char *cache)
1152{
1153#ifndef LPR
1154  unsigned char x[Hash_bytes*2];
1155  int i;
1156
1157  Hash_prefix(x,3,r,Inputs_bytes);
1158  for (i = 0;i < Hash_bytes;++i) x[Hash_bytes+i] = cache[i];
1159#else
1160  unsigned char x[Inputs_bytes+Hash_bytes];
1161  int i;
1162
1163  for (i = 0;i < Inputs_bytes;++i) x[i] = r[i];
1164  for (i = 0;i < Hash_bytes;++i) x[Inputs_bytes+i] = cache[i];
1165#endif
1166  Hash_prefix(h,2,x,sizeof x);
1167}
1168
1169/* ----- session-key hash */
1170
1171/* k = HashSession(b,y,z) */
1172static void HashSession(unsigned char *k,int b,const unsigned char *y,const unsigned char *z)
1173{
1174#ifndef LPR
1175  unsigned char x[Hash_bytes+Ciphertexts_bytes+Confirm_bytes];
1176  int i;
1177
1178  Hash_prefix(x,3,y,Inputs_bytes);
1179  for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Hash_bytes+i] = z[i];
1180#else
1181  unsigned char x[Inputs_bytes+Ciphertexts_bytes+Confirm_bytes];
1182  int i;
1183
1184  for (i = 0;i < Inputs_bytes;++i) x[i] = y[i];
1185  for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Inputs_bytes+i] = z[i];
1186#endif
1187  Hash_prefix(k,b,x,sizeof x);
1188}
1189
1190/* ----- Streamlined NTRU Prime and NTRU LPRime */
1191
1192/* pk,sk = KEM_KeyGen() */
1193static void KEM_KeyGen(unsigned char *pk,unsigned char *sk)
1194{
1195  int i;
1196
1197  ZKeyGen(pk,sk); sk += SecretKeys_bytes;
1198  for (i = 0;i < PublicKeys_bytes;++i) *sk++ = pk[i];
1199  randombytes(sk,Inputs_bytes); sk += Inputs_bytes;
1200  Hash_prefix(sk,4,pk,PublicKeys_bytes);
1201}
1202
1203/* c,r_enc = Hide(r,pk,cache); cache is Hash4(pk) */
1204static void Hide(unsigned char *c,unsigned char *r_enc,const Inputs r,const unsigned char *pk,const unsigned char *cache)
1205{
1206  Inputs_encode(r_enc,r);
1207  ZEncrypt(c,r,pk); c += Ciphertexts_bytes;
1208  HashConfirm(c,r_enc,pk,cache);
1209}
1210
1211/* c,k = Encap(pk) */
1212static void Encap(unsigned char *c,unsigned char *k,const unsigned char *pk)
1213{
1214  Inputs r;
1215  unsigned char r_enc[Inputs_bytes];
1216  unsigned char cache[Hash_bytes];
1217
1218  Hash_prefix(cache,4,pk,PublicKeys_bytes);
1219  Inputs_random(r);
1220  Hide(c,r_enc,r,pk,cache);
1221  HashSession(k,1,r_enc,c);
1222}
1223
1224/* 0 if matching ciphertext+confirm, else -1 */
1225static int Ciphertexts_diff_mask(const unsigned char *c,const unsigned char *c2)
1226{
1227  uint16 differentbits = 0;
1228  int len = Ciphertexts_bytes+Confirm_bytes;
1229
1230  while (len-- > 0) differentbits |= (*c++)^(*c2++);
1231  return (1&((differentbits-1)>>8))-1;
1232}
1233
1234/* k = Decap(c,sk) */
1235static void Decap(unsigned char *k,const unsigned char *c,const unsigned char *sk)
1236{
1237  const unsigned char *pk = sk + SecretKeys_bytes;
1238  const unsigned char *rho = pk + PublicKeys_bytes;
1239  const unsigned char *cache = rho + Inputs_bytes;
1240  Inputs r;
1241  unsigned char r_enc[Inputs_bytes];
1242  unsigned char cnew[Ciphertexts_bytes+Confirm_bytes];
1243  int mask;
1244  int i;
1245
1246  ZDecrypt(r,c,sk);
1247  Hide(cnew,r_enc,r,pk,cache);
1248  mask = Ciphertexts_diff_mask(c,cnew);
1249  for (i = 0;i < Inputs_bytes;++i) r_enc[i] ^= mask&(r_enc[i]^rho[i]);
1250  HashSession(k,1+mask,r_enc,c);
1251}
1252
1253/* ----- crypto_kem API */
1254
1255
1256int crypto_kem_sntrup761_keypair(unsigned char *pk,unsigned char *sk)
1257{
1258  KEM_KeyGen(pk,sk);
1259  return 0;
1260}
1261
1262int crypto_kem_sntrup761_enc(unsigned char *c,unsigned char *k,const unsigned char *pk)
1263{
1264  Encap(c,k,pk);
1265  return 0;
1266}
1267
1268int crypto_kem_sntrup761_dec(unsigned char *k,const unsigned char *c,const unsigned char *sk)
1269{
1270  Decap(k,c,sk);
1271  return 0;
1272}
1273#endif /* USE_SNTRUP761X25519 */
1274