1/*-
2 * SPDX-License-Identifier: BSD-2-Clause
3 *
4 * Copyright (c) 2023 Alexander V. Chernikov
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 * 1. Redistributions of source code must retain the above copyright
10 *    notice, this list of conditions and the following disclaimer.
11 * 2. Redistributions in binary form must reproduce the above copyright
12 *    notice, this list of conditions and the following disclaimer in the
13 *    documentation and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
16 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25 * SUCH DAMAGE.
26 */
27
28#include <sys/param.h>
29#include <sys/refcount.h>
30#include <sys/types.h>
31#include <sys/kernel.h>
32#include <sys/lock.h>
33#include <sys/mutex.h>
34#include <sys/malloc.h>
35#include <sys/module.h>
36#include <sys/socket.h>
37#include <sys/priv.h>
38
39#include <netlink/netlink.h>
40#include <netlink/netlink_ctl.h>
41#include <netlink/netlink_generic.h>
42#include <netlink/netlink_message_parser.h>
43
44#include <machine/stdarg.h>
45#include <tests/ktest.h>
46
47struct mtx ktest_mtx;
48#define	KTEST_LOCK()		mtx_lock(&ktest_mtx)
49#define	KTEST_UNLOCK()		mtx_unlock(&ktest_mtx)
50#define	KTEST_LOCK_ASSERT()	mtx_assert(&ktest_mtx, MA_OWNED)
51
52MTX_SYSINIT(ktest_mtx, &ktest_mtx, "ktest mutex", MTX_DEF);
53
54struct ktest_module {
55	struct ktest_module_info	*info;
56	volatile u_int			refcount;
57	TAILQ_ENTRY(ktest_module)	entries;
58};
59static TAILQ_HEAD(, ktest_module) module_list = TAILQ_HEAD_INITIALIZER(module_list);
60
61struct nl_ktest_parsed {
62	char		*mod_name;
63	char		*test_name;
64	struct nlattr	*test_meta;
65};
66
67#define	_IN(_field)	offsetof(struct genlmsghdr, _field)
68#define	_OUT(_field)	offsetof(struct nl_ktest_parsed, _field)
69
70static const struct nlattr_parser nla_p_get[] = {
71	{ .type = KTEST_ATTR_MOD_NAME, .off = _OUT(mod_name), .cb = nlattr_get_string },
72	{ .type = KTEST_ATTR_TEST_NAME, .off = _OUT(test_name), .cb = nlattr_get_string },
73	{ .type = KTEST_ATTR_TEST_META, .off = _OUT(test_meta), .cb = nlattr_get_nla },
74};
75static const struct nlfield_parser nlf_p_get[] = {
76};
77NL_DECLARE_PARSER(ktest_parser, struct genlmsghdr, nlf_p_get, nla_p_get);
78#undef _IN
79#undef _OUT
80
81static bool
82create_reply(struct nl_writer *nw, struct nlmsghdr *hdr, int cmd)
83{
84	if (!nlmsg_reply(nw, hdr, sizeof(struct genlmsghdr)))
85		return (false);
86
87	struct genlmsghdr *ghdr_new = nlmsg_reserve_object(nw, struct genlmsghdr);
88	ghdr_new->cmd = cmd;
89	ghdr_new->version = 0;
90	ghdr_new->reserved = 0;
91
92	return (true);
93}
94
95static int
96dump_mod_test(struct nlmsghdr *hdr, struct nl_pstate *npt,
97    struct ktest_module *mod, const struct ktest_test_info *test_info)
98{
99	struct nl_writer *nw = npt->nw;
100
101	if (!create_reply(nw, hdr, KTEST_CMD_NEWTEST))
102		goto enomem;
103
104	nlattr_add_string(nw, KTEST_ATTR_MOD_NAME, mod->info->name);
105	nlattr_add_string(nw, KTEST_ATTR_TEST_NAME, test_info->name);
106	nlattr_add_string(nw, KTEST_ATTR_TEST_DESCR, test_info->desc);
107
108	if (nlmsg_end(nw))
109		return (0);
110enomem:
111	nlmsg_abort(nw);
112	return (ENOMEM);
113}
114
115static int
116dump_mod_tests(struct nlmsghdr *hdr, struct nl_pstate *npt,
117    struct ktest_module *mod, struct nl_ktest_parsed *attrs)
118{
119	for (int i = 0; i < mod->info->num_tests; i++) {
120		const struct ktest_test_info *test_info = &mod->info->tests[i];
121		if (attrs->test_name != NULL && strcmp(attrs->test_name, test_info->name))
122			continue;
123		int error = dump_mod_test(hdr, npt, mod, test_info);
124		if (error != 0)
125			return (error);
126	}
127
128	return (0);
129}
130
131static int
132dump_tests(struct nlmsghdr *hdr, struct nl_pstate *npt)
133{
134	struct nl_ktest_parsed attrs = { };
135	struct ktest_module *mod;
136	int error;
137
138	error = nl_parse_nlmsg(hdr, &ktest_parser, npt, &attrs);
139	if (error != 0)
140		return (error);
141
142	hdr->nlmsg_flags |= NLM_F_MULTI;
143
144	KTEST_LOCK();
145	TAILQ_FOREACH(mod, &module_list, entries) {
146		if (attrs.mod_name && strcmp(attrs.mod_name, mod->info->name))
147			continue;
148		error = dump_mod_tests(hdr, npt, mod, &attrs);
149		if (error != 0)
150			break;
151	}
152	KTEST_UNLOCK();
153
154	if (!nlmsg_end_dump(npt->nw, error, hdr)) {
155		//NL_LOG(LOG_DEBUG, "Unable to finalize the dump");
156		return (ENOMEM);
157	}
158
159	return (error);
160}
161
162static int
163run_test(struct nlmsghdr *hdr, struct nl_pstate *npt)
164{
165	struct nl_ktest_parsed attrs = { };
166	struct ktest_module *mod;
167	int error;
168
169	error = nl_parse_nlmsg(hdr, &ktest_parser, npt, &attrs);
170	if (error != 0)
171		return (error);
172
173	if (attrs.mod_name == NULL) {
174		nlmsg_report_err_msg(npt, "KTEST_ATTR_MOD_NAME not set");
175		return (EINVAL);
176	}
177
178	if (attrs.test_name == NULL) {
179		nlmsg_report_err_msg(npt, "KTEST_ATTR_TEST_NAME not set");
180		return (EINVAL);
181	}
182
183	const struct ktest_test_info *test = NULL;
184
185	KTEST_LOCK();
186	TAILQ_FOREACH(mod, &module_list, entries) {
187		if (strcmp(attrs.mod_name, mod->info->name))
188			continue;
189
190		const struct ktest_module_info *info = mod->info;
191
192		for (int i = 0; i < info->num_tests; i++) {
193			const struct ktest_test_info *test_info = &info->tests[i];
194
195			if (!strcmp(attrs.test_name, test_info->name)) {
196				test = test_info;
197				break;
198			}
199		}
200		break;
201	}
202	if (test != NULL)
203		refcount_acquire(&mod->refcount);
204	KTEST_UNLOCK();
205
206	if (test == NULL)
207		return (ESRCH);
208
209	/* Run the test */
210	struct ktest_test_context ctx = {
211		.npt = npt,
212		.hdr = hdr,
213		.buf = npt_alloc(npt, KTEST_MAX_BUF),
214		.bufsize = KTEST_MAX_BUF,
215	};
216
217	if (ctx.buf == NULL) {
218		//NL_LOG(LOG_DEBUG, "unable to allocate temporary buffer");
219		return (ENOMEM);
220	}
221
222	if (test->parse != NULL && attrs.test_meta != NULL) {
223		error = test->parse(&ctx, attrs.test_meta);
224		if (error != 0)
225			return (error);
226	}
227
228	hdr->nlmsg_flags |= NLM_F_MULTI;
229
230	KTEST_LOG_LEVEL(&ctx, LOG_INFO, "start running %s", test->name);
231	error = test->func(&ctx);
232	KTEST_LOG_LEVEL(&ctx, LOG_INFO, "end running %s", test->name);
233
234	refcount_release(&mod->refcount);
235
236	if (!nlmsg_end_dump(npt->nw, error, hdr)) {
237		//NL_LOG(LOG_DEBUG, "Unable to finalize the dump");
238		return (ENOMEM);
239	}
240
241	return (error);
242}
243
244
245/* USER API */
246static void
247register_test_module(struct ktest_module_info *info)
248{
249	struct ktest_module *mod = malloc(sizeof(*mod), M_TEMP, M_WAITOK | M_ZERO);
250
251	mod->info = info;
252	info->module_ptr = mod;
253	KTEST_LOCK();
254	TAILQ_INSERT_TAIL(&module_list, mod, entries);
255	KTEST_UNLOCK();
256}
257
258static void
259unregister_test_module(struct ktest_module_info *info)
260{
261	struct ktest_module *mod = info->module_ptr;
262
263	info->module_ptr = NULL;
264
265	KTEST_LOCK();
266	TAILQ_REMOVE(&module_list, mod, entries);
267	KTEST_UNLOCK();
268
269	free(mod, M_TEMP);
270}
271
272static bool
273can_unregister(struct ktest_module_info *info)
274{
275	struct ktest_module *mod = info->module_ptr;
276
277	return (refcount_load(&mod->refcount) == 0);
278}
279
280int
281ktest_default_modevent(module_t mod, int type, void *arg)
282{
283	struct ktest_module_info *info = (struct ktest_module_info *)arg;
284	int error = 0;
285
286	switch (type) {
287	case MOD_LOAD:
288		register_test_module(info);
289		break;
290	case MOD_UNLOAD:
291		if (!can_unregister(info))
292			return (EBUSY);
293		unregister_test_module(info);
294		break;
295	default:
296		error = EOPNOTSUPP;
297		break;
298	}
299	return (error);
300}
301
302bool
303ktest_start_msg(struct ktest_test_context *ctx)
304{
305	return (create_reply(ctx->npt->nw, ctx->hdr, KTEST_CMD_NEWMESSAGE));
306}
307
308void
309ktest_add_msg_meta(struct ktest_test_context *ctx, const char *func,
310    const char *fname, int line)
311{
312	struct nl_writer *nw = ctx->npt->nw;
313	struct timespec ts;
314
315	nanouptime(&ts);
316	nlattr_add(nw, KTEST_MSG_ATTR_TS, sizeof(ts), &ts);
317
318	nlattr_add_string(nw, KTEST_MSG_ATTR_FUNC, func);
319	nlattr_add_string(nw, KTEST_MSG_ATTR_FILE, fname);
320	nlattr_add_u32(nw, KTEST_MSG_ATTR_LINE, line);
321}
322
323void
324ktest_add_msg_text(struct ktest_test_context *ctx, int msg_level,
325    const char *fmt, ...)
326{
327	va_list ap;
328
329	va_start(ap, fmt);
330	vsnprintf(ctx->buf, ctx->bufsize, fmt, ap);
331	va_end(ap);
332
333	nlattr_add_u8(ctx->npt->nw, KTEST_MSG_ATTR_LEVEL, msg_level);
334	nlattr_add_string(ctx->npt->nw, KTEST_MSG_ATTR_TEXT, ctx->buf);
335}
336
337void
338ktest_end_msg(struct ktest_test_context *ctx)
339{
340	nlmsg_end(ctx->npt->nw);
341}
342
343/* Module glue */
344
345static const struct nlhdr_parser *all_parsers[] = { &ktest_parser };
346
347static const struct genl_cmd ktest_cmds[] = {
348	{
349		.cmd_num = KTEST_CMD_LIST,
350		.cmd_name = "KTEST_CMD_LIST",
351		.cmd_cb = dump_tests,
352		.cmd_flags = GENL_CMD_CAP_DO | GENL_CMD_CAP_DUMP | GENL_CMD_CAP_HASPOL,
353	},
354	{
355		.cmd_num = KTEST_CMD_RUN,
356		.cmd_name = "KTEST_CMD_RUN",
357		.cmd_cb = run_test,
358		.cmd_flags = GENL_CMD_CAP_DO | GENL_CMD_CAP_HASPOL,
359		.cmd_priv = PRIV_KLD_LOAD,
360	},
361};
362
363static void
364ktest_nl_register(void)
365{
366	bool ret __diagused;
367	int family_id __diagused;
368
369	NL_VERIFY_PARSERS(all_parsers);
370	family_id = genl_register_family(KTEST_FAMILY_NAME, 0, 1, KTEST_CMD_MAX);
371	MPASS(family_id != 0);
372
373	ret = genl_register_cmds(KTEST_FAMILY_NAME, ktest_cmds, NL_ARRAY_LEN(ktest_cmds));
374	MPASS(ret);
375}
376
377static void
378ktest_nl_unregister(void)
379{
380	MPASS(TAILQ_EMPTY(&module_list));
381
382	genl_unregister_family(KTEST_FAMILY_NAME);
383}
384
385static int
386ktest_modevent(module_t mod, int type, void *unused)
387{
388	int error = 0;
389
390	switch (type) {
391	case MOD_LOAD:
392		ktest_nl_register();
393		break;
394	case MOD_UNLOAD:
395		ktest_nl_unregister();
396		break;
397	default:
398		error = EOPNOTSUPP;
399		break;
400	}
401	return (error);
402}
403
404static moduledata_t ktestmod = {
405        "ktest",
406        ktest_modevent,
407        0
408};
409
410DECLARE_MODULE(ktestmod, ktestmod, SI_SUB_PSEUDO, SI_ORDER_ANY);
411MODULE_VERSION(ktestmod, 1);
412MODULE_DEPEND(ktestmod, netlink, 1, 1, 1);
413
414