proto.c revision 330449
1/*-
2 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3 *
4 * Copyright (c) 2009-2010 The FreeBSD Foundation
5 * All rights reserved.
6 *
7 * This software was developed by Pawel Jakub Dawidek under sponsorship from
8 * the FreeBSD Foundation.
9 *
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
12 * are met:
13 * 1. Redistributions of source code must retain the above copyright
14 *    notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 *    notice, this list of conditions and the following disclaimer in the
17 *    documentation and/or other materials provided with the distribution.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
20 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
23 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
25 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
26 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
27 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
28 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
29 * SUCH DAMAGE.
30 */
31
32#include <sys/cdefs.h>
33__FBSDID("$FreeBSD: stable/11/sbin/hastd/proto.c 330449 2018-03-05 07:26:05Z eadler $");
34
35#include <sys/types.h>
36#include <sys/queue.h>
37#include <sys/socket.h>
38
39#include <errno.h>
40#include <stdint.h>
41#include <string.h>
42#include <strings.h>
43
44#include "pjdlog.h"
45#include "proto.h"
46#include "proto_impl.h"
47
48#define	PROTO_CONN_MAGIC	0x907041c
49struct proto_conn {
50	int		 pc_magic;
51	struct proto	*pc_proto;
52	void		*pc_ctx;
53	int		 pc_side;
54#define	PROTO_SIDE_CLIENT		0
55#define	PROTO_SIDE_SERVER_LISTEN	1
56#define	PROTO_SIDE_SERVER_WORK		2
57};
58
59static TAILQ_HEAD(, proto) protos = TAILQ_HEAD_INITIALIZER(protos);
60
61void
62proto_register(struct proto *proto, bool isdefault)
63{
64	static bool seen_default = false;
65
66	if (!isdefault)
67		TAILQ_INSERT_HEAD(&protos, proto, prt_next);
68	else {
69		PJDLOG_ASSERT(!seen_default);
70		seen_default = true;
71		TAILQ_INSERT_TAIL(&protos, proto, prt_next);
72	}
73}
74
75static struct proto_conn *
76proto_alloc(struct proto *proto, int side)
77{
78	struct proto_conn *conn;
79
80	PJDLOG_ASSERT(proto != NULL);
81	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
82	    side == PROTO_SIDE_SERVER_LISTEN ||
83	    side == PROTO_SIDE_SERVER_WORK);
84
85	conn = malloc(sizeof(*conn));
86	if (conn != NULL) {
87		conn->pc_proto = proto;
88		conn->pc_side = side;
89		conn->pc_magic = PROTO_CONN_MAGIC;
90	}
91	return (conn);
92}
93
94static void
95proto_free(struct proto_conn *conn)
96{
97
98	PJDLOG_ASSERT(conn != NULL);
99	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
100	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT ||
101	    conn->pc_side == PROTO_SIDE_SERVER_LISTEN ||
102	    conn->pc_side == PROTO_SIDE_SERVER_WORK);
103	PJDLOG_ASSERT(conn->pc_proto != NULL);
104
105	bzero(conn, sizeof(*conn));
106	free(conn);
107}
108
109static int
110proto_common_setup(const char *srcaddr, const char *dstaddr,
111    struct proto_conn **connp, int side)
112{
113	struct proto *proto;
114	struct proto_conn *conn;
115	void *ctx;
116	int ret;
117
118	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
119	    side == PROTO_SIDE_SERVER_LISTEN);
120
121	TAILQ_FOREACH(proto, &protos, prt_next) {
122		if (side == PROTO_SIDE_CLIENT) {
123			if (proto->prt_client == NULL)
124				ret = -1;
125			else
126				ret = proto->prt_client(srcaddr, dstaddr, &ctx);
127		} else /* if (side == PROTO_SIDE_SERVER_LISTEN) */ {
128			if (proto->prt_server == NULL)
129				ret = -1;
130			else
131				ret = proto->prt_server(dstaddr, &ctx);
132		}
133		/*
134		 * ret == 0  - success
135		 * ret == -1 - dstaddr is not for this protocol
136		 * ret > 0   - right protocol, but an error occurred
137		 */
138		if (ret >= 0)
139			break;
140	}
141	if (proto == NULL) {
142		/* Unrecognized address. */
143		errno = EINVAL;
144		return (-1);
145	}
146	if (ret > 0) {
147		/* An error occurred. */
148		errno = ret;
149		return (-1);
150	}
151	conn = proto_alloc(proto, side);
152	if (conn == NULL) {
153		if (proto->prt_close != NULL)
154			proto->prt_close(ctx);
155		errno = ENOMEM;
156		return (-1);
157	}
158	conn->pc_ctx = ctx;
159	*connp = conn;
160
161	return (0);
162}
163
164int
165proto_client(const char *srcaddr, const char *dstaddr,
166    struct proto_conn **connp)
167{
168
169	return (proto_common_setup(srcaddr, dstaddr, connp, PROTO_SIDE_CLIENT));
170}
171
172int
173proto_connect(struct proto_conn *conn, int timeout)
174{
175	int ret;
176
177	PJDLOG_ASSERT(conn != NULL);
178	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
179	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
180	PJDLOG_ASSERT(conn->pc_proto != NULL);
181	PJDLOG_ASSERT(conn->pc_proto->prt_connect != NULL);
182	PJDLOG_ASSERT(timeout >= -1);
183
184	ret = conn->pc_proto->prt_connect(conn->pc_ctx, timeout);
185	if (ret != 0) {
186		errno = ret;
187		return (-1);
188	}
189
190	return (0);
191}
192
193int
194proto_connect_wait(struct proto_conn *conn, int timeout)
195{
196	int ret;
197
198	PJDLOG_ASSERT(conn != NULL);
199	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
200	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
201	PJDLOG_ASSERT(conn->pc_proto != NULL);
202	PJDLOG_ASSERT(conn->pc_proto->prt_connect_wait != NULL);
203	PJDLOG_ASSERT(timeout >= 0);
204
205	ret = conn->pc_proto->prt_connect_wait(conn->pc_ctx, timeout);
206	if (ret != 0) {
207		errno = ret;
208		return (-1);
209	}
210
211	return (0);
212}
213
214int
215proto_server(const char *addr, struct proto_conn **connp)
216{
217
218	return (proto_common_setup(NULL, addr, connp, PROTO_SIDE_SERVER_LISTEN));
219}
220
221int
222proto_accept(struct proto_conn *conn, struct proto_conn **newconnp)
223{
224	struct proto_conn *newconn;
225	int ret;
226
227	PJDLOG_ASSERT(conn != NULL);
228	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
229	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_SERVER_LISTEN);
230	PJDLOG_ASSERT(conn->pc_proto != NULL);
231	PJDLOG_ASSERT(conn->pc_proto->prt_accept != NULL);
232
233	newconn = proto_alloc(conn->pc_proto, PROTO_SIDE_SERVER_WORK);
234	if (newconn == NULL)
235		return (-1);
236
237	ret = conn->pc_proto->prt_accept(conn->pc_ctx, &newconn->pc_ctx);
238	if (ret != 0) {
239		proto_free(newconn);
240		errno = ret;
241		return (-1);
242	}
243
244	*newconnp = newconn;
245
246	return (0);
247}
248
249int
250proto_send(const struct proto_conn *conn, const void *data, size_t size)
251{
252	int ret;
253
254	PJDLOG_ASSERT(conn != NULL);
255	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
256	PJDLOG_ASSERT(conn->pc_proto != NULL);
257	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
258
259	ret = conn->pc_proto->prt_send(conn->pc_ctx, data, size, -1);
260	if (ret != 0) {
261		errno = ret;
262		return (-1);
263	}
264	return (0);
265}
266
267int
268proto_recv(const struct proto_conn *conn, void *data, size_t size)
269{
270	int ret;
271
272	PJDLOG_ASSERT(conn != NULL);
273	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
274	PJDLOG_ASSERT(conn->pc_proto != NULL);
275	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
276
277	ret = conn->pc_proto->prt_recv(conn->pc_ctx, data, size, NULL);
278	if (ret != 0) {
279		errno = ret;
280		return (-1);
281	}
282	return (0);
283}
284
285int
286proto_connection_send(const struct proto_conn *conn, struct proto_conn *mconn)
287{
288	const char *protoname;
289	int ret, fd;
290
291	PJDLOG_ASSERT(conn != NULL);
292	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
293	PJDLOG_ASSERT(conn->pc_proto != NULL);
294	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
295	PJDLOG_ASSERT(mconn != NULL);
296	PJDLOG_ASSERT(mconn->pc_magic == PROTO_CONN_MAGIC);
297	PJDLOG_ASSERT(mconn->pc_proto != NULL);
298	fd = proto_descriptor(mconn);
299	PJDLOG_ASSERT(fd >= 0);
300	protoname = mconn->pc_proto->prt_name;
301	PJDLOG_ASSERT(protoname != NULL);
302
303	ret = conn->pc_proto->prt_send(conn->pc_ctx,
304	    (const unsigned char *)protoname, strlen(protoname) + 1, fd);
305	proto_close(mconn);
306	if (ret != 0) {
307		errno = ret;
308		return (-1);
309	}
310	return (0);
311}
312
313int
314proto_connection_recv(const struct proto_conn *conn, bool client,
315    struct proto_conn **newconnp)
316{
317	char protoname[128];
318	struct proto *proto;
319	struct proto_conn *newconn;
320	int ret, fd;
321
322	PJDLOG_ASSERT(conn != NULL);
323	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
324	PJDLOG_ASSERT(conn->pc_proto != NULL);
325	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
326	PJDLOG_ASSERT(newconnp != NULL);
327
328	bzero(protoname, sizeof(protoname));
329
330	ret = conn->pc_proto->prt_recv(conn->pc_ctx, (unsigned char *)protoname,
331	    sizeof(protoname) - 1, &fd);
332	if (ret != 0) {
333		errno = ret;
334		return (-1);
335	}
336
337	PJDLOG_ASSERT(fd >= 0);
338
339	TAILQ_FOREACH(proto, &protos, prt_next) {
340		if (strcmp(proto->prt_name, protoname) == 0)
341			break;
342	}
343	if (proto == NULL) {
344		errno = EINVAL;
345		return (-1);
346	}
347
348	newconn = proto_alloc(proto,
349	    client ? PROTO_SIDE_CLIENT : PROTO_SIDE_SERVER_WORK);
350	if (newconn == NULL)
351		return (-1);
352	PJDLOG_ASSERT(newconn->pc_proto->prt_wrap != NULL);
353	ret = newconn->pc_proto->prt_wrap(fd, client, &newconn->pc_ctx);
354	if (ret != 0) {
355		proto_free(newconn);
356		errno = ret;
357		return (-1);
358	}
359
360	*newconnp = newconn;
361
362	return (0);
363}
364
365int
366proto_descriptor(const struct proto_conn *conn)
367{
368
369	PJDLOG_ASSERT(conn != NULL);
370	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
371	PJDLOG_ASSERT(conn->pc_proto != NULL);
372	PJDLOG_ASSERT(conn->pc_proto->prt_descriptor != NULL);
373
374	return (conn->pc_proto->prt_descriptor(conn->pc_ctx));
375}
376
377bool
378proto_address_match(const struct proto_conn *conn, const char *addr)
379{
380
381	PJDLOG_ASSERT(conn != NULL);
382	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
383	PJDLOG_ASSERT(conn->pc_proto != NULL);
384	PJDLOG_ASSERT(conn->pc_proto->prt_address_match != NULL);
385
386	return (conn->pc_proto->prt_address_match(conn->pc_ctx, addr));
387}
388
389void
390proto_local_address(const struct proto_conn *conn, char *addr, size_t size)
391{
392
393	PJDLOG_ASSERT(conn != NULL);
394	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
395	PJDLOG_ASSERT(conn->pc_proto != NULL);
396	PJDLOG_ASSERT(conn->pc_proto->prt_local_address != NULL);
397
398	conn->pc_proto->prt_local_address(conn->pc_ctx, addr, size);
399}
400
401void
402proto_remote_address(const struct proto_conn *conn, char *addr, size_t size)
403{
404
405	PJDLOG_ASSERT(conn != NULL);
406	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
407	PJDLOG_ASSERT(conn->pc_proto != NULL);
408	PJDLOG_ASSERT(conn->pc_proto->prt_remote_address != NULL);
409
410	conn->pc_proto->prt_remote_address(conn->pc_ctx, addr, size);
411}
412
413int
414proto_timeout(const struct proto_conn *conn, int timeout)
415{
416	struct timeval tv;
417	int fd;
418
419	PJDLOG_ASSERT(conn != NULL);
420	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
421	PJDLOG_ASSERT(conn->pc_proto != NULL);
422
423	fd = proto_descriptor(conn);
424	if (fd == -1)
425		return (-1);
426
427	tv.tv_sec = timeout;
428	tv.tv_usec = 0;
429	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) == -1)
430		return (-1);
431	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) == -1)
432		return (-1);
433
434	return (0);
435}
436
437void
438proto_close(struct proto_conn *conn)
439{
440
441	PJDLOG_ASSERT(conn != NULL);
442	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
443	PJDLOG_ASSERT(conn->pc_proto != NULL);
444	PJDLOG_ASSERT(conn->pc_proto->prt_close != NULL);
445
446	conn->pc_proto->prt_close(conn->pc_ctx);
447	proto_free(conn);
448}
449