1/*-
2 * SPDX-License-Identifier: BSD-2-Clause
3 *
4 * Copyright (c) 2009-2010 The FreeBSD Foundation
5 * All rights reserved.
6 *
7 * This software was developed by Pawel Jakub Dawidek under sponsorship from
8 * the FreeBSD Foundation.
9 *
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
12 * are met:
13 * 1. Redistributions of source code must retain the above copyright
14 *    notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 *    notice, this list of conditions and the following disclaimer in the
17 *    documentation and/or other materials provided with the distribution.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
20 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
23 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
25 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
26 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
27 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
28 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
29 * SUCH DAMAGE.
30 */
31
32#include <sys/types.h>
33#include <sys/queue.h>
34#include <sys/socket.h>
35
36#include <errno.h>
37#include <stdint.h>
38#include <string.h>
39#include <strings.h>
40
41#include "pjdlog.h"
42#include "proto.h"
43#include "proto_impl.h"
44
45#define	PROTO_CONN_MAGIC	0x907041c
46struct proto_conn {
47	int		 pc_magic;
48	struct proto	*pc_proto;
49	void		*pc_ctx;
50	int		 pc_side;
51#define	PROTO_SIDE_CLIENT		0
52#define	PROTO_SIDE_SERVER_LISTEN	1
53#define	PROTO_SIDE_SERVER_WORK		2
54};
55
56static TAILQ_HEAD(, proto) protos = TAILQ_HEAD_INITIALIZER(protos);
57
58void
59proto_register(struct proto *proto, bool isdefault)
60{
61	static bool seen_default = false;
62
63	if (!isdefault)
64		TAILQ_INSERT_HEAD(&protos, proto, prt_next);
65	else {
66		PJDLOG_ASSERT(!seen_default);
67		seen_default = true;
68		TAILQ_INSERT_TAIL(&protos, proto, prt_next);
69	}
70}
71
72static struct proto_conn *
73proto_alloc(struct proto *proto, int side)
74{
75	struct proto_conn *conn;
76
77	PJDLOG_ASSERT(proto != NULL);
78	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
79	    side == PROTO_SIDE_SERVER_LISTEN ||
80	    side == PROTO_SIDE_SERVER_WORK);
81
82	conn = malloc(sizeof(*conn));
83	if (conn != NULL) {
84		conn->pc_proto = proto;
85		conn->pc_side = side;
86		conn->pc_magic = PROTO_CONN_MAGIC;
87	}
88	return (conn);
89}
90
91static void
92proto_free(struct proto_conn *conn)
93{
94
95	PJDLOG_ASSERT(conn != NULL);
96	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
97	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT ||
98	    conn->pc_side == PROTO_SIDE_SERVER_LISTEN ||
99	    conn->pc_side == PROTO_SIDE_SERVER_WORK);
100	PJDLOG_ASSERT(conn->pc_proto != NULL);
101
102	bzero(conn, sizeof(*conn));
103	free(conn);
104}
105
106static int
107proto_common_setup(const char *srcaddr, const char *dstaddr,
108    struct proto_conn **connp, int side)
109{
110	struct proto *proto;
111	struct proto_conn *conn;
112	void *ctx;
113	int ret;
114
115	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
116	    side == PROTO_SIDE_SERVER_LISTEN);
117
118	TAILQ_FOREACH(proto, &protos, prt_next) {
119		if (side == PROTO_SIDE_CLIENT) {
120			if (proto->prt_client == NULL)
121				ret = -1;
122			else
123				ret = proto->prt_client(srcaddr, dstaddr, &ctx);
124		} else /* if (side == PROTO_SIDE_SERVER_LISTEN) */ {
125			if (proto->prt_server == NULL)
126				ret = -1;
127			else
128				ret = proto->prt_server(dstaddr, &ctx);
129		}
130		/*
131		 * ret == 0  - success
132		 * ret == -1 - dstaddr is not for this protocol
133		 * ret > 0   - right protocol, but an error occurred
134		 */
135		if (ret >= 0)
136			break;
137	}
138	if (proto == NULL) {
139		/* Unrecognized address. */
140		errno = EINVAL;
141		return (-1);
142	}
143	if (ret > 0) {
144		/* An error occurred. */
145		errno = ret;
146		return (-1);
147	}
148	conn = proto_alloc(proto, side);
149	if (conn == NULL) {
150		if (proto->prt_close != NULL)
151			proto->prt_close(ctx);
152		errno = ENOMEM;
153		return (-1);
154	}
155	conn->pc_ctx = ctx;
156	*connp = conn;
157
158	return (0);
159}
160
161int
162proto_client(const char *srcaddr, const char *dstaddr,
163    struct proto_conn **connp)
164{
165
166	return (proto_common_setup(srcaddr, dstaddr, connp, PROTO_SIDE_CLIENT));
167}
168
169int
170proto_connect(struct proto_conn *conn, int timeout)
171{
172	int ret;
173
174	PJDLOG_ASSERT(conn != NULL);
175	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
176	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
177	PJDLOG_ASSERT(conn->pc_proto != NULL);
178	PJDLOG_ASSERT(conn->pc_proto->prt_connect != NULL);
179	PJDLOG_ASSERT(timeout >= -1);
180
181	ret = conn->pc_proto->prt_connect(conn->pc_ctx, timeout);
182	if (ret != 0) {
183		errno = ret;
184		return (-1);
185	}
186
187	return (0);
188}
189
190int
191proto_connect_wait(struct proto_conn *conn, int timeout)
192{
193	int ret;
194
195	PJDLOG_ASSERT(conn != NULL);
196	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
197	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
198	PJDLOG_ASSERT(conn->pc_proto != NULL);
199	PJDLOG_ASSERT(conn->pc_proto->prt_connect_wait != NULL);
200	PJDLOG_ASSERT(timeout >= 0);
201
202	ret = conn->pc_proto->prt_connect_wait(conn->pc_ctx, timeout);
203	if (ret != 0) {
204		errno = ret;
205		return (-1);
206	}
207
208	return (0);
209}
210
211int
212proto_server(const char *addr, struct proto_conn **connp)
213{
214
215	return (proto_common_setup(NULL, addr, connp, PROTO_SIDE_SERVER_LISTEN));
216}
217
218int
219proto_accept(struct proto_conn *conn, struct proto_conn **newconnp)
220{
221	struct proto_conn *newconn;
222	int ret;
223
224	PJDLOG_ASSERT(conn != NULL);
225	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
226	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_SERVER_LISTEN);
227	PJDLOG_ASSERT(conn->pc_proto != NULL);
228	PJDLOG_ASSERT(conn->pc_proto->prt_accept != NULL);
229
230	newconn = proto_alloc(conn->pc_proto, PROTO_SIDE_SERVER_WORK);
231	if (newconn == NULL)
232		return (-1);
233
234	ret = conn->pc_proto->prt_accept(conn->pc_ctx, &newconn->pc_ctx);
235	if (ret != 0) {
236		proto_free(newconn);
237		errno = ret;
238		return (-1);
239	}
240
241	*newconnp = newconn;
242
243	return (0);
244}
245
246int
247proto_send(const struct proto_conn *conn, const void *data, size_t size)
248{
249	int ret;
250
251	PJDLOG_ASSERT(conn != NULL);
252	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
253	PJDLOG_ASSERT(conn->pc_proto != NULL);
254	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
255
256	ret = conn->pc_proto->prt_send(conn->pc_ctx, data, size, -1);
257	if (ret != 0) {
258		errno = ret;
259		return (-1);
260	}
261	return (0);
262}
263
264int
265proto_recv(const struct proto_conn *conn, void *data, size_t size)
266{
267	int ret;
268
269	PJDLOG_ASSERT(conn != NULL);
270	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
271	PJDLOG_ASSERT(conn->pc_proto != NULL);
272	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
273
274	ret = conn->pc_proto->prt_recv(conn->pc_ctx, data, size, NULL);
275	if (ret != 0) {
276		errno = ret;
277		return (-1);
278	}
279	return (0);
280}
281
282int
283proto_connection_send(const struct proto_conn *conn, struct proto_conn *mconn)
284{
285	const char *protoname;
286	int ret, fd;
287
288	PJDLOG_ASSERT(conn != NULL);
289	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
290	PJDLOG_ASSERT(conn->pc_proto != NULL);
291	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
292	PJDLOG_ASSERT(mconn != NULL);
293	PJDLOG_ASSERT(mconn->pc_magic == PROTO_CONN_MAGIC);
294	PJDLOG_ASSERT(mconn->pc_proto != NULL);
295	fd = proto_descriptor(mconn);
296	PJDLOG_ASSERT(fd >= 0);
297	protoname = mconn->pc_proto->prt_name;
298	PJDLOG_ASSERT(protoname != NULL);
299
300	ret = conn->pc_proto->prt_send(conn->pc_ctx,
301	    (const unsigned char *)protoname, strlen(protoname) + 1, fd);
302	proto_close(mconn);
303	if (ret != 0) {
304		errno = ret;
305		return (-1);
306	}
307	return (0);
308}
309
310int
311proto_connection_recv(const struct proto_conn *conn, bool client,
312    struct proto_conn **newconnp)
313{
314	char protoname[128];
315	struct proto *proto;
316	struct proto_conn *newconn;
317	int ret, fd;
318
319	PJDLOG_ASSERT(conn != NULL);
320	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
321	PJDLOG_ASSERT(conn->pc_proto != NULL);
322	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
323	PJDLOG_ASSERT(newconnp != NULL);
324
325	bzero(protoname, sizeof(protoname));
326
327	ret = conn->pc_proto->prt_recv(conn->pc_ctx, (unsigned char *)protoname,
328	    sizeof(protoname) - 1, &fd);
329	if (ret != 0) {
330		errno = ret;
331		return (-1);
332	}
333
334	PJDLOG_ASSERT(fd >= 0);
335
336	TAILQ_FOREACH(proto, &protos, prt_next) {
337		if (strcmp(proto->prt_name, protoname) == 0)
338			break;
339	}
340	if (proto == NULL) {
341		errno = EINVAL;
342		return (-1);
343	}
344
345	newconn = proto_alloc(proto,
346	    client ? PROTO_SIDE_CLIENT : PROTO_SIDE_SERVER_WORK);
347	if (newconn == NULL)
348		return (-1);
349	PJDLOG_ASSERT(newconn->pc_proto->prt_wrap != NULL);
350	ret = newconn->pc_proto->prt_wrap(fd, client, &newconn->pc_ctx);
351	if (ret != 0) {
352		proto_free(newconn);
353		errno = ret;
354		return (-1);
355	}
356
357	*newconnp = newconn;
358
359	return (0);
360}
361
362int
363proto_descriptor(const struct proto_conn *conn)
364{
365
366	PJDLOG_ASSERT(conn != NULL);
367	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
368	PJDLOG_ASSERT(conn->pc_proto != NULL);
369	PJDLOG_ASSERT(conn->pc_proto->prt_descriptor != NULL);
370
371	return (conn->pc_proto->prt_descriptor(conn->pc_ctx));
372}
373
374bool
375proto_address_match(const struct proto_conn *conn, const char *addr)
376{
377
378	PJDLOG_ASSERT(conn != NULL);
379	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
380	PJDLOG_ASSERT(conn->pc_proto != NULL);
381	PJDLOG_ASSERT(conn->pc_proto->prt_address_match != NULL);
382
383	return (conn->pc_proto->prt_address_match(conn->pc_ctx, addr));
384}
385
386void
387proto_local_address(const struct proto_conn *conn, char *addr, size_t size)
388{
389
390	PJDLOG_ASSERT(conn != NULL);
391	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
392	PJDLOG_ASSERT(conn->pc_proto != NULL);
393	PJDLOG_ASSERT(conn->pc_proto->prt_local_address != NULL);
394
395	conn->pc_proto->prt_local_address(conn->pc_ctx, addr, size);
396}
397
398void
399proto_remote_address(const struct proto_conn *conn, char *addr, size_t size)
400{
401
402	PJDLOG_ASSERT(conn != NULL);
403	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
404	PJDLOG_ASSERT(conn->pc_proto != NULL);
405	PJDLOG_ASSERT(conn->pc_proto->prt_remote_address != NULL);
406
407	conn->pc_proto->prt_remote_address(conn->pc_ctx, addr, size);
408}
409
410int
411proto_timeout(const struct proto_conn *conn, int timeout)
412{
413	struct timeval tv;
414	int fd;
415
416	PJDLOG_ASSERT(conn != NULL);
417	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
418	PJDLOG_ASSERT(conn->pc_proto != NULL);
419
420	fd = proto_descriptor(conn);
421	if (fd == -1)
422		return (-1);
423
424	tv.tv_sec = timeout;
425	tv.tv_usec = 0;
426	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) == -1)
427		return (-1);
428	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) == -1)
429		return (-1);
430
431	return (0);
432}
433
434void
435proto_close(struct proto_conn *conn)
436{
437
438	PJDLOG_ASSERT(conn != NULL);
439	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
440	PJDLOG_ASSERT(conn->pc_proto != NULL);
441	PJDLOG_ASSERT(conn->pc_proto->prt_close != NULL);
442
443	conn->pc_proto->prt_close(conn->pc_ctx);
444	proto_free(conn);
445}
446