1268745Sbapt/*-
2268745Sbapt * Copyright (c) 2014 Baptiste Daroussin <bapt@FreeBSD.org>
3268745Sbapt * Copyright (c) 2014 Vsevolod Stakhov <vsevolod@FreeBSD.org>
4268745Sbapt * All rights reserved.
5287392Sbapt *
6268745Sbapt * Redistribution and use in source and binary forms, with or without
7268745Sbapt * modification, are permitted provided that the following conditions
8268745Sbapt * are met:
9268745Sbapt * 1. Redistributions of source code must retain the above copyright
10268745Sbapt *    notice, this list of conditions and the following disclaimer
11268745Sbapt *    in this position and unchanged.
12268745Sbapt * 2. Redistributions in binary form must reproduce the above copyright
13268745Sbapt *    notice, this list of conditions and the following disclaimer in the
14268745Sbapt *    documentation and/or other materials provided with the distribution.
15287392Sbapt *
16268745Sbapt * THIS SOFTWARE IS PROVIDED BY THE AUTHOR(S) ``AS IS'' AND ANY EXPRESS OR
17268745Sbapt * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
18268745Sbapt * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
19268745Sbapt * IN NO EVENT SHALL THE AUTHOR(S) BE LIABLE FOR ANY DIRECT, INDIRECT,
20268745Sbapt * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
21268745Sbapt * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
22268745Sbapt * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
23268745Sbapt * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24268745Sbapt * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
25268745Sbapt * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26268745Sbapt */
27268745Sbapt
28268745Sbapt#include <sys/cdefs.h>
29268745Sbapt__FBSDID("$FreeBSD$");
30268745Sbapt
31287392Sbapt#include <sys/procctl.h>
32268745Sbapt#include <sys/time.h>
33268745Sbapt#include <sys/wait.h>
34287392Sbapt
35287392Sbapt#include <err.h>
36287392Sbapt#include <errno.h>
37287392Sbapt#include <getopt.h>
38268745Sbapt#include <signal.h>
39287392Sbapt#include <stdbool.h>
40268745Sbapt#include <stdio.h>
41268745Sbapt#include <stdlib.h>
42268745Sbapt#include <string.h>
43268745Sbapt#include <sysexits.h>
44268745Sbapt#include <unistd.h>
45268745Sbapt
46268745Sbapt#define EXIT_TIMEOUT 124
47268745Sbapt
48268745Sbaptstatic sig_atomic_t sig_chld = 0;
49268745Sbaptstatic sig_atomic_t sig_term = 0;
50268745Sbaptstatic sig_atomic_t sig_alrm = 0;
51268745Sbaptstatic sig_atomic_t sig_ign = 0;
52268745Sbapt
53268745Sbaptstatic void
54268745Sbaptusage(void)
55268745Sbapt{
56287392Sbapt
57268745Sbapt	fprintf(stderr, "Usage: %s [--signal sig | -s sig] [--preserve-status]"
58268745Sbapt	    " [--kill-after time | -k time] [--foreground] <duration> <command>"
59268745Sbapt	    " <arg ...>\n", getprogname());
60268745Sbapt
61268745Sbapt	exit(EX_USAGE);
62268745Sbapt}
63268745Sbapt
64268745Sbaptstatic double
65268745Sbaptparse_duration(const char *duration)
66268745Sbapt{
67268745Sbapt	double ret;
68268745Sbapt	char *end;
69268745Sbapt
70268745Sbapt	ret = strtod(duration, &end);
71268745Sbapt	if (ret == 0 && end == duration)
72287392Sbapt		errx(125, "invalid duration");
73268745Sbapt
74268745Sbapt	if (end == NULL || *end == '\0')
75268745Sbapt		return (ret);
76268745Sbapt
77268745Sbapt	if (end != NULL && *(end + 1) != '\0')
78268745Sbapt		errx(EX_USAGE, "invalid duration");
79268745Sbapt
80268745Sbapt	switch (*end) {
81268745Sbapt	case 's':
82268745Sbapt		break;
83268745Sbapt	case 'm':
84268745Sbapt		ret *= 60;
85268745Sbapt		break;
86268745Sbapt	case 'h':
87268745Sbapt		ret *= 60 * 60;
88268745Sbapt		break;
89268745Sbapt	case 'd':
90268745Sbapt		ret *= 60 * 60 * 24;
91268745Sbapt		break;
92268745Sbapt	default:
93287392Sbapt		errx(125, "invalid duration");
94268745Sbapt	}
95287392Sbapt
96268745Sbapt	if (ret < 0 || ret >= 100000000UL)
97287392Sbapt		errx(125, "invalid duration");
98268745Sbapt
99268745Sbapt	return (ret);
100268745Sbapt}
101268745Sbapt
102268745Sbaptstatic int
103268745Sbaptparse_signal(const char *str)
104268745Sbapt{
105268745Sbapt	int sig, i;
106287392Sbapt	const char *errstr;
107268745Sbapt
108290040Sbapt	sig = strtonum(str, 1, sys_nsig - 1, &errstr);
109268745Sbapt
110287392Sbapt	if (errstr == NULL)
111268745Sbapt		return (sig);
112268745Sbapt	if (strncasecmp(str, "SIG", 3) == 0)
113268745Sbapt		str += 3;
114268745Sbapt
115268745Sbapt	for (i = 1; i < sys_nsig; i++) {
116268745Sbapt		if (strcasecmp(str, sys_signame[i]) == 0)
117268745Sbapt			return (i);
118268745Sbapt	}
119287392Sbapt
120287392Sbapt	errx(125, "invalid signal");
121268745Sbapt}
122268745Sbapt
123268745Sbaptstatic void
124268745Sbaptsig_handler(int signo)
125268745Sbapt{
126268745Sbapt	if (sig_ign != 0 && signo == sig_ign) {
127268745Sbapt		sig_ign = 0;
128268745Sbapt		return;
129268745Sbapt	}
130268745Sbapt
131268745Sbapt	switch(signo) {
132268745Sbapt	case 0:
133268745Sbapt	case SIGINT:
134268745Sbapt	case SIGHUP:
135268745Sbapt	case SIGQUIT:
136268745Sbapt	case SIGTERM:
137268745Sbapt		sig_term = signo;
138268745Sbapt		break;
139268745Sbapt	case SIGCHLD:
140268745Sbapt		sig_chld = 1;
141268745Sbapt		break;
142268745Sbapt	case SIGALRM:
143268745Sbapt		sig_alrm = 1;
144268745Sbapt		break;
145268745Sbapt	}
146268745Sbapt}
147268745Sbapt
148268745Sbaptstatic void
149268745Sbaptset_interval(double iv)
150268745Sbapt{
151268745Sbapt	struct itimerval tim;
152268745Sbapt
153268745Sbapt	memset(&tim, 0, sizeof(tim));
154268745Sbapt	tim.it_value.tv_sec = (time_t)iv;
155268745Sbapt	iv -= (time_t)iv;
156268745Sbapt	tim.it_value.tv_usec = (suseconds_t)(iv * 1000000UL);
157268745Sbapt
158268745Sbapt	if (setitimer(ITIMER_REAL, &tim, NULL) == -1)
159268745Sbapt		err(EX_OSERR, "setitimer()");
160268745Sbapt}
161268745Sbapt
162268745Sbaptint
163268745Sbaptmain(int argc, char **argv)
164268745Sbapt{
165268745Sbapt	int ch;
166268745Sbapt	unsigned long i;
167268745Sbapt	int foreground, preserve;
168268745Sbapt	int error, pstat, status;
169268745Sbapt	int killsig = SIGTERM;
170287392Sbapt	pid_t pid, cpid;
171268745Sbapt	double first_kill;
172268745Sbapt	double second_kill;
173268745Sbapt	bool timedout = false;
174268745Sbapt	bool do_second_kill = false;
175287392Sbapt	bool child_done = false;
176268745Sbapt	struct sigaction signals;
177287392Sbapt	struct procctl_reaper_status info;
178287392Sbapt	struct procctl_reaper_kill killemall;
179268745Sbapt	int signums[] = {
180268745Sbapt		-1,
181268745Sbapt		SIGTERM,
182268745Sbapt		SIGINT,
183268745Sbapt		SIGHUP,
184268745Sbapt		SIGCHLD,
185268745Sbapt		SIGALRM,
186268745Sbapt		SIGQUIT,
187268745Sbapt	};
188268745Sbapt
189268745Sbapt	foreground = preserve = 0;
190268745Sbapt	second_kill = 0;
191268745Sbapt
192287392Sbapt	const struct option longopts[] = {
193268745Sbapt		{ "preserve-status", no_argument,       &preserve,    1 },
194268745Sbapt		{ "foreground",      no_argument,       &foreground,  1 },
195268745Sbapt		{ "kill-after",      required_argument, NULL,        'k'},
196268745Sbapt		{ "signal",          required_argument, NULL,        's'},
197268745Sbapt		{ "help",            no_argument,       NULL,        'h'},
198268745Sbapt		{ NULL,              0,                 NULL,         0 }
199268745Sbapt	};
200268745Sbapt
201268745Sbapt	while ((ch = getopt_long(argc, argv, "+k:s:h", longopts, NULL)) != -1) {
202268745Sbapt		switch (ch) {
203268745Sbapt			case 'k':
204268745Sbapt				do_second_kill = true;
205268745Sbapt				second_kill = parse_duration(optarg);
206268745Sbapt				break;
207268745Sbapt			case 's':
208268745Sbapt				killsig = parse_signal(optarg);
209268745Sbapt				break;
210268745Sbapt			case 0:
211268745Sbapt				break;
212268745Sbapt			case 'h':
213268745Sbapt			default:
214268745Sbapt				usage();
215268745Sbapt				break;
216268745Sbapt		}
217268745Sbapt	}
218268745Sbapt
219268745Sbapt	argc -= optind;
220268745Sbapt	argv += optind;
221268745Sbapt
222268745Sbapt	if (argc < 2)
223268745Sbapt		usage();
224268745Sbapt
225268745Sbapt	first_kill = parse_duration(argv[0]);
226268745Sbapt	argc--;
227268745Sbapt	argv++;
228268745Sbapt
229268745Sbapt	if (!foreground) {
230287392Sbapt		/* Aquire a reaper */
231287392Sbapt		if (procctl(P_PID, getpid(), PROC_REAP_ACQUIRE, NULL) == -1)
232287392Sbapt			err(EX_OSERR, "Fail to acquire the reaper");
233268745Sbapt	}
234268745Sbapt
235268745Sbapt	memset(&signals, 0, sizeof(signals));
236268745Sbapt	sigemptyset(&signals.sa_mask);
237268745Sbapt
238268745Sbapt	if (killsig != SIGKILL && killsig != SIGSTOP)
239268745Sbapt		signums[0] = killsig;
240268745Sbapt
241268745Sbapt	for (i = 0; i < sizeof(signums) / sizeof(signums[0]); i ++)
242268745Sbapt		sigaddset(&signals.sa_mask, signums[i]);
243268745Sbapt
244268745Sbapt	signals.sa_handler = sig_handler;
245268745Sbapt	signals.sa_flags = SA_RESTART;
246268745Sbapt
247268745Sbapt	for (i = 0; i < sizeof(signums) / sizeof(signums[0]); i ++)
248268745Sbapt		if (signums[i] != -1 && signums[i] != 0 &&
249268745Sbapt		    sigaction(signums[i], &signals, NULL) == -1)
250268745Sbapt			err(EX_OSERR, "sigaction()");
251268745Sbapt
252268745Sbapt	signal(SIGTTIN, SIG_IGN);
253268745Sbapt	signal(SIGTTOU, SIG_IGN);
254268745Sbapt
255268745Sbapt	pid = fork();
256268745Sbapt	if (pid == -1)
257268745Sbapt		err(EX_OSERR, "fork()");
258268745Sbapt	else if (pid == 0) {
259268745Sbapt		/* child process */
260268745Sbapt		signal(SIGTTIN, SIG_DFL);
261268745Sbapt		signal(SIGTTOU, SIG_DFL);
262268745Sbapt
263268745Sbapt		error = execvp(argv[0], argv);
264287392Sbapt		if (error == -1) {
265287392Sbapt			if (errno == ENOENT)
266287392Sbapt				err(127, "exec(%s)", argv[0]);
267287392Sbapt			else
268287392Sbapt				err(126, "exec(%s)", argv[0]);
269287392Sbapt		}
270268745Sbapt	}
271268745Sbapt
272268745Sbapt	if (sigprocmask(SIG_BLOCK, &signals.sa_mask, NULL) == -1)
273268745Sbapt		err(EX_OSERR, "sigprocmask()");
274268745Sbapt
275268745Sbapt	/* parent continues here */
276268745Sbapt	set_interval(first_kill);
277268745Sbapt
278268745Sbapt	for (;;) {
279268745Sbapt		sigemptyset(&signals.sa_mask);
280268745Sbapt		sigsuspend(&signals.sa_mask);
281268745Sbapt
282268745Sbapt		if (sig_chld) {
283268745Sbapt			sig_chld = 0;
284268745Sbapt
285287392Sbapt			while ((cpid = waitpid(-1, &status, WNOHANG)) != 0) {
286287392Sbapt				if (cpid < 0) {
287287392Sbapt					if (errno == EINTR)
288287392Sbapt						continue;
289287392Sbapt					else
290287392Sbapt						break;
291287392Sbapt				} else if (cpid == pid) {
292287392Sbapt					pstat = status;
293287392Sbapt					child_done = true;
294287392Sbapt				}
295268745Sbapt			}
296287392Sbapt			if (child_done) {
297287392Sbapt				if (foreground) {
298287392Sbapt					break;
299287392Sbapt				} else {
300287392Sbapt					procctl(P_PID, getpid(),
301287392Sbapt					    PROC_REAP_STATUS, &info);
302287392Sbapt					if (info.rs_children == 0)
303287392Sbapt						break;
304287392Sbapt				}
305287392Sbapt			}
306268745Sbapt		} else if (sig_alrm) {
307268745Sbapt			sig_alrm = 0;
308268745Sbapt
309268745Sbapt			timedout = true;
310287392Sbapt			if (!foreground) {
311287392Sbapt				killemall.rk_sig = killsig;
312287392Sbapt				killemall.rk_flags = 0;
313287392Sbapt				procctl(P_PID, getpid(), PROC_REAP_KILL,
314287392Sbapt				    &killemall);
315287392Sbapt			} else
316268745Sbapt				kill(pid, killsig);
317268745Sbapt
318268745Sbapt			if (do_second_kill) {
319268745Sbapt				set_interval(second_kill);
320268745Sbapt				second_kill = 0;
321268745Sbapt				sig_ign = killsig;
322268745Sbapt				killsig = SIGKILL;
323268745Sbapt			} else
324268745Sbapt				break;
325268745Sbapt
326268745Sbapt		} else if (sig_term) {
327287392Sbapt			if (!foreground) {
328287392Sbapt				killemall.rk_sig = sig_term;
329287392Sbapt				killemall.rk_flags = 0;
330287392Sbapt				procctl(P_PID, getpid(), PROC_REAP_KILL,
331287392Sbapt				    &killemall);
332287392Sbapt			} else
333268745Sbapt				kill(pid, sig_term);
334268745Sbapt
335268745Sbapt			if (do_second_kill) {
336268745Sbapt				set_interval(second_kill);
337268745Sbapt				second_kill = 0;
338268745Sbapt				sig_ign = killsig;
339268745Sbapt				killsig = SIGKILL;
340268745Sbapt			} else
341268745Sbapt				break;
342268745Sbapt		}
343268745Sbapt	}
344268745Sbapt
345287392Sbapt	while (!child_done && wait(&pstat) == -1) {
346268745Sbapt		if (errno != EINTR)
347268745Sbapt			err(EX_OSERR, "waitpid()");
348268745Sbapt	}
349268745Sbapt
350287392Sbapt	if (!foreground)
351287392Sbapt		procctl(P_PID, getpid(), PROC_REAP_RELEASE, NULL);
352287392Sbapt
353268745Sbapt	if (WEXITSTATUS(pstat))
354268745Sbapt		pstat = WEXITSTATUS(pstat);
355268745Sbapt	else if(WIFSIGNALED(pstat))
356268745Sbapt		pstat = 128 + WTERMSIG(pstat);
357268745Sbapt
358268745Sbapt	if (timedout && !preserve)
359287392Sbapt		pstat = EXIT_TIMEOUT;
360268745Sbapt
361268745Sbapt	return (pstat);
362268745Sbapt}
363