1/*-
2 * SPDX-License-Identifier: BSD-2-Clause
3 *
4 * Copyright (c) 2009-2010 The FreeBSD Foundation
5 * All rights reserved.
6 *
7 * This software was developed by Pawel Jakub Dawidek under sponsorship from
8 * the FreeBSD Foundation.
9 *
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
12 * are met:
13 * 1. Redistributions of source code must retain the above copyright
14 *    notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 *    notice, this list of conditions and the following disclaimer in the
17 *    documentation and/or other materials provided with the distribution.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
20 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
23 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
25 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
26 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
27 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
28 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
29 * SUCH DAMAGE.
30 */
31
32#include <sys/types.h>
33#include <sys/socket.h>
34
35#include <errno.h>
36#include <stdbool.h>
37#include <stdint.h>
38#include <stdio.h>
39#include <string.h>
40#include <unistd.h>
41
42#include "pjdlog.h"
43#include "proto_impl.h"
44
45#define	SP_CTX_MAGIC	0x50c3741
46struct sp_ctx {
47	int			sp_magic;
48	int			sp_fd[2];
49	int			sp_side;
50#define	SP_SIDE_UNDEF		0
51#define	SP_SIDE_CLIENT		1
52#define	SP_SIDE_SERVER		2
53};
54
55static void sp_close(void *ctx);
56
57static int
58sp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
59{
60	struct sp_ctx *spctx;
61	int ret;
62
63	if (strcmp(dstaddr, "socketpair://") != 0)
64		return (-1);
65
66	PJDLOG_ASSERT(srcaddr == NULL);
67
68	spctx = malloc(sizeof(*spctx));
69	if (spctx == NULL)
70		return (errno);
71
72	if (socketpair(PF_UNIX, SOCK_STREAM, 0, spctx->sp_fd) == -1) {
73		ret = errno;
74		free(spctx);
75		return (ret);
76	}
77
78	spctx->sp_side = SP_SIDE_UNDEF;
79	spctx->sp_magic = SP_CTX_MAGIC;
80	*ctxp = spctx;
81
82	return (0);
83}
84
85static int
86sp_send(void *ctx, const unsigned char *data, size_t size, int fd)
87{
88	struct sp_ctx *spctx = ctx;
89	int sock;
90
91	PJDLOG_ASSERT(spctx != NULL);
92	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
93
94	switch (spctx->sp_side) {
95	case SP_SIDE_UNDEF:
96		/*
97		 * If the first operation done by the caller is proto_send(),
98		 * we assume this is the client.
99		 */
100		/* FALLTHROUGH */
101		spctx->sp_side = SP_SIDE_CLIENT;
102		/* Close other end. */
103		close(spctx->sp_fd[1]);
104		spctx->sp_fd[1] = -1;
105	case SP_SIDE_CLIENT:
106		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
107		sock = spctx->sp_fd[0];
108		break;
109	case SP_SIDE_SERVER:
110		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
111		sock = spctx->sp_fd[1];
112		break;
113	default:
114		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
115	}
116
117	/* Someone is just trying to decide about side. */
118	if (data == NULL)
119		return (0);
120
121	return (proto_common_send(sock, data, size, fd));
122}
123
124static int
125sp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
126{
127	struct sp_ctx *spctx = ctx;
128	int fd;
129
130	PJDLOG_ASSERT(spctx != NULL);
131	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
132
133	switch (spctx->sp_side) {
134	case SP_SIDE_UNDEF:
135		/*
136		 * If the first operation done by the caller is proto_recv(),
137		 * we assume this is the server.
138		 */
139		/* FALLTHROUGH */
140		spctx->sp_side = SP_SIDE_SERVER;
141		/* Close other end. */
142		close(spctx->sp_fd[0]);
143		spctx->sp_fd[0] = -1;
144	case SP_SIDE_SERVER:
145		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
146		fd = spctx->sp_fd[1];
147		break;
148	case SP_SIDE_CLIENT:
149		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
150		fd = spctx->sp_fd[0];
151		break;
152	default:
153		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
154	}
155
156	/* Someone is just trying to decide about side. */
157	if (data == NULL)
158		return (0);
159
160	return (proto_common_recv(fd, data, size, fdp));
161}
162
163static int
164sp_descriptor(const void *ctx)
165{
166	const struct sp_ctx *spctx = ctx;
167
168	PJDLOG_ASSERT(spctx != NULL);
169	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
170	PJDLOG_ASSERT(spctx->sp_side == SP_SIDE_CLIENT ||
171	    spctx->sp_side == SP_SIDE_SERVER);
172
173	switch (spctx->sp_side) {
174	case SP_SIDE_CLIENT:
175		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
176		return (spctx->sp_fd[0]);
177	case SP_SIDE_SERVER:
178		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
179		return (spctx->sp_fd[1]);
180	}
181
182	PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
183}
184
185static void
186sp_close(void *ctx)
187{
188	struct sp_ctx *spctx = ctx;
189
190	PJDLOG_ASSERT(spctx != NULL);
191	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
192
193	switch (spctx->sp_side) {
194	case SP_SIDE_UNDEF:
195		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
196		close(spctx->sp_fd[0]);
197		spctx->sp_fd[0] = -1;
198		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
199		close(spctx->sp_fd[1]);
200		spctx->sp_fd[1] = -1;
201		break;
202	case SP_SIDE_CLIENT:
203		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
204		close(spctx->sp_fd[0]);
205		spctx->sp_fd[0] = -1;
206		PJDLOG_ASSERT(spctx->sp_fd[1] == -1);
207		break;
208	case SP_SIDE_SERVER:
209		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
210		close(spctx->sp_fd[1]);
211		spctx->sp_fd[1] = -1;
212		PJDLOG_ASSERT(spctx->sp_fd[0] == -1);
213		break;
214	default:
215		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
216	}
217
218	spctx->sp_magic = 0;
219	free(spctx);
220}
221
222static struct proto sp_proto = {
223	.prt_name = "socketpair",
224	.prt_client = sp_client,
225	.prt_send = sp_send,
226	.prt_recv = sp_recv,
227	.prt_descriptor = sp_descriptor,
228	.prt_close = sp_close
229};
230
231static __constructor void
232sp_ctor(void)
233{
234
235	proto_register(&sp_proto, false);
236}
237