proto_tls.c revision 293161
1/*-
2 * Copyright (c) 2011 The FreeBSD Foundation
3 * All rights reserved.
4 *
5 * This software was developed by Pawel Jakub Dawidek under sponsorship from
6 * the FreeBSD Foundation.
7 *
8 * Redistribution and use in source and binary forms, with or without
9 * modification, are permitted provided that the following conditions
10 * are met:
11 * 1. Redistributions of source code must retain the above copyright
12 *    notice, this list of conditions and the following disclaimer.
13 * 2. Redistributions in binary form must reproduce the above copyright
14 *    notice, this list of conditions and the following disclaimer in the
15 *    documentation and/or other materials provided with the distribution.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27 * SUCH DAMAGE.
28 */
29
30#include <config/config.h>
31
32#include <sys/param.h>	/* MAXHOSTNAMELEN */
33#include <sys/socket.h>
34
35#include <arpa/inet.h>
36
37#include <netinet/in.h>
38#include <netinet/tcp.h>
39
40#include <errno.h>
41#include <fcntl.h>
42#include <netdb.h>
43#include <signal.h>
44#include <stdbool.h>
45#include <stdint.h>
46#include <stdio.h>
47#include <string.h>
48#include <unistd.h>
49
50#include <openssl/err.h>
51#include <openssl/ssl.h>
52
53#include <compat/compat.h>
54#ifndef HAVE_CLOSEFROM
55#include <compat/closefrom.h>
56#endif
57#ifndef HAVE_STRLCPY
58#include <compat/strlcpy.h>
59#endif
60
61#include "pjdlog.h"
62#include "proto_impl.h"
63#include "sandbox.h"
64#include "subr.h"
65
66#define	TLS_CTX_MAGIC	0x715c7
67struct tls_ctx {
68	int		tls_magic;
69	struct proto_conn *tls_sock;
70	struct proto_conn *tls_tcp;
71	char		tls_laddr[256];
72	char		tls_raddr[256];
73	int		tls_side;
74#define	TLS_SIDE_CLIENT		0
75#define	TLS_SIDE_SERVER_LISTEN	1
76#define	TLS_SIDE_SERVER_WORK	2
77	bool		tls_wait_called;
78};
79
80#define	TLS_DEFAULT_TIMEOUT	30
81
82static int tls_connect_wait(void *ctx, int timeout);
83static void tls_close(void *ctx);
84
85static void
86block(int fd)
87{
88	int flags;
89
90	flags = fcntl(fd, F_GETFL);
91	if (flags == -1)
92		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
93	flags &= ~O_NONBLOCK;
94	if (fcntl(fd, F_SETFL, flags) == -1)
95		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
96}
97
98static void
99nonblock(int fd)
100{
101	int flags;
102
103	flags = fcntl(fd, F_GETFL);
104	if (flags == -1)
105		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
106	flags |= O_NONBLOCK;
107	if (fcntl(fd, F_SETFL, flags) == -1)
108		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
109}
110
111static int
112wait_for_fd(int fd, int timeout)
113{
114	struct timeval tv;
115	fd_set fdset;
116	int error, ret;
117
118	error = 0;
119
120	for (;;) {
121		FD_ZERO(&fdset);
122		FD_SET(fd, &fdset);
123
124		tv.tv_sec = timeout;
125		tv.tv_usec = 0;
126
127		ret = select(fd + 1, NULL, &fdset, NULL,
128		    timeout == -1 ? NULL : &tv);
129		if (ret == 0) {
130			error = ETIMEDOUT;
131			break;
132		} else if (ret == -1) {
133			if (errno == EINTR)
134				continue;
135			error = errno;
136			break;
137		}
138		PJDLOG_ASSERT(ret > 0);
139		PJDLOG_ASSERT(FD_ISSET(fd, &fdset));
140		break;
141	}
142
143	return (error);
144}
145
146static void
147ssl_log_errors(void)
148{
149	unsigned long error;
150
151	while ((error = ERR_get_error()) != 0)
152		pjdlog_error("SSL error: %s", ERR_error_string(error, NULL));
153}
154
155static int
156ssl_check_error(SSL *ssl, int ret)
157{
158	int error;
159
160	error = SSL_get_error(ssl, ret);
161
162	switch (error) {
163	case SSL_ERROR_NONE:
164		return (0);
165	case SSL_ERROR_WANT_READ:
166		pjdlog_debug(2, "SSL_ERROR_WANT_READ");
167		return (-1);
168	case SSL_ERROR_WANT_WRITE:
169		pjdlog_debug(2, "SSL_ERROR_WANT_WRITE");
170		return (-1);
171	case SSL_ERROR_ZERO_RETURN:
172		pjdlog_exitx(EX_OK, "Connection closed.");
173	case SSL_ERROR_SYSCALL:
174		ssl_log_errors();
175		pjdlog_exitx(EX_TEMPFAIL, "SSL I/O error.");
176	case SSL_ERROR_SSL:
177		ssl_log_errors();
178		pjdlog_exitx(EX_TEMPFAIL, "SSL protocol error.");
179	default:
180		ssl_log_errors();
181		pjdlog_exitx(EX_TEMPFAIL, "Unknown SSL error (%d).", error);
182	}
183}
184
185static void
186tcp_recv_ssl_send(int recvfd, SSL *sendssl)
187{
188	static unsigned char buf[65536];
189	ssize_t tcpdone;
190	int sendfd, ssldone;
191
192	sendfd = SSL_get_fd(sendssl);
193	PJDLOG_ASSERT(sendfd >= 0);
194	pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
195	for (;;) {
196		tcpdone = recv(recvfd, buf, sizeof(buf), 0);
197		pjdlog_debug(2, "%s: recv() returned %zd", __func__, tcpdone);
198		if (tcpdone == 0) {
199			pjdlog_debug(1, "Connection terminated.");
200			exit(0);
201		} else if (tcpdone == -1) {
202			if (errno == EINTR)
203				continue;
204			else if (errno == EAGAIN)
205				break;
206			pjdlog_exit(EX_TEMPFAIL, "recv() failed");
207		}
208		for (;;) {
209			ssldone = SSL_write(sendssl, buf, (int)tcpdone);
210			pjdlog_debug(2, "%s: send() returned %d", __func__,
211			    ssldone);
212			if (ssl_check_error(sendssl, ssldone) == -1) {
213				(void)wait_for_fd(sendfd, -1);
214				continue;
215			}
216			PJDLOG_ASSERT(ssldone == tcpdone);
217			break;
218		}
219	}
220	pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
221}
222
223static void
224ssl_recv_tcp_send(SSL *recvssl, int sendfd)
225{
226	static unsigned char buf[65536];
227	unsigned char *ptr;
228	ssize_t tcpdone;
229	size_t todo;
230	int recvfd, ssldone;
231
232	recvfd = SSL_get_fd(recvssl);
233	PJDLOG_ASSERT(recvfd >= 0);
234	pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
235	for (;;) {
236		ssldone = SSL_read(recvssl, buf, sizeof(buf));
237		pjdlog_debug(2, "%s: SSL_read() returned %d", __func__,
238		    ssldone);
239		if (ssl_check_error(recvssl, ssldone) == -1)
240			break;
241		todo = (size_t)ssldone;
242		ptr = buf;
243		do {
244			tcpdone = send(sendfd, ptr, todo, MSG_NOSIGNAL);
245			pjdlog_debug(2, "%s: send() returned %zd", __func__,
246			    tcpdone);
247			if (tcpdone == 0) {
248				pjdlog_debug(1, "Connection terminated.");
249				exit(0);
250			} else if (tcpdone == -1) {
251				if (errno == EINTR || errno == ENOBUFS)
252					continue;
253				if (errno == EAGAIN) {
254					(void)wait_for_fd(sendfd, -1);
255					continue;
256				}
257				pjdlog_exit(EX_TEMPFAIL, "send() failed");
258			}
259			todo -= tcpdone;
260			ptr += tcpdone;
261		} while (todo > 0);
262	}
263	pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
264}
265
266static void
267tls_loop(int sockfd, SSL *tcpssl)
268{
269	fd_set fds;
270	int maxfd, tcpfd;
271
272	tcpfd = SSL_get_fd(tcpssl);
273	PJDLOG_ASSERT(tcpfd >= 0);
274
275	for (;;) {
276		FD_ZERO(&fds);
277		FD_SET(sockfd, &fds);
278		FD_SET(tcpfd, &fds);
279		maxfd = MAX(sockfd, tcpfd);
280
281		PJDLOG_ASSERT(maxfd + 1 <= (int)FD_SETSIZE);
282		if (select(maxfd + 1, &fds, NULL, NULL, NULL) == -1) {
283			if (errno == EINTR)
284				continue;
285			pjdlog_exit(EX_TEMPFAIL, "select() failed");
286		}
287		if (FD_ISSET(sockfd, &fds))
288			tcp_recv_ssl_send(sockfd, tcpssl);
289		if (FD_ISSET(tcpfd, &fds))
290			ssl_recv_tcp_send(tcpssl, sockfd);
291	}
292}
293
294static void
295tls_certificate_verify(SSL *ssl, const char *fingerprint)
296{
297	unsigned char md[EVP_MAX_MD_SIZE];
298	char mdstr[sizeof("SHA256=") - 1 + EVP_MAX_MD_SIZE * 3];
299	char *mdstrp;
300	unsigned int i, mdsize;
301	X509 *cert;
302
303	if (fingerprint[0] == '\0') {
304		pjdlog_debug(1, "No fingerprint verification requested.");
305		return;
306	}
307
308	cert = SSL_get_peer_certificate(ssl);
309	if (cert == NULL)
310		pjdlog_exitx(EX_TEMPFAIL, "No peer certificate received.");
311
312	if (X509_digest(cert, EVP_sha256(), md, &mdsize) != 1)
313		pjdlog_exitx(EX_TEMPFAIL, "X509_digest() failed.");
314	PJDLOG_ASSERT(mdsize <= EVP_MAX_MD_SIZE);
315
316	X509_free(cert);
317
318	(void)strlcpy(mdstr, "SHA256=", sizeof(mdstr));
319	mdstrp = mdstr + strlen(mdstr);
320	for (i = 0; i < mdsize; i++) {
321		PJDLOG_VERIFY(mdstrp + 3 <= mdstr + sizeof(mdstr));
322		(void)sprintf(mdstrp, "%02hhX:", md[i]);
323		mdstrp += 3;
324	}
325	/* Clear last colon. */
326	mdstrp[-1] = '\0';
327	if (strcasecmp(mdstr, fingerprint) != 0) {
328		pjdlog_exitx(EX_NOPERM,
329		    "Finger print doesn't match. Received \"%s\", expected \"%s\"",
330		    mdstr, fingerprint);
331	}
332}
333
334static void
335tls_exec_client(const char *user, int startfd, const char *srcaddr,
336    const char *dstaddr, const char *fingerprint, const char *defport,
337    int timeout, int debuglevel)
338{
339	struct proto_conn *tcp;
340	char *saddr, *daddr;
341	SSL_CTX *sslctx;
342	SSL *ssl;
343	long ret;
344	int sockfd, tcpfd;
345	uint8_t connected;
346
347	pjdlog_debug_set(debuglevel);
348	pjdlog_prefix_set("[TLS sandbox] (client) ");
349#ifdef HAVE_SETPROCTITLE
350	setproctitle("[TLS sandbox] (client) ");
351#endif
352	proto_set("tcp:port", defport);
353
354	sockfd = startfd;
355
356	/* Change tls:// to tcp://. */
357	if (srcaddr == NULL) {
358		saddr = NULL;
359	} else {
360		saddr = strdup(srcaddr);
361		if (saddr == NULL)
362			pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
363		bcopy("tcp://", saddr, 6);
364	}
365	daddr = strdup(dstaddr);
366	if (daddr == NULL)
367		pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
368	bcopy("tcp://", daddr, 6);
369
370	/* Establish TCP connection. */
371	if (proto_connect(saddr, daddr, timeout, &tcp) == -1)
372		exit(EX_TEMPFAIL);
373
374	SSL_load_error_strings();
375	SSL_library_init();
376
377	/*
378	 * TODO: On FreeBSD we could move this below sandbox() once libc and
379	 *       libcrypto use sysctl kern.arandom to obtain random data
380	 *       instead of /dev/urandom and friends.
381	 */
382	sslctx = SSL_CTX_new(TLSv1_client_method());
383	if (sslctx == NULL)
384		pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
385
386	if (sandbox(user, true, "proto_tls client: %s", dstaddr) != 0)
387		pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS client.");
388	pjdlog_debug(1, "Privileges successfully dropped.");
389
390	SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
391
392	/* Load CA certs. */
393	/* TODO */
394	//SSL_CTX_load_verify_locations(sslctx, cacerts_file, NULL);
395
396	ssl = SSL_new(sslctx);
397	if (ssl == NULL)
398		pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
399
400	tcpfd = proto_descriptor(tcp);
401
402	block(tcpfd);
403
404	if (SSL_set_fd(ssl, tcpfd) != 1)
405		pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
406
407	ret = SSL_connect(ssl);
408	ssl_check_error(ssl, (int)ret);
409
410	nonblock(sockfd);
411	nonblock(tcpfd);
412
413	tls_certificate_verify(ssl, fingerprint);
414
415	/*
416	 * The following byte is send to make proto_connect_wait() to work.
417	 */
418	connected = 1;
419	for (;;) {
420		switch (send(sockfd, &connected, sizeof(connected), 0)) {
421		case -1:
422			if (errno == EINTR || errno == ENOBUFS)
423				continue;
424			if (errno == EAGAIN) {
425				(void)wait_for_fd(sockfd, -1);
426				continue;
427			}
428			pjdlog_exit(EX_TEMPFAIL, "send() failed");
429		case 0:
430			pjdlog_debug(1, "Connection terminated.");
431			exit(0);
432		case 1:
433			break;
434		}
435		break;
436	}
437
438	tls_loop(sockfd, ssl);
439}
440
441static void
442tls_call_exec_client(struct proto_conn *sock, const char *srcaddr,
443    const char *dstaddr, int timeout)
444{
445	char *timeoutstr, *startfdstr, *debugstr;
446	int startfd;
447
448	/* Declare that we are receiver. */
449	proto_recv(sock, NULL, 0);
450
451	if (pjdlog_mode_get() == PJDLOG_MODE_STD)
452		startfd = 3;
453	else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
454		startfd = 0;
455
456	if (proto_descriptor(sock) != startfd) {
457		/* Move socketpair descriptor to descriptor number startfd. */
458		if (dup2(proto_descriptor(sock), startfd) == -1)
459			pjdlog_exit(EX_OSERR, "dup2() failed");
460		proto_close(sock);
461	} else {
462		/*
463		 * The FD_CLOEXEC is cleared by dup2(2), so when we not
464		 * call it, we have to clear it by hand in case it is set.
465		 */
466		if (fcntl(startfd, F_SETFD, 0) == -1)
467			pjdlog_exit(EX_OSERR, "fcntl() failed");
468	}
469
470	closefrom(startfd + 1);
471
472	if (asprintf(&startfdstr, "%d", startfd) == -1)
473		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
474	if (timeout == -1)
475		timeout = TLS_DEFAULT_TIMEOUT;
476	if (asprintf(&timeoutstr, "%d", timeout) == -1)
477		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
478	if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
479		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
480
481	execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
482	    proto_get("user"), "client", startfdstr,
483	    srcaddr == NULL ? "" : srcaddr, dstaddr,
484	    proto_get("tls:fingerprint"), proto_get("tcp:port"), timeoutstr,
485	    debugstr, NULL);
486	pjdlog_exit(EX_SOFTWARE, "execl() failed");
487}
488
489static int
490tls_connect(const char *srcaddr, const char *dstaddr, int timeout, void **ctxp)
491{
492	struct tls_ctx *tlsctx;
493	struct proto_conn *sock;
494	pid_t pid;
495	int error;
496
497	PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0');
498	PJDLOG_ASSERT(dstaddr != NULL);
499	PJDLOG_ASSERT(timeout >= -1);
500	PJDLOG_ASSERT(ctxp != NULL);
501
502	if (strncmp(dstaddr, "tls://", 6) != 0)
503		return (-1);
504	if (srcaddr != NULL && strncmp(srcaddr, "tls://", 6) != 0)
505		return (-1);
506
507	if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
508		return (errno);
509
510#if 0
511	/*
512	 * We use rfork() with the following flags to disable SIGCHLD
513	 * delivery upon the sandbox process exit.
514	 */
515	pid = rfork(RFFDG | RFPROC | RFTSIGZMB | RFTSIGFLAGS(0));
516#else
517	/*
518	 * We don't use rfork() to be able to log information about sandbox
519	 * process exiting.
520	 */
521	pid = fork();
522#endif
523	switch (pid) {
524	case -1:
525		/* Failure. */
526		error = errno;
527		proto_close(sock);
528		return (error);
529	case 0:
530		/* Child. */
531		pjdlog_prefix_set("[TLS sandbox] (client) ");
532#ifdef HAVE_SETPROCTITLE
533		setproctitle("[TLS sandbox] (client) ");
534#endif
535		tls_call_exec_client(sock, srcaddr, dstaddr, timeout);
536		/* NOTREACHED */
537	default:
538		/* Parent. */
539		tlsctx = calloc(1, sizeof(*tlsctx));
540		if (tlsctx == NULL) {
541			error = errno;
542			proto_close(sock);
543			(void)kill(pid, SIGKILL);
544			return (error);
545		}
546		proto_send(sock, NULL, 0);
547		tlsctx->tls_sock = sock;
548		tlsctx->tls_tcp = NULL;
549		tlsctx->tls_side = TLS_SIDE_CLIENT;
550		tlsctx->tls_wait_called = false;
551		tlsctx->tls_magic = TLS_CTX_MAGIC;
552		if (timeout >= 0) {
553			error = tls_connect_wait(tlsctx, timeout);
554			if (error != 0) {
555				(void)kill(pid, SIGKILL);
556				tls_close(tlsctx);
557				return (error);
558			}
559		}
560		*ctxp = tlsctx;
561		return (0);
562	}
563}
564
565static int
566tls_connect_wait(void *ctx, int timeout)
567{
568	struct tls_ctx *tlsctx = ctx;
569	int error, sockfd;
570	uint8_t connected;
571
572	PJDLOG_ASSERT(tlsctx != NULL);
573	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
574	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT);
575	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
576	PJDLOG_ASSERT(!tlsctx->tls_wait_called);
577	PJDLOG_ASSERT(timeout >= 0);
578
579	sockfd = proto_descriptor(tlsctx->tls_sock);
580	error = wait_for_fd(sockfd, timeout);
581	if (error != 0)
582		return (error);
583
584	for (;;) {
585		switch (recv(sockfd, &connected, sizeof(connected),
586		    MSG_WAITALL)) {
587		case -1:
588			if (errno == EINTR || errno == ENOBUFS)
589				continue;
590			error = errno;
591			break;
592		case 0:
593			pjdlog_debug(1, "Connection terminated.");
594			error = ENOTCONN;
595			break;
596		case 1:
597			tlsctx->tls_wait_called = true;
598			break;
599		}
600		break;
601	}
602
603	return (error);
604}
605
606static int
607tls_server(const char *lstaddr, void **ctxp)
608{
609	struct proto_conn *tcp;
610	struct tls_ctx *tlsctx;
611	char *laddr;
612	int error;
613
614	if (strncmp(lstaddr, "tls://", 6) != 0)
615		return (-1);
616
617	tlsctx = malloc(sizeof(*tlsctx));
618	if (tlsctx == NULL) {
619		pjdlog_warning("Unable to allocate memory.");
620		return (ENOMEM);
621	}
622
623	laddr = strdup(lstaddr);
624	if (laddr == NULL) {
625		free(tlsctx);
626		pjdlog_warning("Unable to allocate memory.");
627		return (ENOMEM);
628	}
629	bcopy("tcp://", laddr, 6);
630
631	if (proto_server(laddr, &tcp) == -1) {
632		error = errno;
633		free(tlsctx);
634		free(laddr);
635		return (error);
636	}
637	free(laddr);
638
639	tlsctx->tls_sock = NULL;
640	tlsctx->tls_tcp = tcp;
641	tlsctx->tls_side = TLS_SIDE_SERVER_LISTEN;
642	tlsctx->tls_wait_called = true;
643	tlsctx->tls_magic = TLS_CTX_MAGIC;
644	*ctxp = tlsctx;
645
646	return (0);
647}
648
649static void
650tls_exec_server(const char *user, int startfd, const char *privkey,
651    const char *cert, int debuglevel)
652{
653	SSL_CTX *sslctx;
654	SSL *ssl;
655	int sockfd, tcpfd, ret;
656
657	pjdlog_debug_set(debuglevel);
658	pjdlog_prefix_set("[TLS sandbox] (server) ");
659#ifdef HAVE_SETPROCTITLE
660	setproctitle("[TLS sandbox] (server) ");
661#endif
662
663	sockfd = startfd;
664	tcpfd = startfd + 1;
665
666	SSL_load_error_strings();
667	SSL_library_init();
668
669	sslctx = SSL_CTX_new(TLSv1_server_method());
670	if (sslctx == NULL)
671		pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
672
673	SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
674
675	ssl = SSL_new(sslctx);
676	if (ssl == NULL)
677		pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
678
679	if (SSL_use_RSAPrivateKey_file(ssl, privkey, SSL_FILETYPE_PEM) != 1) {
680		ssl_log_errors();
681		pjdlog_exitx(EX_CONFIG,
682		    "SSL_use_RSAPrivateKey_file(%s) failed.", privkey);
683	}
684
685	if (SSL_use_certificate_file(ssl, cert, SSL_FILETYPE_PEM) != 1) {
686		ssl_log_errors();
687		pjdlog_exitx(EX_CONFIG, "SSL_use_certificate_file(%s) failed.",
688		    cert);
689	}
690
691	if (sandbox(user, true, "proto_tls server") != 0)
692		pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS server.");
693	pjdlog_debug(1, "Privileges successfully dropped.");
694
695	nonblock(sockfd);
696	nonblock(tcpfd);
697
698	if (SSL_set_fd(ssl, tcpfd) != 1)
699		pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
700
701	ret = SSL_accept(ssl);
702	ssl_check_error(ssl, ret);
703
704	tls_loop(sockfd, ssl);
705}
706
707static void
708tls_call_exec_server(struct proto_conn *sock, struct proto_conn *tcp)
709{
710	int startfd, sockfd, tcpfd, safefd;
711	char *startfdstr, *debugstr;
712
713	if (pjdlog_mode_get() == PJDLOG_MODE_STD)
714		startfd = 3;
715	else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
716		startfd = 0;
717
718	/* Declare that we are receiver. */
719	proto_send(sock, NULL, 0);
720
721	sockfd = proto_descriptor(sock);
722	tcpfd = proto_descriptor(tcp);
723
724	safefd = MAX(sockfd, tcpfd);
725	safefd = MAX(safefd, startfd);
726	safefd++;
727
728	/* Move sockfd and tcpfd to safe numbers first. */
729	if (dup2(sockfd, safefd) == -1)
730		pjdlog_exit(EX_OSERR, "dup2() failed");
731	proto_close(sock);
732	sockfd = safefd;
733	if (dup2(tcpfd, safefd + 1) == -1)
734		pjdlog_exit(EX_OSERR, "dup2() failed");
735	proto_close(tcp);
736	tcpfd = safefd + 1;
737
738	/* Move socketpair descriptor to descriptor number startfd. */
739	if (dup2(sockfd, startfd) == -1)
740		pjdlog_exit(EX_OSERR, "dup2() failed");
741	(void)close(sockfd);
742	/* Move tcp descriptor to descriptor number startfd + 1. */
743	if (dup2(tcpfd, startfd + 1) == -1)
744		pjdlog_exit(EX_OSERR, "dup2() failed");
745	(void)close(tcpfd);
746
747	closefrom(startfd + 2);
748
749	/*
750	 * Even if FD_CLOEXEC was set on descriptors before dup2(), it should
751	 * have been cleared on dup2(), but better be safe than sorry.
752	 */
753	if (fcntl(startfd, F_SETFD, 0) == -1)
754		pjdlog_exit(EX_OSERR, "fcntl() failed");
755	if (fcntl(startfd + 1, F_SETFD, 0) == -1)
756		pjdlog_exit(EX_OSERR, "fcntl() failed");
757
758	if (asprintf(&startfdstr, "%d", startfd) == -1)
759		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
760	if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
761		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
762
763	execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
764	    proto_get("user"), "server", startfdstr, proto_get("tls:keyfile"),
765	    proto_get("tls:certfile"), debugstr, NULL);
766	pjdlog_exit(EX_SOFTWARE, "execl() failed");
767}
768
769static int
770tls_accept(void *ctx, void **newctxp)
771{
772	struct tls_ctx *tlsctx = ctx;
773	struct tls_ctx *newtlsctx;
774	struct proto_conn *sock, *tcp;
775	pid_t pid;
776	int error;
777
778	PJDLOG_ASSERT(tlsctx != NULL);
779	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
780	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_SERVER_LISTEN);
781
782	if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
783		return (errno);
784
785	/* Accept TCP connection. */
786	if (proto_accept(tlsctx->tls_tcp, &tcp) == -1) {
787		error = errno;
788		proto_close(sock);
789		return (error);
790	}
791
792	pid = fork();
793	switch (pid) {
794	case -1:
795		/* Failure. */
796		error = errno;
797		proto_close(sock);
798		return (error);
799	case 0:
800		/* Child. */
801		pjdlog_prefix_set("[TLS sandbox] (server) ");
802#ifdef HAVE_SETPROCTITLE
803		setproctitle("[TLS sandbox] (server) ");
804#endif
805		/* Close listen socket. */
806		proto_close(tlsctx->tls_tcp);
807		tls_call_exec_server(sock, tcp);
808		/* NOTREACHED */
809		PJDLOG_ABORT("Unreachable.");
810	default:
811		/* Parent. */
812		newtlsctx = calloc(1, sizeof(*tlsctx));
813		if (newtlsctx == NULL) {
814			error = errno;
815			proto_close(sock);
816			proto_close(tcp);
817			(void)kill(pid, SIGKILL);
818			return (error);
819		}
820		proto_local_address(tcp, newtlsctx->tls_laddr,
821		    sizeof(newtlsctx->tls_laddr));
822		PJDLOG_ASSERT(strncmp(newtlsctx->tls_laddr, "tcp://", 6) == 0);
823		bcopy("tls://", newtlsctx->tls_laddr, 6);
824		*strrchr(newtlsctx->tls_laddr, ':') = '\0';
825		proto_remote_address(tcp, newtlsctx->tls_raddr,
826		    sizeof(newtlsctx->tls_raddr));
827		PJDLOG_ASSERT(strncmp(newtlsctx->tls_raddr, "tcp://", 6) == 0);
828		bcopy("tls://", newtlsctx->tls_raddr, 6);
829		*strrchr(newtlsctx->tls_raddr, ':') = '\0';
830		proto_close(tcp);
831		proto_recv(sock, NULL, 0);
832		newtlsctx->tls_sock = sock;
833		newtlsctx->tls_tcp = NULL;
834		newtlsctx->tls_wait_called = true;
835		newtlsctx->tls_side = TLS_SIDE_SERVER_WORK;
836		newtlsctx->tls_magic = TLS_CTX_MAGIC;
837		*newctxp = newtlsctx;
838		return (0);
839	}
840}
841
842static int
843tls_wrap(int fd, bool client, void **ctxp)
844{
845	struct tls_ctx *tlsctx;
846	struct proto_conn *sock;
847	int error;
848
849	tlsctx = calloc(1, sizeof(*tlsctx));
850	if (tlsctx == NULL)
851		return (errno);
852
853	if (proto_wrap("socketpair", client, fd, &sock) == -1) {
854		error = errno;
855		free(tlsctx);
856		return (error);
857	}
858
859	tlsctx->tls_sock = sock;
860	tlsctx->tls_tcp = NULL;
861	tlsctx->tls_wait_called = (client ? false : true);
862	tlsctx->tls_side = (client ? TLS_SIDE_CLIENT : TLS_SIDE_SERVER_WORK);
863	tlsctx->tls_magic = TLS_CTX_MAGIC;
864	*ctxp = tlsctx;
865
866	return (0);
867}
868
869static int
870tls_send(void *ctx, const unsigned char *data, size_t size, int fd)
871{
872	struct tls_ctx *tlsctx = ctx;
873
874	PJDLOG_ASSERT(tlsctx != NULL);
875	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
876	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
877	    tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
878	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
879	PJDLOG_ASSERT(tlsctx->tls_wait_called);
880	PJDLOG_ASSERT(fd == -1);
881
882	if (proto_send(tlsctx->tls_sock, data, size) == -1)
883		return (errno);
884
885	return (0);
886}
887
888static int
889tls_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
890{
891	struct tls_ctx *tlsctx = ctx;
892
893	PJDLOG_ASSERT(tlsctx != NULL);
894	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
895	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
896	    tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
897	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
898	PJDLOG_ASSERT(tlsctx->tls_wait_called);
899	PJDLOG_ASSERT(fdp == NULL);
900
901	if (proto_recv(tlsctx->tls_sock, data, size) == -1)
902		return (errno);
903
904	return (0);
905}
906
907static int
908tls_descriptor(const void *ctx)
909{
910	const struct tls_ctx *tlsctx = ctx;
911
912	PJDLOG_ASSERT(tlsctx != NULL);
913	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
914
915	switch (tlsctx->tls_side) {
916	case TLS_SIDE_CLIENT:
917	case TLS_SIDE_SERVER_WORK:
918		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
919
920		return (proto_descriptor(tlsctx->tls_sock));
921	case TLS_SIDE_SERVER_LISTEN:
922		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
923
924		return (proto_descriptor(tlsctx->tls_tcp));
925	default:
926		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
927	}
928}
929
930static bool
931tcp_address_match(const void *ctx, const char *addr)
932{
933	const struct tls_ctx *tlsctx = ctx;
934
935	PJDLOG_ASSERT(tlsctx != NULL);
936	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
937
938	return (strcmp(tlsctx->tls_raddr, addr) == 0);
939}
940
941static void
942tls_local_address(const void *ctx, char *addr, size_t size)
943{
944	const struct tls_ctx *tlsctx = ctx;
945
946	PJDLOG_ASSERT(tlsctx != NULL);
947	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
948	PJDLOG_ASSERT(tlsctx->tls_wait_called);
949
950	switch (tlsctx->tls_side) {
951	case TLS_SIDE_CLIENT:
952		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
953
954		PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
955		break;
956	case TLS_SIDE_SERVER_WORK:
957		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
958
959		PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_laddr, size) < size);
960		break;
961	case TLS_SIDE_SERVER_LISTEN:
962		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
963
964		proto_local_address(tlsctx->tls_tcp, addr, size);
965		PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
966		/* Replace tcp:// prefix with tls:// */
967		bcopy("tls://", addr, 6);
968		break;
969	default:
970		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
971	}
972}
973
974static void
975tls_remote_address(const void *ctx, char *addr, size_t size)
976{
977	const struct tls_ctx *tlsctx = ctx;
978
979	PJDLOG_ASSERT(tlsctx != NULL);
980	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
981	PJDLOG_ASSERT(tlsctx->tls_wait_called);
982
983	switch (tlsctx->tls_side) {
984	case TLS_SIDE_CLIENT:
985		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
986
987		PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
988		break;
989	case TLS_SIDE_SERVER_WORK:
990		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
991
992		PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_raddr, size) < size);
993		break;
994	case TLS_SIDE_SERVER_LISTEN:
995		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
996
997		proto_remote_address(tlsctx->tls_tcp, addr, size);
998		PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
999		/* Replace tcp:// prefix with tls:// */
1000		bcopy("tls://", addr, 6);
1001		break;
1002	default:
1003		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
1004	}
1005}
1006
1007static void
1008tls_close(void *ctx)
1009{
1010	struct tls_ctx *tlsctx = ctx;
1011
1012	PJDLOG_ASSERT(tlsctx != NULL);
1013	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
1014
1015	if (tlsctx->tls_sock != NULL) {
1016		proto_close(tlsctx->tls_sock);
1017		tlsctx->tls_sock = NULL;
1018	}
1019	if (tlsctx->tls_tcp != NULL) {
1020		proto_close(tlsctx->tls_tcp);
1021		tlsctx->tls_tcp = NULL;
1022	}
1023	tlsctx->tls_side = 0;
1024	tlsctx->tls_magic = 0;
1025	free(tlsctx);
1026}
1027
1028static int
1029tls_exec(int argc, char *argv[])
1030{
1031
1032	PJDLOG_ASSERT(argc > 3);
1033	PJDLOG_ASSERT(strcmp(argv[0], "tls") == 0);
1034
1035	pjdlog_init(atoi(argv[3]) == 0 ? PJDLOG_MODE_SYSLOG : PJDLOG_MODE_STD);
1036
1037	if (strcmp(argv[2], "client") == 0) {
1038		if (argc != 10)
1039			return (EINVAL);
1040		tls_exec_client(argv[1], atoi(argv[3]),
1041		    argv[4][0] == '\0' ? NULL : argv[4], argv[5], argv[6],
1042		    argv[7], atoi(argv[8]), atoi(argv[9]));
1043	} else if (strcmp(argv[2], "server") == 0) {
1044		if (argc != 7)
1045			return (EINVAL);
1046		tls_exec_server(argv[1], atoi(argv[3]), argv[4], argv[5],
1047		    atoi(argv[6]));
1048	}
1049	return (EINVAL);
1050}
1051
1052static struct proto tls_proto = {
1053	.prt_name = "tls",
1054	.prt_connect = tls_connect,
1055	.prt_connect_wait = tls_connect_wait,
1056	.prt_server = tls_server,
1057	.prt_accept = tls_accept,
1058	.prt_wrap = tls_wrap,
1059	.prt_send = tls_send,
1060	.prt_recv = tls_recv,
1061	.prt_descriptor = tls_descriptor,
1062	.prt_address_match = tcp_address_match,
1063	.prt_local_address = tls_local_address,
1064	.prt_remote_address = tls_remote_address,
1065	.prt_close = tls_close,
1066	.prt_exec = tls_exec
1067};
1068
1069static __constructor void
1070tls_ctor(void)
1071{
1072
1073	proto_register(&tls_proto, false);
1074}
1075