1/*-
2 * Copyright (c) 2009-2010 The FreeBSD Foundation
3 * All rights reserved.
4 *
5 * This software was developed by Pawel Jakub Dawidek under sponsorship from
6 * the FreeBSD Foundation.
7 *
8 * Redistribution and use in source and binary forms, with or without
9 * modification, are permitted provided that the following conditions
10 * are met:
11 * 1. Redistributions of source code must retain the above copyright
12 *    notice, this list of conditions and the following disclaimer.
13 * 2. Redistributions in binary form must reproduce the above copyright
14 *    notice, this list of conditions and the following disclaimer in the
15 *    documentation and/or other materials provided with the distribution.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27 * SUCH DAMAGE.
28 */
29
30#include <sys/cdefs.h>
31__FBSDID("$FreeBSD$");
32
33#include <sys/types.h>
34#include <sys/queue.h>
35#include <sys/socket.h>
36
37#include <errno.h>
38#include <stdint.h>
39#include <string.h>
40#include <strings.h>
41
42#include "pjdlog.h"
43#include "proto.h"
44#include "proto_impl.h"
45
46#define	PROTO_CONN_MAGIC	0x907041c
47struct proto_conn {
48	int		 pc_magic;
49	struct proto	*pc_proto;
50	void		*pc_ctx;
51	int		 pc_side;
52#define	PROTO_SIDE_CLIENT		0
53#define	PROTO_SIDE_SERVER_LISTEN	1
54#define	PROTO_SIDE_SERVER_WORK		2
55};
56
57static TAILQ_HEAD(, proto) protos = TAILQ_HEAD_INITIALIZER(protos);
58
59void
60proto_register(struct proto *proto, bool isdefault)
61{
62	static bool seen_default = false;
63
64	if (!isdefault)
65		TAILQ_INSERT_HEAD(&protos, proto, prt_next);
66	else {
67		PJDLOG_ASSERT(!seen_default);
68		seen_default = true;
69		TAILQ_INSERT_TAIL(&protos, proto, prt_next);
70	}
71}
72
73static struct proto_conn *
74proto_alloc(struct proto *proto, int side)
75{
76	struct proto_conn *conn;
77
78	PJDLOG_ASSERT(proto != NULL);
79	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
80	    side == PROTO_SIDE_SERVER_LISTEN ||
81	    side == PROTO_SIDE_SERVER_WORK);
82
83	conn = malloc(sizeof(*conn));
84	if (conn != NULL) {
85		conn->pc_proto = proto;
86		conn->pc_side = side;
87		conn->pc_magic = PROTO_CONN_MAGIC;
88	}
89	return (conn);
90}
91
92static void
93proto_free(struct proto_conn *conn)
94{
95
96	PJDLOG_ASSERT(conn != NULL);
97	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
98	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT ||
99	    conn->pc_side == PROTO_SIDE_SERVER_LISTEN ||
100	    conn->pc_side == PROTO_SIDE_SERVER_WORK);
101	PJDLOG_ASSERT(conn->pc_proto != NULL);
102
103	bzero(conn, sizeof(*conn));
104	free(conn);
105}
106
107static int
108proto_common_setup(const char *srcaddr, const char *dstaddr,
109    struct proto_conn **connp, int side)
110{
111	struct proto *proto;
112	struct proto_conn *conn;
113	void *ctx;
114	int ret;
115
116	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
117	    side == PROTO_SIDE_SERVER_LISTEN);
118
119	TAILQ_FOREACH(proto, &protos, prt_next) {
120		if (side == PROTO_SIDE_CLIENT) {
121			if (proto->prt_client == NULL)
122				ret = -1;
123			else
124				ret = proto->prt_client(srcaddr, dstaddr, &ctx);
125		} else /* if (side == PROTO_SIDE_SERVER_LISTEN) */ {
126			if (proto->prt_server == NULL)
127				ret = -1;
128			else
129				ret = proto->prt_server(dstaddr, &ctx);
130		}
131		/*
132		 * ret == 0  - success
133		 * ret == -1 - dstaddr is not for this protocol
134		 * ret > 0   - right protocol, but an error occurred
135		 */
136		if (ret >= 0)
137			break;
138	}
139	if (proto == NULL) {
140		/* Unrecognized address. */
141		errno = EINVAL;
142		return (-1);
143	}
144	if (ret > 0) {
145		/* An error occurred. */
146		errno = ret;
147		return (-1);
148	}
149	conn = proto_alloc(proto, side);
150	if (conn == NULL) {
151		if (proto->prt_close != NULL)
152			proto->prt_close(ctx);
153		errno = ENOMEM;
154		return (-1);
155	}
156	conn->pc_ctx = ctx;
157	*connp = conn;
158
159	return (0);
160}
161
162int
163proto_client(const char *srcaddr, const char *dstaddr,
164    struct proto_conn **connp)
165{
166
167	return (proto_common_setup(srcaddr, dstaddr, connp, PROTO_SIDE_CLIENT));
168}
169
170int
171proto_connect(struct proto_conn *conn, int timeout)
172{
173	int ret;
174
175	PJDLOG_ASSERT(conn != NULL);
176	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
177	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
178	PJDLOG_ASSERT(conn->pc_proto != NULL);
179	PJDLOG_ASSERT(conn->pc_proto->prt_connect != NULL);
180	PJDLOG_ASSERT(timeout >= -1);
181
182	ret = conn->pc_proto->prt_connect(conn->pc_ctx, timeout);
183	if (ret != 0) {
184		errno = ret;
185		return (-1);
186	}
187
188	return (0);
189}
190
191int
192proto_connect_wait(struct proto_conn *conn, int timeout)
193{
194	int ret;
195
196	PJDLOG_ASSERT(conn != NULL);
197	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
198	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
199	PJDLOG_ASSERT(conn->pc_proto != NULL);
200	PJDLOG_ASSERT(conn->pc_proto->prt_connect_wait != NULL);
201	PJDLOG_ASSERT(timeout >= 0);
202
203	ret = conn->pc_proto->prt_connect_wait(conn->pc_ctx, timeout);
204	if (ret != 0) {
205		errno = ret;
206		return (-1);
207	}
208
209	return (0);
210}
211
212int
213proto_server(const char *addr, struct proto_conn **connp)
214{
215
216	return (proto_common_setup(NULL, addr, connp, PROTO_SIDE_SERVER_LISTEN));
217}
218
219int
220proto_accept(struct proto_conn *conn, struct proto_conn **newconnp)
221{
222	struct proto_conn *newconn;
223	int ret;
224
225	PJDLOG_ASSERT(conn != NULL);
226	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
227	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_SERVER_LISTEN);
228	PJDLOG_ASSERT(conn->pc_proto != NULL);
229	PJDLOG_ASSERT(conn->pc_proto->prt_accept != NULL);
230
231	newconn = proto_alloc(conn->pc_proto, PROTO_SIDE_SERVER_WORK);
232	if (newconn == NULL)
233		return (-1);
234
235	ret = conn->pc_proto->prt_accept(conn->pc_ctx, &newconn->pc_ctx);
236	if (ret != 0) {
237		proto_free(newconn);
238		errno = ret;
239		return (-1);
240	}
241
242	*newconnp = newconn;
243
244	return (0);
245}
246
247int
248proto_send(const struct proto_conn *conn, const void *data, size_t size)
249{
250	int ret;
251
252	PJDLOG_ASSERT(conn != NULL);
253	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
254	PJDLOG_ASSERT(conn->pc_proto != NULL);
255	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
256
257	ret = conn->pc_proto->prt_send(conn->pc_ctx, data, size, -1);
258	if (ret != 0) {
259		errno = ret;
260		return (-1);
261	}
262	return (0);
263}
264
265int
266proto_recv(const struct proto_conn *conn, void *data, size_t size)
267{
268	int ret;
269
270	PJDLOG_ASSERT(conn != NULL);
271	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
272	PJDLOG_ASSERT(conn->pc_proto != NULL);
273	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
274
275	ret = conn->pc_proto->prt_recv(conn->pc_ctx, data, size, NULL);
276	if (ret != 0) {
277		errno = ret;
278		return (-1);
279	}
280	return (0);
281}
282
283int
284proto_connection_send(const struct proto_conn *conn, struct proto_conn *mconn)
285{
286	const char *protoname;
287	int ret, fd;
288
289	PJDLOG_ASSERT(conn != NULL);
290	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
291	PJDLOG_ASSERT(conn->pc_proto != NULL);
292	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
293	PJDLOG_ASSERT(mconn != NULL);
294	PJDLOG_ASSERT(mconn->pc_magic == PROTO_CONN_MAGIC);
295	PJDLOG_ASSERT(mconn->pc_proto != NULL);
296	fd = proto_descriptor(mconn);
297	PJDLOG_ASSERT(fd >= 0);
298	protoname = mconn->pc_proto->prt_name;
299	PJDLOG_ASSERT(protoname != NULL);
300
301	ret = conn->pc_proto->prt_send(conn->pc_ctx, protoname,
302	    strlen(protoname) + 1, fd);
303	proto_close(mconn);
304	if (ret != 0) {
305		errno = ret;
306		return (-1);
307	}
308	return (0);
309}
310
311int
312proto_connection_recv(const struct proto_conn *conn, bool client,
313    struct proto_conn **newconnp)
314{
315	char protoname[128];
316	struct proto *proto;
317	struct proto_conn *newconn;
318	int ret, fd;
319
320	PJDLOG_ASSERT(conn != NULL);
321	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
322	PJDLOG_ASSERT(conn->pc_proto != NULL);
323	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
324	PJDLOG_ASSERT(newconnp != NULL);
325
326	bzero(protoname, sizeof(protoname));
327
328	ret = conn->pc_proto->prt_recv(conn->pc_ctx, protoname,
329	    sizeof(protoname) - 1, &fd);
330	if (ret != 0) {
331		errno = ret;
332		return (-1);
333	}
334
335	PJDLOG_ASSERT(fd >= 0);
336
337	TAILQ_FOREACH(proto, &protos, prt_next) {
338		if (strcmp(proto->prt_name, protoname) == 0)
339			break;
340	}
341	if (proto == NULL) {
342		errno = EINVAL;
343		return (-1);
344	}
345
346	newconn = proto_alloc(proto,
347	    client ? PROTO_SIDE_CLIENT : PROTO_SIDE_SERVER_WORK);
348	if (newconn == NULL)
349		return (-1);
350	PJDLOG_ASSERT(newconn->pc_proto->prt_wrap != NULL);
351	ret = newconn->pc_proto->prt_wrap(fd, client, &newconn->pc_ctx);
352	if (ret != 0) {
353		proto_free(newconn);
354		errno = ret;
355		return (-1);
356	}
357
358	*newconnp = newconn;
359
360	return (0);
361}
362
363int
364proto_descriptor(const struct proto_conn *conn)
365{
366
367	PJDLOG_ASSERT(conn != NULL);
368	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
369	PJDLOG_ASSERT(conn->pc_proto != NULL);
370	PJDLOG_ASSERT(conn->pc_proto->prt_descriptor != NULL);
371
372	return (conn->pc_proto->prt_descriptor(conn->pc_ctx));
373}
374
375bool
376proto_address_match(const struct proto_conn *conn, const char *addr)
377{
378
379	PJDLOG_ASSERT(conn != NULL);
380	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
381	PJDLOG_ASSERT(conn->pc_proto != NULL);
382	PJDLOG_ASSERT(conn->pc_proto->prt_address_match != NULL);
383
384	return (conn->pc_proto->prt_address_match(conn->pc_ctx, addr));
385}
386
387void
388proto_local_address(const struct proto_conn *conn, char *addr, size_t size)
389{
390
391	PJDLOG_ASSERT(conn != NULL);
392	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
393	PJDLOG_ASSERT(conn->pc_proto != NULL);
394	PJDLOG_ASSERT(conn->pc_proto->prt_local_address != NULL);
395
396	conn->pc_proto->prt_local_address(conn->pc_ctx, addr, size);
397}
398
399void
400proto_remote_address(const struct proto_conn *conn, char *addr, size_t size)
401{
402
403	PJDLOG_ASSERT(conn != NULL);
404	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
405	PJDLOG_ASSERT(conn->pc_proto != NULL);
406	PJDLOG_ASSERT(conn->pc_proto->prt_remote_address != NULL);
407
408	conn->pc_proto->prt_remote_address(conn->pc_ctx, addr, size);
409}
410
411int
412proto_timeout(const struct proto_conn *conn, int timeout)
413{
414	struct timeval tv;
415	int fd;
416
417	PJDLOG_ASSERT(conn != NULL);
418	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
419	PJDLOG_ASSERT(conn->pc_proto != NULL);
420
421	fd = proto_descriptor(conn);
422	if (fd == -1)
423		return (-1);
424
425	tv.tv_sec = timeout;
426	tv.tv_usec = 0;
427	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) == -1)
428		return (-1);
429	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) == -1)
430		return (-1);
431
432	return (0);
433}
434
435void
436proto_close(struct proto_conn *conn)
437{
438
439	PJDLOG_ASSERT(conn != NULL);
440	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
441	PJDLOG_ASSERT(conn->pc_proto != NULL);
442	PJDLOG_ASSERT(conn->pc_proto->prt_close != NULL);
443
444	conn->pc_proto->prt_close(conn->pc_ctx);
445	proto_free(conn);
446}
447