1204076Spjd/*-
2330449Seadler * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3330449Seadler *
4204076Spjd * Copyright (c) 2009-2010 The FreeBSD Foundation
5204076Spjd * All rights reserved.
6204076Spjd *
7204076Spjd * This software was developed by Pawel Jakub Dawidek under sponsorship from
8204076Spjd * the FreeBSD Foundation.
9204076Spjd *
10204076Spjd * Redistribution and use in source and binary forms, with or without
11204076Spjd * modification, are permitted provided that the following conditions
12204076Spjd * are met:
13204076Spjd * 1. Redistributions of source code must retain the above copyright
14204076Spjd *    notice, this list of conditions and the following disclaimer.
15204076Spjd * 2. Redistributions in binary form must reproduce the above copyright
16204076Spjd *    notice, this list of conditions and the following disclaimer in the
17204076Spjd *    documentation and/or other materials provided with the distribution.
18204076Spjd *
19204076Spjd * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
20204076Spjd * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21204076Spjd * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22204076Spjd * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
23204076Spjd * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24204076Spjd * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
25204076Spjd * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
26204076Spjd * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
27204076Spjd * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
28204076Spjd * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
29204076Spjd * SUCH DAMAGE.
30204076Spjd */
31204076Spjd
32204076Spjd#include <sys/cdefs.h>
33204076Spjd__FBSDID("$FreeBSD: stable/11/sbin/hastd/proto.c 330449 2018-03-05 07:26:05Z eadler $");
34204076Spjd
35207371Spjd#include <sys/types.h>
36204076Spjd#include <sys/queue.h>
37207371Spjd#include <sys/socket.h>
38204076Spjd
39204076Spjd#include <errno.h>
40204076Spjd#include <stdint.h>
41218194Spjd#include <string.h>
42218191Spjd#include <strings.h>
43204076Spjd
44218138Spjd#include "pjdlog.h"
45204076Spjd#include "proto.h"
46204076Spjd#include "proto_impl.h"
47204076Spjd
48204076Spjd#define	PROTO_CONN_MAGIC	0x907041c
49204076Spjdstruct proto_conn {
50219873Spjd	int		 pc_magic;
51219873Spjd	struct proto	*pc_proto;
52219873Spjd	void		*pc_ctx;
53219873Spjd	int		 pc_side;
54204076Spjd#define	PROTO_SIDE_CLIENT		0
55204076Spjd#define	PROTO_SIDE_SERVER_LISTEN	1
56204076Spjd#define	PROTO_SIDE_SERVER_WORK		2
57204076Spjd};
58204076Spjd
59219873Spjdstatic TAILQ_HEAD(, proto) protos = TAILQ_HEAD_INITIALIZER(protos);
60204076Spjd
61204076Spjdvoid
62219873Spjdproto_register(struct proto *proto, bool isdefault)
63204076Spjd{
64210869Spjd	static bool seen_default = false;
65204076Spjd
66210869Spjd	if (!isdefault)
67219873Spjd		TAILQ_INSERT_HEAD(&protos, proto, prt_next);
68210869Spjd	else {
69218138Spjd		PJDLOG_ASSERT(!seen_default);
70210869Spjd		seen_default = true;
71219873Spjd		TAILQ_INSERT_TAIL(&protos, proto, prt_next);
72210869Spjd	}
73204076Spjd}
74204076Spjd
75218191Spjdstatic struct proto_conn *
76219873Spjdproto_alloc(struct proto *proto, int side)
77218191Spjd{
78218191Spjd	struct proto_conn *conn;
79218191Spjd
80218191Spjd	PJDLOG_ASSERT(proto != NULL);
81218191Spjd	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
82218191Spjd	    side == PROTO_SIDE_SERVER_LISTEN ||
83218191Spjd	    side == PROTO_SIDE_SERVER_WORK);
84218191Spjd
85218191Spjd	conn = malloc(sizeof(*conn));
86218191Spjd	if (conn != NULL) {
87218191Spjd		conn->pc_proto = proto;
88218191Spjd		conn->pc_side = side;
89218191Spjd		conn->pc_magic = PROTO_CONN_MAGIC;
90218191Spjd	}
91218191Spjd	return (conn);
92218191Spjd}
93218191Spjd
94218191Spjdstatic void
95218191Spjdproto_free(struct proto_conn *conn)
96218191Spjd{
97218191Spjd
98218191Spjd	PJDLOG_ASSERT(conn != NULL);
99218191Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
100218191Spjd	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT ||
101218191Spjd	    conn->pc_side == PROTO_SIDE_SERVER_LISTEN ||
102218191Spjd	    conn->pc_side == PROTO_SIDE_SERVER_WORK);
103218191Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
104218191Spjd
105218191Spjd	bzero(conn, sizeof(*conn));
106218191Spjd	free(conn);
107218191Spjd}
108218191Spjd
109204076Spjdstatic int
110219818Spjdproto_common_setup(const char *srcaddr, const char *dstaddr,
111219818Spjd    struct proto_conn **connp, int side)
112204076Spjd{
113219873Spjd	struct proto *proto;
114204076Spjd	struct proto_conn *conn;
115204076Spjd	void *ctx;
116204076Spjd	int ret;
117204076Spjd
118218191Spjd	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
119218191Spjd	    side == PROTO_SIDE_SERVER_LISTEN);
120204076Spjd
121219873Spjd	TAILQ_FOREACH(proto, &protos, prt_next) {
122218185Spjd		if (side == PROTO_SIDE_CLIENT) {
123219873Spjd			if (proto->prt_client == NULL)
124218185Spjd				ret = -1;
125218185Spjd			else
126219873Spjd				ret = proto->prt_client(srcaddr, dstaddr, &ctx);
127218185Spjd		} else /* if (side == PROTO_SIDE_SERVER_LISTEN) */ {
128219873Spjd			if (proto->prt_server == NULL)
129218185Spjd				ret = -1;
130218185Spjd			else
131219873Spjd				ret = proto->prt_server(dstaddr, &ctx);
132218185Spjd		}
133204076Spjd		/*
134204076Spjd		 * ret == 0  - success
135219818Spjd		 * ret == -1 - dstaddr is not for this protocol
136229778Suqs		 * ret > 0   - right protocol, but an error occurred
137204076Spjd		 */
138204076Spjd		if (ret >= 0)
139204076Spjd			break;
140204076Spjd	}
141204076Spjd	if (proto == NULL) {
142204076Spjd		/* Unrecognized address. */
143204076Spjd		errno = EINVAL;
144204076Spjd		return (-1);
145204076Spjd	}
146204076Spjd	if (ret > 0) {
147229778Suqs		/* An error occurred. */
148204076Spjd		errno = ret;
149204076Spjd		return (-1);
150204076Spjd	}
151218191Spjd	conn = proto_alloc(proto, side);
152218191Spjd	if (conn == NULL) {
153219873Spjd		if (proto->prt_close != NULL)
154219873Spjd			proto->prt_close(ctx);
155218191Spjd		errno = ENOMEM;
156218191Spjd		return (-1);
157218191Spjd	}
158204076Spjd	conn->pc_ctx = ctx;
159204076Spjd	*connp = conn;
160218191Spjd
161204076Spjd	return (0);
162204076Spjd}
163204076Spjd
164204076Spjdint
165219818Spjdproto_client(const char *srcaddr, const char *dstaddr,
166219818Spjd    struct proto_conn **connp)
167204076Spjd{
168204076Spjd
169219818Spjd	return (proto_common_setup(srcaddr, dstaddr, connp, PROTO_SIDE_CLIENT));
170204076Spjd}
171204076Spjd
172204076Spjdint
173218192Spjdproto_connect(struct proto_conn *conn, int timeout)
174204076Spjd{
175204076Spjd	int ret;
176204076Spjd
177218138Spjd	PJDLOG_ASSERT(conn != NULL);
178218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
179218138Spjd	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
180218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
181219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_connect != NULL);
182218193Spjd	PJDLOG_ASSERT(timeout >= -1);
183204076Spjd
184219873Spjd	ret = conn->pc_proto->prt_connect(conn->pc_ctx, timeout);
185204076Spjd	if (ret != 0) {
186204076Spjd		errno = ret;
187204076Spjd		return (-1);
188204076Spjd	}
189204076Spjd
190204076Spjd	return (0);
191204076Spjd}
192204076Spjd
193204076Spjdint
194218193Spjdproto_connect_wait(struct proto_conn *conn, int timeout)
195218193Spjd{
196218193Spjd	int ret;
197218193Spjd
198218193Spjd	PJDLOG_ASSERT(conn != NULL);
199218193Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
200218193Spjd	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
201218193Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
202219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_connect_wait != NULL);
203218193Spjd	PJDLOG_ASSERT(timeout >= 0);
204218193Spjd
205219873Spjd	ret = conn->pc_proto->prt_connect_wait(conn->pc_ctx, timeout);
206218193Spjd	if (ret != 0) {
207218193Spjd		errno = ret;
208218193Spjd		return (-1);
209218193Spjd	}
210218193Spjd
211218193Spjd	return (0);
212218193Spjd}
213218193Spjd
214218193Spjdint
215204076Spjdproto_server(const char *addr, struct proto_conn **connp)
216204076Spjd{
217204076Spjd
218219818Spjd	return (proto_common_setup(NULL, addr, connp, PROTO_SIDE_SERVER_LISTEN));
219204076Spjd}
220204076Spjd
221204076Spjdint
222204076Spjdproto_accept(struct proto_conn *conn, struct proto_conn **newconnp)
223204076Spjd{
224204076Spjd	struct proto_conn *newconn;
225204076Spjd	int ret;
226204076Spjd
227218138Spjd	PJDLOG_ASSERT(conn != NULL);
228218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
229218138Spjd	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_SERVER_LISTEN);
230218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
231219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_accept != NULL);
232204076Spjd
233218191Spjd	newconn = proto_alloc(conn->pc_proto, PROTO_SIDE_SERVER_WORK);
234204076Spjd	if (newconn == NULL)
235204076Spjd		return (-1);
236204076Spjd
237219873Spjd	ret = conn->pc_proto->prt_accept(conn->pc_ctx, &newconn->pc_ctx);
238204076Spjd	if (ret != 0) {
239218191Spjd		proto_free(newconn);
240204076Spjd		errno = ret;
241204076Spjd		return (-1);
242204076Spjd	}
243204076Spjd
244204076Spjd	*newconnp = newconn;
245204076Spjd
246204076Spjd	return (0);
247204076Spjd}
248204076Spjd
249204076Spjdint
250212033Spjdproto_send(const struct proto_conn *conn, const void *data, size_t size)
251204076Spjd{
252204076Spjd	int ret;
253204076Spjd
254218138Spjd	PJDLOG_ASSERT(conn != NULL);
255218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
256218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
257219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
258204076Spjd
259219873Spjd	ret = conn->pc_proto->prt_send(conn->pc_ctx, data, size, -1);
260204076Spjd	if (ret != 0) {
261204076Spjd		errno = ret;
262204076Spjd		return (-1);
263204076Spjd	}
264204076Spjd	return (0);
265204076Spjd}
266204076Spjd
267204076Spjdint
268212033Spjdproto_recv(const struct proto_conn *conn, void *data, size_t size)
269204076Spjd{
270204076Spjd	int ret;
271204076Spjd
272218138Spjd	PJDLOG_ASSERT(conn != NULL);
273218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
274218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
275219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
276204076Spjd
277219873Spjd	ret = conn->pc_proto->prt_recv(conn->pc_ctx, data, size, NULL);
278204076Spjd	if (ret != 0) {
279204076Spjd		errno = ret;
280204076Spjd		return (-1);
281204076Spjd	}
282204076Spjd	return (0);
283204076Spjd}
284204076Spjd
285204076Spjdint
286218194Spjdproto_connection_send(const struct proto_conn *conn, struct proto_conn *mconn)
287218139Spjd{
288218194Spjd	const char *protoname;
289218194Spjd	int ret, fd;
290218139Spjd
291218139Spjd	PJDLOG_ASSERT(conn != NULL);
292218139Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
293218139Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
294219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
295218194Spjd	PJDLOG_ASSERT(mconn != NULL);
296218194Spjd	PJDLOG_ASSERT(mconn->pc_magic == PROTO_CONN_MAGIC);
297218194Spjd	PJDLOG_ASSERT(mconn->pc_proto != NULL);
298218194Spjd	fd = proto_descriptor(mconn);
299218194Spjd	PJDLOG_ASSERT(fd >= 0);
300219873Spjd	protoname = mconn->pc_proto->prt_name;
301218194Spjd	PJDLOG_ASSERT(protoname != NULL);
302218139Spjd
303259193Strociny	ret = conn->pc_proto->prt_send(conn->pc_ctx,
304259193Strociny	    (const unsigned char *)protoname, strlen(protoname) + 1, fd);
305218194Spjd	proto_close(mconn);
306218139Spjd	if (ret != 0) {
307218139Spjd		errno = ret;
308218139Spjd		return (-1);
309218139Spjd	}
310218139Spjd	return (0);
311218139Spjd}
312218139Spjd
313218139Spjdint
314218194Spjdproto_connection_recv(const struct proto_conn *conn, bool client,
315218194Spjd    struct proto_conn **newconnp)
316218139Spjd{
317218194Spjd	char protoname[128];
318219873Spjd	struct proto *proto;
319218194Spjd	struct proto_conn *newconn;
320218194Spjd	int ret, fd;
321218139Spjd
322218139Spjd	PJDLOG_ASSERT(conn != NULL);
323218139Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
324218139Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
325219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
326218194Spjd	PJDLOG_ASSERT(newconnp != NULL);
327218139Spjd
328218194Spjd	bzero(protoname, sizeof(protoname));
329218194Spjd
330259193Strociny	ret = conn->pc_proto->prt_recv(conn->pc_ctx, (unsigned char *)protoname,
331218194Spjd	    sizeof(protoname) - 1, &fd);
332218139Spjd	if (ret != 0) {
333218139Spjd		errno = ret;
334218139Spjd		return (-1);
335218139Spjd	}
336218194Spjd
337218194Spjd	PJDLOG_ASSERT(fd >= 0);
338218194Spjd
339219873Spjd	TAILQ_FOREACH(proto, &protos, prt_next) {
340219873Spjd		if (strcmp(proto->prt_name, protoname) == 0)
341218194Spjd			break;
342218194Spjd	}
343218194Spjd	if (proto == NULL) {
344218194Spjd		errno = EINVAL;
345218194Spjd		return (-1);
346218194Spjd	}
347218194Spjd
348218194Spjd	newconn = proto_alloc(proto,
349218194Spjd	    client ? PROTO_SIDE_CLIENT : PROTO_SIDE_SERVER_WORK);
350218194Spjd	if (newconn == NULL)
351218194Spjd		return (-1);
352219873Spjd	PJDLOG_ASSERT(newconn->pc_proto->prt_wrap != NULL);
353219873Spjd	ret = newconn->pc_proto->prt_wrap(fd, client, &newconn->pc_ctx);
354218194Spjd	if (ret != 0) {
355218194Spjd		proto_free(newconn);
356218194Spjd		errno = ret;
357218194Spjd		return (-1);
358218194Spjd	}
359218194Spjd
360218194Spjd	*newconnp = newconn;
361218194Spjd
362218139Spjd	return (0);
363218139Spjd}
364218139Spjd
365218139Spjdint
366204076Spjdproto_descriptor(const struct proto_conn *conn)
367204076Spjd{
368204076Spjd
369218138Spjd	PJDLOG_ASSERT(conn != NULL);
370218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
371218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
372219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_descriptor != NULL);
373204076Spjd
374219873Spjd	return (conn->pc_proto->prt_descriptor(conn->pc_ctx));
375204076Spjd}
376204076Spjd
377204076Spjdbool
378204076Spjdproto_address_match(const struct proto_conn *conn, const char *addr)
379204076Spjd{
380204076Spjd
381218138Spjd	PJDLOG_ASSERT(conn != NULL);
382218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
383218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
384219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_address_match != NULL);
385204076Spjd
386219873Spjd	return (conn->pc_proto->prt_address_match(conn->pc_ctx, addr));
387204076Spjd}
388204076Spjd
389204076Spjdvoid
390204076Spjdproto_local_address(const struct proto_conn *conn, char *addr, size_t size)
391204076Spjd{
392204076Spjd
393218138Spjd	PJDLOG_ASSERT(conn != NULL);
394218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
395218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
396219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_local_address != NULL);
397204076Spjd
398219873Spjd	conn->pc_proto->prt_local_address(conn->pc_ctx, addr, size);
399204076Spjd}
400204076Spjd
401204076Spjdvoid
402204076Spjdproto_remote_address(const struct proto_conn *conn, char *addr, size_t size)
403204076Spjd{
404204076Spjd
405218138Spjd	PJDLOG_ASSERT(conn != NULL);
406218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
407218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
408219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_remote_address != NULL);
409204076Spjd
410219873Spjd	conn->pc_proto->prt_remote_address(conn->pc_ctx, addr, size);
411204076Spjd}
412204076Spjd
413207371Spjdint
414207371Spjdproto_timeout(const struct proto_conn *conn, int timeout)
415207371Spjd{
416207371Spjd	struct timeval tv;
417207371Spjd	int fd;
418207371Spjd
419218138Spjd	PJDLOG_ASSERT(conn != NULL);
420218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
421218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
422207371Spjd
423207371Spjd	fd = proto_descriptor(conn);
424229945Spjd	if (fd == -1)
425207371Spjd		return (-1);
426207371Spjd
427207371Spjd	tv.tv_sec = timeout;
428207371Spjd	tv.tv_usec = 0;
429229945Spjd	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) == -1)
430207371Spjd		return (-1);
431229945Spjd	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) == -1)
432207371Spjd		return (-1);
433207371Spjd
434207371Spjd	return (0);
435207371Spjd}
436207371Spjd
437204076Spjdvoid
438204076Spjdproto_close(struct proto_conn *conn)
439204076Spjd{
440204076Spjd
441218138Spjd	PJDLOG_ASSERT(conn != NULL);
442218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
443218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
444219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_close != NULL);
445204076Spjd
446219873Spjd	conn->pc_proto->prt_close(conn->pc_ctx);
447218191Spjd	proto_free(conn);
448204076Spjd}
449