1/*-
2 * SPDX-License-Identifier: BSD-2-Clause
3 *
4 * Copyright (c) 2022 Alexander V. Chernikov <melifaro@FreeBSD.org>
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 * 1. Redistributions of source code must retain the above copyright
10 *    notice, this list of conditions and the following disclaimer.
11 * 2. Redistributions in binary form must reproduce the above copyright
12 *    notice, this list of conditions and the following disclaimer in the
13 *    documentation and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
16 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25 * SUCH DAMAGE.
26 */
27
28#include <sys/cdefs.h>
29#include "opt_inet.h"
30#include "opt_inet6.h"
31#include <sys/types.h>
32#include <sys/malloc.h>
33#include <sys/rmlock.h>
34#include <sys/socket.h>
35
36#include <machine/stdarg.h>
37
38#include <net/if.h>
39#include <net/route.h>
40#include <net/route/nhop.h>
41
42#include <net/route/route_ctl.h>
43#include <netinet/in.h>
44#include <netlink/netlink.h>
45#include <netlink/netlink_ctl.h>
46#include <netlink/netlink_var.h>
47#include <netlink/netlink_route.h>
48
49#define	DEBUG_MOD_NAME	nl_parser
50#define	DEBUG_MAX_LEVEL	LOG_DEBUG3
51#include <netlink/netlink_debug.h>
52_DECLARE_DEBUG(LOG_INFO);
53
54bool
55nlmsg_report_err_msg(struct nl_pstate *npt, const char *fmt, ...)
56{
57	va_list ap;
58
59	if (npt->err_msg != NULL)
60		return (false);
61	char *buf = npt_alloc(npt, NL_MAX_ERROR_BUF);
62	if (buf == NULL)
63		return (false);
64	va_start(ap, fmt);
65	vsnprintf(buf, NL_MAX_ERROR_BUF, fmt, ap);
66	va_end(ap);
67
68	npt->err_msg = buf;
69	return (true);
70}
71
72bool
73nlmsg_report_err_offset(struct nl_pstate *npt, uint32_t off)
74{
75	if (npt->err_off != 0)
76		return (false);
77	npt->err_off = off;
78	return (true);
79}
80
81void
82nlmsg_report_cookie(struct nl_pstate *npt, struct nlattr *nla)
83{
84	MPASS(nla->nla_type == NLMSGERR_ATTR_COOKIE);
85	MPASS(nla->nla_len >= sizeof(struct nlattr));
86	npt->cookie = nla;
87}
88
89void
90nlmsg_report_cookie_u32(struct nl_pstate *npt, uint32_t val)
91{
92	struct nlattr *nla = npt_alloc(npt, sizeof(*nla) + sizeof(uint32_t));
93
94	nla->nla_type = NLMSGERR_ATTR_COOKIE;
95	nla->nla_len = sizeof(*nla) + sizeof(uint32_t);
96	memcpy(nla + 1, &val, sizeof(uint32_t));
97	nlmsg_report_cookie(npt, nla);
98}
99
100static const struct nlattr_parser *
101search_states(const struct nlattr_parser *ps, int pslen, int key)
102{
103	int left_i = 0, right_i = pslen - 1;
104
105	if (key < ps[0].type || key > ps[pslen - 1].type)
106		return (NULL);
107
108	while (left_i + 1 < right_i) {
109		int mid_i = (left_i + right_i) / 2;
110		if (key < ps[mid_i].type)
111			right_i = mid_i;
112		else if (key > ps[mid_i].type)
113			left_i = mid_i + 1;
114		else
115			return (&ps[mid_i]);
116	}
117	if (ps[left_i].type == key)
118		return (&ps[left_i]);
119	else if (ps[right_i].type == key)
120		return (&ps[right_i]);
121	return (NULL);
122}
123
124int
125nl_parse_attrs_raw(struct nlattr *nla_head, int len, const struct nlattr_parser *ps, int pslen,
126    struct nl_pstate *npt, void *target)
127{
128	struct nlattr *nla = NULL;
129	int error = 0;
130
131	NL_LOG(LOG_DEBUG3, "parse %p remaining_len %d", nla_head, len);
132	int orig_len = len;
133	NLA_FOREACH(nla, nla_head, len) {
134		NL_LOG(LOG_DEBUG3, ">> parsing %p attr_type %d len %d (rem %d)", nla, nla->nla_type, nla->nla_len, len);
135		if (nla->nla_len < sizeof(struct nlattr)) {
136			NLMSG_REPORT_ERR_MSG(npt, "Invalid attr %p type %d len: %d",
137			    nla, nla->nla_type, nla->nla_len);
138			uint32_t off = (char *)nla - (char *)npt->hdr;
139			nlmsg_report_err_offset(npt, off);
140			return (EINVAL);
141		}
142
143		int nla_type = nla->nla_type & NLA_TYPE_MASK;
144		const struct nlattr_parser *s = search_states(ps, pslen, nla_type);
145		if (s != NULL) {
146			void *ptr = (void *)((char *)target + s->off);
147			error = s->cb(nla, npt, s->arg, ptr);
148			if (error != 0) {
149				uint32_t off = (char *)nla - (char *)npt->hdr;
150				nlmsg_report_err_offset(npt, off);
151				NL_LOG(LOG_DEBUG3, "parse failed at offset %u", off);
152				return (error);
153			}
154		} else {
155			/* Ignore non-specified attributes */
156			NL_LOG(LOG_DEBUG3, "ignoring attr %d", nla->nla_type);
157		}
158	}
159	if (len >= sizeof(struct nlattr)) {
160		nla = (struct nlattr *)((char *)nla_head + (orig_len - len));
161		NL_LOG(LOG_DEBUG3, " >>> end %p attr_type %d len %d", nla,
162		    nla->nla_type, nla->nla_len);
163	}
164	NL_LOG(LOG_DEBUG3, "end parse: %p remaining_len %d", nla, len);
165
166	return (0);
167}
168
169void
170nl_get_attrs_bmask_raw(struct nlattr *nla_head, int len, struct nlattr_bmask *bm)
171{
172	struct nlattr *nla = NULL;
173
174	BIT_ZERO(NL_ATTR_BMASK_SIZE, bm);
175
176	NLA_FOREACH(nla, nla_head, len) {
177		if (nla->nla_len < sizeof(struct nlattr))
178			return;
179		int nla_type = nla->nla_type & NLA_TYPE_MASK;
180		if (nla_type < NL_ATTR_BMASK_SIZE)
181			BIT_SET(NL_ATTR_BMASK_SIZE, nla_type, bm);
182		else
183			NL_LOG(LOG_DEBUG2, "Skipping type %d in the mask: too short",
184			    nla_type);
185	}
186}
187
188bool
189nl_has_attr(const struct nlattr_bmask *bm, unsigned int nla_type)
190{
191	MPASS(nla_type < NL_ATTR_BMASK_SIZE);
192
193	return (BIT_ISSET(NL_ATTR_BMASK_SIZE, nla_type, bm));
194}
195
196int
197nlattr_get_flag(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
198{
199	if (__predict_false(NLA_DATA_LEN(nla) != 0)) {
200		NLMSG_REPORT_ERR_MSG(npt, "nla type %d size(%u) is not a flag",
201		    nla->nla_type, NLA_DATA_LEN(nla));
202		return (EINVAL);
203	}
204
205	*((uint8_t *)target) = 1;
206	return (0);
207}
208
209static struct sockaddr *
210parse_rta_ip4(void *rta_data, struct nl_pstate *npt, int *perror)
211{
212	struct sockaddr_in *sin;
213
214	sin = (struct sockaddr_in *)npt_alloc_sockaddr(npt, sizeof(struct sockaddr_in));
215	if (__predict_false(sin == NULL)) {
216		*perror = ENOBUFS;
217		return (NULL);
218	}
219	sin->sin_len = sizeof(struct sockaddr_in);
220	sin->sin_family = AF_INET;
221	memcpy(&sin->sin_addr, rta_data, sizeof(struct in_addr));
222	return ((struct sockaddr *)sin);
223}
224
225static struct sockaddr *
226parse_rta_ip6(void *rta_data, struct nl_pstate *npt, int *perror)
227{
228	struct sockaddr_in6 *sin6;
229
230	sin6 = (struct sockaddr_in6 *)npt_alloc_sockaddr(npt, sizeof(struct sockaddr_in6));
231	if (__predict_false(sin6 == NULL)) {
232		*perror = ENOBUFS;
233		return (NULL);
234	}
235	sin6->sin6_len = sizeof(struct sockaddr_in6);
236	sin6->sin6_family = AF_INET6;
237	memcpy(&sin6->sin6_addr, rta_data, sizeof(struct in6_addr));
238	return ((struct sockaddr *)sin6);
239}
240
241static struct sockaddr *
242parse_rta_ip(struct rtattr *rta, struct nl_pstate *npt, int *perror)
243{
244	void *rta_data = NL_RTA_DATA(rta);
245	int rta_len = NL_RTA_DATA_LEN(rta);
246
247	if (rta_len == sizeof(struct in_addr)) {
248		return (parse_rta_ip4(rta_data, npt, perror));
249	} else if (rta_len == sizeof(struct in6_addr)) {
250		return (parse_rta_ip6(rta_data, npt, perror));
251	} else {
252		NLMSG_REPORT_ERR_MSG(npt, "unknown IP len: %d for rta type %d",
253		    rta_len, rta->rta_type);
254		*perror = ENOTSUP;
255		return (NULL);
256	}
257	return (NULL);
258}
259
260int
261nlattr_get_ip(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
262{
263	int error = 0;
264
265	struct sockaddr *sa = parse_rta_ip((struct rtattr *)nla, npt, &error);
266
267	*((struct sockaddr **)target) = sa;
268	return (error);
269}
270
271static struct sockaddr *
272parse_rta_via(struct rtattr *rta, struct nl_pstate *npt, int *perror)
273{
274	struct rtvia *via = NL_RTA_DATA(rta);
275	int data_len = NL_RTA_DATA_LEN(rta);
276
277	if (__predict_false(data_len) < sizeof(struct rtvia)) {
278		NLMSG_REPORT_ERR_MSG(npt, "undersized RTA_VIA(%d) attr: len %d",
279		    rta->rta_type, data_len);
280		*perror = EINVAL;
281		return (NULL);
282	}
283	data_len -= offsetof(struct rtvia, rtvia_addr);
284
285	switch (via->rtvia_family) {
286	case AF_INET:
287		if (__predict_false(data_len < sizeof(struct in_addr))) {
288			*perror = EINVAL;
289			return (NULL);
290		}
291		return (parse_rta_ip4(via->rtvia_addr, npt, perror));
292	case AF_INET6:
293		if (__predict_false(data_len < sizeof(struct in6_addr))) {
294			*perror = EINVAL;
295			return (NULL);
296		}
297		return (parse_rta_ip6(via->rtvia_addr, npt, perror));
298	default:
299		*perror = ENOTSUP;
300		return (NULL);
301	}
302}
303
304int
305nlattr_get_ipvia(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
306{
307	int error = 0;
308
309	struct sockaddr *sa = parse_rta_via((struct rtattr *)nla, npt, &error);
310
311	*((struct sockaddr **)target) = sa;
312	return (error);
313}
314
315int
316nlattr_get_bool(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
317{
318	if (__predict_false(NLA_DATA_LEN(nla) != sizeof(bool))) {
319		NLMSG_REPORT_ERR_MSG(npt, "nla type %d size(%u) is not bool",
320		    nla->nla_type, NLA_DATA_LEN(nla));
321		return (EINVAL);
322	}
323	*((bool *)target) = *((const bool *)NL_RTA_DATA_CONST(nla));
324	return (0);
325}
326
327int
328nlattr_get_uint8(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
329{
330	if (__predict_false(NLA_DATA_LEN(nla) != sizeof(uint8_t))) {
331		NLMSG_REPORT_ERR_MSG(npt, "nla type %d size(%u) is not uint8",
332		    nla->nla_type, NLA_DATA_LEN(nla));
333		return (EINVAL);
334	}
335	*((uint8_t *)target) = *((const uint8_t *)NL_RTA_DATA_CONST(nla));
336	return (0);
337}
338
339int
340nlattr_get_uint16(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
341{
342	if (__predict_false(NLA_DATA_LEN(nla) != sizeof(uint16_t))) {
343		NLMSG_REPORT_ERR_MSG(npt, "nla type %d size(%u) is not uint16",
344		    nla->nla_type, NLA_DATA_LEN(nla));
345		return (EINVAL);
346	}
347	*((uint16_t *)target) = *((const uint16_t *)NL_RTA_DATA_CONST(nla));
348	return (0);
349}
350
351int
352nlattr_get_uint32(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
353{
354	if (__predict_false(NLA_DATA_LEN(nla) != sizeof(uint32_t))) {
355		NLMSG_REPORT_ERR_MSG(npt, "nla type %d size(%u) is not uint32",
356		    nla->nla_type, NLA_DATA_LEN(nla));
357		return (EINVAL);
358	}
359	*((uint32_t *)target) = *((const uint32_t *)NL_RTA_DATA_CONST(nla));
360	return (0);
361}
362
363int
364nlattr_get_uint64(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
365{
366	if (__predict_false(NLA_DATA_LEN(nla) != sizeof(uint64_t))) {
367		NLMSG_REPORT_ERR_MSG(npt, "nla type %d size(%u) is not uint64",
368		    nla->nla_type, NLA_DATA_LEN(nla));
369		return (EINVAL);
370	}
371	memcpy(target, NL_RTA_DATA_CONST(nla), sizeof(uint64_t));
372	return (0);
373}
374
375int
376nlattr_get_in_addr(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
377{
378	if (__predict_false(NLA_DATA_LEN(nla) != sizeof(in_addr_t))) {
379		NLMSG_REPORT_ERR_MSG(npt, "nla type %d size(%u) is not in_addr_t",
380		    nla->nla_type, NLA_DATA_LEN(nla));
381		return (EINVAL);
382	}
383	memcpy(target, NLA_DATA_CONST(nla), sizeof(in_addr_t));
384	return (0);
385}
386
387int
388nlattr_get_in6_addr(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
389{
390	if (__predict_false(NLA_DATA_LEN(nla) != sizeof(struct in6_addr))) {
391		NLMSG_REPORT_ERR_MSG(npt, "nla type %d size(%u) is not struct in6_addr",
392		    nla->nla_type, NLA_DATA_LEN(nla));
393		return (EINVAL);
394	}
395	memcpy(target, NLA_DATA_CONST(nla), sizeof(struct in6_addr));
396	return (0);
397}
398
399static int
400nlattr_get_ifp_internal(struct nlattr *nla, struct nl_pstate *npt,
401    void *target, bool zero_ok)
402{
403	if (__predict_false(NLA_DATA_LEN(nla) != sizeof(uint32_t))) {
404		NLMSG_REPORT_ERR_MSG(npt, "nla type %d size(%u) is not uint32",
405		    nla->nla_type, NLA_DATA_LEN(nla));
406		return (EINVAL);
407	}
408	uint32_t ifindex = *((const uint32_t *)NLA_DATA_CONST(nla));
409
410	if (ifindex == 0 && zero_ok) {
411		*((struct ifnet **)target) = NULL;
412		return (0);
413	}
414
415	NET_EPOCH_ASSERT();
416
417	struct ifnet *ifp = ifnet_byindex(ifindex);
418	if (__predict_false(ifp == NULL)) {
419		NLMSG_REPORT_ERR_MSG(npt, "nla type %d: ifindex %u invalid",
420		    nla->nla_type, ifindex);
421		return (ENOENT);
422	}
423	*((struct ifnet **)target) = ifp;
424	NL_LOG(LOG_DEBUG3, "nla type %d: ifindex %u -> %s", nla->nla_type,
425	    ifindex, if_name(ifp));
426
427	return (0);
428}
429
430int
431nlattr_get_ifp(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
432{
433	return (nlattr_get_ifp_internal(nla, npt, target, false));
434}
435
436int
437nlattr_get_ifpz(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
438{
439	return (nlattr_get_ifp_internal(nla, npt, target, true));
440}
441
442int
443nlattr_get_chara(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
444{
445	int maxlen = NLA_DATA_LEN(nla);
446	int target_size = (size_t)arg;
447	int len = strnlen((char *)NLA_DATA(nla), maxlen);
448
449	if (__predict_false(len >= maxlen) || __predict_false(len >= target_size)) {
450		NLMSG_REPORT_ERR_MSG(npt, "nla type %d size(%u) is not NULL-terminated or longer than %u",
451		    nla->nla_type, maxlen, target_size);
452		return (EINVAL);
453	}
454
455	strncpy((char *)target, (char *)NLA_DATA(nla), target_size);
456	return (0);
457}
458
459int
460nlattr_get_string(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
461{
462	int maxlen = NLA_DATA_LEN(nla);
463
464	if (__predict_false(strnlen((char *)NLA_DATA(nla), maxlen) >= maxlen)) {
465		NLMSG_REPORT_ERR_MSG(npt, "nla type %d size(%u) is not NULL-terminated",
466		    nla->nla_type, maxlen);
467		return (EINVAL);
468	}
469
470	*((char **)target) = (char *)NLA_DATA(nla);
471	return (0);
472}
473
474int
475nlattr_get_stringn(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
476{
477	int maxlen = NLA_DATA_LEN(nla);
478
479	char *buf = npt_alloc(npt, maxlen + 1);
480	if (buf == NULL)
481		return (ENOMEM);
482	buf[maxlen] = '\0';
483	memcpy(buf, NLA_DATA(nla), maxlen);
484
485	*((char **)target) = buf;
486	return (0);
487}
488
489int
490nlattr_get_bytes(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
491{
492	size_t size = (size_t)arg;
493
494	if (NLA_DATA_LEN(nla) != size)
495		return (EINVAL);
496
497	memcpy(target, NLA_DATA(nla), size);
498
499	return (0);
500}
501
502int
503nlattr_get_nla(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
504{
505	NL_LOG(LOG_DEBUG3, "STORING %p len %d", nla, nla->nla_len);
506	*((struct nlattr **)target) = nla;
507	return (0);
508}
509
510int
511nlattr_get_nested(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
512{
513	const struct nlhdr_parser *p = (const struct nlhdr_parser *)arg;
514	int error;
515
516	/* Assumes target points to the beginning of the structure */
517	error = nl_parse_header(NLA_DATA(nla), NLA_DATA_LEN(nla), p, npt, target);
518	return (error);
519}
520
521int
522nlattr_get_nested_ptr(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target)
523{
524	const struct nlhdr_parser *p = (const struct nlhdr_parser *)arg;
525	int error;
526
527	/* Assumes target points to the beginning of the structure */
528	error = nl_parse_header(NLA_DATA(nla), NLA_DATA_LEN(nla), p, npt, *(void **)target);
529	return (error);
530}
531
532int
533nlf_get_ifp(void *src, struct nl_pstate *npt, void *target)
534{
535	int ifindex = *((const int *)src);
536
537	NET_EPOCH_ASSERT();
538
539	struct ifnet *ifp = ifnet_byindex(ifindex);
540	if (ifp == NULL) {
541		NL_LOG(LOG_DEBUG, "ifindex %u invalid", ifindex);
542		return (ENOENT);
543	}
544	*((struct ifnet **)target) = ifp;
545
546	return (0);
547}
548
549int
550nlf_get_ifpz(void *src, struct nl_pstate *npt, void *target)
551{
552	int ifindex = *((const int *)src);
553
554	NET_EPOCH_ASSERT();
555
556	struct ifnet *ifp = ifnet_byindex(ifindex);
557	if (ifindex != 0 && ifp == NULL) {
558		NL_LOG(LOG_DEBUG, "ifindex %u invalid", ifindex);
559		return (ENOENT);
560	}
561	*((struct ifnet **)target) = ifp;
562
563	return (0);
564}
565
566int
567nlf_get_u8(void *src, struct nl_pstate *npt, void *target)
568{
569	uint8_t val = *((const uint8_t *)src);
570
571	*((uint8_t *)target) = val;
572
573	return (0);
574}
575
576int
577nlf_get_u8_u32(void *src, struct nl_pstate *npt, void *target)
578{
579	*((uint32_t *)target) = *((const uint8_t *)src);
580	return (0);
581}
582
583int
584nlf_get_u16(void *src, struct nl_pstate *npt, void *target)
585{
586	*((uint16_t *)target) = *((const uint16_t *)src);
587	return (0);
588}
589
590int
591nlf_get_u32(void *src, struct nl_pstate *npt, void *target)
592{
593	*((uint32_t *)target) = *((const uint32_t *)src);
594	return (0);
595}
596
597