1/* $OpenBSD: sshkey-xmss.c,v 1.12 2022/10/28 00:39:29 djm Exp $ */
2/*
3 * Copyright (c) 2017 Markus Friedl.  All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 * 1. Redistributions of source code must retain the above copyright
9 *    notice, this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright
11 *    notice, this list of conditions and the following disclaimer in the
12 *    documentation and/or other materials provided with the distribution.
13 *
14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26#include "includes.h"
27#ifdef WITH_XMSS
28
29#include <sys/types.h>
30#include <sys/uio.h>
31
32#include <stdio.h>
33#include <string.h>
34#include <unistd.h>
35#include <fcntl.h>
36#include <errno.h>
37#ifdef HAVE_SYS_FILE_H
38# include <sys/file.h>
39#endif
40
41#include "ssh2.h"
42#include "ssherr.h"
43#include "sshbuf.h"
44#include "cipher.h"
45#include "sshkey.h"
46#include "sshkey-xmss.h"
47#include "atomicio.h"
48#include "log.h"
49
50#include "xmss_fast.h"
51
52/* opaque internal XMSS state */
53#define XMSS_MAGIC		"xmss-state-v1"
54#define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
55struct ssh_xmss_state {
56	xmss_params	params;
57	u_int32_t	n, w, h, k;
58
59	bds_state	bds;
60	u_char		*stack;
61	u_int32_t	stackoffset;
62	u_char		*stacklevels;
63	u_char		*auth;
64	u_char		*keep;
65	u_char		*th_nodes;
66	u_char		*retain;
67	treehash_inst	*treehash;
68
69	u_int32_t	idx;		/* state read from file */
70	u_int32_t	maxidx;		/* restricted # of signatures */
71	int		have_state;	/* .state file exists */
72	int		lockfd;		/* locked in sshkey_xmss_get_state() */
73	u_char		allow_update;	/* allow sshkey_xmss_update_state() */
74	char		*enc_ciphername;/* encrypt state with cipher */
75	u_char		*enc_keyiv;	/* encrypt state with key */
76	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
77};
78
79int	 sshkey_xmss_init_bds_state(struct sshkey *);
80int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
81void	 sshkey_xmss_free_bds(struct sshkey *);
82int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
83	    int *, int);
84int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
85	    struct sshbuf **);
86int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
87	    struct sshbuf **);
88int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
89int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
90
91#define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
92    0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)
93
94int
95sshkey_xmss_init(struct sshkey *key, const char *name)
96{
97	struct ssh_xmss_state *state;
98
99	if (key->xmss_state != NULL)
100		return SSH_ERR_INVALID_FORMAT;
101	if (name == NULL)
102		return SSH_ERR_INVALID_FORMAT;
103	state = calloc(sizeof(struct ssh_xmss_state), 1);
104	if (state == NULL)
105		return SSH_ERR_ALLOC_FAIL;
106	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
107		state->n = 32;
108		state->w = 16;
109		state->h = 10;
110	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
111		state->n = 32;
112		state->w = 16;
113		state->h = 16;
114	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
115		state->n = 32;
116		state->w = 16;
117		state->h = 20;
118	} else {
119		free(state);
120		return SSH_ERR_KEY_TYPE_UNKNOWN;
121	}
122	if ((key->xmss_name = strdup(name)) == NULL) {
123		free(state);
124		return SSH_ERR_ALLOC_FAIL;
125	}
126	state->k = 2;	/* XXX hardcoded */
127	state->lockfd = -1;
128	if (xmss_set_params(&state->params, state->n, state->h, state->w,
129	    state->k) != 0) {
130		free(state);
131		return SSH_ERR_INVALID_FORMAT;
132	}
133	key->xmss_state = state;
134	return 0;
135}
136
137void
138sshkey_xmss_free_state(struct sshkey *key)
139{
140	struct ssh_xmss_state *state = key->xmss_state;
141
142	sshkey_xmss_free_bds(key);
143	if (state) {
144		if (state->enc_keyiv) {
145			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
146			free(state->enc_keyiv);
147		}
148		free(state->enc_ciphername);
149		free(state);
150	}
151	key->xmss_state = NULL;
152}
153
154#define SSH_XMSS_K2_MAGIC	"k=2"
155#define num_stack(x)		((x->h+1)*(x->n))
156#define num_stacklevels(x)	(x->h+1)
157#define num_auth(x)		((x->h)*(x->n))
158#define num_keep(x)		((x->h >> 1)*(x->n))
159#define num_th_nodes(x)		((x->h - x->k)*(x->n))
160#define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
161#define num_treehash(x)		((x->h) - (x->k))
162
163int
164sshkey_xmss_init_bds_state(struct sshkey *key)
165{
166	struct ssh_xmss_state *state = key->xmss_state;
167	u_int32_t i;
168
169	state->stackoffset = 0;
170	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
171	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
172	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
173	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
174	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
175	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
176	    (state->treehash = calloc(num_treehash(state),
177	    sizeof(treehash_inst))) == NULL) {
178		sshkey_xmss_free_bds(key);
179		return SSH_ERR_ALLOC_FAIL;
180	}
181	for (i = 0; i < state->h - state->k; i++)
182		state->treehash[i].node = &state->th_nodes[state->n*i];
183	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
184	    state->stacklevels, state->auth, state->keep, state->treehash,
185	    state->retain, 0);
186	return 0;
187}
188
189void
190sshkey_xmss_free_bds(struct sshkey *key)
191{
192	struct ssh_xmss_state *state = key->xmss_state;
193
194	if (state == NULL)
195		return;
196	free(state->stack);
197	free(state->stacklevels);
198	free(state->auth);
199	free(state->keep);
200	free(state->th_nodes);
201	free(state->retain);
202	free(state->treehash);
203	state->stack = NULL;
204	state->stacklevels = NULL;
205	state->auth = NULL;
206	state->keep = NULL;
207	state->th_nodes = NULL;
208	state->retain = NULL;
209	state->treehash = NULL;
210}
211
212void *
213sshkey_xmss_params(const struct sshkey *key)
214{
215	struct ssh_xmss_state *state = key->xmss_state;
216
217	if (state == NULL)
218		return NULL;
219	return &state->params;
220}
221
222void *
223sshkey_xmss_bds_state(const struct sshkey *key)
224{
225	struct ssh_xmss_state *state = key->xmss_state;
226
227	if (state == NULL)
228		return NULL;
229	return &state->bds;
230}
231
232int
233sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
234{
235	struct ssh_xmss_state *state = key->xmss_state;
236
237	if (lenp == NULL)
238		return SSH_ERR_INVALID_ARGUMENT;
239	if (state == NULL)
240		return SSH_ERR_INVALID_FORMAT;
241	*lenp = 4 + state->n +
242	    state->params.wots_par.keysize +
243	    state->h * state->n;
244	return 0;
245}
246
247size_t
248sshkey_xmss_pklen(const struct sshkey *key)
249{
250	struct ssh_xmss_state *state = key->xmss_state;
251
252	if (state == NULL)
253		return 0;
254	return state->n * 2;
255}
256
257size_t
258sshkey_xmss_sklen(const struct sshkey *key)
259{
260	struct ssh_xmss_state *state = key->xmss_state;
261
262	if (state == NULL)
263		return 0;
264	return state->n * 4 + 4;
265}
266
267int
268sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
269{
270	struct ssh_xmss_state *state = k->xmss_state;
271	const struct sshcipher *cipher;
272	size_t keylen = 0, ivlen = 0;
273
274	if (state == NULL)
275		return SSH_ERR_INVALID_ARGUMENT;
276	if ((cipher = cipher_by_name(ciphername)) == NULL)
277		return SSH_ERR_INTERNAL_ERROR;
278	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
279		return SSH_ERR_ALLOC_FAIL;
280	keylen = cipher_keylen(cipher);
281	ivlen = cipher_ivlen(cipher);
282	state->enc_keyiv_len = keylen + ivlen;
283	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
284		free(state->enc_ciphername);
285		state->enc_ciphername = NULL;
286		return SSH_ERR_ALLOC_FAIL;
287	}
288	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
289	return 0;
290}
291
292int
293sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
294{
295	struct ssh_xmss_state *state = k->xmss_state;
296	int r;
297
298	if (state == NULL || state->enc_keyiv == NULL ||
299	    state->enc_ciphername == NULL)
300		return SSH_ERR_INVALID_ARGUMENT;
301	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
302	    (r = sshbuf_put_string(b, state->enc_keyiv,
303	    state->enc_keyiv_len)) != 0)
304		return r;
305	return 0;
306}
307
308int
309sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
310{
311	struct ssh_xmss_state *state = k->xmss_state;
312	size_t len;
313	int r;
314
315	if (state == NULL)
316		return SSH_ERR_INVALID_ARGUMENT;
317	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
318	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
319		return r;
320	state->enc_keyiv_len = len;
321	return 0;
322}
323
324int
325sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
326    enum sshkey_serialize_rep opts)
327{
328	struct ssh_xmss_state *state = k->xmss_state;
329	u_char have_info = 1;
330	u_int32_t idx;
331	int r;
332
333	if (state == NULL)
334		return SSH_ERR_INVALID_ARGUMENT;
335	if (opts != SSHKEY_SERIALIZE_INFO)
336		return 0;
337	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
338	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
339	    (r = sshbuf_put_u32(b, idx)) != 0 ||
340	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
341		return r;
342	return 0;
343}
344
345int
346sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
347{
348	struct ssh_xmss_state *state = k->xmss_state;
349	u_char have_info;
350	int r;
351
352	if (state == NULL)
353		return SSH_ERR_INVALID_ARGUMENT;
354	/* optional */
355	if (sshbuf_len(b) == 0)
356		return 0;
357	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
358		return r;
359	if (have_info != 1)
360		return SSH_ERR_INVALID_ARGUMENT;
361	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
362	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
363		return r;
364	return 0;
365}
366
367int
368sshkey_xmss_generate_private_key(struct sshkey *k, int bits)
369{
370	int r;
371	const char *name;
372
373	if (bits == 10) {
374		name = XMSS_SHA2_256_W16_H10_NAME;
375	} else if (bits == 16) {
376		name = XMSS_SHA2_256_W16_H16_NAME;
377	} else if (bits == 20) {
378		name = XMSS_SHA2_256_W16_H20_NAME;
379	} else {
380		name = XMSS_DEFAULT_NAME;
381	}
382	if ((r = sshkey_xmss_init(k, name)) != 0 ||
383	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
384	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
385		return r;
386	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
387	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
388		return SSH_ERR_ALLOC_FAIL;
389	}
390	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
391	    sshkey_xmss_params(k));
392	return 0;
393}
394
395int
396sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
397    int *have_file, int printerror)
398{
399	struct sshbuf *b = NULL, *enc = NULL;
400	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
401	u_int32_t len;
402	unsigned char buf[4], *data = NULL;
403
404	*have_file = 0;
405	if ((fd = open(filename, O_RDONLY)) >= 0) {
406		*have_file = 1;
407		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
408			PRINT("corrupt state file: %s", filename);
409			goto done;
410		}
411		len = PEEK_U32(buf);
412		if ((data = calloc(len, 1)) == NULL) {
413			ret = SSH_ERR_ALLOC_FAIL;
414			goto done;
415		}
416		if (atomicio(read, fd, data, len) != len) {
417			PRINT("cannot read blob: %s", filename);
418			goto done;
419		}
420		if ((enc = sshbuf_from(data, len)) == NULL) {
421			ret = SSH_ERR_ALLOC_FAIL;
422			goto done;
423		}
424		sshkey_xmss_free_bds(k);
425		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
426			ret = r;
427			goto done;
428		}
429		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
430			ret = r;
431			goto done;
432		}
433		ret = 0;
434	}
435done:
436	if (fd != -1)
437		close(fd);
438	free(data);
439	sshbuf_free(enc);
440	sshbuf_free(b);
441	return ret;
442}
443
444int
445sshkey_xmss_get_state(const struct sshkey *k, int printerror)
446{
447	struct ssh_xmss_state *state = k->xmss_state;
448	u_int32_t idx = 0;
449	char *filename = NULL;
450	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
451	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
452	int ret = SSH_ERR_INVALID_ARGUMENT, r;
453
454	if (state == NULL)
455		goto done;
456	/*
457	 * If maxidx is set, then we are allowed a limited number
458	 * of signatures, but don't need to access the disk.
459	 * Otherwise we need to deal with the on-disk state.
460	 */
461	if (state->maxidx) {
462		/* xmss_sk always contains the current state */
463		idx = PEEK_U32(k->xmss_sk);
464		if (idx < state->maxidx) {
465			state->allow_update = 1;
466			return 0;
467		}
468		return SSH_ERR_INVALID_ARGUMENT;
469	}
470	if ((filename = k->xmss_filename) == NULL)
471		goto done;
472	if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
473	    asprintf(&statefile, "%s.state", filename) == -1 ||
474	    asprintf(&ostatefile, "%s.ostate", filename) == -1) {
475		ret = SSH_ERR_ALLOC_FAIL;
476		goto done;
477	}
478	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
479		ret = SSH_ERR_SYSTEM_ERROR;
480		PRINT("cannot open/create: %s", lockfile);
481		goto done;
482	}
483	while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
484		if (errno != EWOULDBLOCK) {
485			ret = SSH_ERR_SYSTEM_ERROR;
486			PRINT("cannot lock: %s", lockfile);
487			goto done;
488		}
489		if (++tries > 10) {
490			ret = SSH_ERR_SYSTEM_ERROR;
491			PRINT("giving up on: %s", lockfile);
492			goto done;
493		}
494		usleep(1000*100*tries);
495	}
496	/* XXX no longer const */
497	if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
498	    statefile, &have_state, printerror)) != 0) {
499		if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
500		    ostatefile, &have_ostate, printerror)) == 0) {
501			state->allow_update = 1;
502			r = sshkey_xmss_forward_state(k, 1);
503			state->idx = PEEK_U32(k->xmss_sk);
504			state->allow_update = 0;
505		}
506	}
507	if (!have_state && !have_ostate) {
508		/* check that bds state is initialized */
509		if (state->bds.auth == NULL)
510			goto done;
511		PRINT("start from scratch idx 0: %u", state->idx);
512	} else if (r != 0) {
513		ret = r;
514		goto done;
515	}
516	if (state->idx + 1 < state->idx) {
517		PRINT("state wrap: %u", state->idx);
518		goto done;
519	}
520	state->have_state = have_state;
521	state->lockfd = lockfd;
522	state->allow_update = 1;
523	lockfd = -1;
524	ret = 0;
525done:
526	if (lockfd != -1)
527		close(lockfd);
528	free(lockfile);
529	free(statefile);
530	free(ostatefile);
531	return ret;
532}
533
534int
535sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
536{
537	struct ssh_xmss_state *state = k->xmss_state;
538	u_char *sig = NULL;
539	size_t required_siglen;
540	unsigned long long smlen;
541	u_char data;
542	int ret, r;
543
544	if (state == NULL || !state->allow_update)
545		return SSH_ERR_INVALID_ARGUMENT;
546	if (reserve == 0)
547		return SSH_ERR_INVALID_ARGUMENT;
548	if (state->idx + reserve <= state->idx)
549		return SSH_ERR_INVALID_ARGUMENT;
550	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
551		return r;
552	if ((sig = malloc(required_siglen)) == NULL)
553		return SSH_ERR_ALLOC_FAIL;
554	while (reserve-- > 0) {
555		state->idx = PEEK_U32(k->xmss_sk);
556		smlen = required_siglen;
557		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
558		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
559			r = SSH_ERR_INVALID_ARGUMENT;
560			break;
561		}
562	}
563	free(sig);
564	return r;
565}
566
567int
568sshkey_xmss_update_state(const struct sshkey *k, int printerror)
569{
570	struct ssh_xmss_state *state = k->xmss_state;
571	struct sshbuf *b = NULL, *enc = NULL;
572	u_int32_t idx = 0;
573	unsigned char buf[4];
574	char *filename = NULL;
575	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
576	int fd = -1;
577	int ret = SSH_ERR_INVALID_ARGUMENT;
578
579	if (state == NULL || !state->allow_update)
580		return ret;
581	if (state->maxidx) {
582		/* no update since the number of signatures is limited */
583		ret = 0;
584		goto done;
585	}
586	idx = PEEK_U32(k->xmss_sk);
587	if (idx == state->idx) {
588		/* no signature happened, no need to update */
589		ret = 0;
590		goto done;
591	} else if (idx != state->idx + 1) {
592		PRINT("more than one signature happened: idx %u state %u",
593		    idx, state->idx);
594		goto done;
595	}
596	state->idx = idx;
597	if ((filename = k->xmss_filename) == NULL)
598		goto done;
599	if (asprintf(&statefile, "%s.state", filename) == -1 ||
600	    asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
601	    asprintf(&nstatefile, "%s.nstate", filename) == -1) {
602		ret = SSH_ERR_ALLOC_FAIL;
603		goto done;
604	}
605	unlink(nstatefile);
606	if ((b = sshbuf_new()) == NULL) {
607		ret = SSH_ERR_ALLOC_FAIL;
608		goto done;
609	}
610	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
611		PRINT("SERLIALIZE FAILED: %d", ret);
612		goto done;
613	}
614	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
615		PRINT("ENCRYPT FAILED: %d", ret);
616		goto done;
617	}
618	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
619		ret = SSH_ERR_SYSTEM_ERROR;
620		PRINT("open new state file: %s", nstatefile);
621		goto done;
622	}
623	POKE_U32(buf, sshbuf_len(enc));
624	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
625		ret = SSH_ERR_SYSTEM_ERROR;
626		PRINT("write new state file hdr: %s", nstatefile);
627		close(fd);
628		goto done;
629	}
630	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
631	    sshbuf_len(enc)) {
632		ret = SSH_ERR_SYSTEM_ERROR;
633		PRINT("write new state file data: %s", nstatefile);
634		close(fd);
635		goto done;
636	}
637	if (fsync(fd) == -1) {
638		ret = SSH_ERR_SYSTEM_ERROR;
639		PRINT("sync new state file: %s", nstatefile);
640		close(fd);
641		goto done;
642	}
643	if (close(fd) == -1) {
644		ret = SSH_ERR_SYSTEM_ERROR;
645		PRINT("close new state file: %s", nstatefile);
646		goto done;
647	}
648	if (state->have_state) {
649		unlink(ostatefile);
650		if (link(statefile, ostatefile)) {
651			ret = SSH_ERR_SYSTEM_ERROR;
652			PRINT("backup state %s to %s", statefile, ostatefile);
653			goto done;
654		}
655	}
656	if (rename(nstatefile, statefile) == -1) {
657		ret = SSH_ERR_SYSTEM_ERROR;
658		PRINT("rename %s to %s", nstatefile, statefile);
659		goto done;
660	}
661	ret = 0;
662done:
663	if (state->lockfd != -1) {
664		close(state->lockfd);
665		state->lockfd = -1;
666	}
667	if (nstatefile)
668		unlink(nstatefile);
669	free(statefile);
670	free(ostatefile);
671	free(nstatefile);
672	sshbuf_free(b);
673	sshbuf_free(enc);
674	return ret;
675}
676
677int
678sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
679{
680	struct ssh_xmss_state *state = k->xmss_state;
681	treehash_inst *th;
682	u_int32_t i, node;
683	int r;
684
685	if (state == NULL)
686		return SSH_ERR_INVALID_ARGUMENT;
687	if (state->stack == NULL)
688		return SSH_ERR_INVALID_ARGUMENT;
689	state->stackoffset = state->bds.stackoffset;	/* copy back */
690	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
691	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
692	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
693	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
694	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
695	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
696	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
697	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
698	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
699	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
700		return r;
701	for (i = 0; i < num_treehash(state); i++) {
702		th = &state->treehash[i];
703		node = th->node - state->th_nodes;
704		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
705		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
706		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
707		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
708		    (r = sshbuf_put_u32(b, node)) != 0)
709			return r;
710	}
711	return 0;
712}
713
714int
715sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
716    enum sshkey_serialize_rep opts)
717{
718	struct ssh_xmss_state *state = k->xmss_state;
719	int r = SSH_ERR_INVALID_ARGUMENT;
720	u_char have_stack, have_filename, have_enc;
721
722	if (state == NULL)
723		return SSH_ERR_INVALID_ARGUMENT;
724	if ((r = sshbuf_put_u8(b, opts)) != 0)
725		return r;
726	switch (opts) {
727	case SSHKEY_SERIALIZE_STATE:
728		r = sshkey_xmss_serialize_state(k, b);
729		break;
730	case SSHKEY_SERIALIZE_FULL:
731		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
732			return r;
733		r = sshkey_xmss_serialize_state(k, b);
734		break;
735	case SSHKEY_SERIALIZE_SHIELD:
736		/* all of stack/filename/enc are optional */
737		have_stack = state->stack != NULL;
738		if ((r = sshbuf_put_u8(b, have_stack)) != 0)
739			return r;
740		if (have_stack) {
741			state->idx = PEEK_U32(k->xmss_sk);	/* update */
742			if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
743				return r;
744		}
745		have_filename = k->xmss_filename != NULL;
746		if ((r = sshbuf_put_u8(b, have_filename)) != 0)
747			return r;
748		if (have_filename &&
749		    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
750			return r;
751		have_enc = state->enc_keyiv != NULL;
752		if ((r = sshbuf_put_u8(b, have_enc)) != 0)
753			return r;
754		if (have_enc &&
755		    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
756			return r;
757		if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
758		    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
759			return r;
760		break;
761	case SSHKEY_SERIALIZE_DEFAULT:
762		r = 0;
763		break;
764	default:
765		r = SSH_ERR_INVALID_ARGUMENT;
766		break;
767	}
768	return r;
769}
770
771int
772sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
773{
774	struct ssh_xmss_state *state = k->xmss_state;
775	treehash_inst *th;
776	u_int32_t i, lh, node;
777	size_t ls, lsl, la, lk, ln, lr;
778	char *magic;
779	int r = SSH_ERR_INTERNAL_ERROR;
780
781	if (state == NULL)
782		return SSH_ERR_INVALID_ARGUMENT;
783	if (k->xmss_sk == NULL)
784		return SSH_ERR_INVALID_ARGUMENT;
785	if ((state->treehash = calloc(num_treehash(state),
786	    sizeof(treehash_inst))) == NULL)
787		return SSH_ERR_ALLOC_FAIL;
788	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
789	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
790	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
791	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
792	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
793	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
794	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
795	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
796	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
797	    (r = sshbuf_get_u32(b, &lh)) != 0)
798		goto out;
799	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
800		r = SSH_ERR_INVALID_ARGUMENT;
801		goto out;
802	}
803	/* XXX check stackoffset */
804	if (ls != num_stack(state) ||
805	    lsl != num_stacklevels(state) ||
806	    la != num_auth(state) ||
807	    lk != num_keep(state) ||
808	    ln != num_th_nodes(state) ||
809	    lr != num_retain(state) ||
810	    lh != num_treehash(state)) {
811		r = SSH_ERR_INVALID_ARGUMENT;
812		goto out;
813	}
814	for (i = 0; i < num_treehash(state); i++) {
815		th = &state->treehash[i];
816		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
817		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
818		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
819		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
820		    (r = sshbuf_get_u32(b, &node)) != 0)
821			goto out;
822		if (node < num_th_nodes(state))
823			th->node = &state->th_nodes[node];
824	}
825	POKE_U32(k->xmss_sk, state->idx);
826	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
827	    state->stacklevels, state->auth, state->keep, state->treehash,
828	    state->retain, 0);
829	/* success */
830	r = 0;
831 out:
832	free(magic);
833	return r;
834}
835
836int
837sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
838{
839	struct ssh_xmss_state *state = k->xmss_state;
840	enum sshkey_serialize_rep opts;
841	u_char have_state, have_stack, have_filename, have_enc;
842	int r;
843
844	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
845		return r;
846
847	opts = have_state;
848	switch (opts) {
849	case SSHKEY_SERIALIZE_DEFAULT:
850		r = 0;
851		break;
852	case SSHKEY_SERIALIZE_SHIELD:
853		if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
854			return r;
855		if (have_stack &&
856		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
857			return r;
858		if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
859			return r;
860		if (have_filename &&
861		    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
862			return r;
863		if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
864			return r;
865		if (have_enc &&
866		    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
867			return r;
868		if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
869		    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
870			return r;
871		break;
872	case SSHKEY_SERIALIZE_STATE:
873		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
874			return r;
875		break;
876	case SSHKEY_SERIALIZE_FULL:
877		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
878		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
879			return r;
880		break;
881	default:
882		r = SSH_ERR_INVALID_FORMAT;
883		break;
884	}
885	return r;
886}
887
888int
889sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
890   struct sshbuf **retp)
891{
892	struct ssh_xmss_state *state = k->xmss_state;
893	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
894	struct sshcipher_ctx *ciphercontext = NULL;
895	const struct sshcipher *cipher;
896	u_char *cp, *key, *iv = NULL;
897	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
898	int r = SSH_ERR_INTERNAL_ERROR;
899
900	if (retp != NULL)
901		*retp = NULL;
902	if (state == NULL ||
903	    state->enc_keyiv == NULL ||
904	    state->enc_ciphername == NULL)
905		return SSH_ERR_INTERNAL_ERROR;
906	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
907		r = SSH_ERR_INTERNAL_ERROR;
908		goto out;
909	}
910	blocksize = cipher_blocksize(cipher);
911	keylen = cipher_keylen(cipher);
912	ivlen = cipher_ivlen(cipher);
913	authlen = cipher_authlen(cipher);
914	if (state->enc_keyiv_len != keylen + ivlen) {
915		r = SSH_ERR_INVALID_FORMAT;
916		goto out;
917	}
918	key = state->enc_keyiv;
919	if ((encrypted = sshbuf_new()) == NULL ||
920	    (encoded = sshbuf_new()) == NULL ||
921	    (padded = sshbuf_new()) == NULL ||
922	    (iv = malloc(ivlen)) == NULL) {
923		r = SSH_ERR_ALLOC_FAIL;
924		goto out;
925	}
926
927	/* replace first 4 bytes of IV with index to ensure uniqueness */
928	memcpy(iv, key + keylen, ivlen);
929	POKE_U32(iv, state->idx);
930
931	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
932	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
933		goto out;
934
935	/* padded state will be encrypted */
936	if ((r = sshbuf_putb(padded, b)) != 0)
937		goto out;
938	i = 0;
939	while (sshbuf_len(padded) % blocksize) {
940		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
941			goto out;
942	}
943	encrypted_len = sshbuf_len(padded);
944
945	/* header including the length of state is used as AAD */
946	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
947		goto out;
948	aadlen = sshbuf_len(encoded);
949
950	/* concat header and state */
951	if ((r = sshbuf_putb(encoded, padded)) != 0)
952		goto out;
953
954	/* reserve space for encryption of encoded data plus auth tag */
955	/* encrypt at offset addlen */
956	if ((r = sshbuf_reserve(encrypted,
957	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
958	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
959	    iv, ivlen, 1)) != 0 ||
960	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
961	    encrypted_len, aadlen, authlen)) != 0)
962		goto out;
963
964	/* success */
965	r = 0;
966 out:
967	if (retp != NULL) {
968		*retp = encrypted;
969		encrypted = NULL;
970	}
971	sshbuf_free(padded);
972	sshbuf_free(encoded);
973	sshbuf_free(encrypted);
974	cipher_free(ciphercontext);
975	free(iv);
976	return r;
977}
978
979int
980sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
981   struct sshbuf **retp)
982{
983	struct ssh_xmss_state *state = k->xmss_state;
984	struct sshbuf *copy = NULL, *decrypted = NULL;
985	struct sshcipher_ctx *ciphercontext = NULL;
986	const struct sshcipher *cipher = NULL;
987	u_char *key, *iv = NULL, *dp;
988	size_t keylen, ivlen, authlen, aadlen;
989	u_int blocksize, encrypted_len, index;
990	int r = SSH_ERR_INTERNAL_ERROR;
991
992	if (retp != NULL)
993		*retp = NULL;
994	if (state == NULL ||
995	    state->enc_keyiv == NULL ||
996	    state->enc_ciphername == NULL)
997		return SSH_ERR_INTERNAL_ERROR;
998	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
999		r = SSH_ERR_INVALID_FORMAT;
1000		goto out;
1001	}
1002	blocksize = cipher_blocksize(cipher);
1003	keylen = cipher_keylen(cipher);
1004	ivlen = cipher_ivlen(cipher);
1005	authlen = cipher_authlen(cipher);
1006	if (state->enc_keyiv_len != keylen + ivlen) {
1007		r = SSH_ERR_INTERNAL_ERROR;
1008		goto out;
1009	}
1010	key = state->enc_keyiv;
1011
1012	if ((copy = sshbuf_fromb(encoded)) == NULL ||
1013	    (decrypted = sshbuf_new()) == NULL ||
1014	    (iv = malloc(ivlen)) == NULL) {
1015		r = SSH_ERR_ALLOC_FAIL;
1016		goto out;
1017	}
1018
1019	/* check magic */
1020	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1021	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1022		r = SSH_ERR_INVALID_FORMAT;
1023		goto out;
1024	}
1025	/* parse public portion */
1026	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1027	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1028	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1029		goto out;
1030
1031	/* check size of encrypted key blob */
1032	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1033		r = SSH_ERR_INVALID_FORMAT;
1034		goto out;
1035	}
1036	/* check that an appropriate amount of auth data is present */
1037	if (sshbuf_len(encoded) < authlen ||
1038	    sshbuf_len(encoded) - authlen < encrypted_len) {
1039		r = SSH_ERR_INVALID_FORMAT;
1040		goto out;
1041	}
1042
1043	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1044
1045	/* replace first 4 bytes of IV with index to ensure uniqueness */
1046	memcpy(iv, key + keylen, ivlen);
1047	POKE_U32(iv, index);
1048
1049	/* decrypt private state of key */
1050	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1051	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
1052	    iv, ivlen, 0)) != 0 ||
1053	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1054	    encrypted_len, aadlen, authlen)) != 0)
1055		goto out;
1056
1057	/* there should be no trailing data */
1058	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1059		goto out;
1060	if (sshbuf_len(encoded) != 0) {
1061		r = SSH_ERR_INVALID_FORMAT;
1062		goto out;
1063	}
1064
1065	/* remove AAD */
1066	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1067		goto out;
1068	/* XXX encrypted includes unchecked padding */
1069
1070	/* success */
1071	r = 0;
1072	if (retp != NULL) {
1073		*retp = decrypted;
1074		decrypted = NULL;
1075	}
1076 out:
1077	cipher_free(ciphercontext);
1078	sshbuf_free(copy);
1079	sshbuf_free(decrypted);
1080	free(iv);
1081	return r;
1082}
1083
1084u_int32_t
1085sshkey_xmss_signatures_left(const struct sshkey *k)
1086{
1087	struct ssh_xmss_state *state = k->xmss_state;
1088	u_int32_t idx;
1089
1090	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1091	    state->maxidx) {
1092		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1093		if (idx < state->maxidx)
1094			return state->maxidx - idx;
1095	}
1096	return 0;
1097}
1098
1099int
1100sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1101{
1102	struct ssh_xmss_state *state = k->xmss_state;
1103
1104	if (sshkey_type_plain(k->type) != KEY_XMSS)
1105		return SSH_ERR_INVALID_ARGUMENT;
1106	if (maxsign == 0)
1107		return 0;
1108	if (state->idx + maxsign < state->idx)
1109		return SSH_ERR_INVALID_ARGUMENT;
1110	state->maxidx = state->idx + maxsign;
1111	return 0;
1112}
1113#endif /* WITH_XMSS */
1114