1/*-
2 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3 *
4 * Copyright (c) 2009-2010 The FreeBSD Foundation
5 * Copyright (c) 2011 Pawel Jakub Dawidek <pawel@dawidek.net>
6 * All rights reserved.
7 *
8 * This software was developed by Pawel Jakub Dawidek under sponsorship from
9 * the FreeBSD Foundation.
10 *
11 * Redistribution and use in source and binary forms, with or without
12 * modification, are permitted provided that the following conditions
13 * are met:
14 * 1. Redistributions of source code must retain the above copyright
15 *    notice, this list of conditions and the following disclaimer.
16 * 2. Redistributions in binary form must reproduce the above copyright
17 *    notice, this list of conditions and the following disclaimer in the
18 *    documentation and/or other materials provided with the distribution.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30 * SUCH DAMAGE.
31 */
32
33#include <sys/cdefs.h>
34__FBSDID("$FreeBSD: stable/11/sbin/hastd/proto_tcp.c 330449 2018-03-05 07:26:05Z eadler $");
35
36#include <sys/param.h>	/* MAXHOSTNAMELEN */
37#include <sys/socket.h>
38
39#include <arpa/inet.h>
40
41#include <netinet/in.h>
42#include <netinet/tcp.h>
43
44#include <errno.h>
45#include <fcntl.h>
46#include <netdb.h>
47#include <stdbool.h>
48#include <stdint.h>
49#include <stdio.h>
50#include <string.h>
51#include <unistd.h>
52
53#include "pjdlog.h"
54#include "proto_impl.h"
55#include "subr.h"
56
57#define	TCP_CTX_MAGIC	0x7c41c
58struct tcp_ctx {
59	int			tc_magic;
60	struct sockaddr_storage	tc_sa;
61	int			tc_fd;
62	int			tc_side;
63#define	TCP_SIDE_CLIENT		0
64#define	TCP_SIDE_SERVER_LISTEN	1
65#define	TCP_SIDE_SERVER_WORK	2
66};
67
68static int tcp_connect_wait(void *ctx, int timeout);
69static void tcp_close(void *ctx);
70
71/*
72 * Function converts the given string to unsigned number.
73 */
74static int
75numfromstr(const char *str, intmax_t minnum, intmax_t maxnum, intmax_t *nump)
76{
77	intmax_t digit, num;
78
79	if (str[0] == '\0')
80		goto invalid;	/* Empty string. */
81	num = 0;
82	for (; *str != '\0'; str++) {
83		if (*str < '0' || *str > '9')
84			goto invalid;	/* Non-digit character. */
85		digit = *str - '0';
86		if (num > num * 10 + digit)
87			goto invalid;	/* Overflow. */
88		num = num * 10 + digit;
89		if (num > maxnum)
90			goto invalid;	/* Too big. */
91	}
92	if (num < minnum)
93		goto invalid;	/* Too small. */
94	*nump = num;
95	return (0);
96invalid:
97	errno = EINVAL;
98	return (-1);
99}
100
101static int
102tcp_addr(const char *addr, int defport, struct sockaddr_storage *sap)
103{
104	char iporhost[MAXHOSTNAMELEN], portstr[6];
105	struct addrinfo hints;
106	struct addrinfo *res;
107	const char *pp;
108	intmax_t port;
109	size_t size;
110	int error;
111
112	if (addr == NULL)
113		return (-1);
114
115	bzero(&hints, sizeof(hints));
116	hints.ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV;
117	hints.ai_family = PF_UNSPEC;
118	hints.ai_socktype = SOCK_STREAM;
119	hints.ai_protocol = IPPROTO_TCP;
120
121	if (strncasecmp(addr, "tcp4://", 7) == 0) {
122		addr += 7;
123		hints.ai_family = PF_INET;
124	} else if (strncasecmp(addr, "tcp6://", 7) == 0) {
125		addr += 7;
126		hints.ai_family = PF_INET6;
127	} else if (strncasecmp(addr, "tcp://", 6) == 0) {
128		addr += 6;
129	} else {
130		/*
131		 * Because TCP is the default assume IP or host is given without
132		 * prefix.
133		 */
134	}
135
136	/*
137	 * Extract optional port.
138	 * There are three cases to consider.
139	 * 1. hostname with port, eg. freefall.freebsd.org:8457
140	 * 2. IPv4 address with port, eg. 192.168.0.101:8457
141	 * 3. IPv6 address with port, eg. [fe80::1]:8457
142	 * We discover IPv6 address by checking for two colons and if port is
143	 * given, the address has to start with [.
144	 */
145	pp = NULL;
146	if (strchr(addr, ':') != strrchr(addr, ':')) {
147		if (addr[0] == '[')
148			pp = strrchr(addr, ':');
149	} else {
150		pp = strrchr(addr, ':');
151	}
152	if (pp == NULL) {
153		/* Port not given, use the default. */
154		port = defport;
155	} else {
156		if (numfromstr(pp + 1, 1, 65535, &port) == -1)
157			return (errno);
158	}
159	(void)snprintf(portstr, sizeof(portstr), "%jd", (intmax_t)port);
160	/* Extract host name or IP address. */
161	if (pp == NULL) {
162		size = sizeof(iporhost);
163		if (strlcpy(iporhost, addr, size) >= size)
164			return (ENAMETOOLONG);
165	} else if (addr[0] == '[' && pp[-1] == ']') {
166		size = (size_t)(pp - addr - 2 + 1);
167		if (size > sizeof(iporhost))
168			return (ENAMETOOLONG);
169		(void)strlcpy(iporhost, addr + 1, size);
170	} else {
171		size = (size_t)(pp - addr + 1);
172		if (size > sizeof(iporhost))
173			return (ENAMETOOLONG);
174		(void)strlcpy(iporhost, addr, size);
175	}
176
177	error = getaddrinfo(iporhost, portstr, &hints, &res);
178	if (error != 0) {
179		pjdlog_debug(1, "getaddrinfo(%s, %s) failed: %s.", iporhost,
180		    portstr, gai_strerror(error));
181		return (EINVAL);
182	}
183	if (res == NULL)
184		return (ENOENT);
185
186	memcpy(sap, res->ai_addr, res->ai_addrlen);
187
188	freeaddrinfo(res);
189
190	return (0);
191}
192
193static int
194tcp_setup_new(const char *addr, int side, void **ctxp)
195{
196	struct tcp_ctx *tctx;
197	int ret, nodelay;
198
199	PJDLOG_ASSERT(addr != NULL);
200	PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
201	    side == TCP_SIDE_SERVER_LISTEN);
202	PJDLOG_ASSERT(ctxp != NULL);
203
204	tctx = malloc(sizeof(*tctx));
205	if (tctx == NULL)
206		return (errno);
207
208	/* Parse given address. */
209	if ((ret = tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &tctx->tc_sa)) != 0) {
210		free(tctx);
211		return (ret);
212	}
213
214	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
215
216	tctx->tc_fd = socket(tctx->tc_sa.ss_family, SOCK_STREAM, 0);
217	if (tctx->tc_fd == -1) {
218		ret = errno;
219		free(tctx);
220		return (ret);
221	}
222
223	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
224
225	/* Socket settings. */
226	nodelay = 1;
227	if (setsockopt(tctx->tc_fd, IPPROTO_TCP, TCP_NODELAY, &nodelay,
228	    sizeof(nodelay)) == -1) {
229		pjdlog_errno(LOG_WARNING, "Unable to set TCP_NOELAY");
230	}
231
232	tctx->tc_side = side;
233	tctx->tc_magic = TCP_CTX_MAGIC;
234	*ctxp = tctx;
235
236	return (0);
237}
238
239static int
240tcp_setup_wrap(int fd, int side, void **ctxp)
241{
242	struct tcp_ctx *tctx;
243
244	PJDLOG_ASSERT(fd >= 0);
245	PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
246	    side == TCP_SIDE_SERVER_WORK);
247	PJDLOG_ASSERT(ctxp != NULL);
248
249	tctx = malloc(sizeof(*tctx));
250	if (tctx == NULL)
251		return (errno);
252
253	tctx->tc_fd = fd;
254	tctx->tc_sa.ss_family = AF_UNSPEC;
255	tctx->tc_side = side;
256	tctx->tc_magic = TCP_CTX_MAGIC;
257	*ctxp = tctx;
258
259	return (0);
260}
261
262static int
263tcp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
264{
265	struct tcp_ctx *tctx;
266	struct sockaddr_storage sa;
267	int ret;
268
269	ret = tcp_setup_new(dstaddr, TCP_SIDE_CLIENT, ctxp);
270	if (ret != 0)
271		return (ret);
272	tctx = *ctxp;
273	if (srcaddr == NULL)
274		return (0);
275	ret = tcp_addr(srcaddr, 0, &sa);
276	if (ret != 0) {
277		tcp_close(tctx);
278		return (ret);
279	}
280	if (bind(tctx->tc_fd, (struct sockaddr *)&sa, sa.ss_len) == -1) {
281		ret = errno;
282		tcp_close(tctx);
283		return (ret);
284	}
285	return (0);
286}
287
288static int
289tcp_connect(void *ctx, int timeout)
290{
291	struct tcp_ctx *tctx = ctx;
292	int error, flags;
293
294	PJDLOG_ASSERT(tctx != NULL);
295	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
296	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
297	PJDLOG_ASSERT(tctx->tc_fd >= 0);
298	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
299	PJDLOG_ASSERT(timeout >= -1);
300
301	flags = fcntl(tctx->tc_fd, F_GETFL);
302	if (flags == -1) {
303		pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
304		return (errno);
305	}
306	/*
307	 * We make socket non-blocking so we can handle connection timeout
308	 * manually.
309	 */
310	flags |= O_NONBLOCK;
311	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
312		pjdlog_common(LOG_DEBUG, 1, errno,
313		    "fcntl(F_SETFL, O_NONBLOCK) failed");
314		return (errno);
315	}
316
317	if (connect(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
318	    tctx->tc_sa.ss_len) == 0) {
319		if (timeout == -1)
320			return (0);
321		error = 0;
322		goto done;
323	}
324	if (errno != EINPROGRESS) {
325		error = errno;
326		pjdlog_common(LOG_DEBUG, 1, errno, "connect() failed");
327		goto done;
328	}
329	if (timeout == -1)
330		return (0);
331	return (tcp_connect_wait(ctx, timeout));
332done:
333	flags &= ~O_NONBLOCK;
334	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
335		if (error == 0)
336			error = errno;
337		pjdlog_common(LOG_DEBUG, 1, errno,
338		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
339	}
340	return (error);
341}
342
343static int
344tcp_connect_wait(void *ctx, int timeout)
345{
346	struct tcp_ctx *tctx = ctx;
347	struct timeval tv;
348	fd_set fdset;
349	socklen_t esize;
350	int error, flags, ret;
351
352	PJDLOG_ASSERT(tctx != NULL);
353	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
354	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
355	PJDLOG_ASSERT(tctx->tc_fd >= 0);
356	PJDLOG_ASSERT(timeout >= 0);
357
358	tv.tv_sec = timeout;
359	tv.tv_usec = 0;
360again:
361	FD_ZERO(&fdset);
362	FD_SET(tctx->tc_fd, &fdset);
363	ret = select(tctx->tc_fd + 1, NULL, &fdset, NULL, &tv);
364	if (ret == 0) {
365		error = ETIMEDOUT;
366		goto done;
367	} else if (ret == -1) {
368		if (errno == EINTR)
369			goto again;
370		error = errno;
371		pjdlog_common(LOG_DEBUG, 1, errno, "select() failed");
372		goto done;
373	}
374	PJDLOG_ASSERT(ret > 0);
375	PJDLOG_ASSERT(FD_ISSET(tctx->tc_fd, &fdset));
376	esize = sizeof(error);
377	if (getsockopt(tctx->tc_fd, SOL_SOCKET, SO_ERROR, &error,
378	    &esize) == -1) {
379		error = errno;
380		pjdlog_common(LOG_DEBUG, 1, errno,
381		    "getsockopt(SO_ERROR) failed");
382		goto done;
383	}
384	if (error != 0) {
385		pjdlog_common(LOG_DEBUG, 1, error,
386		    "getsockopt(SO_ERROR) returned error");
387		goto done;
388	}
389	error = 0;
390done:
391	flags = fcntl(tctx->tc_fd, F_GETFL);
392	if (flags == -1) {
393		if (error == 0)
394			error = errno;
395		pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
396		return (error);
397	}
398	flags &= ~O_NONBLOCK;
399	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
400		if (error == 0)
401			error = errno;
402		pjdlog_common(LOG_DEBUG, 1, errno,
403		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
404	}
405	return (error);
406}
407
408static int
409tcp_server(const char *addr, void **ctxp)
410{
411	struct tcp_ctx *tctx;
412	int ret, val;
413
414	ret = tcp_setup_new(addr, TCP_SIDE_SERVER_LISTEN, ctxp);
415	if (ret != 0)
416		return (ret);
417
418	tctx = *ctxp;
419
420	val = 1;
421	/* Ignore failure. */
422	(void)setsockopt(tctx->tc_fd, SOL_SOCKET, SO_REUSEADDR, &val,
423	   sizeof(val));
424
425	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
426
427	if (bind(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
428	    tctx->tc_sa.ss_len) == -1) {
429		ret = errno;
430		tcp_close(tctx);
431		return (ret);
432	}
433	if (listen(tctx->tc_fd, 8) == -1) {
434		ret = errno;
435		tcp_close(tctx);
436		return (ret);
437	}
438
439	return (0);
440}
441
442static int
443tcp_accept(void *ctx, void **newctxp)
444{
445	struct tcp_ctx *tctx = ctx;
446	struct tcp_ctx *newtctx;
447	socklen_t fromlen;
448	int ret;
449
450	PJDLOG_ASSERT(tctx != NULL);
451	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
452	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_SERVER_LISTEN);
453	PJDLOG_ASSERT(tctx->tc_fd >= 0);
454	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
455
456	newtctx = malloc(sizeof(*newtctx));
457	if (newtctx == NULL)
458		return (errno);
459
460	fromlen = tctx->tc_sa.ss_len;
461	newtctx->tc_fd = accept(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
462	    &fromlen);
463	if (newtctx->tc_fd == -1) {
464		ret = errno;
465		free(newtctx);
466		return (ret);
467	}
468
469	newtctx->tc_side = TCP_SIDE_SERVER_WORK;
470	newtctx->tc_magic = TCP_CTX_MAGIC;
471	*newctxp = newtctx;
472
473	return (0);
474}
475
476static int
477tcp_wrap(int fd, bool client, void **ctxp)
478{
479
480	return (tcp_setup_wrap(fd,
481	    client ? TCP_SIDE_CLIENT : TCP_SIDE_SERVER_WORK, ctxp));
482}
483
484static int
485tcp_send(void *ctx, const unsigned char *data, size_t size, int fd)
486{
487	struct tcp_ctx *tctx = ctx;
488
489	PJDLOG_ASSERT(tctx != NULL);
490	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
491	PJDLOG_ASSERT(tctx->tc_fd >= 0);
492	PJDLOG_ASSERT(fd == -1);
493
494	return (proto_common_send(tctx->tc_fd, data, size, -1));
495}
496
497static int
498tcp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
499{
500	struct tcp_ctx *tctx = ctx;
501
502	PJDLOG_ASSERT(tctx != NULL);
503	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
504	PJDLOG_ASSERT(tctx->tc_fd >= 0);
505	PJDLOG_ASSERT(fdp == NULL);
506
507	return (proto_common_recv(tctx->tc_fd, data, size, NULL));
508}
509
510static int
511tcp_descriptor(const void *ctx)
512{
513	const struct tcp_ctx *tctx = ctx;
514
515	PJDLOG_ASSERT(tctx != NULL);
516	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
517
518	return (tctx->tc_fd);
519}
520
521static bool
522tcp_address_match(const void *ctx, const char *addr)
523{
524	const struct tcp_ctx *tctx = ctx;
525	struct sockaddr_storage sa1, sa2;
526	socklen_t salen;
527
528	PJDLOG_ASSERT(tctx != NULL);
529	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
530
531	if (tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &sa1) != 0)
532		return (false);
533
534	salen = sizeof(sa2);
535	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa2, &salen) == -1)
536		return (false);
537
538	if (sa1.ss_family != sa2.ss_family || sa1.ss_len != sa2.ss_len)
539		return (false);
540
541	switch (sa1.ss_family) {
542	case AF_INET:
543	    {
544		struct sockaddr_in *sin1, *sin2;
545
546		sin1 = (struct sockaddr_in *)&sa1;
547		sin2 = (struct sockaddr_in *)&sa2;
548
549		return (memcmp(&sin1->sin_addr, &sin2->sin_addr,
550		    sizeof(sin1->sin_addr)) == 0);
551	    }
552	case AF_INET6:
553	    {
554		struct sockaddr_in6 *sin1, *sin2;
555
556		sin1 = (struct sockaddr_in6 *)&sa1;
557		sin2 = (struct sockaddr_in6 *)&sa2;
558
559		return (memcmp(&sin1->sin6_addr, &sin2->sin6_addr,
560		    sizeof(sin1->sin6_addr)) == 0);
561	    }
562	default:
563		return (false);
564	}
565}
566
567static void
568tcp_local_address(const void *ctx, char *addr, size_t size)
569{
570	const struct tcp_ctx *tctx = ctx;
571	struct sockaddr_storage sa;
572	socklen_t salen;
573
574	PJDLOG_ASSERT(tctx != NULL);
575	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
576
577	salen = sizeof(sa);
578	if (getsockname(tctx->tc_fd, (struct sockaddr *)&sa, &salen) == -1) {
579		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
580		return;
581	}
582	PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
583}
584
585static void
586tcp_remote_address(const void *ctx, char *addr, size_t size)
587{
588	const struct tcp_ctx *tctx = ctx;
589	struct sockaddr_storage sa;
590	socklen_t salen;
591
592	PJDLOG_ASSERT(tctx != NULL);
593	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
594
595	salen = sizeof(sa);
596	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa, &salen) == -1) {
597		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
598		return;
599	}
600	PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
601}
602
603static void
604tcp_close(void *ctx)
605{
606	struct tcp_ctx *tctx = ctx;
607
608	PJDLOG_ASSERT(tctx != NULL);
609	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
610
611	if (tctx->tc_fd >= 0)
612		close(tctx->tc_fd);
613	tctx->tc_magic = 0;
614	free(tctx);
615}
616
617static struct proto tcp_proto = {
618	.prt_name = "tcp",
619	.prt_client = tcp_client,
620	.prt_connect = tcp_connect,
621	.prt_connect_wait = tcp_connect_wait,
622	.prt_server = tcp_server,
623	.prt_accept = tcp_accept,
624	.prt_wrap = tcp_wrap,
625	.prt_send = tcp_send,
626	.prt_recv = tcp_recv,
627	.prt_descriptor = tcp_descriptor,
628	.prt_address_match = tcp_address_match,
629	.prt_local_address = tcp_local_address,
630	.prt_remote_address = tcp_remote_address,
631	.prt_close = tcp_close
632};
633
634static __constructor void
635tcp_ctor(void)
636{
637
638	proto_register(&tcp_proto, true);
639}
640