1/*-
2 * SPDX-License-Identifier: BSD-2-Clause
3 *
4 * Copyright (c) 2022 Alexander V. Chernikov <melifaro@FreeBSD.org>
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 * 1. Redistributions of source code must retain the above copyright
10 *    notice, this list of conditions and the following disclaimer.
11 * 2. Redistributions in binary form must reproduce the above copyright
12 *    notice, this list of conditions and the following disclaimer in the
13 *    documentation and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
16 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25 * SUCH DAMAGE.
26 */
27
28#ifndef _NETLINK_NETLINK_MESSAGE_PARSER_H_
29#define _NETLINK_NETLINK_MESSAGE_PARSER_H_
30
31#ifdef _KERNEL
32
33#include <sys/bitset.h>
34
35/*
36 * It is not meant to be included directly
37 */
38
39/* Parsing state */
40struct linear_buffer {
41	char		*base;	/* Base allocated memory pointer */
42	uint32_t	offset;	/* Currently used offset */
43	uint32_t	size;	/* Total buffer size */
44} __aligned(_Alignof(__max_align_t));
45
46static inline void *
47lb_alloc(struct linear_buffer *lb, int len)
48{
49	len = roundup2(len, _Alignof(__max_align_t));
50	if (lb->offset + len > lb->size)
51		return (NULL);
52	void *data = (void *)(lb->base + lb->offset);
53	lb->offset += len;
54	return (data);
55}
56
57static inline void
58lb_clear(struct linear_buffer *lb)
59{
60	memset(lb->base, 0, lb->size);
61	lb->offset = 0;
62}
63
64#define	NL_MAX_ERROR_BUF	128
65#define	SCRATCH_BUFFER_SIZE	(1024 + NL_MAX_ERROR_BUF)
66struct nl_pstate {
67        struct linear_buffer    lb;		/* Per-message scratch buffer */
68        struct nlpcb		*nlp;		/* Originator socket */
69	struct nl_writer	*nw;		/* Message writer to use */
70	struct nlmsghdr		*hdr;		/* Current parsed message header */
71	uint32_t		err_off;	/* error offset from hdr start */
72        int			error;		/* last operation error */
73	char			*err_msg;	/* Description of last error */
74	struct nlattr		*cookie;	/* NLA to return to the userspace */
75	bool			strict;		/* Strict parsing required */
76};
77
78static inline void *
79npt_alloc(struct nl_pstate *npt, int len)
80{
81	return (lb_alloc(&npt->lb, len));
82}
83#define npt_alloc_sockaddr(_npt, _len)  ((struct sockaddr *)(npt_alloc(_npt, _len)))
84
85typedef int parse_field_f(void *hdr, struct nl_pstate *npt,
86    void *target);
87struct nlfield_parser {
88	uint16_t	off_in;
89	uint16_t	off_out;
90	parse_field_f	*cb;
91};
92static const struct nlfield_parser nlf_p_empty[] = {};
93
94int nlf_get_ifp(void *src, struct nl_pstate *npt, void *target);
95int nlf_get_ifpz(void *src, struct nl_pstate *npt, void *target);
96int nlf_get_u8(void *src, struct nl_pstate *npt, void *target);
97int nlf_get_u16(void *src, struct nl_pstate *npt, void *target);
98int nlf_get_u32(void *src, struct nl_pstate *npt, void *target);
99int nlf_get_u8_u32(void *src, struct nl_pstate *npt, void *target);
100
101
102struct nlattr_parser;
103typedef int parse_attr_f(struct nlattr *attr, struct nl_pstate *npt,
104    const void *arg, void *target);
105struct nlattr_parser {
106	uint16_t			type;	/* Attribute type */
107	uint16_t			off;	/* field offset in the target structure */
108	parse_attr_f			*cb;	/* parser function to call */
109	const void			*arg;
110};
111
112typedef bool strict_parser_f(void *hdr, struct nl_pstate *npt);
113typedef bool post_parser_f(void *parsed_attrs, struct nl_pstate *npt);
114
115struct nlhdr_parser {
116	int				nl_hdr_off; /* aligned netlink header size */
117	int				out_hdr_off; /* target header size */
118	int				fp_size;
119	int				np_size;
120	const struct nlfield_parser	*fp; /* array of header field parsers */
121	const struct nlattr_parser	*np; /* array of attribute parsers */
122	strict_parser_f			*sp; /* Pre-parse strict validation function */
123	post_parser_f			*post_parse;
124};
125
126#define	NL_DECLARE_PARSER_EXT(_name, _t, _sp, _fp, _np, _pp)	\
127static const struct nlhdr_parser _name = {			\
128	.nl_hdr_off = sizeof(_t),				\
129	.fp = &((_fp)[0]),					\
130	.np = &((_np)[0]),					\
131	.fp_size = NL_ARRAY_LEN(_fp),				\
132	.np_size = NL_ARRAY_LEN(_np),				\
133	.sp = _sp,						\
134	.post_parse = _pp,					\
135}
136
137#define	NL_DECLARE_PARSER(_name, _t, _fp, _np)			\
138	NL_DECLARE_PARSER_EXT(_name, _t, NULL, _fp, _np, NULL)
139
140#define	NL_DECLARE_STRICT_PARSER(_name, _t, _sp, _fp, _np)	\
141	NL_DECLARE_PARSER_EXT(_name, _t, _sp, _fp, _np, NULL)
142
143#define	NL_DECLARE_ARR_PARSER(_name, _t, _o, _fp, _np)	\
144static const struct nlhdr_parser _name = {		\
145	.nl_hdr_off = sizeof(_t),			\
146	.out_hdr_off = sizeof(_o),			\
147	.fp = &((_fp)[0]),				\
148	.np = &((_np)[0]),				\
149	.fp_size = NL_ARRAY_LEN(_fp),			\
150	.np_size = NL_ARRAY_LEN(_np),			\
151}
152
153#define	NL_DECLARE_ATTR_PARSER_EXT(_name, _np, _pp)	\
154static const struct nlhdr_parser _name = {		\
155	.np = &((_np)[0]),				\
156	.np_size = NL_ARRAY_LEN(_np),			\
157	.post_parse = (_pp)				\
158}
159
160#define	NL_DECLARE_ATTR_PARSER(_name, _np)		\
161	NL_DECLARE_ATTR_PARSER_EXT(_name, _np, NULL)
162
163#define	NL_ATTR_BMASK_SIZE	128
164BITSET_DEFINE(nlattr_bmask, NL_ATTR_BMASK_SIZE);
165
166void nl_get_attrs_bmask_raw(struct nlattr *nla_head, int len, struct nlattr_bmask *bm);
167bool nl_has_attr(const struct nlattr_bmask *bm, unsigned int nla_type);
168
169int nl_parse_attrs_raw(struct nlattr *nla_head, int len, const struct nlattr_parser *ps,
170    int pslen, struct nl_pstate *npt, void *target);
171
172int nlattr_get_flag(struct nlattr *nla, struct nl_pstate *npt,
173    const void *arg, void *target);
174int nlattr_get_ip(struct nlattr *nla, struct nl_pstate *npt,
175    const void *arg, void *target);
176int nlattr_get_bool(struct nlattr *nla, struct nl_pstate *npt,
177    const void *arg, void *target);
178int nlattr_get_uint8(struct nlattr *nla, struct nl_pstate *npt,
179    const void *arg, void *target);
180int nlattr_get_uint16(struct nlattr *nla, struct nl_pstate *npt,
181    const void *arg, void *target);
182int nlattr_get_uint32(struct nlattr *nla, struct nl_pstate *npt,
183    const void *arg, void *target);
184int nlattr_get_uint64(struct nlattr *nla, struct nl_pstate *npt,
185    const void *arg, void *target);
186int nlattr_get_in_addr(struct nlattr *nla, struct nl_pstate *npt,
187    const void *arg, void *target);
188int nlattr_get_in6_addr(struct nlattr *nla, struct nl_pstate *npt,
189    const void *arg, void *target);
190int nlattr_get_ifp(struct nlattr *nla, struct nl_pstate *npt,
191    const void *arg, void *target);
192int nlattr_get_ifpz(struct nlattr *nla, struct nl_pstate *npt,
193    const void *arg, void *target);
194int nlattr_get_ipvia(struct nlattr *nla, struct nl_pstate *npt,
195    const void *arg, void *target);
196int nlattr_get_chara(struct nlattr *nla, struct nl_pstate *npt,
197    const void *arg, void *target);
198int nlattr_get_string(struct nlattr *nla, struct nl_pstate *npt,
199    const void *arg, void *target);
200int nlattr_get_stringn(struct nlattr *nla, struct nl_pstate *npt,
201    const void *arg, void *target);
202int nlattr_get_bytes(struct nlattr *nla, struct nl_pstate *npt,
203    const void *arg, void *target);
204int nlattr_get_nla(struct nlattr *nla, struct nl_pstate *npt,
205    const void *arg, void *target);
206int nlattr_get_nested(struct nlattr *nla, struct nl_pstate *npt,
207    const void *arg, void *target);
208int nlattr_get_nested_ptr(struct nlattr *nla, struct nl_pstate *npt,
209    const void *arg, void *target);
210
211bool nlmsg_report_err_msg(struct nl_pstate *npt, const char *fmt, ...);
212
213#define	NLMSG_REPORT_ERR_MSG(_npt, _fmt, ...) {	\
214	nlmsg_report_err_msg(_npt, _fmt, ## __VA_ARGS__); \
215	NLP_LOG(LOG_DEBUG, (_npt)->nlp, _fmt, ## __VA_ARGS__); \
216}
217
218bool nlmsg_report_err_offset(struct nl_pstate *npt, uint32_t off);
219
220void nlmsg_report_cookie(struct nl_pstate *npt, struct nlattr *nla);
221void nlmsg_report_cookie_u32(struct nl_pstate *npt, uint32_t val);
222
223/*
224 * Have it inline so compiler can optimize field accesses into
225 * the list of direct function calls without iteration.
226 */
227static inline int
228nl_parse_header(void *hdr, int len, const struct nlhdr_parser *parser,
229    struct nl_pstate *npt, void *target)
230{
231	int error;
232
233	if (__predict_false(len < parser->nl_hdr_off)) {
234		if (npt->strict) {
235			nlmsg_report_err_msg(npt, "header too short: expected %d, got %d",
236			    parser->nl_hdr_off, len);
237			return (EINVAL);
238		}
239
240		/* Compat with older applications: pretend there's a full header */
241		void *tmp_hdr = npt_alloc(npt, parser->nl_hdr_off);
242		if (tmp_hdr == NULL)
243			return (EINVAL);
244		memcpy(tmp_hdr, hdr, len);
245		hdr = tmp_hdr;
246		len = parser->nl_hdr_off;
247	}
248
249	if (npt->strict && parser->sp != NULL && !parser->sp(hdr, npt))
250		return (EINVAL);
251
252	/* Extract fields first */
253	for (int i = 0; i < parser->fp_size; i++) {
254		const struct nlfield_parser *fp = &parser->fp[i];
255		void *src = (char *)hdr + fp->off_in;
256		void *dst = (char *)target + fp->off_out;
257
258		error = fp->cb(src, npt, dst);
259		if (error != 0)
260			return (error);
261	}
262
263	struct nlattr *nla_head = (struct nlattr *)((char *)hdr + parser->nl_hdr_off);
264	error = nl_parse_attrs_raw(nla_head, len - parser->nl_hdr_off, parser->np,
265	    parser->np_size, npt, target);
266
267	if (parser->post_parse != NULL && error == 0) {
268		if (!parser->post_parse(target, npt))
269			return (EINVAL);
270	}
271
272	return (error);
273}
274
275static inline int
276nl_parse_nested(struct nlattr *nla, const struct nlhdr_parser *parser,
277    struct nl_pstate *npt, void *target)
278{
279	struct nlattr *nla_head = (struct nlattr *)NLA_DATA(nla);
280
281	return (nl_parse_attrs_raw(nla_head, NLA_DATA_LEN(nla), parser->np,
282	    parser->np_size, npt, target));
283}
284
285/*
286 * Checks that attributes are sorted by attribute type.
287 */
288static inline void
289nl_verify_parsers(const struct nlhdr_parser **parser, int count)
290{
291#ifdef INVARIANTS
292	for (int i = 0; i < count; i++) {
293		const struct nlhdr_parser *p = parser[i];
294		int attr_type = 0;
295		for (int j = 0; j < p->np_size; j++) {
296			MPASS(p->np[j].type > attr_type);
297			attr_type = p->np[j].type;
298
299			/* Recurse into nested objects. */
300			if (p->np[j].cb == nlattr_get_nested ||
301			    p->np[j].cb == nlattr_get_nested_ptr) {
302				const struct nlhdr_parser *np =
303				    (const struct nlhdr_parser *)p->np[j].arg;
304				nl_verify_parsers(&np, 1);
305			}
306		}
307	}
308#endif
309}
310void nl_verify_parsers(const struct nlhdr_parser **parser, int count);
311#define	NL_VERIFY_PARSERS(_p)	nl_verify_parsers((_p), NL_ARRAY_LEN(_p))
312
313static inline int
314nl_parse_nlmsg(struct nlmsghdr *hdr, const struct nlhdr_parser *parser,
315    struct nl_pstate *npt, void *target)
316{
317	return (nl_parse_header(hdr + 1, hdr->nlmsg_len - sizeof(*hdr), parser, npt, target));
318}
319
320static inline void
321nl_get_attrs_bmask_nlmsg(struct nlmsghdr *hdr, const struct nlhdr_parser *parser,
322    struct nlattr_bmask *bm)
323{
324	struct nlattr *nla_head;
325
326	nla_head = (struct nlattr *)((char *)(hdr + 1) + parser->nl_hdr_off);
327	int len = hdr->nlmsg_len - sizeof(*hdr) - parser->nl_hdr_off;
328
329	nl_get_attrs_bmask_raw(nla_head, len, bm);
330}
331
332#endif
333#endif
334