1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2023 Meta Platforms, Inc. and affiliates. */
3
4#include <vmlinux.h>
5#include <bpf/bpf_tracing.h>
6#include <bpf/bpf_helpers.h>
7#include <bpf/bpf_core_read.h>
8#include "bpf_misc.h"
9#include "bpf_experimental.h"
10
11extern void bpf_rcu_read_lock(void) __ksym;
12extern void bpf_rcu_read_unlock(void) __ksym;
13
14struct node_data {
15	long key;
16	long list_data;
17	struct bpf_rb_node r;
18	struct bpf_list_node l;
19	struct bpf_refcount ref;
20};
21
22struct map_value {
23	struct node_data __kptr *node;
24};
25
26struct {
27	__uint(type, BPF_MAP_TYPE_ARRAY);
28	__type(key, int);
29	__type(value, struct map_value);
30	__uint(max_entries, 2);
31} stashed_nodes SEC(".maps");
32
33struct node_acquire {
34	long key;
35	long data;
36	struct bpf_rb_node node;
37	struct bpf_refcount refcount;
38};
39
40#define private(name) SEC(".bss." #name) __hidden __attribute__((aligned(8)))
41private(A) struct bpf_spin_lock lock;
42private(A) struct bpf_rb_root root __contains(node_data, r);
43private(A) struct bpf_list_head head __contains(node_data, l);
44
45private(B) struct bpf_spin_lock alock;
46private(B) struct bpf_rb_root aroot __contains(node_acquire, node);
47
48private(C) struct bpf_spin_lock block;
49private(C) struct bpf_rb_root broot __contains(node_data, r);
50
51static bool less(struct bpf_rb_node *node_a, const struct bpf_rb_node *node_b)
52{
53	struct node_data *a;
54	struct node_data *b;
55
56	a = container_of(node_a, struct node_data, r);
57	b = container_of(node_b, struct node_data, r);
58
59	return a->key < b->key;
60}
61
62static bool less_a(struct bpf_rb_node *a, const struct bpf_rb_node *b)
63{
64	struct node_acquire *node_a;
65	struct node_acquire *node_b;
66
67	node_a = container_of(a, struct node_acquire, node);
68	node_b = container_of(b, struct node_acquire, node);
69
70	return node_a->key < node_b->key;
71}
72
73static long __insert_in_tree_and_list(struct bpf_list_head *head,
74				      struct bpf_rb_root *root,
75				      struct bpf_spin_lock *lock)
76{
77	struct node_data *n, *m;
78
79	n = bpf_obj_new(typeof(*n));
80	if (!n)
81		return -1;
82
83	m = bpf_refcount_acquire(n);
84	m->key = 123;
85	m->list_data = 456;
86
87	bpf_spin_lock(lock);
88	if (bpf_rbtree_add(root, &n->r, less)) {
89		/* Failure to insert - unexpected */
90		bpf_spin_unlock(lock);
91		bpf_obj_drop(m);
92		return -2;
93	}
94	bpf_spin_unlock(lock);
95
96	bpf_spin_lock(lock);
97	if (bpf_list_push_front(head, &m->l)) {
98		/* Failure to insert - unexpected */
99		bpf_spin_unlock(lock);
100		return -3;
101	}
102	bpf_spin_unlock(lock);
103	return 0;
104}
105
106static long __stash_map_insert_tree(int idx, int val, struct bpf_rb_root *root,
107				    struct bpf_spin_lock *lock)
108{
109	struct map_value *mapval;
110	struct node_data *n, *m;
111
112	mapval = bpf_map_lookup_elem(&stashed_nodes, &idx);
113	if (!mapval)
114		return -1;
115
116	n = bpf_obj_new(typeof(*n));
117	if (!n)
118		return -2;
119
120	n->key = val;
121	m = bpf_refcount_acquire(n);
122
123	n = bpf_kptr_xchg(&mapval->node, n);
124	if (n) {
125		bpf_obj_drop(n);
126		bpf_obj_drop(m);
127		return -3;
128	}
129
130	bpf_spin_lock(lock);
131	if (bpf_rbtree_add(root, &m->r, less)) {
132		/* Failure to insert - unexpected */
133		bpf_spin_unlock(lock);
134		return -4;
135	}
136	bpf_spin_unlock(lock);
137	return 0;
138}
139
140static long __read_from_tree(struct bpf_rb_root *root,
141			     struct bpf_spin_lock *lock,
142			     bool remove_from_tree)
143{
144	struct bpf_rb_node *rb;
145	struct node_data *n;
146	long res = -99;
147
148	bpf_spin_lock(lock);
149
150	rb = bpf_rbtree_first(root);
151	if (!rb) {
152		bpf_spin_unlock(lock);
153		return -1;
154	}
155
156	n = container_of(rb, struct node_data, r);
157	res = n->key;
158
159	if (!remove_from_tree) {
160		bpf_spin_unlock(lock);
161		return res;
162	}
163
164	rb = bpf_rbtree_remove(root, rb);
165	bpf_spin_unlock(lock);
166	if (!rb)
167		return -2;
168	n = container_of(rb, struct node_data, r);
169	bpf_obj_drop(n);
170	return res;
171}
172
173static long __read_from_list(struct bpf_list_head *head,
174			     struct bpf_spin_lock *lock,
175			     bool remove_from_list)
176{
177	struct bpf_list_node *l;
178	struct node_data *n;
179	long res = -99;
180
181	bpf_spin_lock(lock);
182
183	l = bpf_list_pop_front(head);
184	if (!l) {
185		bpf_spin_unlock(lock);
186		return -1;
187	}
188
189	n = container_of(l, struct node_data, l);
190	res = n->list_data;
191
192	if (!remove_from_list) {
193		if (bpf_list_push_back(head, &n->l)) {
194			bpf_spin_unlock(lock);
195			return -2;
196		}
197	}
198
199	bpf_spin_unlock(lock);
200
201	if (remove_from_list)
202		bpf_obj_drop(n);
203	return res;
204}
205
206static long __read_from_unstash(int idx)
207{
208	struct node_data *n = NULL;
209	struct map_value *mapval;
210	long val = -99;
211
212	mapval = bpf_map_lookup_elem(&stashed_nodes, &idx);
213	if (!mapval)
214		return -1;
215
216	n = bpf_kptr_xchg(&mapval->node, n);
217	if (!n)
218		return -2;
219
220	val = n->key;
221	bpf_obj_drop(n);
222	return val;
223}
224
225#define INSERT_READ_BOTH(rem_tree, rem_list, desc)			\
226SEC("tc")								\
227__description(desc)							\
228__success __retval(579)							\
229long insert_and_remove_tree_##rem_tree##_list_##rem_list(void *ctx)	\
230{									\
231	long err, tree_data, list_data;					\
232									\
233	err = __insert_in_tree_and_list(&head, &root, &lock);		\
234	if (err)							\
235		return err;						\
236									\
237	err = __read_from_tree(&root, &lock, rem_tree);			\
238	if (err < 0)							\
239		return err;						\
240	else								\
241		tree_data = err;					\
242									\
243	err = __read_from_list(&head, &lock, rem_list);			\
244	if (err < 0)							\
245		return err;						\
246	else								\
247		list_data = err;					\
248									\
249	return tree_data + list_data;					\
250}
251
252/* After successful insert of struct node_data into both collections:
253 *   - it should have refcount = 2
254 *   - removing / not removing the node_data from a collection after
255 *     reading should have no effect on ability to read / remove from
256 *     the other collection
257 */
258INSERT_READ_BOTH(true, true, "insert_read_both: remove from tree + list");
259INSERT_READ_BOTH(false, false, "insert_read_both: remove from neither");
260INSERT_READ_BOTH(true, false, "insert_read_both: remove from tree");
261INSERT_READ_BOTH(false, true, "insert_read_both: remove from list");
262
263#undef INSERT_READ_BOTH
264#define INSERT_READ_BOTH(rem_tree, rem_list, desc)			\
265SEC("tc")								\
266__description(desc)							\
267__success __retval(579)							\
268long insert_and_remove_lf_tree_##rem_tree##_list_##rem_list(void *ctx)	\
269{									\
270	long err, tree_data, list_data;					\
271									\
272	err = __insert_in_tree_and_list(&head, &root, &lock);		\
273	if (err)							\
274		return err;						\
275									\
276	err = __read_from_list(&head, &lock, rem_list);			\
277	if (err < 0)							\
278		return err;						\
279	else								\
280		list_data = err;					\
281									\
282	err = __read_from_tree(&root, &lock, rem_tree);			\
283	if (err < 0)							\
284		return err;						\
285	else								\
286		tree_data = err;					\
287									\
288	return tree_data + list_data;					\
289}
290
291/* Similar to insert_read_both, but list data is read and possibly removed
292 * first
293 *
294 * Results should be no different than reading and possibly removing rbtree
295 * node first
296 */
297INSERT_READ_BOTH(true, true, "insert_read_both_list_first: remove from tree + list");
298INSERT_READ_BOTH(false, false, "insert_read_both_list_first: remove from neither");
299INSERT_READ_BOTH(true, false, "insert_read_both_list_first: remove from tree");
300INSERT_READ_BOTH(false, true, "insert_read_both_list_first: remove from list");
301
302#define INSERT_DOUBLE_READ_AND_DEL(read_fn, read_root, desc)		\
303SEC("tc")								\
304__description(desc)							\
305__success __retval(-1)							\
306long insert_double_##read_fn##_and_del_##read_root(void *ctx)		\
307{									\
308	long err, list_data;						\
309									\
310	err = __insert_in_tree_and_list(&head, &root, &lock);		\
311	if (err)							\
312		return err;						\
313									\
314	err = read_fn(&read_root, &lock, true);				\
315	if (err < 0)							\
316		return err;						\
317	else								\
318		list_data = err;					\
319									\
320	err = read_fn(&read_root, &lock, true);				\
321	if (err < 0)							\
322		return err;						\
323									\
324	return err + list_data;						\
325}
326
327/* Insert into both tree and list, then try reading-and-removing from either twice
328 *
329 * The second read-and-remove should fail on read step since the node has
330 * already been removed
331 */
332INSERT_DOUBLE_READ_AND_DEL(__read_from_tree, root, "insert_double_del: 2x read-and-del from tree");
333INSERT_DOUBLE_READ_AND_DEL(__read_from_list, head, "insert_double_del: 2x read-and-del from list");
334
335#define INSERT_STASH_READ(rem_tree, desc)				\
336SEC("tc")								\
337__description(desc)							\
338__success __retval(84)							\
339long insert_rbtree_and_stash__del_tree_##rem_tree(void *ctx)		\
340{									\
341	long err, tree_data, map_data;					\
342									\
343	err = __stash_map_insert_tree(0, 42, &root, &lock);		\
344	if (err)							\
345		return err;						\
346									\
347	err = __read_from_tree(&root, &lock, rem_tree);			\
348	if (err < 0)							\
349		return err;						\
350	else								\
351		tree_data = err;					\
352									\
353	err = __read_from_unstash(0);					\
354	if (err < 0)							\
355		return err;						\
356	else								\
357		map_data = err;						\
358									\
359	return tree_data + map_data;					\
360}
361
362/* Stash a refcounted node in map_val, insert same node into tree, then try
363 * reading data from tree then unstashed map_val, possibly removing from tree
364 *
365 * Removing from tree should have no effect on map_val kptr validity
366 */
367INSERT_STASH_READ(true, "insert_stash_read: remove from tree");
368INSERT_STASH_READ(false, "insert_stash_read: don't remove from tree");
369
370SEC("tc")
371__success
372long rbtree_refcounted_node_ref_escapes(void *ctx)
373{
374	struct node_acquire *n, *m;
375
376	n = bpf_obj_new(typeof(*n));
377	if (!n)
378		return 1;
379
380	bpf_spin_lock(&alock);
381	bpf_rbtree_add(&aroot, &n->node, less_a);
382	m = bpf_refcount_acquire(n);
383	bpf_spin_unlock(&alock);
384	if (!m)
385		return 2;
386
387	m->key = 2;
388	bpf_obj_drop(m);
389	return 0;
390}
391
392SEC("tc")
393__success
394long rbtree_refcounted_node_ref_escapes_owning_input(void *ctx)
395{
396	struct node_acquire *n, *m;
397
398	n = bpf_obj_new(typeof(*n));
399	if (!n)
400		return 1;
401
402	m = bpf_refcount_acquire(n);
403	m->key = 2;
404
405	bpf_spin_lock(&alock);
406	bpf_rbtree_add(&aroot, &n->node, less_a);
407	bpf_spin_unlock(&alock);
408
409	bpf_obj_drop(m);
410
411	return 0;
412}
413
414static long __stash_map_empty_xchg(struct node_data *n, int idx)
415{
416	struct map_value *mapval = bpf_map_lookup_elem(&stashed_nodes, &idx);
417
418	if (!mapval) {
419		bpf_obj_drop(n);
420		return 1;
421	}
422	n = bpf_kptr_xchg(&mapval->node, n);
423	if (n) {
424		bpf_obj_drop(n);
425		return 2;
426	}
427	return 0;
428}
429
430SEC("tc")
431long rbtree_wrong_owner_remove_fail_a1(void *ctx)
432{
433	struct node_data *n, *m;
434
435	n = bpf_obj_new(typeof(*n));
436	if (!n)
437		return 1;
438	m = bpf_refcount_acquire(n);
439
440	if (__stash_map_empty_xchg(n, 0)) {
441		bpf_obj_drop(m);
442		return 2;
443	}
444
445	if (__stash_map_empty_xchg(m, 1))
446		return 3;
447
448	return 0;
449}
450
451SEC("tc")
452long rbtree_wrong_owner_remove_fail_b(void *ctx)
453{
454	struct map_value *mapval;
455	struct node_data *n;
456	int idx = 0;
457
458	mapval = bpf_map_lookup_elem(&stashed_nodes, &idx);
459	if (!mapval)
460		return 1;
461
462	n = bpf_kptr_xchg(&mapval->node, NULL);
463	if (!n)
464		return 2;
465
466	bpf_spin_lock(&block);
467
468	bpf_rbtree_add(&broot, &n->r, less);
469
470	bpf_spin_unlock(&block);
471	return 0;
472}
473
474SEC("tc")
475long rbtree_wrong_owner_remove_fail_a2(void *ctx)
476{
477	struct map_value *mapval;
478	struct bpf_rb_node *res;
479	struct node_data *m;
480	int idx = 1;
481
482	mapval = bpf_map_lookup_elem(&stashed_nodes, &idx);
483	if (!mapval)
484		return 1;
485
486	m = bpf_kptr_xchg(&mapval->node, NULL);
487	if (!m)
488		return 2;
489	bpf_spin_lock(&lock);
490
491	/* make m non-owning ref */
492	bpf_list_push_back(&head, &m->l);
493	res = bpf_rbtree_remove(&root, &m->r);
494
495	bpf_spin_unlock(&lock);
496	if (res) {
497		bpf_obj_drop(container_of(res, struct node_data, r));
498		return 3;
499	}
500	return 0;
501}
502
503SEC("?fentry.s/bpf_testmod_test_read")
504__success
505int BPF_PROG(rbtree_sleepable_rcu,
506	     struct file *file, struct kobject *kobj,
507	     struct bin_attribute *bin_attr, char *buf, loff_t off, size_t len)
508{
509	struct bpf_rb_node *rb;
510	struct node_data *n, *m = NULL;
511
512	n = bpf_obj_new(typeof(*n));
513	if (!n)
514		return 0;
515
516	bpf_rcu_read_lock();
517	bpf_spin_lock(&lock);
518	bpf_rbtree_add(&root, &n->r, less);
519	rb = bpf_rbtree_first(&root);
520	if (!rb)
521		goto err_out;
522
523	rb = bpf_rbtree_remove(&root, rb);
524	if (!rb)
525		goto err_out;
526
527	m = container_of(rb, struct node_data, r);
528
529err_out:
530	bpf_spin_unlock(&lock);
531	bpf_rcu_read_unlock();
532	if (m)
533		bpf_obj_drop(m);
534	return 0;
535}
536
537SEC("?fentry.s/bpf_testmod_test_read")
538__success
539int BPF_PROG(rbtree_sleepable_rcu_no_explicit_rcu_lock,
540	     struct file *file, struct kobject *kobj,
541	     struct bin_attribute *bin_attr, char *buf, loff_t off, size_t len)
542{
543	struct bpf_rb_node *rb;
544	struct node_data *n, *m = NULL;
545
546	n = bpf_obj_new(typeof(*n));
547	if (!n)
548		return 0;
549
550	/* No explicit bpf_rcu_read_lock */
551	bpf_spin_lock(&lock);
552	bpf_rbtree_add(&root, &n->r, less);
553	rb = bpf_rbtree_first(&root);
554	if (!rb)
555		goto err_out;
556
557	rb = bpf_rbtree_remove(&root, rb);
558	if (!rb)
559		goto err_out;
560
561	m = container_of(rb, struct node_data, r);
562
563err_out:
564	bpf_spin_unlock(&lock);
565	/* No explicit bpf_rcu_read_unlock */
566	if (m)
567		bpf_obj_drop(m);
568	return 0;
569}
570
571char _license[] SEC("license") = "GPL";
572