1/*-
2 * SPDX-License-Identifier: BSD-2-Clause
3 *
4 * Copyright (c) 2022 Alexander V. Chernikov <melifaro@FreeBSD.org>
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 * 1. Redistributions of source code must retain the above copyright
10 *    notice, this list of conditions and the following disclaimer.
11 * 2. Redistributions in binary form must reproduce the above copyright
12 *    notice, this list of conditions and the following disclaimer in the
13 *    documentation and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
16 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25 * SUCH DAMAGE.
26 */
27
28#include <sys/param.h>
29#include <sys/malloc.h>
30#include <sys/lock.h>
31#include <sys/rmlock.h>
32#include <sys/mbuf.h>
33#include <sys/socket.h>
34#include <sys/socketvar.h>
35#include <sys/syslog.h>
36
37#include <netlink/netlink.h>
38#include <netlink/netlink_ctl.h>
39#include <netlink/netlink_linux.h>
40#include <netlink/netlink_var.h>
41
42#define	DEBUG_MOD_NAME	nl_writer
43#define	DEBUG_MAX_LEVEL	LOG_DEBUG3
44#include <netlink/netlink_debug.h>
45_DECLARE_DEBUG(LOG_INFO);
46
47static bool
48nlmsg_get_buf(struct nl_writer *nw, u_int len, bool waitok)
49{
50	const int mflag = waitok ? M_WAITOK : M_NOWAIT;
51
52	MPASS(nw->buf == NULL);
53
54	NL_LOG(LOG_DEBUG3, "Setting up nw %p len %u %s", nw, len,
55	    waitok ? "wait" : "nowait");
56
57	nw->buf = nl_buf_alloc(len, mflag);
58	if (__predict_false(nw->buf == NULL))
59		return (false);
60	nw->hdr = NULL;
61	nw->malloc_flag = mflag;
62	nw->num_messages = 0;
63	nw->enomem = false;
64
65	return (true);
66}
67
68static bool
69nl_send_one(struct nl_writer *nw)
70{
71
72	return (nl_send(nw, nw->nlp));
73}
74
75bool
76_nlmsg_get_unicast_writer(struct nl_writer *nw, int size, struct nlpcb *nlp)
77{
78	nw->nlp = nlp;
79	nw->cb = nl_send_one;
80
81	return (nlmsg_get_buf(nw, size, false));
82}
83
84bool
85_nlmsg_get_group_writer(struct nl_writer *nw, int size, int protocol, int group_id)
86{
87	nw->group.proto = protocol;
88	nw->group.id = group_id;
89	nw->cb = nl_send_group;
90
91	return (nlmsg_get_buf(nw, size, false));
92}
93
94void
95_nlmsg_ignore_limit(struct nl_writer *nw)
96{
97	nw->ignore_limit = true;
98}
99
100bool
101_nlmsg_flush(struct nl_writer *nw)
102{
103	bool result;
104
105	if (__predict_false(nw->hdr != NULL)) {
106		/* Last message has not been completed, skip it. */
107		int completed_len = (char *)nw->hdr - nw->buf->data;
108		/* Send completed messages */
109		nw->buf->datalen -= nw->buf->datalen - completed_len;
110		nw->hdr = NULL;
111        }
112
113	if (nw->buf->datalen == 0) {
114		MPASS(nw->num_messages == 0);
115		nl_buf_free(nw->buf);
116		nw->buf = NULL;
117		return (true);
118	}
119
120	result = nw->cb(nw);
121	nw->num_messages = 0;
122
123	if (!result) {
124		NL_LOG(LOG_DEBUG, "nw %p flush with %p() failed", nw, nw->cb);
125	}
126
127	return (result);
128}
129
130/*
131 * Flushes previous data and allocates new underlying storage
132 *  sufficient for holding at least @required_len bytes.
133 * Return true on success.
134 */
135bool
136_nlmsg_refill_buffer(struct nl_writer *nw, u_int required_len)
137{
138	struct nl_buf *new;
139	u_int completed_len, new_len, last_len;
140
141	MPASS(nw->buf != NULL);
142
143	if (nw->enomem)
144		return (false);
145
146	NL_LOG(LOG_DEBUG3, "no space at offset %u/%u (want %u), trying to "
147	    "reclaim", nw->buf->datalen, nw->buf->buflen, required_len);
148
149	/* Calculate new buffer size and allocate it. */
150	completed_len = (nw->hdr != NULL) ?
151	    (char *)nw->hdr - nw->buf->data : nw->buf->datalen;
152	if (completed_len > 0 && required_len < NLMBUFSIZE) {
153		/* We already ran out of space, use largest effective size. */
154		new_len = max(nw->buf->buflen, NLMBUFSIZE);
155	} else {
156		if (nw->buf->buflen < NLMBUFSIZE)
157			/* XXXGL: does this happen? */
158			new_len = NLMBUFSIZE;
159		else
160			new_len = nw->buf->buflen * 2;
161		while (new_len < required_len)
162			new_len *= 2;
163	}
164
165	new = nl_buf_alloc(new_len, nw->malloc_flag | M_ZERO);
166	if (__predict_false(new == NULL)) {
167		nw->enomem = true;
168		NL_LOG(LOG_DEBUG, "getting new buf failed, setting ENOMEM");
169		return (false);
170	}
171
172	/* Copy last (unfinished) header to the new storage. */
173	last_len = nw->buf->datalen - completed_len;
174	if (last_len > 0) {
175		memcpy(new->data, nw->hdr, last_len);
176		new->datalen = last_len;
177	}
178
179	NL_LOG(LOG_DEBUG2, "completed: %u bytes, copied: %u bytes",
180	    completed_len, last_len);
181
182	if (completed_len > 0) {
183		nlmsg_flush(nw);
184		MPASS(nw->buf == NULL);
185	} else
186		nl_buf_free(nw->buf);
187	nw->buf = new;
188	nw->hdr = (last_len > 0) ? (struct nlmsghdr *)new->data : NULL;
189	NL_LOG(LOG_DEBUG2, "switched buffer: used %u/%u bytes",
190	    new->datalen, new->buflen);
191
192	return (true);
193}
194
195bool
196_nlmsg_add(struct nl_writer *nw, uint32_t portid, uint32_t seq, uint16_t type,
197    uint16_t flags, uint32_t len)
198{
199	struct nl_buf *nb = nw->buf;
200	struct nlmsghdr *hdr;
201	u_int required_len;
202
203	MPASS(nw->hdr == NULL);
204
205	required_len = NETLINK_ALIGN(len + sizeof(struct nlmsghdr));
206	if (__predict_false(nb->datalen + required_len > nb->buflen)) {
207		if (!nlmsg_refill_buffer(nw, required_len))
208			return (false);
209		nb = nw->buf;
210	}
211
212	hdr = (struct nlmsghdr *)(&nb->data[nb->datalen]);
213
214	hdr->nlmsg_len = len;
215	hdr->nlmsg_type = type;
216	hdr->nlmsg_flags = flags;
217	hdr->nlmsg_seq = seq;
218	hdr->nlmsg_pid = portid;
219
220	nw->hdr = hdr;
221	nb->datalen += sizeof(struct nlmsghdr);
222
223	return (true);
224}
225
226bool
227_nlmsg_end(struct nl_writer *nw)
228{
229	struct nl_buf *nb = nw->buf;
230
231	MPASS(nw->hdr != NULL);
232
233	if (nw->enomem) {
234		NL_LOG(LOG_DEBUG, "ENOMEM when dumping message");
235		nlmsg_abort(nw);
236		return (false);
237	}
238
239	nw->hdr->nlmsg_len = nb->data + nb->datalen - (char *)nw->hdr;
240	NL_LOG(LOG_DEBUG2, "wrote msg len: %u type: %d: flags: 0x%X seq: %u pid: %u",
241	    nw->hdr->nlmsg_len, nw->hdr->nlmsg_type, nw->hdr->nlmsg_flags,
242	    nw->hdr->nlmsg_seq, nw->hdr->nlmsg_pid);
243	nw->hdr = NULL;
244	nw->num_messages++;
245	return (true);
246}
247
248void
249_nlmsg_abort(struct nl_writer *nw)
250{
251	struct nl_buf *nb = nw->buf;
252
253	if (nw->hdr != NULL) {
254		nb->datalen = (char *)nw->hdr - nb->data;
255		nw->hdr = NULL;
256	}
257}
258
259void
260nlmsg_ack(struct nlpcb *nlp, int error, struct nlmsghdr *hdr,
261    struct nl_pstate *npt)
262{
263	struct nlmsgerr *errmsg;
264	int payload_len;
265	uint32_t flags = nlp->nl_flags;
266	struct nl_writer *nw = npt->nw;
267	bool cap_ack;
268
269	payload_len = sizeof(struct nlmsgerr);
270
271	/*
272	 * The only case when we send the full message in the
273	 * reply is when there is an error and NETLINK_CAP_ACK
274	 * is not set.
275	 */
276	cap_ack = (error == 0) || (flags & NLF_CAP_ACK);
277	if (!cap_ack)
278		payload_len += hdr->nlmsg_len - sizeof(struct nlmsghdr);
279	payload_len = NETLINK_ALIGN(payload_len);
280
281	uint16_t nl_flags = cap_ack ? NLM_F_CAPPED : 0;
282	if ((npt->err_msg || npt->err_off) && nlp->nl_flags & NLF_EXT_ACK)
283		nl_flags |= NLM_F_ACK_TLVS;
284
285	NL_LOG(LOG_DEBUG3, "acknowledging message type %d seq %d",
286	    hdr->nlmsg_type, hdr->nlmsg_seq);
287
288	if (!nlmsg_add(nw, nlp->nl_port, hdr->nlmsg_seq, NLMSG_ERROR, nl_flags, payload_len))
289		goto enomem;
290
291	errmsg = nlmsg_reserve_data(nw, payload_len, struct nlmsgerr);
292	errmsg->error = error;
293	/* In case of error copy the whole message, else just the header */
294	memcpy(&errmsg->msg, hdr, cap_ack ? sizeof(*hdr) : hdr->nlmsg_len);
295
296	if (npt->err_msg != NULL && nlp->nl_flags & NLF_EXT_ACK)
297		nlattr_add_string(nw, NLMSGERR_ATTR_MSG, npt->err_msg);
298	if (npt->err_off != 0 && nlp->nl_flags & NLF_EXT_ACK)
299		nlattr_add_u32(nw, NLMSGERR_ATTR_OFFS, npt->err_off);
300	if (npt->cookie != NULL)
301		nlattr_add_raw(nw, npt->cookie);
302
303	if (nlmsg_end(nw))
304		return;
305enomem:
306	NLP_LOG(LOG_DEBUG, nlp, "error allocating ack data for message %d seq %u",
307	    hdr->nlmsg_type, hdr->nlmsg_seq);
308	nlmsg_abort(nw);
309}
310
311bool
312_nlmsg_end_dump(struct nl_writer *nw, int error, struct nlmsghdr *hdr)
313{
314	if (!nlmsg_add(nw, hdr->nlmsg_pid, hdr->nlmsg_seq, NLMSG_DONE, 0, sizeof(int))) {
315		NL_LOG(LOG_DEBUG, "Error finalizing table dump");
316		return (false);
317	}
318	/* Save operation result */
319	int *perror = nlmsg_reserve_object(nw, int);
320	NL_LOG(LOG_DEBUG2, "record error=%d at off %d (%p)", error,
321	    nw->buf->datalen, perror);
322	*perror = error;
323	nlmsg_end(nw);
324	nw->suppress_ack = true;
325
326	return (true);
327}
328
329/*
330 * KPI functions.
331 */
332
333u_int
334nlattr_save_offset(const struct nl_writer *nw)
335{
336	return (nw->buf->datalen - ((char *)nw->hdr - nw->buf->data));
337}
338
339void *
340nlmsg_reserve_data_raw(struct nl_writer *nw, size_t sz)
341{
342	struct nl_buf *nb = nw->buf;
343	void *data;
344
345	sz = NETLINK_ALIGN(sz);
346	if (__predict_false(nb->datalen + sz > nb->buflen)) {
347		if (!nlmsg_refill_buffer(nw, sz))
348			return (NULL);
349		nb = nw->buf;
350	}
351
352	data = &nb->data[nb->datalen];
353	bzero(data, sz);
354	nb->datalen += sz;
355
356	return (data);
357}
358
359bool
360nlattr_add(struct nl_writer *nw, int attr_type, int attr_len, const void *data)
361{
362	struct nl_buf *nb = nw->buf;
363	struct nlattr *nla;
364	u_int required_len;
365
366	required_len = NLA_ALIGN(attr_len + sizeof(struct nlattr));
367	if (__predict_false(nb->datalen + required_len > nb->buflen)) {
368		if (!nlmsg_refill_buffer(nw, required_len))
369			return (false);
370		nb = nw->buf;
371	}
372
373	nla = (struct nlattr *)(&nb->data[nb->datalen]);
374
375	nla->nla_len = attr_len + sizeof(struct nlattr);
376	nla->nla_type = attr_type;
377	if (attr_len > 0) {
378		if ((attr_len % 4) != 0) {
379			/* clear padding bytes */
380			bzero((char *)nla + required_len - 4, 4);
381		}
382		memcpy((nla + 1), data, attr_len);
383	}
384	nb->datalen += required_len;
385	return (true);
386}
387
388#include <netlink/ktest_netlink_message_writer.h>
389