1204076Spjd/*-
2330449Seadler * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3330449Seadler *
4204076Spjd * Copyright (c) 2009-2010 The FreeBSD Foundation
5222118Spjd * Copyright (c) 2011 Pawel Jakub Dawidek <pawel@dawidek.net>
6204076Spjd * All rights reserved.
7204076Spjd *
8204076Spjd * This software was developed by Pawel Jakub Dawidek under sponsorship from
9204076Spjd * the FreeBSD Foundation.
10204076Spjd *
11204076Spjd * Redistribution and use in source and binary forms, with or without
12204076Spjd * modification, are permitted provided that the following conditions
13204076Spjd * are met:
14204076Spjd * 1. Redistributions of source code must retain the above copyright
15204076Spjd *    notice, this list of conditions and the following disclaimer.
16204076Spjd * 2. Redistributions in binary form must reproduce the above copyright
17204076Spjd *    notice, this list of conditions and the following disclaimer in the
18204076Spjd *    documentation and/or other materials provided with the distribution.
19204076Spjd *
20204076Spjd * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21204076Spjd * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22204076Spjd * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23204076Spjd * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24204076Spjd * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25204076Spjd * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26204076Spjd * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27204076Spjd * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28204076Spjd * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29204076Spjd * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30204076Spjd * SUCH DAMAGE.
31204076Spjd */
32204076Spjd
33204076Spjd#include <sys/cdefs.h>
34204076Spjd__FBSDID("$FreeBSD: stable/11/sbin/hastd/proto_tcp.c 330449 2018-03-05 07:26:05Z eadler $");
35204076Spjd
36204076Spjd#include <sys/param.h>	/* MAXHOSTNAMELEN */
37219873Spjd#include <sys/socket.h>
38204076Spjd
39219873Spjd#include <arpa/inet.h>
40219873Spjd
41204076Spjd#include <netinet/in.h>
42204076Spjd#include <netinet/tcp.h>
43204076Spjd
44204076Spjd#include <errno.h>
45207390Spjd#include <fcntl.h>
46204076Spjd#include <netdb.h>
47204076Spjd#include <stdbool.h>
48204076Spjd#include <stdint.h>
49204076Spjd#include <stdio.h>
50204076Spjd#include <string.h>
51204076Spjd#include <unistd.h>
52204076Spjd
53204076Spjd#include "pjdlog.h"
54204076Spjd#include "proto_impl.h"
55207390Spjd#include "subr.h"
56204076Spjd
57222118Spjd#define	TCP_CTX_MAGIC	0x7c41c
58222116Spjdstruct tcp_ctx {
59204076Spjd	int			tc_magic;
60222118Spjd	struct sockaddr_storage	tc_sa;
61204076Spjd	int			tc_fd;
62204076Spjd	int			tc_side;
63222116Spjd#define	TCP_SIDE_CLIENT		0
64222116Spjd#define	TCP_SIDE_SERVER_LISTEN	1
65222116Spjd#define	TCP_SIDE_SERVER_WORK	2
66204076Spjd};
67204076Spjd
68222116Spjdstatic int tcp_connect_wait(void *ctx, int timeout);
69222116Spjdstatic void tcp_close(void *ctx);
70204076Spjd
71204076Spjd/*
72204076Spjd * Function converts the given string to unsigned number.
73204076Spjd */
74204076Spjdstatic int
75204076Spjdnumfromstr(const char *str, intmax_t minnum, intmax_t maxnum, intmax_t *nump)
76204076Spjd{
77204076Spjd	intmax_t digit, num;
78204076Spjd
79204076Spjd	if (str[0] == '\0')
80204076Spjd		goto invalid;	/* Empty string. */
81204076Spjd	num = 0;
82204076Spjd	for (; *str != '\0'; str++) {
83204076Spjd		if (*str < '0' || *str > '9')
84204076Spjd			goto invalid;	/* Non-digit character. */
85204076Spjd		digit = *str - '0';
86204076Spjd		if (num > num * 10 + digit)
87204076Spjd			goto invalid;	/* Overflow. */
88204076Spjd		num = num * 10 + digit;
89204076Spjd		if (num > maxnum)
90204076Spjd			goto invalid;	/* Too big. */
91204076Spjd	}
92204076Spjd	if (num < minnum)
93204076Spjd		goto invalid;	/* Too small. */
94204076Spjd	*nump = num;
95204076Spjd	return (0);
96204076Spjdinvalid:
97204076Spjd	errno = EINVAL;
98204076Spjd	return (-1);
99204076Spjd}
100204076Spjd
101204076Spjdstatic int
102222118Spjdtcp_addr(const char *addr, int defport, struct sockaddr_storage *sap)
103204076Spjd{
104222118Spjd	char iporhost[MAXHOSTNAMELEN], portstr[6];
105222118Spjd	struct addrinfo hints;
106222118Spjd	struct addrinfo *res;
107204076Spjd	const char *pp;
108222118Spjd	intmax_t port;
109204076Spjd	size_t size;
110222118Spjd	int error;
111204076Spjd
112204076Spjd	if (addr == NULL)
113204076Spjd		return (-1);
114204076Spjd
115222118Spjd	bzero(&hints, sizeof(hints));
116222118Spjd	hints.ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV;
117222118Spjd	hints.ai_family = PF_UNSPEC;
118222118Spjd	hints.ai_socktype = SOCK_STREAM;
119222118Spjd	hints.ai_protocol = IPPROTO_TCP;
120222118Spjd
121222118Spjd	if (strncasecmp(addr, "tcp4://", 7) == 0) {
122204076Spjd		addr += 7;
123222118Spjd		hints.ai_family = PF_INET;
124222118Spjd	} else if (strncasecmp(addr, "tcp6://", 7) == 0) {
125222118Spjd		addr += 7;
126222118Spjd		hints.ai_family = PF_INET6;
127222118Spjd	} else if (strncasecmp(addr, "tcp://", 6) == 0) {
128204076Spjd		addr += 6;
129222118Spjd	} else {
130210870Spjd		/*
131222116Spjd		 * Because TCP is the default assume IP or host is given without
132210870Spjd		 * prefix.
133210870Spjd		 */
134210870Spjd	}
135204076Spjd
136222118Spjd	/*
137222118Spjd	 * Extract optional port.
138222118Spjd	 * There are three cases to consider.
139222118Spjd	 * 1. hostname with port, eg. freefall.freebsd.org:8457
140222118Spjd	 * 2. IPv4 address with port, eg. 192.168.0.101:8457
141222118Spjd	 * 3. IPv6 address with port, eg. [fe80::1]:8457
142222118Spjd	 * We discover IPv6 address by checking for two colons and if port is
143222118Spjd	 * given, the address has to start with [.
144222118Spjd	 */
145222118Spjd	pp = NULL;
146222118Spjd	if (strchr(addr, ':') != strrchr(addr, ':')) {
147222118Spjd		if (addr[0] == '[')
148222118Spjd			pp = strrchr(addr, ':');
149222118Spjd	} else {
150222118Spjd		pp = strrchr(addr, ':');
151222118Spjd	}
152204076Spjd	if (pp == NULL) {
153204076Spjd		/* Port not given, use the default. */
154222118Spjd		port = defport;
155204076Spjd	} else {
156229945Spjd		if (numfromstr(pp + 1, 1, 65535, &port) == -1)
157204076Spjd			return (errno);
158204076Spjd	}
159222118Spjd	(void)snprintf(portstr, sizeof(portstr), "%jd", (intmax_t)port);
160204076Spjd	/* Extract host name or IP address. */
161204076Spjd	if (pp == NULL) {
162204076Spjd		size = sizeof(iporhost);
163204076Spjd		if (strlcpy(iporhost, addr, size) >= size)
164204076Spjd			return (ENAMETOOLONG);
165222118Spjd	} else if (addr[0] == '[' && pp[-1] == ']') {
166222118Spjd		size = (size_t)(pp - addr - 2 + 1);
167222118Spjd		if (size > sizeof(iporhost))
168222118Spjd			return (ENAMETOOLONG);
169222118Spjd		(void)strlcpy(iporhost, addr + 1, size);
170204076Spjd	} else {
171204076Spjd		size = (size_t)(pp - addr + 1);
172204076Spjd		if (size > sizeof(iporhost))
173204076Spjd			return (ENAMETOOLONG);
174211407Spjd		(void)strlcpy(iporhost, addr, size);
175204076Spjd	}
176222118Spjd
177222118Spjd	error = getaddrinfo(iporhost, portstr, &hints, &res);
178222118Spjd	if (error != 0) {
179222118Spjd		pjdlog_debug(1, "getaddrinfo(%s, %s) failed: %s.", iporhost,
180222118Spjd		    portstr, gai_strerror(error));
181204076Spjd		return (EINVAL);
182222118Spjd	}
183222118Spjd	if (res == NULL)
184222118Spjd		return (ENOENT);
185204076Spjd
186222118Spjd	memcpy(sap, res->ai_addr, res->ai_addrlen);
187222118Spjd
188222118Spjd	freeaddrinfo(res);
189222118Spjd
190204076Spjd	return (0);
191204076Spjd}
192204076Spjd
193204076Spjdstatic int
194222116Spjdtcp_setup_new(const char *addr, int side, void **ctxp)
195204076Spjd{
196222116Spjd	struct tcp_ctx *tctx;
197218158Spjd	int ret, nodelay;
198204076Spjd
199218194Spjd	PJDLOG_ASSERT(addr != NULL);
200222116Spjd	PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
201222116Spjd	    side == TCP_SIDE_SERVER_LISTEN);
202218194Spjd	PJDLOG_ASSERT(ctxp != NULL);
203218194Spjd
204204076Spjd	tctx = malloc(sizeof(*tctx));
205204076Spjd	if (tctx == NULL)
206204076Spjd		return (errno);
207204076Spjd
208204076Spjd	/* Parse given address. */
209222118Spjd	if ((ret = tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &tctx->tc_sa)) != 0) {
210204076Spjd		free(tctx);
211204076Spjd		return (ret);
212204076Spjd	}
213204076Spjd
214222118Spjd	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
215218194Spjd
216222118Spjd	tctx->tc_fd = socket(tctx->tc_sa.ss_family, SOCK_STREAM, 0);
217204076Spjd	if (tctx->tc_fd == -1) {
218204076Spjd		ret = errno;
219204076Spjd		free(tctx);
220204076Spjd		return (ret);
221204076Spjd	}
222204076Spjd
223222118Spjd	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
224219818Spjd
225204076Spjd	/* Socket settings. */
226218158Spjd	nodelay = 1;
227218158Spjd	if (setsockopt(tctx->tc_fd, IPPROTO_TCP, TCP_NODELAY, &nodelay,
228218158Spjd	    sizeof(nodelay)) == -1) {
229218194Spjd		pjdlog_errno(LOG_WARNING, "Unable to set TCP_NOELAY");
230204076Spjd	}
231204076Spjd
232204076Spjd	tctx->tc_side = side;
233222116Spjd	tctx->tc_magic = TCP_CTX_MAGIC;
234204076Spjd	*ctxp = tctx;
235204076Spjd
236204076Spjd	return (0);
237204076Spjd}
238204076Spjd
239204076Spjdstatic int
240222116Spjdtcp_setup_wrap(int fd, int side, void **ctxp)
241218194Spjd{
242222116Spjd	struct tcp_ctx *tctx;
243218194Spjd
244218194Spjd	PJDLOG_ASSERT(fd >= 0);
245222116Spjd	PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
246222116Spjd	    side == TCP_SIDE_SERVER_WORK);
247218194Spjd	PJDLOG_ASSERT(ctxp != NULL);
248218194Spjd
249218194Spjd	tctx = malloc(sizeof(*tctx));
250218194Spjd	if (tctx == NULL)
251218194Spjd		return (errno);
252218194Spjd
253218194Spjd	tctx->tc_fd = fd;
254222118Spjd	tctx->tc_sa.ss_family = AF_UNSPEC;
255218194Spjd	tctx->tc_side = side;
256222116Spjd	tctx->tc_magic = TCP_CTX_MAGIC;
257218194Spjd	*ctxp = tctx;
258218194Spjd
259218194Spjd	return (0);
260218194Spjd}
261218194Spjd
262218194Spjdstatic int
263222116Spjdtcp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
264204076Spjd{
265222116Spjd	struct tcp_ctx *tctx;
266222118Spjd	struct sockaddr_storage sa;
267219818Spjd	int ret;
268204076Spjd
269222116Spjd	ret = tcp_setup_new(dstaddr, TCP_SIDE_CLIENT, ctxp);
270219818Spjd	if (ret != 0)
271219818Spjd		return (ret);
272219818Spjd	tctx = *ctxp;
273219818Spjd	if (srcaddr == NULL)
274219818Spjd		return (0);
275222118Spjd	ret = tcp_addr(srcaddr, 0, &sa);
276219818Spjd	if (ret != 0) {
277222116Spjd		tcp_close(tctx);
278219818Spjd		return (ret);
279219818Spjd	}
280229945Spjd	if (bind(tctx->tc_fd, (struct sockaddr *)&sa, sa.ss_len) == -1) {
281219818Spjd		ret = errno;
282222116Spjd		tcp_close(tctx);
283219818Spjd		return (ret);
284219818Spjd	}
285219818Spjd	return (0);
286204076Spjd}
287204076Spjd
288204076Spjdstatic int
289222116Spjdtcp_connect(void *ctx, int timeout)
290204076Spjd{
291222116Spjd	struct tcp_ctx *tctx = ctx;
292218193Spjd	int error, flags;
293204076Spjd
294218138Spjd	PJDLOG_ASSERT(tctx != NULL);
295222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
296222116Spjd	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
297218138Spjd	PJDLOG_ASSERT(tctx->tc_fd >= 0);
298222118Spjd	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
299218193Spjd	PJDLOG_ASSERT(timeout >= -1);
300204076Spjd
301207390Spjd	flags = fcntl(tctx->tc_fd, F_GETFL);
302207390Spjd	if (flags == -1) {
303225781Spjd		pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
304204076Spjd		return (errno);
305204076Spjd	}
306207390Spjd	/*
307211875Spjd	 * We make socket non-blocking so we can handle connection timeout
308211875Spjd	 * manually.
309207390Spjd	 */
310207390Spjd	flags |= O_NONBLOCK;
311207390Spjd	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
312225781Spjd		pjdlog_common(LOG_DEBUG, 1, errno,
313225781Spjd		    "fcntl(F_SETFL, O_NONBLOCK) failed");
314207390Spjd		return (errno);
315207390Spjd	}
316204076Spjd
317222118Spjd	if (connect(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
318222118Spjd	    tctx->tc_sa.ss_len) == 0) {
319218193Spjd		if (timeout == -1)
320218193Spjd			return (0);
321207390Spjd		error = 0;
322207390Spjd		goto done;
323207390Spjd	}
324207390Spjd	if (errno != EINPROGRESS) {
325207390Spjd		error = errno;
326207390Spjd		pjdlog_common(LOG_DEBUG, 1, errno, "connect() failed");
327207390Spjd		goto done;
328207390Spjd	}
329218193Spjd	if (timeout == -1)
330218193Spjd		return (0);
331222116Spjd	return (tcp_connect_wait(ctx, timeout));
332218193Spjddone:
333218193Spjd	flags &= ~O_NONBLOCK;
334218193Spjd	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
335218193Spjd		if (error == 0)
336218193Spjd			error = errno;
337218193Spjd		pjdlog_common(LOG_DEBUG, 1, errno,
338218193Spjd		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
339218193Spjd	}
340218193Spjd	return (error);
341218193Spjd}
342218193Spjd
343218193Spjdstatic int
344222116Spjdtcp_connect_wait(void *ctx, int timeout)
345218193Spjd{
346222116Spjd	struct tcp_ctx *tctx = ctx;
347218193Spjd	struct timeval tv;
348218193Spjd	fd_set fdset;
349218193Spjd	socklen_t esize;
350218193Spjd	int error, flags, ret;
351218193Spjd
352218193Spjd	PJDLOG_ASSERT(tctx != NULL);
353222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
354222116Spjd	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
355218193Spjd	PJDLOG_ASSERT(tctx->tc_fd >= 0);
356218193Spjd	PJDLOG_ASSERT(timeout >= 0);
357218193Spjd
358218192Spjd	tv.tv_sec = timeout;
359207390Spjd	tv.tv_usec = 0;
360207390Spjdagain:
361207390Spjd	FD_ZERO(&fdset);
362219864Spjd	FD_SET(tctx->tc_fd, &fdset);
363207390Spjd	ret = select(tctx->tc_fd + 1, NULL, &fdset, NULL, &tv);
364207390Spjd	if (ret == 0) {
365207390Spjd		error = ETIMEDOUT;
366207390Spjd		goto done;
367207390Spjd	} else if (ret == -1) {
368207390Spjd		if (errno == EINTR)
369207390Spjd			goto again;
370207390Spjd		error = errno;
371207390Spjd		pjdlog_common(LOG_DEBUG, 1, errno, "select() failed");
372207390Spjd		goto done;
373207390Spjd	}
374218138Spjd	PJDLOG_ASSERT(ret > 0);
375218138Spjd	PJDLOG_ASSERT(FD_ISSET(tctx->tc_fd, &fdset));
376207390Spjd	esize = sizeof(error);
377207390Spjd	if (getsockopt(tctx->tc_fd, SOL_SOCKET, SO_ERROR, &error,
378207390Spjd	    &esize) == -1) {
379207390Spjd		error = errno;
380207390Spjd		pjdlog_common(LOG_DEBUG, 1, errno,
381207390Spjd		    "getsockopt(SO_ERROR) failed");
382207390Spjd		goto done;
383207390Spjd	}
384207390Spjd	if (error != 0) {
385207390Spjd		pjdlog_common(LOG_DEBUG, 1, error,
386207390Spjd		    "getsockopt(SO_ERROR) returned error");
387207390Spjd		goto done;
388207390Spjd	}
389207390Spjd	error = 0;
390207390Spjddone:
391218193Spjd	flags = fcntl(tctx->tc_fd, F_GETFL);
392218193Spjd	if (flags == -1) {
393218193Spjd		if (error == 0)
394218193Spjd			error = errno;
395218193Spjd		pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
396218193Spjd		return (error);
397218193Spjd	}
398207390Spjd	flags &= ~O_NONBLOCK;
399207390Spjd	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
400207390Spjd		if (error == 0)
401207390Spjd			error = errno;
402207390Spjd		pjdlog_common(LOG_DEBUG, 1, errno,
403207390Spjd		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
404207390Spjd	}
405207390Spjd	return (error);
406204076Spjd}
407204076Spjd
408204076Spjdstatic int
409222116Spjdtcp_server(const char *addr, void **ctxp)
410204076Spjd{
411222116Spjd	struct tcp_ctx *tctx;
412204076Spjd	int ret, val;
413204076Spjd
414222116Spjd	ret = tcp_setup_new(addr, TCP_SIDE_SERVER_LISTEN, ctxp);
415204076Spjd	if (ret != 0)
416204076Spjd		return (ret);
417204076Spjd
418204076Spjd	tctx = *ctxp;
419204076Spjd
420204076Spjd	val = 1;
421204076Spjd	/* Ignore failure. */
422204076Spjd	(void)setsockopt(tctx->tc_fd, SOL_SOCKET, SO_REUSEADDR, &val,
423204076Spjd	   sizeof(val));
424204076Spjd
425222118Spjd	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
426218194Spjd
427222118Spjd	if (bind(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
428229945Spjd	    tctx->tc_sa.ss_len) == -1) {
429204076Spjd		ret = errno;
430222116Spjd		tcp_close(tctx);
431204076Spjd		return (ret);
432204076Spjd	}
433229945Spjd	if (listen(tctx->tc_fd, 8) == -1) {
434204076Spjd		ret = errno;
435222116Spjd		tcp_close(tctx);
436204076Spjd		return (ret);
437204076Spjd	}
438204076Spjd
439204076Spjd	return (0);
440204076Spjd}
441204076Spjd
442204076Spjdstatic int
443222116Spjdtcp_accept(void *ctx, void **newctxp)
444204076Spjd{
445222116Spjd	struct tcp_ctx *tctx = ctx;
446222116Spjd	struct tcp_ctx *newtctx;
447204076Spjd	socklen_t fromlen;
448204076Spjd	int ret;
449204076Spjd
450218138Spjd	PJDLOG_ASSERT(tctx != NULL);
451222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
452222116Spjd	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_SERVER_LISTEN);
453218138Spjd	PJDLOG_ASSERT(tctx->tc_fd >= 0);
454222118Spjd	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
455204076Spjd
456204076Spjd	newtctx = malloc(sizeof(*newtctx));
457204076Spjd	if (newtctx == NULL)
458204076Spjd		return (errno);
459204076Spjd
460222118Spjd	fromlen = tctx->tc_sa.ss_len;
461222118Spjd	newtctx->tc_fd = accept(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
462204076Spjd	    &fromlen);
463229945Spjd	if (newtctx->tc_fd == -1) {
464204076Spjd		ret = errno;
465204076Spjd		free(newtctx);
466204076Spjd		return (ret);
467204076Spjd	}
468204076Spjd
469222116Spjd	newtctx->tc_side = TCP_SIDE_SERVER_WORK;
470222116Spjd	newtctx->tc_magic = TCP_CTX_MAGIC;
471204076Spjd	*newctxp = newtctx;
472204076Spjd
473204076Spjd	return (0);
474204076Spjd}
475204076Spjd
476204076Spjdstatic int
477222116Spjdtcp_wrap(int fd, bool client, void **ctxp)
478204076Spjd{
479218194Spjd
480222116Spjd	return (tcp_setup_wrap(fd,
481222116Spjd	    client ? TCP_SIDE_CLIENT : TCP_SIDE_SERVER_WORK, ctxp));
482218194Spjd}
483218194Spjd
484218194Spjdstatic int
485222116Spjdtcp_send(void *ctx, const unsigned char *data, size_t size, int fd)
486218194Spjd{
487222116Spjd	struct tcp_ctx *tctx = ctx;
488204076Spjd
489218138Spjd	PJDLOG_ASSERT(tctx != NULL);
490222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
491218138Spjd	PJDLOG_ASSERT(tctx->tc_fd >= 0);
492218194Spjd	PJDLOG_ASSERT(fd == -1);
493204076Spjd
494218194Spjd	return (proto_common_send(tctx->tc_fd, data, size, -1));
495204076Spjd}
496204076Spjd
497204076Spjdstatic int
498222116Spjdtcp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
499204076Spjd{
500222116Spjd	struct tcp_ctx *tctx = ctx;
501204076Spjd
502218138Spjd	PJDLOG_ASSERT(tctx != NULL);
503222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
504218138Spjd	PJDLOG_ASSERT(tctx->tc_fd >= 0);
505218194Spjd	PJDLOG_ASSERT(fdp == NULL);
506204076Spjd
507218194Spjd	return (proto_common_recv(tctx->tc_fd, data, size, NULL));
508204076Spjd}
509204076Spjd
510204076Spjdstatic int
511222116Spjdtcp_descriptor(const void *ctx)
512204076Spjd{
513222116Spjd	const struct tcp_ctx *tctx = ctx;
514204076Spjd
515218138Spjd	PJDLOG_ASSERT(tctx != NULL);
516222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
517204076Spjd
518204076Spjd	return (tctx->tc_fd);
519204076Spjd}
520204076Spjd
521204076Spjdstatic bool
522222116Spjdtcp_address_match(const void *ctx, const char *addr)
523204076Spjd{
524222116Spjd	const struct tcp_ctx *tctx = ctx;
525222118Spjd	struct sockaddr_storage sa1, sa2;
526222118Spjd	socklen_t salen;
527204076Spjd
528218138Spjd	PJDLOG_ASSERT(tctx != NULL);
529222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
530204076Spjd
531222118Spjd	if (tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &sa1) != 0)
532204076Spjd		return (false);
533204076Spjd
534222118Spjd	salen = sizeof(sa2);
535229945Spjd	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa2, &salen) == -1)
536204076Spjd		return (false);
537204076Spjd
538222118Spjd	if (sa1.ss_family != sa2.ss_family || sa1.ss_len != sa2.ss_len)
539222118Spjd		return (false);
540222118Spjd
541222118Spjd	switch (sa1.ss_family) {
542222118Spjd	case AF_INET:
543222118Spjd	    {
544222118Spjd		struct sockaddr_in *sin1, *sin2;
545222118Spjd
546222118Spjd		sin1 = (struct sockaddr_in *)&sa1;
547222118Spjd		sin2 = (struct sockaddr_in *)&sa2;
548222118Spjd
549222118Spjd		return (memcmp(&sin1->sin_addr, &sin2->sin_addr,
550222118Spjd		    sizeof(sin1->sin_addr)) == 0);
551222118Spjd	    }
552222118Spjd	case AF_INET6:
553222118Spjd	    {
554222118Spjd		struct sockaddr_in6 *sin1, *sin2;
555222118Spjd
556222118Spjd		sin1 = (struct sockaddr_in6 *)&sa1;
557222118Spjd		sin2 = (struct sockaddr_in6 *)&sa2;
558222118Spjd
559222118Spjd		return (memcmp(&sin1->sin6_addr, &sin2->sin6_addr,
560222118Spjd		    sizeof(sin1->sin6_addr)) == 0);
561222118Spjd	    }
562222118Spjd	default:
563222118Spjd		return (false);
564222118Spjd	}
565204076Spjd}
566204076Spjd
567204076Spjdstatic void
568222116Spjdtcp_local_address(const void *ctx, char *addr, size_t size)
569204076Spjd{
570222116Spjd	const struct tcp_ctx *tctx = ctx;
571222118Spjd	struct sockaddr_storage sa;
572222118Spjd	socklen_t salen;
573204076Spjd
574218138Spjd	PJDLOG_ASSERT(tctx != NULL);
575222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
576204076Spjd
577222118Spjd	salen = sizeof(sa);
578229945Spjd	if (getsockname(tctx->tc_fd, (struct sockaddr *)&sa, &salen) == -1) {
579210876Spjd		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
580204076Spjd		return;
581204076Spjd	}
582222118Spjd	PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
583204076Spjd}
584204076Spjd
585204076Spjdstatic void
586222116Spjdtcp_remote_address(const void *ctx, char *addr, size_t size)
587204076Spjd{
588222116Spjd	const struct tcp_ctx *tctx = ctx;
589222118Spjd	struct sockaddr_storage sa;
590222118Spjd	socklen_t salen;
591204076Spjd
592218138Spjd	PJDLOG_ASSERT(tctx != NULL);
593222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
594204076Spjd
595222118Spjd	salen = sizeof(sa);
596229945Spjd	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa, &salen) == -1) {
597210876Spjd		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
598204076Spjd		return;
599204076Spjd	}
600222118Spjd	PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
601204076Spjd}
602204076Spjd
603204076Spjdstatic void
604222116Spjdtcp_close(void *ctx)
605204076Spjd{
606222116Spjd	struct tcp_ctx *tctx = ctx;
607204076Spjd
608218138Spjd	PJDLOG_ASSERT(tctx != NULL);
609222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
610204076Spjd
611204076Spjd	if (tctx->tc_fd >= 0)
612204076Spjd		close(tctx->tc_fd);
613204076Spjd	tctx->tc_magic = 0;
614204076Spjd	free(tctx);
615204076Spjd}
616204076Spjd
617222116Spjdstatic struct proto tcp_proto = {
618222116Spjd	.prt_name = "tcp",
619222116Spjd	.prt_client = tcp_client,
620222116Spjd	.prt_connect = tcp_connect,
621222116Spjd	.prt_connect_wait = tcp_connect_wait,
622222116Spjd	.prt_server = tcp_server,
623222116Spjd	.prt_accept = tcp_accept,
624222116Spjd	.prt_wrap = tcp_wrap,
625222116Spjd	.prt_send = tcp_send,
626222116Spjd	.prt_recv = tcp_recv,
627222116Spjd	.prt_descriptor = tcp_descriptor,
628222116Spjd	.prt_address_match = tcp_address_match,
629222116Spjd	.prt_local_address = tcp_local_address,
630222116Spjd	.prt_remote_address = tcp_remote_address,
631222116Spjd	.prt_close = tcp_close
632204076Spjd};
633204076Spjd
634204076Spjdstatic __constructor void
635222116Spjdtcp_ctor(void)
636204076Spjd{
637204076Spjd
638222116Spjd	proto_register(&tcp_proto, true);
639204076Spjd}
640