1/* $OpenBSD: xmss_fast.c,v 1.3 2018/03/22 07:06:11 markus Exp $ */
2/*
3xmss_fast.c version 20160722
4Andreas H��lsing
5Joost Rijneveld
6Public domain.
7*/
8
9#include "includes.h"
10#ifdef WITH_XMSS
11
12#include <stdlib.h>
13#include <string.h>
14#ifdef HAVE_STDINT_H
15# include <stdint.h>
16#endif
17
18#include "xmss_fast.h"
19#include "crypto_api.h"
20#include "xmss_wots.h"
21#include "xmss_hash.h"
22
23#include "xmss_commons.h"
24#include "xmss_hash_address.h"
25// For testing
26#include "stdio.h"
27
28
29
30/**
31 * Used for pseudorandom keygeneration,
32 * generates the seed for the WOTS keypair at address addr
33 *
34 * takes n byte sk_seed and returns n byte seed using 32 byte address addr.
35 */
36static void get_seed(unsigned char *seed, const unsigned char *sk_seed, int n, uint32_t addr[8])
37{
38  unsigned char bytes[32];
39  // Make sure that chain addr, hash addr, and key bit are 0!
40  setChainADRS(addr,0);
41  setHashADRS(addr,0);
42  setKeyAndMask(addr,0);
43  // Generate pseudorandom value
44  addr_to_byte(bytes, addr);
45  prf(seed, bytes, sk_seed, n);
46}
47
48/**
49 * Initialize xmss params struct
50 * parameter names are the same as in the draft
51 * parameter k is K as used in the BDS algorithm
52 */
53int xmss_set_params(xmss_params *params, int n, int h, int w, int k)
54{
55  if (k >= h || k < 2 || (h - k) % 2) {
56    fprintf(stderr, "For BDS traversal, H - K must be even, with H > K >= 2!\n");
57    return 1;
58  }
59  params->h = h;
60  params->n = n;
61  params->k = k;
62  wots_params wots_par;
63  wots_set_params(&wots_par, n, w);
64  params->wots_par = wots_par;
65  return 0;
66}
67
68/**
69 * Initialize BDS state struct
70 * parameter names are the same as used in the description of the BDS traversal
71 */
72void xmss_set_bds_state(bds_state *state, unsigned char *stack, int stackoffset, unsigned char *stacklevels, unsigned char *auth, unsigned char *keep, treehash_inst *treehash, unsigned char *retain, int next_leaf)
73{
74  state->stack = stack;
75  state->stackoffset = stackoffset;
76  state->stacklevels = stacklevels;
77  state->auth = auth;
78  state->keep = keep;
79  state->treehash = treehash;
80  state->retain = retain;
81  state->next_leaf = next_leaf;
82}
83
84/**
85 * Initialize xmssmt_params struct
86 * parameter names are the same as in the draft
87 *
88 * Especially h is the total tree height, i.e. the XMSS trees have height h/d
89 */
90int xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w, int k)
91{
92  if (h % d) {
93    fprintf(stderr, "d must divide h without remainder!\n");
94    return 1;
95  }
96  params->h = h;
97  params->d = d;
98  params->n = n;
99  params->index_len = (h + 7) / 8;
100  xmss_params xmss_par;
101  if (xmss_set_params(&xmss_par, n, (h/d), w, k)) {
102    return 1;
103  }
104  params->xmss_par = xmss_par;
105  return 0;
106}
107
108/**
109 * Computes a leaf from a WOTS public key using an L-tree.
110 */
111static void l_tree(unsigned char *leaf, unsigned char *wots_pk, const xmss_params *params, const unsigned char *pub_seed, uint32_t addr[8])
112{
113  unsigned int l = params->wots_par.len;
114  unsigned int n = params->n;
115  uint32_t i = 0;
116  uint32_t height = 0;
117  uint32_t bound;
118
119  //ADRS.setTreeHeight(0);
120  setTreeHeight(addr, height);
121
122  while (l > 1) {
123     bound = l >> 1; //floor(l / 2);
124     for (i = 0; i < bound; i++) {
125       //ADRS.setTreeIndex(i);
126       setTreeIndex(addr, i);
127       //wots_pk[i] = RAND_HASH(pk[2i], pk[2i + 1], SEED, ADRS);
128       hash_h(wots_pk+i*n, wots_pk+i*2*n, pub_seed, addr, n);
129     }
130     //if ( l % 2 == 1 ) {
131     if (l & 1) {
132       //pk[floor(l / 2) + 1] = pk[l];
133       memcpy(wots_pk+(l>>1)*n, wots_pk+(l-1)*n, n);
134       //l = ceil(l / 2);
135       l=(l>>1)+1;
136     }
137     else {
138       //l = ceil(l / 2);
139       l=(l>>1);
140     }
141     //ADRS.setTreeHeight(ADRS.getTreeHeight() + 1);
142     height++;
143     setTreeHeight(addr, height);
144   }
145   //return pk[0];
146   memcpy(leaf, wots_pk, n);
147}
148
149/**
150 * Computes the leaf at a given address. First generates the WOTS key pair, then computes leaf using l_tree. As this happens position independent, we only require that addr encodes the right ltree-address.
151 */
152static void gen_leaf_wots(unsigned char *leaf, const unsigned char *sk_seed, const xmss_params *params, const unsigned char *pub_seed, uint32_t ltree_addr[8], uint32_t ots_addr[8])
153{
154  unsigned char seed[params->n];
155  unsigned char pk[params->wots_par.keysize];
156
157  get_seed(seed, sk_seed, params->n, ots_addr);
158  wots_pkgen(pk, seed, &(params->wots_par), pub_seed, ots_addr);
159
160  l_tree(leaf, pk, params, pub_seed, ltree_addr);
161}
162
163static int treehash_minheight_on_stack(bds_state* state, const xmss_params *params, const treehash_inst *treehash) {
164  unsigned int r = params->h, i;
165  for (i = 0; i < treehash->stackusage; i++) {
166    if (state->stacklevels[state->stackoffset - i - 1] < r) {
167      r = state->stacklevels[state->stackoffset - i - 1];
168    }
169  }
170  return r;
171}
172
173/**
174 * Merkle's TreeHash algorithm. The address only needs to initialize the first 78 bits of addr. Everything else will be set by treehash.
175 * Currently only used for key generation.
176 *
177 */
178static void treehash_setup(unsigned char *node, int height, int index, bds_state *state, const unsigned char *sk_seed, const xmss_params *params, const unsigned char *pub_seed, const uint32_t addr[8])
179{
180  unsigned int idx = index;
181  unsigned int n = params->n;
182  unsigned int h = params->h;
183  unsigned int k = params->k;
184  // use three different addresses because at this point we use all three formats in parallel
185  uint32_t ots_addr[8];
186  uint32_t ltree_addr[8];
187  uint32_t  node_addr[8];
188  // only copy layer and tree address parts
189  memcpy(ots_addr, addr, 12);
190  // type = ots
191  setType(ots_addr, 0);
192  memcpy(ltree_addr, addr, 12);
193  setType(ltree_addr, 1);
194  memcpy(node_addr, addr, 12);
195  setType(node_addr, 2);
196
197  uint32_t lastnode, i;
198  unsigned char stack[(height+1)*n];
199  unsigned int stacklevels[height+1];
200  unsigned int stackoffset=0;
201  unsigned int nodeh;
202
203  lastnode = idx+(1<<height);
204
205  for (i = 0; i < h-k; i++) {
206    state->treehash[i].h = i;
207    state->treehash[i].completed = 1;
208    state->treehash[i].stackusage = 0;
209  }
210
211  i = 0;
212  for (; idx < lastnode; idx++) {
213    setLtreeADRS(ltree_addr, idx);
214    setOTSADRS(ots_addr, idx);
215    gen_leaf_wots(stack+stackoffset*n, sk_seed, params, pub_seed, ltree_addr, ots_addr);
216    stacklevels[stackoffset] = 0;
217    stackoffset++;
218    if (h - k > 0 && i == 3) {
219      memcpy(state->treehash[0].node, stack+stackoffset*n, n);
220    }
221    while (stackoffset>1 && stacklevels[stackoffset-1] == stacklevels[stackoffset-2])
222    {
223      nodeh = stacklevels[stackoffset-1];
224      if (i >> nodeh == 1) {
225        memcpy(state->auth + nodeh*n, stack+(stackoffset-1)*n, n);
226      }
227      else {
228        if (nodeh < h - k && i >> nodeh == 3) {
229          memcpy(state->treehash[nodeh].node, stack+(stackoffset-1)*n, n);
230        }
231        else if (nodeh >= h - k) {
232          memcpy(state->retain + ((1 << (h - 1 - nodeh)) + nodeh - h + (((i >> nodeh) - 3) >> 1)) * n, stack+(stackoffset-1)*n, n);
233        }
234      }
235      setTreeHeight(node_addr, stacklevels[stackoffset-1]);
236      setTreeIndex(node_addr, (idx >> (stacklevels[stackoffset-1]+1)));
237      hash_h(stack+(stackoffset-2)*n, stack+(stackoffset-2)*n, pub_seed,
238          node_addr, n);
239      stacklevels[stackoffset-2]++;
240      stackoffset--;
241    }
242    i++;
243  }
244
245  for (i = 0; i < n; i++)
246    node[i] = stack[i];
247}
248
249static void treehash_update(treehash_inst *treehash, bds_state *state, const unsigned char *sk_seed, const xmss_params *params, const unsigned char *pub_seed, const uint32_t addr[8]) {
250  int n = params->n;
251
252  uint32_t ots_addr[8];
253  uint32_t ltree_addr[8];
254  uint32_t  node_addr[8];
255  // only copy layer and tree address parts
256  memcpy(ots_addr, addr, 12);
257  // type = ots
258  setType(ots_addr, 0);
259  memcpy(ltree_addr, addr, 12);
260  setType(ltree_addr, 1);
261  memcpy(node_addr, addr, 12);
262  setType(node_addr, 2);
263
264  setLtreeADRS(ltree_addr, treehash->next_idx);
265  setOTSADRS(ots_addr, treehash->next_idx);
266
267  unsigned char nodebuffer[2 * n];
268  unsigned int nodeheight = 0;
269  gen_leaf_wots(nodebuffer, sk_seed, params, pub_seed, ltree_addr, ots_addr);
270  while (treehash->stackusage > 0 && state->stacklevels[state->stackoffset-1] == nodeheight) {
271    memcpy(nodebuffer + n, nodebuffer, n);
272    memcpy(nodebuffer, state->stack + (state->stackoffset-1)*n, n);
273    setTreeHeight(node_addr, nodeheight);
274    setTreeIndex(node_addr, (treehash->next_idx >> (nodeheight+1)));
275    hash_h(nodebuffer, nodebuffer, pub_seed, node_addr, n);
276    nodeheight++;
277    treehash->stackusage--;
278    state->stackoffset--;
279  }
280  if (nodeheight == treehash->h) { // this also implies stackusage == 0
281    memcpy(treehash->node, nodebuffer, n);
282    treehash->completed = 1;
283  }
284  else {
285    memcpy(state->stack + state->stackoffset*n, nodebuffer, n);
286    treehash->stackusage++;
287    state->stacklevels[state->stackoffset] = nodeheight;
288    state->stackoffset++;
289    treehash->next_idx++;
290  }
291}
292
293/**
294 * Computes a root node given a leaf and an authapth
295 */
296static void validate_authpath(unsigned char *root, const unsigned char *leaf, unsigned long leafidx, const unsigned char *authpath, const xmss_params *params, const unsigned char *pub_seed, uint32_t addr[8])
297{
298  unsigned int n = params->n;
299
300  uint32_t i, j;
301  unsigned char buffer[2*n];
302
303  // If leafidx is odd (last bit = 1), current path element is a right child and authpath has to go to the left.
304  // Otherwise, it is the other way around
305  if (leafidx & 1) {
306    for (j = 0; j < n; j++)
307      buffer[n+j] = leaf[j];
308    for (j = 0; j < n; j++)
309      buffer[j] = authpath[j];
310  }
311  else {
312    for (j = 0; j < n; j++)
313      buffer[j] = leaf[j];
314    for (j = 0; j < n; j++)
315      buffer[n+j] = authpath[j];
316  }
317  authpath += n;
318
319  for (i=0; i < params->h-1; i++) {
320    setTreeHeight(addr, i);
321    leafidx >>= 1;
322    setTreeIndex(addr, leafidx);
323    if (leafidx&1) {
324      hash_h(buffer+n, buffer, pub_seed, addr, n);
325      for (j = 0; j < n; j++)
326        buffer[j] = authpath[j];
327    }
328    else {
329      hash_h(buffer, buffer, pub_seed, addr, n);
330      for (j = 0; j < n; j++)
331        buffer[j+n] = authpath[j];
332    }
333    authpath += n;
334  }
335  setTreeHeight(addr, (params->h-1));
336  leafidx >>= 1;
337  setTreeIndex(addr, leafidx);
338  hash_h(root, buffer, pub_seed, addr, n);
339}
340
341/**
342 * Performs one treehash update on the instance that needs it the most.
343 * Returns 1 if such an instance was not found
344 **/
345static char bds_treehash_update(bds_state *state, unsigned int updates, const unsigned char *sk_seed, const xmss_params *params, unsigned char *pub_seed, const uint32_t addr[8]) {
346  uint32_t i, j;
347  unsigned int level, l_min, low;
348  unsigned int h = params->h;
349  unsigned int k = params->k;
350  unsigned int used = 0;
351
352  for (j = 0; j < updates; j++) {
353    l_min = h;
354    level = h - k;
355    for (i = 0; i < h - k; i++) {
356      if (state->treehash[i].completed) {
357        low = h;
358      }
359      else if (state->treehash[i].stackusage == 0) {
360        low = i;
361      }
362      else {
363        low = treehash_minheight_on_stack(state, params, &(state->treehash[i]));
364      }
365      if (low < l_min) {
366        level = i;
367        l_min = low;
368      }
369    }
370    if (level == h - k) {
371      break;
372    }
373    treehash_update(&(state->treehash[level]), state, sk_seed, params, pub_seed, addr);
374    used++;
375  }
376  return updates - used;
377}
378
379/**
380 * Updates the state (typically NEXT_i) by adding a leaf and updating the stack
381 * Returns 1 if all leaf nodes have already been processed
382 **/
383static char bds_state_update(bds_state *state, const unsigned char *sk_seed, const xmss_params *params, unsigned char *pub_seed, const uint32_t addr[8]) {
384  uint32_t ltree_addr[8];
385  uint32_t node_addr[8];
386  uint32_t ots_addr[8];
387
388  int n = params->n;
389  int h = params->h;
390  int k = params->k;
391
392  int nodeh;
393  int idx = state->next_leaf;
394  if (idx == 1 << h) {
395    return 1;
396  }
397
398  // only copy layer and tree address parts
399  memcpy(ots_addr, addr, 12);
400  // type = ots
401  setType(ots_addr, 0);
402  memcpy(ltree_addr, addr, 12);
403  setType(ltree_addr, 1);
404  memcpy(node_addr, addr, 12);
405  setType(node_addr, 2);
406
407  setOTSADRS(ots_addr, idx);
408  setLtreeADRS(ltree_addr, idx);
409
410  gen_leaf_wots(state->stack+state->stackoffset*n, sk_seed, params, pub_seed, ltree_addr, ots_addr);
411
412  state->stacklevels[state->stackoffset] = 0;
413  state->stackoffset++;
414  if (h - k > 0 && idx == 3) {
415    memcpy(state->treehash[0].node, state->stack+state->stackoffset*n, n);
416  }
417  while (state->stackoffset>1 && state->stacklevels[state->stackoffset-1] == state->stacklevels[state->stackoffset-2]) {
418    nodeh = state->stacklevels[state->stackoffset-1];
419    if (idx >> nodeh == 1) {
420      memcpy(state->auth + nodeh*n, state->stack+(state->stackoffset-1)*n, n);
421    }
422    else {
423      if (nodeh < h - k && idx >> nodeh == 3) {
424        memcpy(state->treehash[nodeh].node, state->stack+(state->stackoffset-1)*n, n);
425      }
426      else if (nodeh >= h - k) {
427        memcpy(state->retain + ((1 << (h - 1 - nodeh)) + nodeh - h + (((idx >> nodeh) - 3) >> 1)) * n, state->stack+(state->stackoffset-1)*n, n);
428      }
429    }
430    setTreeHeight(node_addr, state->stacklevels[state->stackoffset-1]);
431    setTreeIndex(node_addr, (idx >> (state->stacklevels[state->stackoffset-1]+1)));
432    hash_h(state->stack+(state->stackoffset-2)*n, state->stack+(state->stackoffset-2)*n, pub_seed, node_addr, n);
433
434    state->stacklevels[state->stackoffset-2]++;
435    state->stackoffset--;
436  }
437  state->next_leaf++;
438  return 0;
439}
440
441/**
442 * Returns the auth path for node leaf_idx and computes the auth path for the
443 * next leaf node, using the algorithm described by Buchmann, Dahmen and Szydlo
444 * in "Post Quantum Cryptography", Springer 2009.
445 */
446static void bds_round(bds_state *state, const unsigned long leaf_idx, const unsigned char *sk_seed, const xmss_params *params, unsigned char *pub_seed, uint32_t addr[8])
447{
448  unsigned int i;
449  unsigned int n = params->n;
450  unsigned int h = params->h;
451  unsigned int k = params->k;
452
453  unsigned int tau = h;
454  unsigned int startidx;
455  unsigned int offset, rowidx;
456  unsigned char buf[2 * n];
457
458  uint32_t ots_addr[8];
459  uint32_t ltree_addr[8];
460  uint32_t  node_addr[8];
461  // only copy layer and tree address parts
462  memcpy(ots_addr, addr, 12);
463  // type = ots
464  setType(ots_addr, 0);
465  memcpy(ltree_addr, addr, 12);
466  setType(ltree_addr, 1);
467  memcpy(node_addr, addr, 12);
468  setType(node_addr, 2);
469
470  for (i = 0; i < h; i++) {
471    if (! ((leaf_idx >> i) & 1)) {
472      tau = i;
473      break;
474    }
475  }
476
477  if (tau > 0) {
478    memcpy(buf,     state->auth + (tau-1) * n, n);
479    // we need to do this before refreshing state->keep to prevent overwriting
480    memcpy(buf + n, state->keep + ((tau-1) >> 1) * n, n);
481  }
482  if (!((leaf_idx >> (tau + 1)) & 1) && (tau < h - 1)) {
483    memcpy(state->keep + (tau >> 1)*n, state->auth + tau*n, n);
484  }
485  if (tau == 0) {
486    setLtreeADRS(ltree_addr, leaf_idx);
487    setOTSADRS(ots_addr, leaf_idx);
488    gen_leaf_wots(state->auth, sk_seed, params, pub_seed, ltree_addr, ots_addr);
489  }
490  else {
491    setTreeHeight(node_addr, (tau-1));
492    setTreeIndex(node_addr, leaf_idx >> tau);
493    hash_h(state->auth + tau * n, buf, pub_seed, node_addr, n);
494    for (i = 0; i < tau; i++) {
495      if (i < h - k) {
496        memcpy(state->auth + i * n, state->treehash[i].node, n);
497      }
498      else {
499        offset = (1 << (h - 1 - i)) + i - h;
500        rowidx = ((leaf_idx >> i) - 1) >> 1;
501        memcpy(state->auth + i * n, state->retain + (offset + rowidx) * n, n);
502      }
503    }
504
505    for (i = 0; i < ((tau < h - k) ? tau : (h - k)); i++) {
506      startidx = leaf_idx + 1 + 3 * (1 << i);
507      if (startidx < 1U << h) {
508        state->treehash[i].h = i;
509        state->treehash[i].next_idx = startidx;
510        state->treehash[i].completed = 0;
511        state->treehash[i].stackusage = 0;
512      }
513    }
514  }
515}
516
517/*
518 * Generates a XMSS key pair for a given parameter set.
519 * Format sk: [(32bit) idx || SK_SEED || SK_PRF || PUB_SEED || root]
520 * Format pk: [root || PUB_SEED] omitting algo oid.
521 */
522int xmss_keypair(unsigned char *pk, unsigned char *sk, bds_state *state, xmss_params *params)
523{
524  unsigned int n = params->n;
525  // Set idx = 0
526  sk[0] = 0;
527  sk[1] = 0;
528  sk[2] = 0;
529  sk[3] = 0;
530  // Init SK_SEED (n byte), SK_PRF (n byte), and PUB_SEED (n byte)
531  randombytes(sk+4, 3*n);
532  // Copy PUB_SEED to public key
533  memcpy(pk+n, sk+4+2*n, n);
534
535  uint32_t addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
536
537  // Compute root
538  treehash_setup(pk, params->h, 0, state, sk+4, params, sk+4+2*n, addr);
539  // copy root to sk
540  memcpy(sk+4+3*n, pk, n);
541  return 0;
542}
543
544/**
545 * Signs a message.
546 * Returns
547 * 1. an array containing the signature followed by the message AND
548 * 2. an updated secret key!
549 *
550 */
551int xmss_sign(unsigned char *sk, bds_state *state, unsigned char *sig_msg, unsigned long long *sig_msg_len, const unsigned char *msg, unsigned long long msglen, const xmss_params *params)
552{
553  unsigned int h = params->h;
554  unsigned int n = params->n;
555  unsigned int k = params->k;
556  uint16_t i = 0;
557
558  // Extract SK
559  unsigned long idx = ((unsigned long)sk[0] << 24) | ((unsigned long)sk[1] << 16) | ((unsigned long)sk[2] << 8) | sk[3];
560  unsigned char sk_seed[n];
561  memcpy(sk_seed, sk+4, n);
562  unsigned char sk_prf[n];
563  memcpy(sk_prf, sk+4+n, n);
564  unsigned char pub_seed[n];
565  memcpy(pub_seed, sk+4+2*n, n);
566
567  // index as 32 bytes string
568  unsigned char idx_bytes_32[32];
569  to_byte(idx_bytes_32, idx, 32);
570
571  unsigned char hash_key[3*n];
572
573  // Update SK
574  sk[0] = ((idx + 1) >> 24) & 255;
575  sk[1] = ((idx + 1) >> 16) & 255;
576  sk[2] = ((idx + 1) >> 8) & 255;
577  sk[3] = (idx + 1) & 255;
578  // -- Secret key for this non-forward-secure version is now updated.
579  // -- A productive implementation should use a file handle instead and write the updated secret key at this point!
580
581  // Init working params
582  unsigned char R[n];
583  unsigned char msg_h[n];
584  unsigned char ots_seed[n];
585  uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
586
587  // ---------------------------------
588  // Message Hashing
589  // ---------------------------------
590
591  // Message Hash:
592  // First compute pseudorandom value
593  prf(R, idx_bytes_32, sk_prf, n);
594  // Generate hash key (R || root || idx)
595  memcpy(hash_key, R, n);
596  memcpy(hash_key+n, sk+4+3*n, n);
597  to_byte(hash_key+2*n, idx, n);
598  // Then use it for message digest
599  h_msg(msg_h, msg, msglen, hash_key, 3*n, n);
600
601  // Start collecting signature
602  *sig_msg_len = 0;
603
604  // Copy index to signature
605  sig_msg[0] = (idx >> 24) & 255;
606  sig_msg[1] = (idx >> 16) & 255;
607  sig_msg[2] = (idx >> 8) & 255;
608  sig_msg[3] = idx & 255;
609
610  sig_msg += 4;
611  *sig_msg_len += 4;
612
613  // Copy R to signature
614  for (i = 0; i < n; i++)
615    sig_msg[i] = R[i];
616
617  sig_msg += n;
618  *sig_msg_len += n;
619
620  // ----------------------------------
621  // Now we start to "really sign"
622  // ----------------------------------
623
624  // Prepare Address
625  setType(ots_addr, 0);
626  setOTSADRS(ots_addr, idx);
627
628  // Compute seed for OTS key pair
629  get_seed(ots_seed, sk_seed, n, ots_addr);
630
631  // Compute WOTS signature
632  wots_sign(sig_msg, msg_h, ots_seed, &(params->wots_par), pub_seed, ots_addr);
633
634  sig_msg += params->wots_par.keysize;
635  *sig_msg_len += params->wots_par.keysize;
636
637  // the auth path was already computed during the previous round
638  memcpy(sig_msg, state->auth, h*n);
639
640  if (idx < (1U << h) - 1) {
641    bds_round(state, idx, sk_seed, params, pub_seed, ots_addr);
642    bds_treehash_update(state, (h - k) >> 1, sk_seed, params, pub_seed, ots_addr);
643  }
644
645/* TODO: save key/bds state here! */
646
647  sig_msg += params->h*n;
648  *sig_msg_len += params->h*n;
649
650  //Whipe secret elements?
651  //zerobytes(tsk, CRYPTO_SECRETKEYBYTES);
652
653
654  memcpy(sig_msg, msg, msglen);
655  *sig_msg_len += msglen;
656
657  return 0;
658}
659
660/**
661 * Verifies a given message signature pair under a given public key.
662 */
663int xmss_sign_open(unsigned char *msg, unsigned long long *msglen, const unsigned char *sig_msg, unsigned long long sig_msg_len, const unsigned char *pk, const xmss_params *params)
664{
665  unsigned int n = params->n;
666
667  unsigned long long i, m_len;
668  unsigned long idx=0;
669  unsigned char wots_pk[params->wots_par.keysize];
670  unsigned char pkhash[n];
671  unsigned char root[n];
672  unsigned char msg_h[n];
673  unsigned char hash_key[3*n];
674
675  unsigned char pub_seed[n];
676  memcpy(pub_seed, pk+n, n);
677
678  // Init addresses
679  uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
680  uint32_t ltree_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
681  uint32_t node_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
682
683  setType(ots_addr, 0);
684  setType(ltree_addr, 1);
685  setType(node_addr, 2);
686
687  // Extract index
688  idx = ((unsigned long)sig_msg[0] << 24) | ((unsigned long)sig_msg[1] << 16) | ((unsigned long)sig_msg[2] << 8) | sig_msg[3];
689  printf("verify:: idx = %lu\n", idx);
690
691  // Generate hash key (R || root || idx)
692  memcpy(hash_key, sig_msg+4,n);
693  memcpy(hash_key+n, pk, n);
694  to_byte(hash_key+2*n, idx, n);
695
696  sig_msg += (n+4);
697  sig_msg_len -= (n+4);
698
699  // hash message
700  unsigned long long tmp_sig_len = params->wots_par.keysize+params->h*n;
701  m_len = sig_msg_len - tmp_sig_len;
702  h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n);
703
704  //-----------------------
705  // Verify signature
706  //-----------------------
707
708  // Prepare Address
709  setOTSADRS(ots_addr, idx);
710  // Check WOTS signature
711  wots_pkFromSig(wots_pk, sig_msg, msg_h, &(params->wots_par), pub_seed, ots_addr);
712
713  sig_msg += params->wots_par.keysize;
714  sig_msg_len -= params->wots_par.keysize;
715
716  // Compute Ltree
717  setLtreeADRS(ltree_addr, idx);
718  l_tree(pkhash, wots_pk, params, pub_seed, ltree_addr);
719
720  // Compute root
721  validate_authpath(root, pkhash, idx, sig_msg, params, pub_seed, node_addr);
722
723  sig_msg += params->h*n;
724  sig_msg_len -= params->h*n;
725
726  for (i = 0; i < n; i++)
727    if (root[i] != pk[i])
728      goto fail;
729
730  *msglen = sig_msg_len;
731  for (i = 0; i < *msglen; i++)
732    msg[i] = sig_msg[i];
733
734  return 0;
735
736
737fail:
738  *msglen = sig_msg_len;
739  for (i = 0; i < *msglen; i++)
740    msg[i] = 0;
741  *msglen = -1;
742  return -1;
743}
744
745/*
746 * Generates a XMSSMT key pair for a given parameter set.
747 * Format sk: [(ceil(h/8) bit) idx || SK_SEED || SK_PRF || PUB_SEED || root]
748 * Format pk: [root || PUB_SEED] omitting algo oid.
749 */
750int xmssmt_keypair(unsigned char *pk, unsigned char *sk, bds_state *states, unsigned char *wots_sigs, xmssmt_params *params)
751{
752  unsigned int n = params->n;
753  unsigned int i;
754  unsigned char ots_seed[params->n];
755  // Set idx = 0
756  for (i = 0; i < params->index_len; i++) {
757    sk[i] = 0;
758  }
759  // Init SK_SEED (n byte), SK_PRF (n byte), and PUB_SEED (n byte)
760  randombytes(sk+params->index_len, 3*n);
761  // Copy PUB_SEED to public key
762  memcpy(pk+n, sk+params->index_len+2*n, n);
763
764  // Set address to point on the single tree on layer d-1
765  uint32_t addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
766  setLayerADRS(addr, (params->d-1));
767  // Set up state and compute wots signatures for all but topmost tree root
768  for (i = 0; i < params->d - 1; i++) {
769    // Compute seed for OTS key pair
770    treehash_setup(pk, params->xmss_par.h, 0, states + i, sk+params->index_len, &(params->xmss_par), pk+n, addr);
771    setLayerADRS(addr, (i+1));
772    get_seed(ots_seed, sk+params->index_len, n, addr);
773    wots_sign(wots_sigs + i*params->xmss_par.wots_par.keysize, pk, ots_seed, &(params->xmss_par.wots_par), pk+n, addr);
774  }
775  treehash_setup(pk, params->xmss_par.h, 0, states + i, sk+params->index_len, &(params->xmss_par), pk+n, addr);
776  memcpy(sk+params->index_len+3*n, pk, n);
777  return 0;
778}
779
780/**
781 * Signs a message.
782 * Returns
783 * 1. an array containing the signature followed by the message AND
784 * 2. an updated secret key!
785 *
786 */
787int xmssmt_sign(unsigned char *sk, bds_state *states, unsigned char *wots_sigs, unsigned char *sig_msg, unsigned long long *sig_msg_len, const unsigned char *msg, unsigned long long msglen, const xmssmt_params *params)
788{
789  unsigned int n = params->n;
790
791  unsigned int tree_h = params->xmss_par.h;
792  unsigned int h = params->h;
793  unsigned int k = params->xmss_par.k;
794  unsigned int idx_len = params->index_len;
795  uint64_t idx_tree;
796  uint32_t idx_leaf;
797  uint64_t i, j;
798  int needswap_upto = -1;
799  unsigned int updates;
800
801  unsigned char sk_seed[n];
802  unsigned char sk_prf[n];
803  unsigned char pub_seed[n];
804  // Init working params
805  unsigned char R[n];
806  unsigned char msg_h[n];
807  unsigned char hash_key[3*n];
808  unsigned char ots_seed[n];
809  uint32_t addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
810  uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
811  unsigned char idx_bytes_32[32];
812  bds_state tmp;
813
814  // Extract SK
815  unsigned long long idx = 0;
816  for (i = 0; i < idx_len; i++) {
817    idx |= ((unsigned long long)sk[i]) << 8*(idx_len - 1 - i);
818  }
819
820  memcpy(sk_seed, sk+idx_len, n);
821  memcpy(sk_prf, sk+idx_len+n, n);
822  memcpy(pub_seed, sk+idx_len+2*n, n);
823
824  // Update SK
825  for (i = 0; i < idx_len; i++) {
826    sk[i] = ((idx + 1) >> 8*(idx_len - 1 - i)) & 255;
827  }
828  // -- Secret key for this non-forward-secure version is now updated.
829  // -- A productive implementation should use a file handle instead and write the updated secret key at this point!
830
831
832  // ---------------------------------
833  // Message Hashing
834  // ---------------------------------
835
836  // Message Hash:
837  // First compute pseudorandom value
838  to_byte(idx_bytes_32, idx, 32);
839  prf(R, idx_bytes_32, sk_prf, n);
840  // Generate hash key (R || root || idx)
841  memcpy(hash_key, R, n);
842  memcpy(hash_key+n, sk+idx_len+3*n, n);
843  to_byte(hash_key+2*n, idx, n);
844
845  // Then use it for message digest
846  h_msg(msg_h, msg, msglen, hash_key, 3*n, n);
847
848  // Start collecting signature
849  *sig_msg_len = 0;
850
851  // Copy index to signature
852  for (i = 0; i < idx_len; i++) {
853    sig_msg[i] = (idx >> 8*(idx_len - 1 - i)) & 255;
854  }
855
856  sig_msg += idx_len;
857  *sig_msg_len += idx_len;
858
859  // Copy R to signature
860  for (i = 0; i < n; i++)
861    sig_msg[i] = R[i];
862
863  sig_msg += n;
864  *sig_msg_len += n;
865
866  // ----------------------------------
867  // Now we start to "really sign"
868  // ----------------------------------
869
870  // Handle lowest layer separately as it is slightly different...
871
872  // Prepare Address
873  setType(ots_addr, 0);
874  idx_tree = idx >> tree_h;
875  idx_leaf = (idx & ((1 << tree_h)-1));
876  setLayerADRS(ots_addr, 0);
877  setTreeADRS(ots_addr, idx_tree);
878  setOTSADRS(ots_addr, idx_leaf);
879
880  // Compute seed for OTS key pair
881  get_seed(ots_seed, sk_seed, n, ots_addr);
882
883  // Compute WOTS signature
884  wots_sign(sig_msg, msg_h, ots_seed, &(params->xmss_par.wots_par), pub_seed, ots_addr);
885
886  sig_msg += params->xmss_par.wots_par.keysize;
887  *sig_msg_len += params->xmss_par.wots_par.keysize;
888
889  memcpy(sig_msg, states[0].auth, tree_h*n);
890  sig_msg += tree_h*n;
891  *sig_msg_len += tree_h*n;
892
893  // prepare signature of remaining layers
894  for (i = 1; i < params->d; i++) {
895    // put WOTS signature in place
896    memcpy(sig_msg, wots_sigs + (i-1)*params->xmss_par.wots_par.keysize, params->xmss_par.wots_par.keysize);
897
898    sig_msg += params->xmss_par.wots_par.keysize;
899    *sig_msg_len += params->xmss_par.wots_par.keysize;
900
901    // put AUTH nodes in place
902    memcpy(sig_msg, states[i].auth, tree_h*n);
903    sig_msg += tree_h*n;
904    *sig_msg_len += tree_h*n;
905  }
906
907  updates = (tree_h - k) >> 1;
908
909  setTreeADRS(addr, (idx_tree + 1));
910  // mandatory update for NEXT_0 (does not count towards h-k/2) if NEXT_0 exists
911  if ((1 + idx_tree) * (1 << tree_h) + idx_leaf < (1ULL << h)) {
912    bds_state_update(&states[params->d], sk_seed, &(params->xmss_par), pub_seed, addr);
913  }
914
915  for (i = 0; i < params->d; i++) {
916    // check if we're not at the end of a tree
917    if (! (((idx + 1) & ((1ULL << ((i+1)*tree_h)) - 1)) == 0)) {
918      idx_leaf = (idx >> (tree_h * i)) & ((1 << tree_h)-1);
919      idx_tree = (idx >> (tree_h * (i+1)));
920      setLayerADRS(addr, i);
921      setTreeADRS(addr, idx_tree);
922      if (i == (unsigned int) (needswap_upto + 1)) {
923        bds_round(&states[i], idx_leaf, sk_seed, &(params->xmss_par), pub_seed, addr);
924      }
925      updates = bds_treehash_update(&states[i], updates, sk_seed, &(params->xmss_par), pub_seed, addr);
926      setTreeADRS(addr, (idx_tree + 1));
927      // if a NEXT-tree exists for this level;
928      if ((1 + idx_tree) * (1 << tree_h) + idx_leaf < (1ULL << (h - tree_h * i))) {
929        if (i > 0 && updates > 0 && states[params->d + i].next_leaf < (1ULL << h)) {
930          bds_state_update(&states[params->d + i], sk_seed, &(params->xmss_par), pub_seed, addr);
931          updates--;
932        }
933      }
934    }
935    else if (idx < (1ULL << h) - 1) {
936      memcpy(&tmp, states+params->d + i, sizeof(bds_state));
937      memcpy(states+params->d + i, states + i, sizeof(bds_state));
938      memcpy(states + i, &tmp, sizeof(bds_state));
939
940      setLayerADRS(ots_addr, (i+1));
941      setTreeADRS(ots_addr, ((idx + 1) >> ((i+2) * tree_h)));
942      setOTSADRS(ots_addr, (((idx >> ((i+1) * tree_h)) + 1) & ((1 << tree_h)-1)));
943
944      get_seed(ots_seed, sk+params->index_len, n, ots_addr);
945      wots_sign(wots_sigs + i*params->xmss_par.wots_par.keysize, states[i].stack, ots_seed, &(params->xmss_par.wots_par), pub_seed, ots_addr);
946
947      states[params->d + i].stackoffset = 0;
948      states[params->d + i].next_leaf = 0;
949
950      updates--; // WOTS-signing counts as one update
951      needswap_upto = i;
952      for (j = 0; j < tree_h-k; j++) {
953        states[i].treehash[j].completed = 1;
954      }
955    }
956  }
957
958  //Whipe secret elements?
959  //zerobytes(tsk, CRYPTO_SECRETKEYBYTES);
960
961  memcpy(sig_msg, msg, msglen);
962  *sig_msg_len += msglen;
963
964  return 0;
965}
966
967/**
968 * Verifies a given message signature pair under a given public key.
969 */
970int xmssmt_sign_open(unsigned char *msg, unsigned long long *msglen, const unsigned char *sig_msg, unsigned long long sig_msg_len, const unsigned char *pk, const xmssmt_params *params)
971{
972  unsigned int n = params->n;
973
974  unsigned int tree_h = params->xmss_par.h;
975  unsigned int idx_len = params->index_len;
976  uint64_t idx_tree;
977  uint32_t idx_leaf;
978
979  unsigned long long i, m_len;
980  unsigned long long idx=0;
981  unsigned char wots_pk[params->xmss_par.wots_par.keysize];
982  unsigned char pkhash[n];
983  unsigned char root[n];
984  unsigned char msg_h[n];
985  unsigned char hash_key[3*n];
986
987  unsigned char pub_seed[n];
988  memcpy(pub_seed, pk+n, n);
989
990  // Init addresses
991  uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
992  uint32_t ltree_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
993  uint32_t node_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
994
995  // Extract index
996  for (i = 0; i < idx_len; i++) {
997    idx |= ((unsigned long long)sig_msg[i]) << (8*(idx_len - 1 - i));
998  }
999  printf("verify:: idx = %llu\n", idx);
1000  sig_msg += idx_len;
1001  sig_msg_len -= idx_len;
1002
1003  // Generate hash key (R || root || idx)
1004  memcpy(hash_key, sig_msg,n);
1005  memcpy(hash_key+n, pk, n);
1006  to_byte(hash_key+2*n, idx, n);
1007
1008  sig_msg += n;
1009  sig_msg_len -= n;
1010
1011
1012  // hash message (recall, R is now on pole position at sig_msg
1013  unsigned long long tmp_sig_len = (params->d * params->xmss_par.wots_par.keysize) + (params->h * n);
1014  m_len = sig_msg_len - tmp_sig_len;
1015  h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n);
1016
1017
1018  //-----------------------
1019  // Verify signature
1020  //-----------------------
1021
1022  // Prepare Address
1023  idx_tree = idx >> tree_h;
1024  idx_leaf = (idx & ((1 << tree_h)-1));
1025  setLayerADRS(ots_addr, 0);
1026  setTreeADRS(ots_addr, idx_tree);
1027  setType(ots_addr, 0);
1028
1029  memcpy(ltree_addr, ots_addr, 12);
1030  setType(ltree_addr, 1);
1031
1032  memcpy(node_addr, ltree_addr, 12);
1033  setType(node_addr, 2);
1034
1035  setOTSADRS(ots_addr, idx_leaf);
1036
1037  // Check WOTS signature
1038  wots_pkFromSig(wots_pk, sig_msg, msg_h, &(params->xmss_par.wots_par), pub_seed, ots_addr);
1039
1040  sig_msg += params->xmss_par.wots_par.keysize;
1041  sig_msg_len -= params->xmss_par.wots_par.keysize;
1042
1043  // Compute Ltree
1044  setLtreeADRS(ltree_addr, idx_leaf);
1045  l_tree(pkhash, wots_pk, &(params->xmss_par), pub_seed, ltree_addr);
1046
1047  // Compute root
1048  validate_authpath(root, pkhash, idx_leaf, sig_msg, &(params->xmss_par), pub_seed, node_addr);
1049
1050  sig_msg += tree_h*n;
1051  sig_msg_len -= tree_h*n;
1052
1053  for (i = 1; i < params->d; i++) {
1054    // Prepare Address
1055    idx_leaf = (idx_tree & ((1 << tree_h)-1));
1056    idx_tree = idx_tree >> tree_h;
1057
1058    setLayerADRS(ots_addr, i);
1059    setTreeADRS(ots_addr, idx_tree);
1060    setType(ots_addr, 0);
1061
1062    memcpy(ltree_addr, ots_addr, 12);
1063    setType(ltree_addr, 1);
1064
1065    memcpy(node_addr, ltree_addr, 12);
1066    setType(node_addr, 2);
1067
1068    setOTSADRS(ots_addr, idx_leaf);
1069
1070    // Check WOTS signature
1071    wots_pkFromSig(wots_pk, sig_msg, root, &(params->xmss_par.wots_par), pub_seed, ots_addr);
1072
1073    sig_msg += params->xmss_par.wots_par.keysize;
1074    sig_msg_len -= params->xmss_par.wots_par.keysize;
1075
1076    // Compute Ltree
1077    setLtreeADRS(ltree_addr, idx_leaf);
1078    l_tree(pkhash, wots_pk, &(params->xmss_par), pub_seed, ltree_addr);
1079
1080    // Compute root
1081    validate_authpath(root, pkhash, idx_leaf, sig_msg, &(params->xmss_par), pub_seed, node_addr);
1082
1083    sig_msg += tree_h*n;
1084    sig_msg_len -= tree_h*n;
1085
1086  }
1087
1088  for (i = 0; i < n; i++)
1089    if (root[i] != pk[i])
1090      goto fail;
1091
1092  *msglen = sig_msg_len;
1093  for (i = 0; i < *msglen; i++)
1094    msg[i] = sig_msg[i];
1095
1096  return 0;
1097
1098
1099fail:
1100  *msglen = sig_msg_len;
1101  for (i = 0; i < *msglen; i++)
1102    msg[i] = 0;
1103  *msglen = -1;
1104  return -1;
1105}
1106#endif /* WITH_XMSS */
1107