1/*-
2 * SPDX-License-Identifier: BSD-2-Clause
3 *
4 * Copyright (c) 2009-2010 The FreeBSD Foundation
5 * Copyright (c) 2011 Pawel Jakub Dawidek <pawel@dawidek.net>
6 * All rights reserved.
7 *
8 * This software was developed by Pawel Jakub Dawidek under sponsorship from
9 * the FreeBSD Foundation.
10 *
11 * Redistribution and use in source and binary forms, with or without
12 * modification, are permitted provided that the following conditions
13 * are met:
14 * 1. Redistributions of source code must retain the above copyright
15 *    notice, this list of conditions and the following disclaimer.
16 * 2. Redistributions in binary form must reproduce the above copyright
17 *    notice, this list of conditions and the following disclaimer in the
18 *    documentation and/or other materials provided with the distribution.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30 * SUCH DAMAGE.
31 */
32
33#include <sys/param.h>	/* MAXHOSTNAMELEN */
34#include <sys/socket.h>
35
36#include <arpa/inet.h>
37
38#include <netinet/in.h>
39#include <netinet/tcp.h>
40
41#include <errno.h>
42#include <fcntl.h>
43#include <netdb.h>
44#include <stdbool.h>
45#include <stdint.h>
46#include <stdio.h>
47#include <string.h>
48#include <unistd.h>
49
50#include "pjdlog.h"
51#include "proto_impl.h"
52#include "subr.h"
53
54#define	TCP_CTX_MAGIC	0x7c41c
55struct tcp_ctx {
56	int			tc_magic;
57	struct sockaddr_storage	tc_sa;
58	int			tc_fd;
59	int			tc_side;
60#define	TCP_SIDE_CLIENT		0
61#define	TCP_SIDE_SERVER_LISTEN	1
62#define	TCP_SIDE_SERVER_WORK	2
63};
64
65static int tcp_connect_wait(void *ctx, int timeout);
66static void tcp_close(void *ctx);
67
68/*
69 * Function converts the given string to unsigned number.
70 */
71static int
72numfromstr(const char *str, intmax_t minnum, intmax_t maxnum, intmax_t *nump)
73{
74	intmax_t digit, num;
75
76	if (str[0] == '\0')
77		goto invalid;	/* Empty string. */
78	num = 0;
79	for (; *str != '\0'; str++) {
80		if (*str < '0' || *str > '9')
81			goto invalid;	/* Non-digit character. */
82		digit = *str - '0';
83		if (num > num * 10 + digit)
84			goto invalid;	/* Overflow. */
85		num = num * 10 + digit;
86		if (num > maxnum)
87			goto invalid;	/* Too big. */
88	}
89	if (num < minnum)
90		goto invalid;	/* Too small. */
91	*nump = num;
92	return (0);
93invalid:
94	errno = EINVAL;
95	return (-1);
96}
97
98static int
99tcp_addr(const char *addr, int defport, struct sockaddr_storage *sap)
100{
101	char iporhost[MAXHOSTNAMELEN], portstr[6];
102	struct addrinfo hints;
103	struct addrinfo *res;
104	const char *pp;
105	intmax_t port;
106	size_t size;
107	int error;
108
109	if (addr == NULL)
110		return (-1);
111
112	bzero(&hints, sizeof(hints));
113	hints.ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV;
114	hints.ai_family = PF_UNSPEC;
115	hints.ai_socktype = SOCK_STREAM;
116	hints.ai_protocol = IPPROTO_TCP;
117
118	if (strncasecmp(addr, "tcp4://", 7) == 0) {
119		addr += 7;
120		hints.ai_family = PF_INET;
121	} else if (strncasecmp(addr, "tcp6://", 7) == 0) {
122		addr += 7;
123		hints.ai_family = PF_INET6;
124	} else if (strncasecmp(addr, "tcp://", 6) == 0) {
125		addr += 6;
126	} else {
127		/*
128		 * Because TCP is the default assume IP or host is given without
129		 * prefix.
130		 */
131	}
132
133	/*
134	 * Extract optional port.
135	 * There are three cases to consider.
136	 * 1. hostname with port, eg. freefall.freebsd.org:8457
137	 * 2. IPv4 address with port, eg. 192.168.0.101:8457
138	 * 3. IPv6 address with port, eg. [fe80::1]:8457
139	 * We discover IPv6 address by checking for two colons and if port is
140	 * given, the address has to start with [.
141	 */
142	pp = NULL;
143	if (strchr(addr, ':') != strrchr(addr, ':')) {
144		if (addr[0] == '[')
145			pp = strrchr(addr, ':');
146	} else {
147		pp = strrchr(addr, ':');
148	}
149	if (pp == NULL) {
150		/* Port not given, use the default. */
151		port = defport;
152	} else {
153		if (numfromstr(pp + 1, 1, 65535, &port) == -1)
154			return (errno);
155	}
156	(void)snprintf(portstr, sizeof(portstr), "%jd", (intmax_t)port);
157	/* Extract host name or IP address. */
158	if (pp == NULL) {
159		size = sizeof(iporhost);
160		if (strlcpy(iporhost, addr, size) >= size)
161			return (ENAMETOOLONG);
162	} else if (addr[0] == '[' && pp[-1] == ']') {
163		size = (size_t)(pp - addr - 2 + 1);
164		if (size > sizeof(iporhost))
165			return (ENAMETOOLONG);
166		(void)strlcpy(iporhost, addr + 1, size);
167	} else {
168		size = (size_t)(pp - addr + 1);
169		if (size > sizeof(iporhost))
170			return (ENAMETOOLONG);
171		(void)strlcpy(iporhost, addr, size);
172	}
173
174	error = getaddrinfo(iporhost, portstr, &hints, &res);
175	if (error != 0) {
176		pjdlog_debug(1, "getaddrinfo(%s, %s) failed: %s.", iporhost,
177		    portstr, gai_strerror(error));
178		return (EINVAL);
179	}
180	if (res == NULL)
181		return (ENOENT);
182
183	memcpy(sap, res->ai_addr, res->ai_addrlen);
184
185	freeaddrinfo(res);
186
187	return (0);
188}
189
190static int
191tcp_setup_new(const char *addr, int side, void **ctxp)
192{
193	struct tcp_ctx *tctx;
194	int ret, nodelay;
195
196	PJDLOG_ASSERT(addr != NULL);
197	PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
198	    side == TCP_SIDE_SERVER_LISTEN);
199	PJDLOG_ASSERT(ctxp != NULL);
200
201	tctx = malloc(sizeof(*tctx));
202	if (tctx == NULL)
203		return (errno);
204
205	/* Parse given address. */
206	if ((ret = tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &tctx->tc_sa)) != 0) {
207		free(tctx);
208		return (ret);
209	}
210
211	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
212
213	tctx->tc_fd = socket(tctx->tc_sa.ss_family, SOCK_STREAM, 0);
214	if (tctx->tc_fd == -1) {
215		ret = errno;
216		free(tctx);
217		return (ret);
218	}
219
220	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
221
222	/* Socket settings. */
223	nodelay = 1;
224	if (setsockopt(tctx->tc_fd, IPPROTO_TCP, TCP_NODELAY, &nodelay,
225	    sizeof(nodelay)) == -1) {
226		pjdlog_errno(LOG_WARNING, "Unable to set TCP_NOELAY");
227	}
228
229	tctx->tc_side = side;
230	tctx->tc_magic = TCP_CTX_MAGIC;
231	*ctxp = tctx;
232
233	return (0);
234}
235
236static int
237tcp_setup_wrap(int fd, int side, void **ctxp)
238{
239	struct tcp_ctx *tctx;
240
241	PJDLOG_ASSERT(fd >= 0);
242	PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
243	    side == TCP_SIDE_SERVER_WORK);
244	PJDLOG_ASSERT(ctxp != NULL);
245
246	tctx = malloc(sizeof(*tctx));
247	if (tctx == NULL)
248		return (errno);
249
250	tctx->tc_fd = fd;
251	tctx->tc_sa.ss_family = AF_UNSPEC;
252	tctx->tc_side = side;
253	tctx->tc_magic = TCP_CTX_MAGIC;
254	*ctxp = tctx;
255
256	return (0);
257}
258
259static int
260tcp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
261{
262	struct tcp_ctx *tctx;
263	struct sockaddr_storage sa;
264	int ret;
265
266	ret = tcp_setup_new(dstaddr, TCP_SIDE_CLIENT, ctxp);
267	if (ret != 0)
268		return (ret);
269	tctx = *ctxp;
270	if (srcaddr == NULL)
271		return (0);
272	ret = tcp_addr(srcaddr, 0, &sa);
273	if (ret != 0) {
274		tcp_close(tctx);
275		return (ret);
276	}
277	if (bind(tctx->tc_fd, (struct sockaddr *)&sa, sa.ss_len) == -1) {
278		ret = errno;
279		tcp_close(tctx);
280		return (ret);
281	}
282	return (0);
283}
284
285static int
286tcp_connect(void *ctx, int timeout)
287{
288	struct tcp_ctx *tctx = ctx;
289	int error, flags;
290
291	PJDLOG_ASSERT(tctx != NULL);
292	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
293	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
294	PJDLOG_ASSERT(tctx->tc_fd >= 0);
295	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
296	PJDLOG_ASSERT(timeout >= -1);
297
298	flags = fcntl(tctx->tc_fd, F_GETFL);
299	if (flags == -1) {
300		pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
301		return (errno);
302	}
303	/*
304	 * We make socket non-blocking so we can handle connection timeout
305	 * manually.
306	 */
307	flags |= O_NONBLOCK;
308	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
309		pjdlog_common(LOG_DEBUG, 1, errno,
310		    "fcntl(F_SETFL, O_NONBLOCK) failed");
311		return (errno);
312	}
313
314	if (connect(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
315	    tctx->tc_sa.ss_len) == 0) {
316		if (timeout == -1)
317			return (0);
318		error = 0;
319		goto done;
320	}
321	if (errno != EINPROGRESS) {
322		error = errno;
323		pjdlog_common(LOG_DEBUG, 1, errno, "connect() failed");
324		goto done;
325	}
326	if (timeout == -1)
327		return (0);
328	return (tcp_connect_wait(ctx, timeout));
329done:
330	flags &= ~O_NONBLOCK;
331	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
332		if (error == 0)
333			error = errno;
334		pjdlog_common(LOG_DEBUG, 1, errno,
335		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
336	}
337	return (error);
338}
339
340static int
341tcp_connect_wait(void *ctx, int timeout)
342{
343	struct tcp_ctx *tctx = ctx;
344	struct timeval tv;
345	fd_set fdset;
346	socklen_t esize;
347	int error, flags, ret;
348
349	PJDLOG_ASSERT(tctx != NULL);
350	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
351	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
352	PJDLOG_ASSERT(tctx->tc_fd >= 0);
353	PJDLOG_ASSERT(timeout >= 0);
354
355	tv.tv_sec = timeout;
356	tv.tv_usec = 0;
357again:
358	FD_ZERO(&fdset);
359	FD_SET(tctx->tc_fd, &fdset);
360	ret = select(tctx->tc_fd + 1, NULL, &fdset, NULL, &tv);
361	if (ret == 0) {
362		error = ETIMEDOUT;
363		goto done;
364	} else if (ret == -1) {
365		if (errno == EINTR)
366			goto again;
367		error = errno;
368		pjdlog_common(LOG_DEBUG, 1, errno, "select() failed");
369		goto done;
370	}
371	PJDLOG_ASSERT(ret > 0);
372	PJDLOG_ASSERT(FD_ISSET(tctx->tc_fd, &fdset));
373	esize = sizeof(error);
374	if (getsockopt(tctx->tc_fd, SOL_SOCKET, SO_ERROR, &error,
375	    &esize) == -1) {
376		error = errno;
377		pjdlog_common(LOG_DEBUG, 1, errno,
378		    "getsockopt(SO_ERROR) failed");
379		goto done;
380	}
381	if (error != 0) {
382		pjdlog_common(LOG_DEBUG, 1, error,
383		    "getsockopt(SO_ERROR) returned error");
384		goto done;
385	}
386	error = 0;
387done:
388	flags = fcntl(tctx->tc_fd, F_GETFL);
389	if (flags == -1) {
390		if (error == 0)
391			error = errno;
392		pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
393		return (error);
394	}
395	flags &= ~O_NONBLOCK;
396	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
397		if (error == 0)
398			error = errno;
399		pjdlog_common(LOG_DEBUG, 1, errno,
400		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
401	}
402	return (error);
403}
404
405static int
406tcp_server(const char *addr, void **ctxp)
407{
408	struct tcp_ctx *tctx;
409	int ret, val;
410
411	ret = tcp_setup_new(addr, TCP_SIDE_SERVER_LISTEN, ctxp);
412	if (ret != 0)
413		return (ret);
414
415	tctx = *ctxp;
416
417	val = 1;
418	/* Ignore failure. */
419	(void)setsockopt(tctx->tc_fd, SOL_SOCKET, SO_REUSEADDR, &val,
420	   sizeof(val));
421
422	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
423
424	if (bind(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
425	    tctx->tc_sa.ss_len) == -1) {
426		ret = errno;
427		tcp_close(tctx);
428		return (ret);
429	}
430	if (listen(tctx->tc_fd, 8) == -1) {
431		ret = errno;
432		tcp_close(tctx);
433		return (ret);
434	}
435
436	return (0);
437}
438
439static int
440tcp_accept(void *ctx, void **newctxp)
441{
442	struct tcp_ctx *tctx = ctx;
443	struct tcp_ctx *newtctx;
444	socklen_t fromlen;
445	int ret;
446
447	PJDLOG_ASSERT(tctx != NULL);
448	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
449	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_SERVER_LISTEN);
450	PJDLOG_ASSERT(tctx->tc_fd >= 0);
451	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
452
453	newtctx = malloc(sizeof(*newtctx));
454	if (newtctx == NULL)
455		return (errno);
456
457	fromlen = tctx->tc_sa.ss_len;
458	newtctx->tc_fd = accept(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
459	    &fromlen);
460	if (newtctx->tc_fd == -1) {
461		ret = errno;
462		free(newtctx);
463		return (ret);
464	}
465
466	newtctx->tc_side = TCP_SIDE_SERVER_WORK;
467	newtctx->tc_magic = TCP_CTX_MAGIC;
468	*newctxp = newtctx;
469
470	return (0);
471}
472
473static int
474tcp_wrap(int fd, bool client, void **ctxp)
475{
476
477	return (tcp_setup_wrap(fd,
478	    client ? TCP_SIDE_CLIENT : TCP_SIDE_SERVER_WORK, ctxp));
479}
480
481static int
482tcp_send(void *ctx, const unsigned char *data, size_t size, int fd)
483{
484	struct tcp_ctx *tctx = ctx;
485
486	PJDLOG_ASSERT(tctx != NULL);
487	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
488	PJDLOG_ASSERT(tctx->tc_fd >= 0);
489	PJDLOG_ASSERT(fd == -1);
490
491	return (proto_common_send(tctx->tc_fd, data, size, -1));
492}
493
494static int
495tcp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
496{
497	struct tcp_ctx *tctx = ctx;
498
499	PJDLOG_ASSERT(tctx != NULL);
500	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
501	PJDLOG_ASSERT(tctx->tc_fd >= 0);
502	PJDLOG_ASSERT(fdp == NULL);
503
504	return (proto_common_recv(tctx->tc_fd, data, size, NULL));
505}
506
507static int
508tcp_descriptor(const void *ctx)
509{
510	const struct tcp_ctx *tctx = ctx;
511
512	PJDLOG_ASSERT(tctx != NULL);
513	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
514
515	return (tctx->tc_fd);
516}
517
518static bool
519tcp_address_match(const void *ctx, const char *addr)
520{
521	const struct tcp_ctx *tctx = ctx;
522	struct sockaddr_storage sa1, sa2;
523	socklen_t salen;
524
525	PJDLOG_ASSERT(tctx != NULL);
526	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
527
528	if (tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &sa1) != 0)
529		return (false);
530
531	salen = sizeof(sa2);
532	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa2, &salen) == -1)
533		return (false);
534
535	if (sa1.ss_family != sa2.ss_family || sa1.ss_len != sa2.ss_len)
536		return (false);
537
538	switch (sa1.ss_family) {
539	case AF_INET:
540	    {
541		struct sockaddr_in *sin1, *sin2;
542
543		sin1 = (struct sockaddr_in *)&sa1;
544		sin2 = (struct sockaddr_in *)&sa2;
545
546		return (memcmp(&sin1->sin_addr, &sin2->sin_addr,
547		    sizeof(sin1->sin_addr)) == 0);
548	    }
549	case AF_INET6:
550	    {
551		struct sockaddr_in6 *sin1, *sin2;
552
553		sin1 = (struct sockaddr_in6 *)&sa1;
554		sin2 = (struct sockaddr_in6 *)&sa2;
555
556		return (memcmp(&sin1->sin6_addr, &sin2->sin6_addr,
557		    sizeof(sin1->sin6_addr)) == 0);
558	    }
559	default:
560		return (false);
561	}
562}
563
564static void
565tcp_local_address(const void *ctx, char *addr, size_t size)
566{
567	const struct tcp_ctx *tctx = ctx;
568	struct sockaddr_storage sa;
569	socklen_t salen;
570
571	PJDLOG_ASSERT(tctx != NULL);
572	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
573
574	salen = sizeof(sa);
575	if (getsockname(tctx->tc_fd, (struct sockaddr *)&sa, &salen) == -1) {
576		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
577		return;
578	}
579	PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
580}
581
582static void
583tcp_remote_address(const void *ctx, char *addr, size_t size)
584{
585	const struct tcp_ctx *tctx = ctx;
586	struct sockaddr_storage sa;
587	socklen_t salen;
588
589	PJDLOG_ASSERT(tctx != NULL);
590	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
591
592	salen = sizeof(sa);
593	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa, &salen) == -1) {
594		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
595		return;
596	}
597	PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
598}
599
600static void
601tcp_close(void *ctx)
602{
603	struct tcp_ctx *tctx = ctx;
604
605	PJDLOG_ASSERT(tctx != NULL);
606	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
607
608	if (tctx->tc_fd >= 0)
609		close(tctx->tc_fd);
610	tctx->tc_magic = 0;
611	free(tctx);
612}
613
614static struct proto tcp_proto = {
615	.prt_name = "tcp",
616	.prt_client = tcp_client,
617	.prt_connect = tcp_connect,
618	.prt_connect_wait = tcp_connect_wait,
619	.prt_server = tcp_server,
620	.prt_accept = tcp_accept,
621	.prt_wrap = tcp_wrap,
622	.prt_send = tcp_send,
623	.prt_recv = tcp_recv,
624	.prt_descriptor = tcp_descriptor,
625	.prt_address_match = tcp_address_match,
626	.prt_local_address = tcp_local_address,
627	.prt_remote_address = tcp_remote_address,
628	.prt_close = tcp_close
629};
630
631static __constructor void
632tcp_ctor(void)
633{
634
635	proto_register(&tcp_proto, true);
636}
637