1// SPDX-License-Identifier: GPL-2.0
2/* Converted from tools/testing/selftests/bpf/verifier/jeq_infer_not_null.c */
3
4#include <linux/bpf.h>
5#include <bpf/bpf_helpers.h>
6#include "bpf_misc.h"
7
8struct {
9	__uint(type, BPF_MAP_TYPE_XSKMAP);
10	__uint(max_entries, 1);
11	__type(key, int);
12	__type(value, int);
13} map_xskmap SEC(".maps");
14
15/* This is equivalent to the following program:
16 *
17 *   r6 = skb->sk;
18 *   r7 = sk_fullsock(r6);
19 *   r0 = sk_fullsock(r6);
20 *   if (r0 == 0) return 0;    (a)
21 *   if (r0 != r7) return 0;   (b)
22 *   *r7->type;                (c)
23 *   return 0;
24 *
25 * It is safe to dereference r7 at point (c), because of (a) and (b).
26 * The test verifies that relation r0 == r7 is propagated from (b) to (c).
27 */
28SEC("cgroup/skb")
29__description("jne/jeq infer not null, PTR_TO_SOCKET_OR_NULL -> PTR_TO_SOCKET for JNE false branch")
30__success __failure_unpriv __msg_unpriv("R7 pointer comparison")
31__retval(0)
32__naked void socket_for_jne_false_branch(void)
33{
34	asm volatile ("					\
35	/* r6 = skb->sk; */				\
36	r6 = *(u64*)(r1 + %[__sk_buff_sk]);		\
37	/* if (r6 == 0) return 0; */			\
38	if r6 == 0 goto l0_%=;				\
39	/* r7 = sk_fullsock(skb); */			\
40	r1 = r6;					\
41	call %[bpf_sk_fullsock];			\
42	r7 = r0;					\
43	/* r0 = sk_fullsock(skb); */			\
44	r1 = r6;					\
45	call %[bpf_sk_fullsock];			\
46	/* if (r0 == null) return 0; */			\
47	if r0 == 0 goto l0_%=;				\
48	/* if (r0 == r7) r0 = *(r7->type); */		\
49	if r0 != r7 goto l0_%=;		/* Use ! JNE ! */\
50	r0 = *(u32*)(r7 + %[bpf_sock_type]);		\
51l0_%=:	/* return 0 */					\
52	r0 = 0;						\
53	exit;						\
54"	:
55	: __imm(bpf_sk_fullsock),
56	  __imm_const(__sk_buff_sk, offsetof(struct __sk_buff, sk)),
57	  __imm_const(bpf_sock_type, offsetof(struct bpf_sock, type))
58	: __clobber_all);
59}
60
61/* Same as above, but verify that another branch of JNE still
62 * prohibits access to PTR_MAYBE_NULL.
63 */
64SEC("cgroup/skb")
65__description("jne/jeq infer not null, PTR_TO_SOCKET_OR_NULL unchanged for JNE true branch")
66__failure __msg("R7 invalid mem access 'sock_or_null'")
67__failure_unpriv __msg_unpriv("R7 pointer comparison")
68__naked void unchanged_for_jne_true_branch(void)
69{
70	asm volatile ("					\
71	/* r6 = skb->sk */				\
72	r6 = *(u64*)(r1 + %[__sk_buff_sk]);		\
73	/* if (r6 == 0) return 0; */			\
74	if r6 == 0 goto l0_%=;				\
75	/* r7 = sk_fullsock(skb); */			\
76	r1 = r6;					\
77	call %[bpf_sk_fullsock];			\
78	r7 = r0;					\
79	/* r0 = sk_fullsock(skb); */			\
80	r1 = r6;					\
81	call %[bpf_sk_fullsock];			\
82	/* if (r0 == null) return 0; */			\
83	if r0 != 0 goto l0_%=;				\
84	/* if (r0 == r7) return 0; */			\
85	if r0 != r7 goto l1_%=;		/* Use ! JNE ! */\
86	goto l0_%=;					\
87l1_%=:	/* r0 = *(r7->type); */				\
88	r0 = *(u32*)(r7 + %[bpf_sock_type]);		\
89l0_%=:	/* return 0 */					\
90	r0 = 0;						\
91	exit;						\
92"	:
93	: __imm(bpf_sk_fullsock),
94	  __imm_const(__sk_buff_sk, offsetof(struct __sk_buff, sk)),
95	  __imm_const(bpf_sock_type, offsetof(struct bpf_sock, type))
96	: __clobber_all);
97}
98
99/* Same as a first test, but not null should be inferred for JEQ branch */
100SEC("cgroup/skb")
101__description("jne/jeq infer not null, PTR_TO_SOCKET_OR_NULL -> PTR_TO_SOCKET for JEQ true branch")
102__success __failure_unpriv __msg_unpriv("R7 pointer comparison")
103__retval(0)
104__naked void socket_for_jeq_true_branch(void)
105{
106	asm volatile ("					\
107	/* r6 = skb->sk; */				\
108	r6 = *(u64*)(r1 + %[__sk_buff_sk]);		\
109	/* if (r6 == null) return 0; */			\
110	if r6 == 0 goto l0_%=;				\
111	/* r7 = sk_fullsock(skb); */			\
112	r1 = r6;					\
113	call %[bpf_sk_fullsock];			\
114	r7 = r0;					\
115	/* r0 = sk_fullsock(skb); */			\
116	r1 = r6;					\
117	call %[bpf_sk_fullsock];			\
118	/* if (r0 == null) return 0; */			\
119	if r0 == 0 goto l0_%=;				\
120	/* if (r0 != r7) return 0; */			\
121	if r0 == r7 goto l1_%=;		/* Use ! JEQ ! */\
122	goto l0_%=;					\
123l1_%=:	/* r0 = *(r7->type); */				\
124	r0 = *(u32*)(r7 + %[bpf_sock_type]);		\
125l0_%=:	/* return 0; */					\
126	r0 = 0;						\
127	exit;						\
128"	:
129	: __imm(bpf_sk_fullsock),
130	  __imm_const(__sk_buff_sk, offsetof(struct __sk_buff, sk)),
131	  __imm_const(bpf_sock_type, offsetof(struct bpf_sock, type))
132	: __clobber_all);
133}
134
135/* Same as above, but verify that another branch of JNE still
136 * prohibits access to PTR_MAYBE_NULL.
137 */
138SEC("cgroup/skb")
139__description("jne/jeq infer not null, PTR_TO_SOCKET_OR_NULL unchanged for JEQ false branch")
140__failure __msg("R7 invalid mem access 'sock_or_null'")
141__failure_unpriv __msg_unpriv("R7 pointer comparison")
142__naked void unchanged_for_jeq_false_branch(void)
143{
144	asm volatile ("					\
145	/* r6 = skb->sk; */				\
146	r6 = *(u64*)(r1 + %[__sk_buff_sk]);		\
147	/* if (r6 == null) return 0; */			\
148	if r6 == 0 goto l0_%=;				\
149	/* r7 = sk_fullsock(skb); */			\
150	r1 = r6;					\
151	call %[bpf_sk_fullsock];			\
152	r7 = r0;					\
153	/* r0 = sk_fullsock(skb); */			\
154	r1 = r6;					\
155	call %[bpf_sk_fullsock];			\
156	/* if (r0 == null) return 0; */			\
157	if r0 == 0 goto l0_%=;				\
158	/* if (r0 != r7) r0 = *(r7->type); */		\
159	if r0 == r7 goto l0_%=;		/* Use ! JEQ ! */\
160	r0 = *(u32*)(r7 + %[bpf_sock_type]);		\
161l0_%=:	/* return 0; */					\
162	r0 = 0;						\
163	exit;						\
164"	:
165	: __imm(bpf_sk_fullsock),
166	  __imm_const(__sk_buff_sk, offsetof(struct __sk_buff, sk)),
167	  __imm_const(bpf_sock_type, offsetof(struct bpf_sock, type))
168	: __clobber_all);
169}
170
171/* Maps are treated in a different branch of `mark_ptr_not_null_reg`,
172 * so separate test for maps case.
173 */
174SEC("xdp")
175__description("jne/jeq infer not null, PTR_TO_MAP_VALUE_OR_NULL -> PTR_TO_MAP_VALUE")
176__success __retval(0)
177__naked void null_ptr_to_map_value(void)
178{
179	asm volatile ("					\
180	/* r9 = &some stack to use as key */		\
181	r1 = 0;						\
182	*(u32*)(r10 - 8) = r1;				\
183	r9 = r10;					\
184	r9 += -8;					\
185	/* r8 = process local map */			\
186	r8 = %[map_xskmap] ll;				\
187	/* r6 = map_lookup_elem(r8, r9); */		\
188	r1 = r8;					\
189	r2 = r9;					\
190	call %[bpf_map_lookup_elem];			\
191	r6 = r0;					\
192	/* r7 = map_lookup_elem(r8, r9); */		\
193	r1 = r8;					\
194	r2 = r9;					\
195	call %[bpf_map_lookup_elem];			\
196	r7 = r0;					\
197	/* if (r6 == 0) return 0; */			\
198	if r6 == 0 goto l0_%=;				\
199	/* if (r6 != r7) return 0; */			\
200	if r6 != r7 goto l0_%=;				\
201	/* read *r7; */					\
202	r0 = *(u32*)(r7 + %[bpf_xdp_sock_queue_id]);	\
203l0_%=:	/* return 0; */					\
204	r0 = 0;						\
205	exit;						\
206"	:
207	: __imm(bpf_map_lookup_elem),
208	  __imm_addr(map_xskmap),
209	  __imm_const(bpf_xdp_sock_queue_id, offsetof(struct bpf_xdp_sock, queue_id))
210	: __clobber_all);
211}
212
213char _license[] SEC("license") = "GPL";
214