1/*      $NetBSD: sp_common.c,v 1.30 2011/03/08 12:39:28 pooka Exp $	*/
2
3/*
4 * Copyright (c) 2010, 2011 Antti Kantee.  All Rights Reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 * 1. Redistributions of source code must retain the above copyright
10 *    notice, this list of conditions and the following disclaimer.
11 * 2. Redistributions in binary form must reproduce the above copyright
12 *    notice, this list of conditions and the following disclaimer in the
13 *    documentation and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS
16 * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 * DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25 * SUCH DAMAGE.
26 */
27
28/*
29 * Common client/server sysproxy routines.  #included.
30 */
31
32#include <sys/cdefs.h>
33
34#include <sys/types.h>
35#include <sys/mman.h>
36#include <sys/queue.h>
37#include <sys/socket.h>
38#include <sys/un.h>
39#include <sys/syslimits.h>
40
41#include <arpa/inet.h>
42#include <netinet/in.h>
43#include <netinet/tcp.h>
44
45#include <assert.h>
46#include <errno.h>
47#include <fcntl.h>
48#include <inttypes.h>
49#include <poll.h>
50#include <pthread.h>
51#include <stdarg.h>
52#include <stddef.h>
53#include <stdio.h>
54#include <stdlib.h>
55#include <string.h>
56#include <unistd.h>
57
58//#define DEBUG
59#ifdef DEBUG
60#define DPRINTF(x) mydprintf x
61static void
62mydprintf(const char *fmt, ...)
63{
64	va_list ap;
65
66	va_start(ap, fmt);
67	vfprintf(stderr, fmt, ap);
68	va_end(ap);
69}
70#else
71#define DPRINTF(x)
72#endif
73
74#ifndef HOSTOPS
75#define host_poll poll
76#define host_read read
77#define host_sendmsg sendmsg
78#define host_setsockopt setsockopt
79#endif
80
81#define IOVPUT(_io_, _b_) _io_.iov_base = &_b_; _io_.iov_len = sizeof(_b_);
82#define IOVPUT_WITHSIZE(_io_, _b_, _l_) _io_.iov_base = _b_; _io_.iov_len = _l_;
83#define SENDIOV(_spc_, _iov_) dosend(_spc_, _iov_, __arraycount(_iov_))
84
85/*
86 * Bah, I hate writing on-off-wire conversions in C
87 */
88
89enum { RUMPSP_REQ, RUMPSP_RESP, RUMPSP_ERROR };
90enum {	RUMPSP_HANDSHAKE,
91	RUMPSP_SYSCALL,
92	RUMPSP_COPYIN, RUMPSP_COPYINSTR,
93	RUMPSP_COPYOUT, RUMPSP_COPYOUTSTR,
94	RUMPSP_ANONMMAP,
95	RUMPSP_PREFORK,
96	RUMPSP_RAISE };
97
98enum { HANDSHAKE_GUEST, HANDSHAKE_AUTH, HANDSHAKE_FORK, HANDSHAKE_EXEC };
99
100#define AUTHLEN 4 /* 128bit fork auth */
101
102struct rsp_hdr {
103	uint64_t rsp_len;
104	uint64_t rsp_reqno;
105	uint16_t rsp_class;
106	uint16_t rsp_type;
107	/*
108	 * We want this structure 64bit-aligned for typecast fun,
109	 * so might as well use the following for something.
110	 */
111	union {
112		uint32_t sysnum;
113		uint32_t error;
114		uint32_t handshake;
115		uint32_t signo;
116	} u;
117};
118#define HDRSZ sizeof(struct rsp_hdr)
119#define rsp_sysnum u.sysnum
120#define rsp_error u.error
121#define rsp_handshake u.handshake
122#define rsp_signo u.signo
123
124#define MAXBANNER 96
125
126/*
127 * Data follows the header.  We have two types of structured data.
128 */
129
130/* copyin/copyout */
131struct rsp_copydata {
132	size_t rcp_len;
133	void *rcp_addr;
134	uint8_t rcp_data[0];
135};
136
137/* syscall response */
138struct rsp_sysresp {
139	int rsys_error;
140	register_t rsys_retval[2];
141};
142
143struct handshake_fork {
144	uint32_t rf_auth[4];
145	int rf_cancel;
146};
147
148struct respwait {
149	uint64_t rw_reqno;
150	void *rw_data;
151	size_t rw_dlen;
152	int rw_done;
153	int rw_error;
154
155	pthread_cond_t rw_cv;
156
157	TAILQ_ENTRY(respwait) rw_entries;
158};
159
160struct prefork;
161struct spclient {
162	int spc_fd;
163	int spc_refcnt;
164	int spc_state;
165
166	pthread_mutex_t spc_mtx;
167	pthread_cond_t spc_cv;
168
169	struct lwp *spc_mainlwp;
170	pid_t spc_pid;
171
172	TAILQ_HEAD(, respwait) spc_respwait;
173
174	/* rest of the fields are zeroed upon disconnect */
175#define SPC_ZEROFF offsetof(struct spclient, spc_pfd)
176	struct pollfd *spc_pfd;
177
178	struct rsp_hdr spc_hdr;
179	uint8_t *spc_buf;
180	size_t spc_off;
181
182	uint64_t spc_nextreq;
183	uint64_t spc_syscallreq;
184	uint64_t spc_generation;
185	int spc_ostatus, spc_istatus;
186	int spc_reconnecting;
187	int spc_inexec;
188
189	LIST_HEAD(, prefork) spc_pflist;
190};
191#define SPCSTATUS_FREE 0
192#define SPCSTATUS_BUSY 1
193#define SPCSTATUS_WANTED 2
194
195#define SPCSTATE_NEW     0
196#define SPCSTATE_RUNNING 1
197#define SPCSTATE_DYING   2
198
199typedef int (*addrparse_fn)(const char *, struct sockaddr **, int);
200typedef int (*connecthook_fn)(int);
201typedef void (*cleanup_fn)(struct sockaddr *);
202
203static int readframe(struct spclient *);
204static void handlereq(struct spclient *);
205
206static __inline void
207spcresetbuf(struct spclient *spc)
208{
209
210	spc->spc_buf = NULL;
211	spc->spc_off = 0;
212}
213
214static __inline void
215spcfreebuf(struct spclient *spc)
216{
217
218	free(spc->spc_buf);
219	spcresetbuf(spc);
220}
221
222static void
223sendlockl(struct spclient *spc)
224{
225
226	while (spc->spc_ostatus != SPCSTATUS_FREE) {
227		spc->spc_ostatus = SPCSTATUS_WANTED;
228		pthread_cond_wait(&spc->spc_cv, &spc->spc_mtx);
229	}
230	spc->spc_ostatus = SPCSTATUS_BUSY;
231}
232
233static void __unused
234sendlock(struct spclient *spc)
235{
236
237	pthread_mutex_lock(&spc->spc_mtx);
238	sendlockl(spc);
239	pthread_mutex_unlock(&spc->spc_mtx);
240}
241
242static void
243sendunlockl(struct spclient *spc)
244{
245
246	if (spc->spc_ostatus == SPCSTATUS_WANTED)
247		pthread_cond_broadcast(&spc->spc_cv);
248	spc->spc_ostatus = SPCSTATUS_FREE;
249}
250
251static void
252sendunlock(struct spclient *spc)
253{
254
255	pthread_mutex_lock(&spc->spc_mtx);
256	sendunlockl(spc);
257	pthread_mutex_unlock(&spc->spc_mtx);
258}
259
260static int
261dosend(struct spclient *spc, struct iovec *iov, size_t iovlen)
262{
263	struct msghdr msg;
264	struct pollfd pfd;
265	ssize_t n = 0;
266	int fd = spc->spc_fd;
267
268	pfd.fd = fd;
269	pfd.events = POLLOUT;
270
271	memset(&msg, 0, sizeof(msg));
272
273	for (;;) {
274		/* not first round?  poll */
275		if (n) {
276			if (host_poll(&pfd, 1, INFTIM) == -1) {
277				if (errno == EINTR)
278					continue;
279				return errno;
280			}
281		}
282
283		msg.msg_iov = iov;
284		msg.msg_iovlen = iovlen;
285		n = host_sendmsg(fd, &msg, MSG_NOSIGNAL);
286		if (n == -1)  {
287			if (errno == EPIPE)
288				return ENOTCONN;
289			if (errno != EAGAIN)
290				return errno;
291			continue;
292		}
293		if (n == 0) {
294			return ENOTCONN;
295		}
296
297		/* ok, need to adjust iovec for potential next round */
298		while (n >= (ssize_t)iov[0].iov_len && iovlen) {
299			n -= iov[0].iov_len;
300			iov++;
301			iovlen--;
302		}
303
304		if (iovlen == 0) {
305			_DIAGASSERT(n == 0);
306			break;
307		} else {
308			iov[0].iov_base = (uint8_t *)iov[0].iov_base + n;
309			iov[0].iov_len -= n;
310		}
311	}
312
313	return 0;
314}
315
316static void
317doputwait(struct spclient *spc, struct respwait *rw, struct rsp_hdr *rhdr)
318{
319
320	rw->rw_data = NULL;
321	rw->rw_dlen = rw->rw_done = rw->rw_error = 0;
322	pthread_cond_init(&rw->rw_cv, NULL);
323
324	pthread_mutex_lock(&spc->spc_mtx);
325	rw->rw_reqno = rhdr->rsp_reqno = spc->spc_nextreq++;
326	TAILQ_INSERT_TAIL(&spc->spc_respwait, rw, rw_entries);
327}
328
329static void __unused
330putwait_locked(struct spclient *spc, struct respwait *rw, struct rsp_hdr *rhdr)
331{
332
333	doputwait(spc, rw, rhdr);
334	pthread_mutex_unlock(&spc->spc_mtx);
335}
336
337static void
338putwait(struct spclient *spc, struct respwait *rw, struct rsp_hdr *rhdr)
339{
340
341	doputwait(spc, rw, rhdr);
342	sendlockl(spc);
343	pthread_mutex_unlock(&spc->spc_mtx);
344}
345
346static void
347dounputwait(struct spclient *spc, struct respwait *rw)
348{
349
350	TAILQ_REMOVE(&spc->spc_respwait, rw, rw_entries);
351	pthread_mutex_unlock(&spc->spc_mtx);
352	pthread_cond_destroy(&rw->rw_cv);
353
354}
355
356static void __unused
357unputwait_locked(struct spclient *spc, struct respwait *rw)
358{
359
360	pthread_mutex_lock(&spc->spc_mtx);
361	dounputwait(spc, rw);
362}
363
364static void
365unputwait(struct spclient *spc, struct respwait *rw)
366{
367
368	pthread_mutex_lock(&spc->spc_mtx);
369	sendunlockl(spc);
370
371	dounputwait(spc, rw);
372}
373
374static void
375kickwaiter(struct spclient *spc)
376{
377	struct respwait *rw;
378	int error = 0;
379
380	pthread_mutex_lock(&spc->spc_mtx);
381	TAILQ_FOREACH(rw, &spc->spc_respwait, rw_entries) {
382		if (rw->rw_reqno == spc->spc_hdr.rsp_reqno)
383			break;
384	}
385	if (rw == NULL) {
386		DPRINTF(("no waiter found, invalid reqno %" PRIu64 "?\n",
387		    spc->spc_hdr.rsp_reqno));
388		pthread_mutex_unlock(&spc->spc_mtx);
389		spcfreebuf(spc);
390		return;
391	}
392	DPRINTF(("rump_sp: client %p woke up waiter at %p\n", spc, rw));
393	rw->rw_data = spc->spc_buf;
394	rw->rw_done = 1;
395	rw->rw_dlen = (size_t)(spc->spc_off - HDRSZ);
396	if (spc->spc_hdr.rsp_class == RUMPSP_ERROR) {
397		error = rw->rw_error = spc->spc_hdr.rsp_error;
398	}
399	pthread_cond_signal(&rw->rw_cv);
400	pthread_mutex_unlock(&spc->spc_mtx);
401
402	if (error)
403		spcfreebuf(spc);
404	else
405		spcresetbuf(spc);
406}
407
408static void
409kickall(struct spclient *spc)
410{
411	struct respwait *rw;
412
413	/* DIAGASSERT(mutex_owned(spc_lock)) */
414	TAILQ_FOREACH(rw, &spc->spc_respwait, rw_entries)
415		pthread_cond_broadcast(&rw->rw_cv);
416}
417
418static int
419readframe(struct spclient *spc)
420{
421	int fd = spc->spc_fd;
422	size_t left;
423	size_t framelen;
424	ssize_t n;
425
426	/* still reading header? */
427	if (spc->spc_off < HDRSZ) {
428		DPRINTF(("rump_sp: readframe getting header at offset %zu\n",
429		    spc->spc_off));
430
431		left = HDRSZ - spc->spc_off;
432		/*LINTED: cast ok */
433		n = host_read(fd, (uint8_t*)&spc->spc_hdr + spc->spc_off, left);
434		if (n == 0) {
435			return -1;
436		}
437		if (n == -1) {
438			if (errno == EAGAIN)
439				return 0;
440			return -1;
441		}
442
443		spc->spc_off += n;
444		if (spc->spc_off < HDRSZ) {
445			return 0;
446		}
447
448		/*LINTED*/
449		framelen = spc->spc_hdr.rsp_len;
450
451		if (framelen < HDRSZ) {
452			return -1;
453		} else if (framelen == HDRSZ) {
454			return 1;
455		}
456
457		spc->spc_buf = malloc(framelen - HDRSZ);
458		if (spc->spc_buf == NULL) {
459			return -1;
460		}
461		memset(spc->spc_buf, 0, framelen - HDRSZ);
462
463		/* "fallthrough" */
464	} else {
465		/*LINTED*/
466		framelen = spc->spc_hdr.rsp_len;
467	}
468
469	left = framelen - spc->spc_off;
470
471	DPRINTF(("rump_sp: readframe getting body at offset %zu, left %zu\n",
472	    spc->spc_off, left));
473
474	if (left == 0)
475		return 1;
476	n = host_read(fd, spc->spc_buf + (spc->spc_off - HDRSZ), left);
477	if (n == 0) {
478		return -1;
479	}
480	if (n == -1) {
481		if (errno == EAGAIN)
482			return 0;
483		return -1;
484	}
485	spc->spc_off += n;
486	left -= n;
487
488	/* got everything? */
489	if (left == 0)
490		return 1;
491	else
492		return 0;
493}
494
495static int
496tcp_parse(const char *addr, struct sockaddr **sa, int allow_wildcard)
497{
498	struct sockaddr_in sin;
499	char buf[64];
500	const char *p;
501	size_t l;
502	int port;
503
504	memset(&sin, 0, sizeof(sin));
505	sin.sin_len = sizeof(sin);
506	sin.sin_family = AF_INET;
507
508	p = strchr(addr, ':');
509	if (!p) {
510		fprintf(stderr, "rump_sp_tcp: missing port specifier\n");
511		return EINVAL;
512	}
513
514	l = p - addr;
515	if (l > sizeof(buf)-1) {
516		fprintf(stderr, "rump_sp_tcp: address too long\n");
517		return EINVAL;
518	}
519	strncpy(buf, addr, l);
520	buf[l] = '\0';
521
522	/* special INADDR_ANY treatment */
523	if (strcmp(buf, "*") == 0 || strcmp(buf, "0") == 0) {
524		sin.sin_addr.s_addr = INADDR_ANY;
525	} else {
526		switch (inet_pton(AF_INET, buf, &sin.sin_addr)) {
527		case 1:
528			break;
529		case 0:
530			fprintf(stderr, "rump_sp_tcp: cannot parse %s\n", buf);
531			return EINVAL;
532		case -1:
533			fprintf(stderr, "rump_sp_tcp: inet_pton failed\n");
534			return errno;
535		default:
536			assert(/*CONSTCOND*/0);
537			return EINVAL;
538		}
539	}
540
541	if (!allow_wildcard && sin.sin_addr.s_addr == INADDR_ANY) {
542		fprintf(stderr, "rump_sp_tcp: client needs !INADDR_ANY\n");
543		return EINVAL;
544	}
545
546	/* advance to port number & parse */
547	p++;
548	l = strspn(p, "0123456789");
549	if (l == 0) {
550		fprintf(stderr, "rump_sp_tcp: port now found: %s\n", p);
551		return EINVAL;
552	}
553	strncpy(buf, p, l);
554	buf[l] = '\0';
555
556	if (*(p+l) != '/' && *(p+l) != '\0') {
557		fprintf(stderr, "rump_sp_tcp: junk at end of port: %s\n", addr);
558		return EINVAL;
559	}
560
561	port = atoi(buf);
562	if (port < 0 || port >= (1<<(8*sizeof(in_port_t)))) {
563		fprintf(stderr, "rump_sp_tcp: port %d out of range\n", port);
564		return ERANGE;
565	}
566	sin.sin_port = htons(port);
567
568	*sa = malloc(sizeof(sin));
569	if (*sa == NULL)
570		return errno;
571	memcpy(*sa, &sin, sizeof(sin));
572	return 0;
573}
574
575static int
576tcp_connecthook(int s)
577{
578	int x;
579
580	x = 1;
581	host_setsockopt(s, IPPROTO_TCP, TCP_NODELAY, &x, sizeof(x));
582
583	return 0;
584}
585
586static char parsedurl[256];
587
588/*ARGSUSED*/
589static int
590unix_parse(const char *addr, struct sockaddr **sa, int allow_wildcard)
591{
592	struct sockaddr_un sun;
593	size_t slen;
594	int savepath = 0;
595
596	if (strlen(addr) > sizeof(sun.sun_path))
597		return ENAMETOOLONG;
598
599	/*
600	 * The pathname can be all kinds of spaghetti elementals,
601	 * so meek and obidient we accept everything.  However, use
602	 * full path for easy cleanup in case someone gives a relative
603	 * one and the server does a chdir() between now than the
604	 * cleanup.
605	 */
606	memset(&sun, 0, sizeof(sun));
607	sun.sun_family = AF_LOCAL;
608	if (*addr != '/') {
609		char mywd[PATH_MAX];
610
611		if (getcwd(mywd, sizeof(mywd)) == NULL) {
612			fprintf(stderr, "warning: cannot determine cwd, "
613			    "omitting socket cleanup\n");
614		} else {
615			if (strlen(addr) + strlen(mywd) > sizeof(sun.sun_path))
616				return ENAMETOOLONG;
617			strlcpy(sun.sun_path, mywd, sizeof(sun.sun_path));
618			strlcat(sun.sun_path, "/", sizeof(sun.sun_path));
619			savepath = 1;
620		}
621	}
622	strlcat(sun.sun_path, addr, sizeof(sun.sun_path));
623	sun.sun_len = SUN_LEN(&sun);
624	slen = sun.sun_len+1; /* get the 0 too */
625
626	if (savepath && *parsedurl == '\0') {
627		snprintf(parsedurl, sizeof(parsedurl),
628		    "unix://%s", sun.sun_path);
629	}
630
631	*sa = malloc(slen);
632	if (*sa == NULL)
633		return errno;
634	memcpy(*sa, &sun, slen);
635
636	return 0;
637}
638
639static void
640unix_cleanup(struct sockaddr *sa)
641{
642	struct sockaddr_un *sun = (void *)sa;
643
644	/*
645	 * cleanup only absolute paths.  see unix_parse() above
646	 */
647	if (*sun->sun_path == '/') {
648		unlink(sun->sun_path);
649	}
650}
651
652/*ARGSUSED*/
653static int
654notsupp(void)
655{
656
657	fprintf(stderr, "rump_sp: support not yet implemented\n");
658	return EOPNOTSUPP;
659}
660
661static int
662success(void)
663{
664
665	return 0;
666}
667
668struct {
669	const char *id;
670	int domain;
671	addrparse_fn ap;
672	connecthook_fn connhook;
673	cleanup_fn cleanup;
674} parsetab[] = {
675	{ "tcp", PF_INET, tcp_parse, tcp_connecthook, (cleanup_fn)success },
676	{ "unix", PF_LOCAL, unix_parse, (connecthook_fn)success, unix_cleanup },
677	{ "tcp6", PF_INET6, (addrparse_fn)notsupp, (connecthook_fn)success,
678			    (cleanup_fn)success },
679};
680#define NPARSE (sizeof(parsetab)/sizeof(parsetab[0]))
681
682static int
683parseurl(const char *url, struct sockaddr **sap, unsigned *idxp,
684	int allow_wildcard)
685{
686	char id[16];
687	const char *p, *p2;
688	size_t l;
689	unsigned i;
690	int error;
691
692	/*
693	 * Parse the url
694	 */
695
696	p = url;
697	p2 = strstr(p, "://");
698	if (!p2) {
699		fprintf(stderr, "rump_sp: invalid locator ``%s''\n", p);
700		return EINVAL;
701	}
702	l = p2-p;
703	if (l > sizeof(id)-1) {
704		fprintf(stderr, "rump_sp: identifier too long in ``%s''\n", p);
705		return EINVAL;
706	}
707
708	strncpy(id, p, l);
709	id[l] = '\0';
710	p2 += 3; /* beginning of address */
711
712	for (i = 0; i < NPARSE; i++) {
713		if (strcmp(id, parsetab[i].id) == 0) {
714			error = parsetab[i].ap(p2, sap, allow_wildcard);
715			if (error)
716				return error;
717			break;
718		}
719	}
720	if (i == NPARSE) {
721		fprintf(stderr, "rump_sp: invalid identifier ``%s''\n", p);
722		return EINVAL;
723	}
724
725	*idxp = i;
726	return 0;
727}
728