1// SPDX-License-Identifier: GPL-2.0
2// Copyright (c) 2024 Meta
3
4#include "vmlinux.h"
5#include <bpf/bpf_helpers.h>
6#include <bpf/bpf_core_read.h>
7#include <bpf/bpf_endian.h>
8#include "bpf_tracing_net.h"
9#include "bpf_kfuncs.h"
10
11#define ATTR __always_inline
12#include "test_jhash.h"
13
14static bool ipv6_addr_loopback(const struct in6_addr *a)
15{
16	return (a->s6_addr32[0] | a->s6_addr32[1] |
17		a->s6_addr32[2] | (a->s6_addr32[3] ^ bpf_htonl(1))) == 0;
18}
19
20volatile const __u16 ports[2];
21unsigned int bucket[2];
22
23SEC("iter/tcp")
24int iter_tcp_soreuse(struct bpf_iter__tcp *ctx)
25{
26	struct sock *sk = (struct sock *)ctx->sk_common;
27	struct inet_hashinfo *hinfo;
28	unsigned int hash;
29	struct net *net;
30	int idx;
31
32	if (!sk)
33		return 0;
34
35	sk = bpf_core_cast(sk, struct sock);
36	if (sk->sk_family != AF_INET6 ||
37	    sk->sk_state != TCP_LISTEN ||
38	    !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr))
39		return 0;
40
41	if (sk->sk_num == ports[0])
42		idx = 0;
43	else if (sk->sk_num == ports[1])
44		idx = 1;
45	else
46		return 0;
47
48	/* bucket selection as in inet_lhash2_bucket_sk() */
49	net = sk->sk_net.net;
50	hash = jhash2(sk->sk_v6_rcv_saddr.s6_addr32, 4, net->hash_mix);
51	hash ^= sk->sk_num;
52	hinfo = net->ipv4.tcp_death_row.hashinfo;
53	bucket[idx] = hash & hinfo->lhash2_mask;
54	bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx));
55
56	return 0;
57}
58
59#define udp_sk(ptr) container_of(ptr, struct udp_sock, inet.sk)
60
61SEC("iter/udp")
62int iter_udp_soreuse(struct bpf_iter__udp *ctx)
63{
64	struct sock *sk = (struct sock *)ctx->udp_sk;
65	struct udp_table *udptable;
66	int idx;
67
68	if (!sk)
69		return 0;
70
71	sk = bpf_core_cast(sk, struct sock);
72	if (sk->sk_family != AF_INET6 ||
73	    !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr))
74		return 0;
75
76	if (sk->sk_num == ports[0])
77		idx = 0;
78	else if (sk->sk_num == ports[1])
79		idx = 1;
80	else
81		return 0;
82
83	/* bucket selection as in udp_hashslot2() */
84	udptable = sk->sk_net.net->ipv4.udp_table;
85	bucket[idx] = udp_sk(sk)->udp_portaddr_hash & udptable->mask;
86	bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx));
87
88	return 0;
89}
90
91char _license[] SEC("license") = "GPL";
92