1/* $OpenBSD: kex.c,v 1.185 2024/01/08 00:34:33 djm Exp $ */
2/*
3 * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 * 1. Redistributions of source code must retain the above copyright
9 *    notice, this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright
11 *    notice, this list of conditions and the following disclaimer in the
12 *    documentation and/or other materials provided with the distribution.
13 *
14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26#include "includes.h"
27
28#include <sys/types.h>
29#include <errno.h>
30#include <signal.h>
31#include <stdarg.h>
32#include <stdio.h>
33#include <stdlib.h>
34#include <string.h>
35#include <unistd.h>
36#ifdef HAVE_POLL_H
37#include <poll.h>
38#endif
39
40#ifdef WITH_OPENSSL
41#include <openssl/crypto.h>
42#include <openssl/dh.h>
43#endif
44
45#include "ssh.h"
46#include "ssh2.h"
47#include "atomicio.h"
48#include "version.h"
49#include "packet.h"
50#include "compat.h"
51#include "cipher.h"
52#include "sshkey.h"
53#include "kex.h"
54#include "log.h"
55#include "mac.h"
56#include "match.h"
57#include "misc.h"
58#include "dispatch.h"
59#include "monitor.h"
60#include "myproposal.h"
61
62#include "ssherr.h"
63#include "sshbuf.h"
64#include "digest.h"
65#include "xmalloc.h"
66
67/* prototype */
68static int kex_choose_conf(struct ssh *, uint32_t seq);
69static int kex_input_newkeys(int, u_int32_t, struct ssh *);
70
71static const char * const proposal_names[PROPOSAL_MAX] = {
72	"KEX algorithms",
73	"host key algorithms",
74	"ciphers ctos",
75	"ciphers stoc",
76	"MACs ctos",
77	"MACs stoc",
78	"compression ctos",
79	"compression stoc",
80	"languages ctos",
81	"languages stoc",
82};
83
84struct kexalg {
85	char *name;
86	u_int type;
87	int ec_nid;
88	int hash_alg;
89};
90static const struct kexalg kexalgs[] = {
91#ifdef WITH_OPENSSL
92	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
93	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
94	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
95	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
96	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
97	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
98#ifdef HAVE_EVP_SHA256
99	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
100#endif /* HAVE_EVP_SHA256 */
101#ifdef OPENSSL_HAS_ECC
102	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
103	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
104	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
105	    SSH_DIGEST_SHA384 },
106# ifdef OPENSSL_HAS_NISTP521
107	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
108	    SSH_DIGEST_SHA512 },
109# endif /* OPENSSL_HAS_NISTP521 */
110#endif /* OPENSSL_HAS_ECC */
111#endif /* WITH_OPENSSL */
112#if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
113	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
114	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
115#ifdef USE_SNTRUP761X25519
116	{ KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
117	    SSH_DIGEST_SHA512 },
118#endif
119#endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
120	{ NULL, 0, -1, -1},
121};
122
123char *
124kex_alg_list(char sep)
125{
126	char *ret = NULL, *tmp;
127	size_t nlen, rlen = 0;
128	const struct kexalg *k;
129
130	for (k = kexalgs; k->name != NULL; k++) {
131		if (ret != NULL)
132			ret[rlen++] = sep;
133		nlen = strlen(k->name);
134		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
135			free(ret);
136			return NULL;
137		}
138		ret = tmp;
139		memcpy(ret + rlen, k->name, nlen + 1);
140		rlen += nlen;
141	}
142	return ret;
143}
144
145static const struct kexalg *
146kex_alg_by_name(const char *name)
147{
148	const struct kexalg *k;
149
150	for (k = kexalgs; k->name != NULL; k++) {
151		if (strcmp(k->name, name) == 0)
152			return k;
153	}
154	return NULL;
155}
156
157/* Validate KEX method name list */
158int
159kex_names_valid(const char *names)
160{
161	char *s, *cp, *p;
162
163	if (names == NULL || strcmp(names, "") == 0)
164		return 0;
165	if ((s = cp = strdup(names)) == NULL)
166		return 0;
167	for ((p = strsep(&cp, ",")); p && *p != '\0';
168	    (p = strsep(&cp, ","))) {
169		if (kex_alg_by_name(p) == NULL) {
170			error("Unsupported KEX algorithm \"%.100s\"", p);
171			free(s);
172			return 0;
173		}
174	}
175	debug3("kex names ok: [%s]", names);
176	free(s);
177	return 1;
178}
179
180/* returns non-zero if proposal contains any algorithm from algs */
181static int
182has_any_alg(const char *proposal, const char *algs)
183{
184	char *cp;
185
186	if ((cp = match_list(proposal, algs, NULL)) == NULL)
187		return 0;
188	free(cp);
189	return 1;
190}
191
192/*
193 * Concatenate algorithm names, avoiding duplicates in the process.
194 * Caller must free returned string.
195 */
196char *
197kex_names_cat(const char *a, const char *b)
198{
199	char *ret = NULL, *tmp = NULL, *cp, *p;
200	size_t len;
201
202	if (a == NULL || *a == '\0')
203		return strdup(b);
204	if (b == NULL || *b == '\0')
205		return strdup(a);
206	if (strlen(b) > 1024*1024)
207		return NULL;
208	len = strlen(a) + strlen(b) + 2;
209	if ((tmp = cp = strdup(b)) == NULL ||
210	    (ret = calloc(1, len)) == NULL) {
211		free(tmp);
212		return NULL;
213	}
214	strlcpy(ret, a, len);
215	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
216		if (has_any_alg(ret, p))
217			continue; /* Algorithm already present */
218		if (strlcat(ret, ",", len) >= len ||
219		    strlcat(ret, p, len) >= len) {
220			free(tmp);
221			free(ret);
222			return NULL; /* Shouldn't happen */
223		}
224	}
225	free(tmp);
226	return ret;
227}
228
229/*
230 * Assemble a list of algorithms from a default list and a string from a
231 * configuration file. The user-provided string may begin with '+' to
232 * indicate that it should be appended to the default, '-' that the
233 * specified names should be removed, or '^' that they should be placed
234 * at the head.
235 */
236int
237kex_assemble_names(char **listp, const char *def, const char *all)
238{
239	char *cp, *tmp, *patterns;
240	char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
241	int r = SSH_ERR_INTERNAL_ERROR;
242
243	if (listp == NULL || def == NULL || all == NULL)
244		return SSH_ERR_INVALID_ARGUMENT;
245
246	if (*listp == NULL || **listp == '\0') {
247		if ((*listp = strdup(def)) == NULL)
248			return SSH_ERR_ALLOC_FAIL;
249		return 0;
250	}
251
252	list = *listp;
253	*listp = NULL;
254	if (*list == '+') {
255		/* Append names to default list */
256		if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
257			r = SSH_ERR_ALLOC_FAIL;
258			goto fail;
259		}
260		free(list);
261		list = tmp;
262	} else if (*list == '-') {
263		/* Remove names from default list */
264		if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
265			r = SSH_ERR_ALLOC_FAIL;
266			goto fail;
267		}
268		free(list);
269		/* filtering has already been done */
270		return 0;
271	} else if (*list == '^') {
272		/* Place names at head of default list */
273		if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
274			r = SSH_ERR_ALLOC_FAIL;
275			goto fail;
276		}
277		free(list);
278		list = tmp;
279	} else {
280		/* Explicit list, overrides default - just use "list" as is */
281	}
282
283	/*
284	 * The supplied names may be a pattern-list. For the -list case,
285	 * the patterns are applied above. For the +list and explicit list
286	 * cases we need to do it now.
287	 */
288	ret = NULL;
289	if ((patterns = opatterns = strdup(list)) == NULL) {
290		r = SSH_ERR_ALLOC_FAIL;
291		goto fail;
292	}
293	/* Apply positive (i.e. non-negated) patterns from the list */
294	while ((cp = strsep(&patterns, ",")) != NULL) {
295		if (*cp == '!') {
296			/* negated matches are not supported here */
297			r = SSH_ERR_INVALID_ARGUMENT;
298			goto fail;
299		}
300		free(matching);
301		if ((matching = match_filter_allowlist(all, cp)) == NULL) {
302			r = SSH_ERR_ALLOC_FAIL;
303			goto fail;
304		}
305		if ((tmp = kex_names_cat(ret, matching)) == NULL) {
306			r = SSH_ERR_ALLOC_FAIL;
307			goto fail;
308		}
309		free(ret);
310		ret = tmp;
311	}
312	if (ret == NULL || *ret == '\0') {
313		/* An empty name-list is an error */
314		/* XXX better error code? */
315		r = SSH_ERR_INVALID_ARGUMENT;
316		goto fail;
317	}
318
319	/* success */
320	*listp = ret;
321	ret = NULL;
322	r = 0;
323
324 fail:
325	free(matching);
326	free(opatterns);
327	free(list);
328	free(ret);
329	return r;
330}
331
332/*
333 * Fill out a proposal array with dynamically allocated values, which may
334 * be modified as required for compatibility reasons.
335 * Any of the options may be NULL, in which case the default is used.
336 * Array contents must be freed by calling kex_proposal_free_entries.
337 */
338void
339kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
340    const char *kexalgos, const char *ciphers, const char *macs,
341    const char *comp, const char *hkalgs)
342{
343	const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
344	const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
345	const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
346	u_int i;
347	char *cp;
348
349	if (prop == NULL)
350		fatal_f("proposal missing");
351
352	/* Append EXT_INFO signalling to KexAlgorithms */
353	if (kexalgos == NULL)
354		kexalgos = defprop[PROPOSAL_KEX_ALGS];
355	if ((cp = kex_names_cat(kexalgos, ssh->kex->server ?
356	    "ext-info-s,kex-strict-s-v00@openssh.com" :
357	    "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL)
358		fatal_f("kex_names_cat");
359
360	for (i = 0; i < PROPOSAL_MAX; i++) {
361		switch(i) {
362		case PROPOSAL_KEX_ALGS:
363			prop[i] = compat_kex_proposal(ssh, cp);
364			break;
365		case PROPOSAL_ENC_ALGS_CTOS:
366		case PROPOSAL_ENC_ALGS_STOC:
367			prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
368			break;
369		case PROPOSAL_MAC_ALGS_CTOS:
370		case PROPOSAL_MAC_ALGS_STOC:
371			prop[i]  = xstrdup(macs ? macs : defprop[i]);
372			break;
373		case PROPOSAL_COMP_ALGS_CTOS:
374		case PROPOSAL_COMP_ALGS_STOC:
375			prop[i] = xstrdup(comp ? comp : defprop[i]);
376			break;
377		case PROPOSAL_SERVER_HOST_KEY_ALGS:
378			prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
379			break;
380		default:
381			prop[i] = xstrdup(defprop[i]);
382		}
383	}
384	free(cp);
385}
386
387void
388kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
389{
390	u_int i;
391
392	for (i = 0; i < PROPOSAL_MAX; i++)
393		free(prop[i]);
394}
395
396/* put algorithm proposal into buffer */
397int
398kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
399{
400	u_int i;
401	int r;
402
403	sshbuf_reset(b);
404
405	/*
406	 * add a dummy cookie, the cookie will be overwritten by
407	 * kex_send_kexinit(), each time a kexinit is set
408	 */
409	for (i = 0; i < KEX_COOKIE_LEN; i++) {
410		if ((r = sshbuf_put_u8(b, 0)) != 0)
411			return r;
412	}
413	for (i = 0; i < PROPOSAL_MAX; i++) {
414		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
415			return r;
416	}
417	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
418	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
419		return r;
420	return 0;
421}
422
423/* parse buffer and return algorithm proposal */
424int
425kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
426{
427	struct sshbuf *b = NULL;
428	u_char v;
429	u_int i;
430	char **proposal = NULL;
431	int r;
432
433	*propp = NULL;
434	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
435		return SSH_ERR_ALLOC_FAIL;
436	if ((b = sshbuf_fromb(raw)) == NULL) {
437		r = SSH_ERR_ALLOC_FAIL;
438		goto out;
439	}
440	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
441		error_fr(r, "consume cookie");
442		goto out;
443	}
444	/* extract kex init proposal strings */
445	for (i = 0; i < PROPOSAL_MAX; i++) {
446		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
447			error_fr(r, "parse proposal %u", i);
448			goto out;
449		}
450		debug2("%s: %s", proposal_names[i], proposal[i]);
451	}
452	/* first kex follows / reserved */
453	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
454	    (r = sshbuf_get_u32(b, &i)) != 0) {	/* reserved */
455		error_fr(r, "parse");
456		goto out;
457	}
458	if (first_kex_follows != NULL)
459		*first_kex_follows = v;
460	debug2("first_kex_follows %d ", v);
461	debug2("reserved %u ", i);
462	r = 0;
463	*propp = proposal;
464 out:
465	if (r != 0 && proposal != NULL)
466		kex_prop_free(proposal);
467	sshbuf_free(b);
468	return r;
469}
470
471void
472kex_prop_free(char **proposal)
473{
474	u_int i;
475
476	if (proposal == NULL)
477		return;
478	for (i = 0; i < PROPOSAL_MAX; i++)
479		free(proposal[i]);
480	free(proposal);
481}
482
483int
484kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
485{
486	int r;
487
488	/* If in strict mode, any unexpected message is an error */
489	if ((ssh->kex->flags & KEX_INITIAL) && ssh->kex->kex_strict) {
490		ssh_packet_disconnect(ssh, "strict KEX violation: "
491		    "unexpected packet type %u (seqnr %u)", type, seq);
492	}
493	error_f("type %u seq %u", type, seq);
494	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
495	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
496	    (r = sshpkt_send(ssh)) != 0)
497		return r;
498	return 0;
499}
500
501static void
502kex_reset_dispatch(struct ssh *ssh)
503{
504	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
505	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
506}
507
508void
509kex_set_server_sig_algs(struct ssh *ssh, const char *allowed_algs)
510{
511	char *alg, *oalgs, *algs, *sigalgs;
512	const char *sigalg;
513
514	/*
515	 * NB. allowed algorithms may contain certificate algorithms that
516	 * map to a specific plain signature type, e.g.
517	 * rsa-sha2-512-cert-v01@openssh.com => rsa-sha2-512
518	 * We need to be careful here to match these, retain the mapping
519	 * and only add each signature algorithm once.
520	 */
521	if ((sigalgs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
522		fatal_f("sshkey_alg_list failed");
523	oalgs = algs = xstrdup(allowed_algs);
524	free(ssh->kex->server_sig_algs);
525	ssh->kex->server_sig_algs = NULL;
526	for ((alg = strsep(&algs, ",")); alg != NULL && *alg != '\0';
527	    (alg = strsep(&algs, ","))) {
528		if ((sigalg = sshkey_sigalg_by_name(alg)) == NULL)
529			continue;
530		if (!has_any_alg(sigalg, sigalgs))
531			continue;
532		/* Don't add an algorithm twice. */
533		if (ssh->kex->server_sig_algs != NULL &&
534		    has_any_alg(sigalg, ssh->kex->server_sig_algs))
535			continue;
536		xextendf(&ssh->kex->server_sig_algs, ",", "%s", sigalg);
537	}
538	free(oalgs);
539	free(sigalgs);
540	if (ssh->kex->server_sig_algs == NULL)
541		ssh->kex->server_sig_algs = xstrdup("");
542}
543
544static int
545kex_compose_ext_info_server(struct ssh *ssh, struct sshbuf *m)
546{
547	int r;
548
549	if (ssh->kex->server_sig_algs == NULL &&
550	    (ssh->kex->server_sig_algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
551		return SSH_ERR_ALLOC_FAIL;
552	if ((r = sshbuf_put_u32(m, 3)) != 0 ||
553	    (r = sshbuf_put_cstring(m, "server-sig-algs")) != 0 ||
554	    (r = sshbuf_put_cstring(m, ssh->kex->server_sig_algs)) != 0 ||
555	    (r = sshbuf_put_cstring(m,
556	    "publickey-hostbound@openssh.com")) != 0 ||
557	    (r = sshbuf_put_cstring(m, "0")) != 0 ||
558	    (r = sshbuf_put_cstring(m, "ping@openssh.com")) != 0 ||
559	    (r = sshbuf_put_cstring(m, "0")) != 0) {
560		error_fr(r, "compose");
561		return r;
562	}
563	return 0;
564}
565
566static int
567kex_compose_ext_info_client(struct ssh *ssh, struct sshbuf *m)
568{
569	int r;
570
571	if ((r = sshbuf_put_u32(m, 1)) != 0 ||
572	    (r = sshbuf_put_cstring(m, "ext-info-in-auth@openssh.com")) != 0 ||
573	    (r = sshbuf_put_cstring(m, "0")) != 0) {
574		error_fr(r, "compose");
575		goto out;
576	}
577	/* success */
578	r = 0;
579 out:
580	return r;
581}
582
583static int
584kex_maybe_send_ext_info(struct ssh *ssh)
585{
586	int r;
587	struct sshbuf *m = NULL;
588
589	if ((ssh->kex->flags & KEX_INITIAL) == 0)
590		return 0;
591	if (!ssh->kex->ext_info_c && !ssh->kex->ext_info_s)
592		return 0;
593
594	/* Compose EXT_INFO packet. */
595	if ((m = sshbuf_new()) == NULL)
596		fatal_f("sshbuf_new failed");
597	if (ssh->kex->ext_info_c &&
598	    (r = kex_compose_ext_info_server(ssh, m)) != 0)
599		goto fail;
600	if (ssh->kex->ext_info_s &&
601	    (r = kex_compose_ext_info_client(ssh, m)) != 0)
602		goto fail;
603
604	/* Send the actual KEX_INFO packet */
605	debug("Sending SSH2_MSG_EXT_INFO");
606	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
607	    (r = sshpkt_putb(ssh, m)) != 0 ||
608	    (r = sshpkt_send(ssh)) != 0) {
609		error_f("send EXT_INFO");
610		goto fail;
611	}
612
613	r = 0;
614
615 fail:
616	sshbuf_free(m);
617	return r;
618}
619
620int
621kex_server_update_ext_info(struct ssh *ssh)
622{
623	int r;
624
625	if ((ssh->kex->flags & KEX_HAS_EXT_INFO_IN_AUTH) == 0)
626		return 0;
627
628	debug_f("Sending SSH2_MSG_EXT_INFO");
629	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
630	    (r = sshpkt_put_u32(ssh, 1)) != 0 ||
631	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
632	    (r = sshpkt_put_cstring(ssh, ssh->kex->server_sig_algs)) != 0 ||
633	    (r = sshpkt_send(ssh)) != 0) {
634		error_f("send EXT_INFO");
635		return r;
636	}
637	return 0;
638}
639
640int
641kex_send_newkeys(struct ssh *ssh)
642{
643	int r;
644
645	kex_reset_dispatch(ssh);
646	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
647	    (r = sshpkt_send(ssh)) != 0)
648		return r;
649	debug("SSH2_MSG_NEWKEYS sent");
650	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
651	if ((r = kex_maybe_send_ext_info(ssh)) != 0)
652		return r;
653	debug("expecting SSH2_MSG_NEWKEYS");
654	return 0;
655}
656
657/* Check whether an ext_info value contains the expected version string */
658static int
659kex_ext_info_check_ver(struct kex *kex, const char *name,
660    const u_char *val, size_t len, const char *want_ver, u_int flag)
661{
662	if (memchr(val, '\0', len) != NULL) {
663		error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
664		return SSH_ERR_INVALID_FORMAT;
665	}
666	debug_f("%s=<%s>", name, val);
667	if (strcmp(val, want_ver) == 0)
668		kex->flags |= flag;
669	else
670		debug_f("unsupported version of %s extension", name);
671	return 0;
672}
673
674static int
675kex_ext_info_client_parse(struct ssh *ssh, const char *name,
676    const u_char *value, size_t vlen)
677{
678	int r;
679
680	/* NB. some messages are only accepted in the initial EXT_INFO */
681	if (strcmp(name, "server-sig-algs") == 0) {
682		/* Ensure no \0 lurking in value */
683		if (memchr(value, '\0', vlen) != NULL) {
684			error_f("nul byte in %s", name);
685			return SSH_ERR_INVALID_FORMAT;
686		}
687		debug_f("%s=<%s>", name, value);
688		free(ssh->kex->server_sig_algs);
689		ssh->kex->server_sig_algs = xstrdup((const char *)value);
690	} else if (ssh->kex->ext_info_received == 1 &&
691	    strcmp(name, "publickey-hostbound@openssh.com") == 0) {
692		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
693		    "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
694			return r;
695		}
696	} else if (ssh->kex->ext_info_received == 1 &&
697	    strcmp(name, "ping@openssh.com") == 0) {
698		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
699		    "0", KEX_HAS_PING)) != 0) {
700			return r;
701		}
702	} else
703		debug_f("%s (unrecognised)", name);
704
705	return 0;
706}
707
708static int
709kex_ext_info_server_parse(struct ssh *ssh, const char *name,
710    const u_char *value, size_t vlen)
711{
712	int r;
713
714	if (strcmp(name, "ext-info-in-auth@openssh.com") == 0) {
715		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
716		    "0", KEX_HAS_EXT_INFO_IN_AUTH)) != 0) {
717			return r;
718		}
719	} else
720		debug_f("%s (unrecognised)", name);
721	return 0;
722}
723
724int
725kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
726{
727	struct kex *kex = ssh->kex;
728	const int max_ext_info = kex->server ? 1 : 2;
729	u_int32_t i, ninfo;
730	char *name;
731	u_char *val;
732	size_t vlen;
733	int r;
734
735	debug("SSH2_MSG_EXT_INFO received");
736	if (++kex->ext_info_received > max_ext_info) {
737		error("too many SSH2_MSG_EXT_INFO messages sent by peer");
738		return dispatch_protocol_error(type, seq, ssh);
739	}
740	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
741	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
742		return r;
743	if (ninfo >= 1024) {
744		error("SSH2_MSG_EXT_INFO with too many entries, expected "
745		    "<=1024, received %u", ninfo);
746		return dispatch_protocol_error(type, seq, ssh);
747	}
748	for (i = 0; i < ninfo; i++) {
749		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
750			return r;
751		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
752			free(name);
753			return r;
754		}
755		debug3_f("extension %s", name);
756		if (kex->server) {
757			if ((r = kex_ext_info_server_parse(ssh, name,
758			    val, vlen)) != 0)
759				return r;
760		} else {
761			if ((r = kex_ext_info_client_parse(ssh, name,
762			    val, vlen)) != 0)
763				return r;
764		}
765		free(name);
766		free(val);
767	}
768	return sshpkt_get_end(ssh);
769}
770
771static int
772kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
773{
774	struct kex *kex = ssh->kex;
775	int r, initial = (kex->flags & KEX_INITIAL) != 0;
776	char *cp, **prop;
777
778	debug("SSH2_MSG_NEWKEYS received");
779	if (kex->ext_info_c && initial)
780		ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_input_ext_info);
781	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
782	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
783	if ((r = sshpkt_get_end(ssh)) != 0)
784		return r;
785	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
786		return r;
787	if (initial) {
788		/* Remove initial KEX signalling from proposal for rekeying */
789		if ((r = kex_buf2prop(kex->my, NULL, &prop)) != 0)
790			return r;
791		if ((cp = match_filter_denylist(prop[PROPOSAL_KEX_ALGS],
792		    kex->server ?
793		    "ext-info-s,kex-strict-s-v00@openssh.com" :
794		    "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL) {
795			error_f("match_filter_denylist failed");
796			goto fail;
797		}
798		free(prop[PROPOSAL_KEX_ALGS]);
799		prop[PROPOSAL_KEX_ALGS] = cp;
800		if ((r = kex_prop2buf(ssh->kex->my, prop)) != 0) {
801			error_f("kex_prop2buf failed");
802 fail:
803			kex_proposal_free_entries(prop);
804			free(prop);
805			return SSH_ERR_INTERNAL_ERROR;
806		}
807		kex_proposal_free_entries(prop);
808		free(prop);
809	}
810	kex->done = 1;
811	kex->flags &= ~KEX_INITIAL;
812	sshbuf_reset(kex->peer);
813	kex->flags &= ~KEX_INIT_SENT;
814	free(kex->name);
815	kex->name = NULL;
816	return 0;
817}
818
819int
820kex_send_kexinit(struct ssh *ssh)
821{
822	u_char *cookie;
823	struct kex *kex = ssh->kex;
824	int r;
825
826	if (kex == NULL) {
827		error_f("no kex");
828		return SSH_ERR_INTERNAL_ERROR;
829	}
830	if (kex->flags & KEX_INIT_SENT)
831		return 0;
832	kex->done = 0;
833
834	/* generate a random cookie */
835	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
836		error_f("bad kex length: %zu < %d",
837		    sshbuf_len(kex->my), KEX_COOKIE_LEN);
838		return SSH_ERR_INVALID_FORMAT;
839	}
840	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
841		error_f("buffer error");
842		return SSH_ERR_INTERNAL_ERROR;
843	}
844	arc4random_buf(cookie, KEX_COOKIE_LEN);
845
846	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
847	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
848	    (r = sshpkt_send(ssh)) != 0) {
849		error_fr(r, "compose reply");
850		return r;
851	}
852	debug("SSH2_MSG_KEXINIT sent");
853	kex->flags |= KEX_INIT_SENT;
854	return 0;
855}
856
857int
858kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
859{
860	struct kex *kex = ssh->kex;
861	const u_char *ptr;
862	u_int i;
863	size_t dlen;
864	int r;
865
866	debug("SSH2_MSG_KEXINIT received");
867	if (kex == NULL) {
868		error_f("no kex");
869		return SSH_ERR_INTERNAL_ERROR;
870	}
871	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_protocol_error);
872	ptr = sshpkt_ptr(ssh, &dlen);
873	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
874		return r;
875
876	/* discard packet */
877	for (i = 0; i < KEX_COOKIE_LEN; i++) {
878		if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
879			error_fr(r, "discard cookie");
880			return r;
881		}
882	}
883	for (i = 0; i < PROPOSAL_MAX; i++) {
884		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
885			error_fr(r, "discard proposal");
886			return r;
887		}
888	}
889	/*
890	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
891	 * KEX method has the server move first, but a server might be using
892	 * a custom method or one that we otherwise don't support. We should
893	 * be prepared to remember first_kex_follows here so we can eat a
894	 * packet later.
895	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
896	 * for cases where the server *doesn't* go first. I guess we should
897	 * ignore it when it is set for these cases, which is what we do now.
898	 */
899	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
900	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
901	    (r = sshpkt_get_end(ssh)) != 0)
902			return r;
903
904	if (!(kex->flags & KEX_INIT_SENT))
905		if ((r = kex_send_kexinit(ssh)) != 0)
906			return r;
907	if ((r = kex_choose_conf(ssh, seq)) != 0)
908		return r;
909
910	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
911		return (kex->kex[kex->kex_type])(ssh);
912
913	error_f("unknown kex type %u", kex->kex_type);
914	return SSH_ERR_INTERNAL_ERROR;
915}
916
917struct kex *
918kex_new(void)
919{
920	struct kex *kex;
921
922	if ((kex = calloc(1, sizeof(*kex))) == NULL ||
923	    (kex->peer = sshbuf_new()) == NULL ||
924	    (kex->my = sshbuf_new()) == NULL ||
925	    (kex->client_version = sshbuf_new()) == NULL ||
926	    (kex->server_version = sshbuf_new()) == NULL ||
927	    (kex->session_id = sshbuf_new()) == NULL) {
928		kex_free(kex);
929		return NULL;
930	}
931	return kex;
932}
933
934void
935kex_free_newkeys(struct newkeys *newkeys)
936{
937	if (newkeys == NULL)
938		return;
939	if (newkeys->enc.key) {
940		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
941		free(newkeys->enc.key);
942		newkeys->enc.key = NULL;
943	}
944	if (newkeys->enc.iv) {
945		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
946		free(newkeys->enc.iv);
947		newkeys->enc.iv = NULL;
948	}
949	free(newkeys->enc.name);
950	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
951	free(newkeys->comp.name);
952	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
953	mac_clear(&newkeys->mac);
954	if (newkeys->mac.key) {
955		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
956		free(newkeys->mac.key);
957		newkeys->mac.key = NULL;
958	}
959	free(newkeys->mac.name);
960	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
961	freezero(newkeys, sizeof(*newkeys));
962}
963
964void
965kex_free(struct kex *kex)
966{
967	u_int mode;
968
969	if (kex == NULL)
970		return;
971
972#ifdef WITH_OPENSSL
973	DH_free(kex->dh);
974#ifdef OPENSSL_HAS_ECC
975	EC_KEY_free(kex->ec_client_key);
976#endif /* OPENSSL_HAS_ECC */
977#endif /* WITH_OPENSSL */
978	for (mode = 0; mode < MODE_MAX; mode++) {
979		kex_free_newkeys(kex->newkeys[mode]);
980		kex->newkeys[mode] = NULL;
981	}
982	sshbuf_free(kex->peer);
983	sshbuf_free(kex->my);
984	sshbuf_free(kex->client_version);
985	sshbuf_free(kex->server_version);
986	sshbuf_free(kex->client_pub);
987	sshbuf_free(kex->session_id);
988	sshbuf_free(kex->initial_sig);
989	sshkey_free(kex->initial_hostkey);
990	free(kex->failed_choice);
991	free(kex->hostkey_alg);
992	free(kex->name);
993	free(kex);
994}
995
996int
997kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
998{
999	int r;
1000
1001	if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
1002		return r;
1003	ssh->kex->flags = KEX_INITIAL;
1004	kex_reset_dispatch(ssh);
1005	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
1006	return 0;
1007}
1008
1009int
1010kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
1011{
1012	int r;
1013
1014	if ((r = kex_ready(ssh, proposal)) != 0)
1015		return r;
1016	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
1017		kex_free(ssh->kex);
1018		ssh->kex = NULL;
1019		return r;
1020	}
1021	return 0;
1022}
1023
1024/*
1025 * Request key re-exchange, returns 0 on success or a ssherr.h error
1026 * code otherwise. Must not be called if KEX is incomplete or in-progress.
1027 */
1028int
1029kex_start_rekex(struct ssh *ssh)
1030{
1031	if (ssh->kex == NULL) {
1032		error_f("no kex");
1033		return SSH_ERR_INTERNAL_ERROR;
1034	}
1035	if (ssh->kex->done == 0) {
1036		error_f("requested twice");
1037		return SSH_ERR_INTERNAL_ERROR;
1038	}
1039	ssh->kex->done = 0;
1040	return kex_send_kexinit(ssh);
1041}
1042
1043static int
1044choose_enc(struct sshenc *enc, char *client, char *server)
1045{
1046	char *name = match_list(client, server, NULL);
1047
1048	if (name == NULL)
1049		return SSH_ERR_NO_CIPHER_ALG_MATCH;
1050	if ((enc->cipher = cipher_by_name(name)) == NULL) {
1051		error_f("unsupported cipher %s", name);
1052		free(name);
1053		return SSH_ERR_INTERNAL_ERROR;
1054	}
1055	enc->name = name;
1056	enc->enabled = 0;
1057	enc->iv = NULL;
1058	enc->iv_len = cipher_ivlen(enc->cipher);
1059	enc->key = NULL;
1060	enc->key_len = cipher_keylen(enc->cipher);
1061	enc->block_size = cipher_blocksize(enc->cipher);
1062	return 0;
1063}
1064
1065static int
1066choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
1067{
1068	char *name = match_list(client, server, NULL);
1069
1070	if (name == NULL)
1071		return SSH_ERR_NO_MAC_ALG_MATCH;
1072	if (mac_setup(mac, name) < 0) {
1073		error_f("unsupported MAC %s", name);
1074		free(name);
1075		return SSH_ERR_INTERNAL_ERROR;
1076	}
1077	mac->name = name;
1078	mac->key = NULL;
1079	mac->enabled = 0;
1080	return 0;
1081}
1082
1083static int
1084choose_comp(struct sshcomp *comp, char *client, char *server)
1085{
1086	char *name = match_list(client, server, NULL);
1087
1088	if (name == NULL)
1089		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
1090#ifdef WITH_ZLIB
1091	if (strcmp(name, "zlib@openssh.com") == 0) {
1092		comp->type = COMP_DELAYED;
1093	} else if (strcmp(name, "zlib") == 0) {
1094		comp->type = COMP_ZLIB;
1095	} else
1096#endif	/* WITH_ZLIB */
1097	if (strcmp(name, "none") == 0) {
1098		comp->type = COMP_NONE;
1099	} else {
1100		error_f("unsupported compression scheme %s", name);
1101		free(name);
1102		return SSH_ERR_INTERNAL_ERROR;
1103	}
1104	comp->name = name;
1105	return 0;
1106}
1107
1108static int
1109choose_kex(struct kex *k, char *client, char *server)
1110{
1111	const struct kexalg *kexalg;
1112
1113	k->name = match_list(client, server, NULL);
1114
1115	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
1116	if (k->name == NULL)
1117		return SSH_ERR_NO_KEX_ALG_MATCH;
1118	if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
1119		error_f("unsupported KEX method %s", k->name);
1120		return SSH_ERR_INTERNAL_ERROR;
1121	}
1122	k->kex_type = kexalg->type;
1123	k->hash_alg = kexalg->hash_alg;
1124	k->ec_nid = kexalg->ec_nid;
1125	return 0;
1126}
1127
1128static int
1129choose_hostkeyalg(struct kex *k, char *client, char *server)
1130{
1131	free(k->hostkey_alg);
1132	k->hostkey_alg = match_list(client, server, NULL);
1133
1134	debug("kex: host key algorithm: %s",
1135	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
1136	if (k->hostkey_alg == NULL)
1137		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
1138	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
1139	if (k->hostkey_type == KEY_UNSPEC) {
1140		error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
1141		return SSH_ERR_INTERNAL_ERROR;
1142	}
1143	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
1144	return 0;
1145}
1146
1147static int
1148proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
1149{
1150	static int check[] = {
1151		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
1152	};
1153	int *idx;
1154	char *p;
1155
1156	for (idx = &check[0]; *idx != -1; idx++) {
1157		if ((p = strchr(my[*idx], ',')) != NULL)
1158			*p = '\0';
1159		if ((p = strchr(peer[*idx], ',')) != NULL)
1160			*p = '\0';
1161		if (strcmp(my[*idx], peer[*idx]) != 0) {
1162			debug2("proposal mismatch: my %s peer %s",
1163			    my[*idx], peer[*idx]);
1164			return (0);
1165		}
1166	}
1167	debug2("proposals match");
1168	return (1);
1169}
1170
1171static int
1172kexalgs_contains(char **peer, const char *ext)
1173{
1174	return has_any_alg(peer[PROPOSAL_KEX_ALGS], ext);
1175}
1176
1177static int
1178kex_choose_conf(struct ssh *ssh, uint32_t seq)
1179{
1180	struct kex *kex = ssh->kex;
1181	struct newkeys *newkeys;
1182	char **my = NULL, **peer = NULL;
1183	char **cprop, **sprop;
1184	int nenc, nmac, ncomp;
1185	u_int mode, ctos, need, dh_need, authlen;
1186	int r, first_kex_follows;
1187
1188	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
1189	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
1190		goto out;
1191	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
1192	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
1193		goto out;
1194
1195	if (kex->server) {
1196		cprop=peer;
1197		sprop=my;
1198	} else {
1199		cprop=my;
1200		sprop=peer;
1201	}
1202
1203	/* Check whether peer supports ext_info/kex_strict */
1204	if ((kex->flags & KEX_INITIAL) != 0) {
1205		if (kex->server) {
1206			kex->ext_info_c = kexalgs_contains(peer, "ext-info-c");
1207			kex->kex_strict = kexalgs_contains(peer,
1208			    "kex-strict-c-v00@openssh.com");
1209		} else {
1210			kex->ext_info_s = kexalgs_contains(peer, "ext-info-s");
1211			kex->kex_strict = kexalgs_contains(peer,
1212			    "kex-strict-s-v00@openssh.com");
1213		}
1214		if (kex->kex_strict) {
1215			debug3_f("will use strict KEX ordering");
1216			if (seq != 0)
1217				ssh_packet_disconnect(ssh,
1218				    "strict KEX violation: "
1219				    "KEXINIT was not the first packet");
1220		}
1221	}
1222
1223	/* Check whether client supports rsa-sha2 algorithms */
1224	if (kex->server && (kex->flags & KEX_INITIAL)) {
1225		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1226		    "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1227			kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1228		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1229		    "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1230			kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1231	}
1232
1233	/* Algorithm Negotiation */
1234	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1235	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
1236		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1237		peer[PROPOSAL_KEX_ALGS] = NULL;
1238		goto out;
1239	}
1240	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1241	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1242		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1243		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1244		goto out;
1245	}
1246	for (mode = 0; mode < MODE_MAX; mode++) {
1247		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1248			r = SSH_ERR_ALLOC_FAIL;
1249			goto out;
1250		}
1251		kex->newkeys[mode] = newkeys;
1252		ctos = (!kex->server && mode == MODE_OUT) ||
1253		    (kex->server && mode == MODE_IN);
1254		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1255		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1256		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1257		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1258		    sprop[nenc])) != 0) {
1259			kex->failed_choice = peer[nenc];
1260			peer[nenc] = NULL;
1261			goto out;
1262		}
1263		authlen = cipher_authlen(newkeys->enc.cipher);
1264		/* ignore mac for authenticated encryption */
1265		if (authlen == 0 &&
1266		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1267		    sprop[nmac])) != 0) {
1268			kex->failed_choice = peer[nmac];
1269			peer[nmac] = NULL;
1270			goto out;
1271		}
1272		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1273		    sprop[ncomp])) != 0) {
1274			kex->failed_choice = peer[ncomp];
1275			peer[ncomp] = NULL;
1276			goto out;
1277		}
1278		debug("kex: %s cipher: %s MAC: %s compression: %s",
1279		    ctos ? "client->server" : "server->client",
1280		    newkeys->enc.name,
1281		    authlen == 0 ? newkeys->mac.name : "<implicit>",
1282		    newkeys->comp.name);
1283	}
1284	need = dh_need = 0;
1285	for (mode = 0; mode < MODE_MAX; mode++) {
1286		newkeys = kex->newkeys[mode];
1287		need = MAXIMUM(need, newkeys->enc.key_len);
1288		need = MAXIMUM(need, newkeys->enc.block_size);
1289		need = MAXIMUM(need, newkeys->enc.iv_len);
1290		need = MAXIMUM(need, newkeys->mac.key_len);
1291		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1292		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1293		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1294		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1295	}
1296	/* XXX need runden? */
1297	kex->we_need = need;
1298	kex->dh_need = dh_need;
1299
1300	/* ignore the next message if the proposals do not match */
1301	if (first_kex_follows && !proposals_match(my, peer))
1302		ssh->dispatch_skip_packets = 1;
1303	r = 0;
1304 out:
1305	kex_prop_free(my);
1306	kex_prop_free(peer);
1307	return r;
1308}
1309
1310static int
1311derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1312    const struct sshbuf *shared_secret, u_char **keyp)
1313{
1314	struct kex *kex = ssh->kex;
1315	struct ssh_digest_ctx *hashctx = NULL;
1316	char c = id;
1317	u_int have;
1318	size_t mdsz;
1319	u_char *digest;
1320	int r;
1321
1322	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1323		return SSH_ERR_INVALID_ARGUMENT;
1324	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1325		r = SSH_ERR_ALLOC_FAIL;
1326		goto out;
1327	}
1328
1329	/* K1 = HASH(K || H || "A" || session_id) */
1330	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1331	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1332	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1333	    ssh_digest_update(hashctx, &c, 1) != 0 ||
1334	    ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1335	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
1336		r = SSH_ERR_LIBCRYPTO_ERROR;
1337		error_f("KEX hash failed");
1338		goto out;
1339	}
1340	ssh_digest_free(hashctx);
1341	hashctx = NULL;
1342
1343	/*
1344	 * expand key:
1345	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1346	 * Key = K1 || K2 || ... || Kn
1347	 */
1348	for (have = mdsz; need > have; have += mdsz) {
1349		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1350		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1351		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1352		    ssh_digest_update(hashctx, digest, have) != 0 ||
1353		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1354			error_f("KDF failed");
1355			r = SSH_ERR_LIBCRYPTO_ERROR;
1356			goto out;
1357		}
1358		ssh_digest_free(hashctx);
1359		hashctx = NULL;
1360	}
1361#ifdef DEBUG_KEX
1362	fprintf(stderr, "key '%c'== ", c);
1363	dump_digest("key", digest, need);
1364#endif
1365	*keyp = digest;
1366	digest = NULL;
1367	r = 0;
1368 out:
1369	free(digest);
1370	ssh_digest_free(hashctx);
1371	return r;
1372}
1373
1374#define NKEYS	6
1375int
1376kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1377    const struct sshbuf *shared_secret)
1378{
1379	struct kex *kex = ssh->kex;
1380	u_char *keys[NKEYS];
1381	u_int i, j, mode, ctos;
1382	int r;
1383
1384	/* save initial hash as session id */
1385	if ((kex->flags & KEX_INITIAL) != 0) {
1386		if (sshbuf_len(kex->session_id) != 0) {
1387			error_f("already have session ID at kex");
1388			return SSH_ERR_INTERNAL_ERROR;
1389		}
1390		if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1391			return r;
1392	} else if (sshbuf_len(kex->session_id) == 0) {
1393		error_f("no session ID in rekex");
1394		return SSH_ERR_INTERNAL_ERROR;
1395	}
1396	for (i = 0; i < NKEYS; i++) {
1397		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1398		    shared_secret, &keys[i])) != 0) {
1399			for (j = 0; j < i; j++)
1400				free(keys[j]);
1401			return r;
1402		}
1403	}
1404	for (mode = 0; mode < MODE_MAX; mode++) {
1405		ctos = (!kex->server && mode == MODE_OUT) ||
1406		    (kex->server && mode == MODE_IN);
1407		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1408		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1409		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1410	}
1411	return 0;
1412}
1413
1414int
1415kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1416{
1417	struct kex *kex = ssh->kex;
1418
1419	*pubp = NULL;
1420	*prvp = NULL;
1421	if (kex->load_host_public_key == NULL ||
1422	    kex->load_host_private_key == NULL) {
1423		error_f("missing hostkey loader");
1424		return SSH_ERR_INVALID_ARGUMENT;
1425	}
1426	*pubp = kex->load_host_public_key(kex->hostkey_type,
1427	    kex->hostkey_nid, ssh);
1428	*prvp = kex->load_host_private_key(kex->hostkey_type,
1429	    kex->hostkey_nid, ssh);
1430	if (*pubp == NULL)
1431		return SSH_ERR_NO_HOSTKEY_LOADED;
1432	return 0;
1433}
1434
1435int
1436kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1437{
1438	struct kex *kex = ssh->kex;
1439
1440	if (kex->verify_host_key == NULL) {
1441		error_f("missing hostkey verifier");
1442		return SSH_ERR_INVALID_ARGUMENT;
1443	}
1444	if (server_host_key->type != kex->hostkey_type ||
1445	    (kex->hostkey_type == KEY_ECDSA &&
1446	    server_host_key->ecdsa_nid != kex->hostkey_nid))
1447		return SSH_ERR_KEY_TYPE_MISMATCH;
1448	if (kex->verify_host_key(server_host_key, ssh) == -1)
1449		return  SSH_ERR_SIGNATURE_INVALID;
1450	return 0;
1451}
1452
1453#if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1454void
1455dump_digest(const char *msg, const u_char *digest, int len)
1456{
1457	fprintf(stderr, "%s\n", msg);
1458	sshbuf_dump_data(digest, len, stderr);
1459}
1460#endif
1461
1462/*
1463 * Send a plaintext error message to the peer, suffixed by \r\n.
1464 * Only used during banner exchange, and there only for the server.
1465 */
1466static void
1467send_error(struct ssh *ssh, char *msg)
1468{
1469	char *crnl = "\r\n";
1470
1471	if (!ssh->kex->server)
1472		return;
1473
1474	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1475	    msg, strlen(msg)) != strlen(msg) ||
1476	    atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1477	    crnl, strlen(crnl)) != strlen(crnl))
1478		error_f("write: %.100s", strerror(errno));
1479}
1480
1481/*
1482 * Sends our identification string and waits for the peer's. Will block for
1483 * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1484 * Returns on 0 success or a ssherr.h code on failure.
1485 */
1486int
1487kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1488    const char *version_addendum)
1489{
1490	int remote_major, remote_minor, mismatch, oerrno = 0;
1491	size_t len, n;
1492	int r, expect_nl;
1493	u_char c;
1494	struct sshbuf *our_version = ssh->kex->server ?
1495	    ssh->kex->server_version : ssh->kex->client_version;
1496	struct sshbuf *peer_version = ssh->kex->server ?
1497	    ssh->kex->client_version : ssh->kex->server_version;
1498	char *our_version_string = NULL, *peer_version_string = NULL;
1499	char *cp, *remote_version = NULL;
1500
1501	/* Prepare and send our banner */
1502	sshbuf_reset(our_version);
1503	if (version_addendum != NULL && *version_addendum == '\0')
1504		version_addendum = NULL;
1505	if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%s%s%s\r\n",
1506	    PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1507	    version_addendum == NULL ? "" : " ",
1508	    version_addendum == NULL ? "" : version_addendum)) != 0) {
1509		oerrno = errno;
1510		error_fr(r, "sshbuf_putf");
1511		goto out;
1512	}
1513
1514	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1515	    sshbuf_mutable_ptr(our_version),
1516	    sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1517		oerrno = errno;
1518		debug_f("write: %.100s", strerror(errno));
1519		r = SSH_ERR_SYSTEM_ERROR;
1520		goto out;
1521	}
1522	if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1523		oerrno = errno;
1524		error_fr(r, "sshbuf_consume_end");
1525		goto out;
1526	}
1527	our_version_string = sshbuf_dup_string(our_version);
1528	if (our_version_string == NULL) {
1529		error_f("sshbuf_dup_string failed");
1530		r = SSH_ERR_ALLOC_FAIL;
1531		goto out;
1532	}
1533	debug("Local version string %.100s", our_version_string);
1534
1535	/* Read other side's version identification. */
1536	for (n = 0; ; n++) {
1537		if (n >= SSH_MAX_PRE_BANNER_LINES) {
1538			send_error(ssh, "No SSH identification string "
1539			    "received.");
1540			error_f("No SSH version received in first %u lines "
1541			    "from server", SSH_MAX_PRE_BANNER_LINES);
1542			r = SSH_ERR_INVALID_FORMAT;
1543			goto out;
1544		}
1545		sshbuf_reset(peer_version);
1546		expect_nl = 0;
1547		for (;;) {
1548			if (timeout_ms > 0) {
1549				r = waitrfd(ssh_packet_get_connection_in(ssh),
1550				    &timeout_ms, NULL);
1551				if (r == -1 && errno == ETIMEDOUT) {
1552					send_error(ssh, "Timed out waiting "
1553					    "for SSH identification string.");
1554					error("Connection timed out during "
1555					    "banner exchange");
1556					r = SSH_ERR_CONN_TIMEOUT;
1557					goto out;
1558				} else if (r == -1) {
1559					oerrno = errno;
1560					error_f("%s", strerror(errno));
1561					r = SSH_ERR_SYSTEM_ERROR;
1562					goto out;
1563				}
1564			}
1565
1566			len = atomicio(read, ssh_packet_get_connection_in(ssh),
1567			    &c, 1);
1568			if (len != 1 && errno == EPIPE) {
1569				verbose_f("Connection closed by remote host");
1570				r = SSH_ERR_CONN_CLOSED;
1571				goto out;
1572			} else if (len != 1) {
1573				oerrno = errno;
1574				error_f("read: %.100s", strerror(errno));
1575				r = SSH_ERR_SYSTEM_ERROR;
1576				goto out;
1577			}
1578			if (c == '\r') {
1579				expect_nl = 1;
1580				continue;
1581			}
1582			if (c == '\n')
1583				break;
1584			if (c == '\0' || expect_nl) {
1585				verbose_f("banner line contains invalid "
1586				    "characters");
1587				goto invalid;
1588			}
1589			if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1590				oerrno = errno;
1591				error_fr(r, "sshbuf_put");
1592				goto out;
1593			}
1594			if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1595				verbose_f("banner line too long");
1596				goto invalid;
1597			}
1598		}
1599		/* Is this an actual protocol banner? */
1600		if (sshbuf_len(peer_version) > 4 &&
1601		    memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1602			break;
1603		/* If not, then just log the line and continue */
1604		if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1605			error_f("sshbuf_dup_string failed");
1606			r = SSH_ERR_ALLOC_FAIL;
1607			goto out;
1608		}
1609		/* Do not accept lines before the SSH ident from a client */
1610		if (ssh->kex->server) {
1611			verbose_f("client sent invalid protocol identifier "
1612			    "\"%.256s\"", cp);
1613			free(cp);
1614			goto invalid;
1615		}
1616		debug_f("banner line %zu: %s", n, cp);
1617		free(cp);
1618	}
1619	peer_version_string = sshbuf_dup_string(peer_version);
1620	if (peer_version_string == NULL)
1621		fatal_f("sshbuf_dup_string failed");
1622	/* XXX must be same size for sscanf */
1623	if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1624		error_f("calloc failed");
1625		r = SSH_ERR_ALLOC_FAIL;
1626		goto out;
1627	}
1628
1629	/*
1630	 * Check that the versions match.  In future this might accept
1631	 * several versions and set appropriate flags to handle them.
1632	 */
1633	if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1634	    &remote_major, &remote_minor, remote_version) != 3) {
1635		error("Bad remote protocol version identification: '%.100s'",
1636		    peer_version_string);
1637 invalid:
1638		send_error(ssh, "Invalid SSH identification string.");
1639		r = SSH_ERR_INVALID_FORMAT;
1640		goto out;
1641	}
1642	debug("Remote protocol version %d.%d, remote software version %.100s",
1643	    remote_major, remote_minor, remote_version);
1644	compat_banner(ssh, remote_version);
1645
1646	mismatch = 0;
1647	switch (remote_major) {
1648	case 2:
1649		break;
1650	case 1:
1651		if (remote_minor != 99)
1652			mismatch = 1;
1653		break;
1654	default:
1655		mismatch = 1;
1656		break;
1657	}
1658	if (mismatch) {
1659		error("Protocol major versions differ: %d vs. %d",
1660		    PROTOCOL_MAJOR_2, remote_major);
1661		send_error(ssh, "Protocol major versions differ.");
1662		r = SSH_ERR_NO_PROTOCOL_VERSION;
1663		goto out;
1664	}
1665
1666	if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1667		logit("probed from %s port %d with %s.  Don't panic.",
1668		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1669		    peer_version_string);
1670		r = SSH_ERR_CONN_CLOSED; /* XXX */
1671		goto out;
1672	}
1673	if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1674		logit("scanned from %s port %d with %s.  Don't panic.",
1675		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1676		    peer_version_string);
1677		r = SSH_ERR_CONN_CLOSED; /* XXX */
1678		goto out;
1679	}
1680	/* success */
1681	r = 0;
1682 out:
1683	free(our_version_string);
1684	free(peer_version_string);
1685	free(remote_version);
1686	if (r == SSH_ERR_SYSTEM_ERROR)
1687		errno = oerrno;
1688	return r;
1689}
1690
1691