Merge branch 'bpf_fanout'

Willem de Bruijn says:

====================
packet: add cBPF and eBPF fanout modes

Allow programmable fanout modes. Support both classical BPF programs
passed directly and extended BPF programs passed by file descriptor.

One use case is packet steering by deep packet inspection, for
instance for packet steering by application layer header fields.

Separate the configuration of the fanout mode and the configuration
of the program, to allow dynamic updates to the latter at runtime.

Changes
  v1 -> v2:
    - follow SO_LOCK_FILTER semantics on filter updates
    - only accept eBPF programs of type BPF_PROG_TYPE_SOCKET_FILTER
    - rename PACKET_FANOUT_BPF to PACKET_FANOUT_CBPF to match
      man 2 bpf usage: "classic" vs. "extended" BPF.
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/uapi/linux/if_packet.h b/include/uapi/linux/if_packet.h
index d3d715f8c..9e7edfd 100644
--- a/include/uapi/linux/if_packet.h
+++ b/include/uapi/linux/if_packet.h
@@ -55,6 +55,7 @@
 #define PACKET_TX_HAS_OFF		19
 #define PACKET_QDISC_BYPASS		20
 #define PACKET_ROLLOVER_STATS		21
+#define PACKET_FANOUT_DATA		22
 
 #define PACKET_FANOUT_HASH		0
 #define PACKET_FANOUT_LB		1
@@ -62,6 +63,8 @@
 #define PACKET_FANOUT_ROLLOVER		3
 #define PACKET_FANOUT_RND		4
 #define PACKET_FANOUT_QM		5
+#define PACKET_FANOUT_CBPF		6
+#define PACKET_FANOUT_EBPF		7
 #define PACKET_FANOUT_FLAG_ROLLOVER	0x1000
 #define PACKET_FANOUT_FLAG_DEFRAG	0x8000
 
diff --git a/net/packet/af_packet.c b/net/packet/af_packet.c
index b5afe53..7b8e39a 100644
--- a/net/packet/af_packet.c
+++ b/net/packet/af_packet.c
@@ -92,6 +92,7 @@
 #ifdef CONFIG_INET
 #include <net/inet_common.h>
 #endif
+#include <linux/bpf.h>
 
 #include "internal.h"
 
@@ -1410,6 +1411,22 @@
 	return skb_get_queue_mapping(skb) % num;
 }
 
+static unsigned int fanout_demux_bpf(struct packet_fanout *f,
+				     struct sk_buff *skb,
+				     unsigned int num)
+{
+	struct bpf_prog *prog;
+	unsigned int ret = 0;
+
+	rcu_read_lock();
+	prog = rcu_dereference(f->bpf_prog);
+	if (prog)
+		ret = BPF_PROG_RUN(prog, skb) % num;
+	rcu_read_unlock();
+
+	return ret;
+}
+
 static bool fanout_has_flag(struct packet_fanout *f, u16 flag)
 {
 	return f->flags & (flag >> 8);
@@ -1454,6 +1471,10 @@
 	case PACKET_FANOUT_ROLLOVER:
 		idx = fanout_demux_rollover(f, skb, 0, false, num);
 		break;
+	case PACKET_FANOUT_CBPF:
+	case PACKET_FANOUT_EBPF:
+		idx = fanout_demux_bpf(f, skb, num);
+		break;
 	}
 
 	if (fanout_has_flag(f, PACKET_FANOUT_FLAG_ROLLOVER))
@@ -1502,6 +1523,103 @@
 	return false;
 }
 
+static void fanout_init_data(struct packet_fanout *f)
+{
+	switch (f->type) {
+	case PACKET_FANOUT_LB:
+		atomic_set(&f->rr_cur, 0);
+		break;
+	case PACKET_FANOUT_CBPF:
+	case PACKET_FANOUT_EBPF:
+		RCU_INIT_POINTER(f->bpf_prog, NULL);
+		break;
+	}
+}
+
+static void __fanout_set_data_bpf(struct packet_fanout *f, struct bpf_prog *new)
+{
+	struct bpf_prog *old;
+
+	spin_lock(&f->lock);
+	old = rcu_dereference_protected(f->bpf_prog, lockdep_is_held(&f->lock));
+	rcu_assign_pointer(f->bpf_prog, new);
+	spin_unlock(&f->lock);
+
+	if (old) {
+		synchronize_net();
+		bpf_prog_destroy(old);
+	}
+}
+
+static int fanout_set_data_cbpf(struct packet_sock *po, char __user *data,
+				unsigned int len)
+{
+	struct bpf_prog *new;
+	struct sock_fprog fprog;
+	int ret;
+
+	if (sock_flag(&po->sk, SOCK_FILTER_LOCKED))
+		return -EPERM;
+	if (len != sizeof(fprog))
+		return -EINVAL;
+	if (copy_from_user(&fprog, data, len))
+		return -EFAULT;
+
+	ret = bpf_prog_create_from_user(&new, &fprog, NULL);
+	if (ret)
+		return ret;
+
+	__fanout_set_data_bpf(po->fanout, new);
+	return 0;
+}
+
+static int fanout_set_data_ebpf(struct packet_sock *po, char __user *data,
+				unsigned int len)
+{
+	struct bpf_prog *new;
+	u32 fd;
+
+	if (sock_flag(&po->sk, SOCK_FILTER_LOCKED))
+		return -EPERM;
+	if (len != sizeof(fd))
+		return -EINVAL;
+	if (copy_from_user(&fd, data, len))
+		return -EFAULT;
+
+	new = bpf_prog_get(fd);
+	if (IS_ERR(new))
+		return PTR_ERR(new);
+	if (new->type != BPF_PROG_TYPE_SOCKET_FILTER) {
+		bpf_prog_put(new);
+		return -EINVAL;
+	}
+
+	__fanout_set_data_bpf(po->fanout, new);
+	return 0;
+}
+
+static int fanout_set_data(struct packet_sock *po, char __user *data,
+			   unsigned int len)
+{
+	switch (po->fanout->type) {
+	case PACKET_FANOUT_CBPF:
+		return fanout_set_data_cbpf(po, data, len);
+	case PACKET_FANOUT_EBPF:
+		return fanout_set_data_ebpf(po, data, len);
+	default:
+		return -EINVAL;
+	};
+}
+
+static void fanout_release_data(struct packet_fanout *f)
+{
+	switch (f->type) {
+	case PACKET_FANOUT_CBPF:
+	case PACKET_FANOUT_EBPF:
+		__fanout_set_data_bpf(f, NULL);
+	};
+}
+
 static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
 {
 	struct packet_sock *po = pkt_sk(sk);
@@ -1519,6 +1637,8 @@
 	case PACKET_FANOUT_CPU:
 	case PACKET_FANOUT_RND:
 	case PACKET_FANOUT_QM:
+	case PACKET_FANOUT_CBPF:
+	case PACKET_FANOUT_EBPF:
 		break;
 	default:
 		return -EINVAL;
@@ -1561,10 +1681,10 @@
 		match->id = id;
 		match->type = type;
 		match->flags = flags;
-		atomic_set(&match->rr_cur, 0);
 		INIT_LIST_HEAD(&match->list);
 		spin_lock_init(&match->lock);
 		atomic_set(&match->sk_ref, 0);
+		fanout_init_data(match);
 		match->prot_hook.type = po->prot_hook.type;
 		match->prot_hook.dev = po->prot_hook.dev;
 		match->prot_hook.func = packet_rcv_fanout;
@@ -1610,6 +1730,7 @@
 	if (atomic_dec_and_test(&f->sk_ref)) {
 		list_del(&f->list);
 		dev_remove_pack(&f->prot_hook);
+		fanout_release_data(f);
 		kfree(f);
 	}
 	mutex_unlock(&fanout_mutex);
@@ -3529,6 +3650,13 @@
 
 		return fanout_add(sk, val & 0xffff, val >> 16);
 	}
+	case PACKET_FANOUT_DATA:
+	{
+		if (!po->fanout)
+			return -EINVAL;
+
+		return fanout_set_data(po, optval, optlen);
+	}
 	case PACKET_TX_HAS_OFF:
 	{
 		unsigned int val;
diff --git a/net/packet/internal.h b/net/packet/internal.h
index e20b3e8..9ee4631 100644
--- a/net/packet/internal.h
+++ b/net/packet/internal.h
@@ -79,7 +79,10 @@
 	u16			id;
 	u8			type;
 	u8			flags;
-	atomic_t		rr_cur;
+	union {
+		atomic_t		rr_cur;
+		struct bpf_prog __rcu	*bpf_prog;
+	};
 	struct list_head	list;
 	struct sock		*arr[PACKET_FANOUT_MAX];
 	spinlock_t		lock;
diff --git a/tools/testing/selftests/net/psock_fanout.c b/tools/testing/selftests/net/psock_fanout.c
index 08c2a36..4124593 100644
--- a/tools/testing/selftests/net/psock_fanout.c
+++ b/tools/testing/selftests/net/psock_fanout.c
@@ -19,6 +19,8 @@
  *   - PACKET_FANOUT_LB
  *   - PACKET_FANOUT_CPU
  *   - PACKET_FANOUT_ROLLOVER
+ *   - PACKET_FANOUT_CBPF
+ *   - PACKET_FANOUT_EBPF
  *
  * Todo:
  * - functionality: PACKET_FANOUT_FLAG_DEFRAG
@@ -44,7 +46,9 @@
 #include <arpa/inet.h>
 #include <errno.h>
 #include <fcntl.h>
+#include <linux/unistd.h>	/* for __NR_bpf */
 #include <linux/filter.h>
+#include <linux/bpf.h>
 #include <linux/if_packet.h>
 #include <net/ethernet.h>
 #include <netinet/ip.h>
@@ -91,6 +95,51 @@
 	return fd;
 }
 
+static void sock_fanout_set_ebpf(int fd)
+{
+	const int len_off = __builtin_offsetof(struct __sk_buff, len);
+	struct bpf_insn prog[] = {
+		{ BPF_ALU64 | BPF_MOV | BPF_X,   6, 1, 0, 0 },
+		{ BPF_LDX   | BPF_W   | BPF_MEM, 0, 6, len_off, 0 },
+		{ BPF_JMP   | BPF_JGE | BPF_K,   0, 0, 1, DATA_LEN },
+		{ BPF_JMP   | BPF_JA  | BPF_K,   0, 0, 4, 0 },
+		{ BPF_LD    | BPF_B   | BPF_ABS, 0, 0, 0, 0x50 },
+		{ BPF_JMP   | BPF_JEQ | BPF_K,   0, 0, 2, DATA_CHAR },
+		{ BPF_JMP   | BPF_JEQ | BPF_K,   0, 0, 1, DATA_CHAR_1 },
+		{ BPF_ALU   | BPF_MOV | BPF_K,   0, 0, 0, 0 },
+		{ BPF_JMP   | BPF_EXIT,          0, 0, 0, 0 }
+	};
+	char log_buf[512];
+	union bpf_attr attr;
+	int pfd;
+
+	memset(&attr, 0, sizeof(attr));
+	attr.prog_type = BPF_PROG_TYPE_SOCKET_FILTER;
+	attr.insns = (unsigned long) prog;
+	attr.insn_cnt = sizeof(prog) / sizeof(prog[0]);
+	attr.license = (unsigned long) "GPL";
+	attr.log_buf = (unsigned long) log_buf,
+	attr.log_size = sizeof(log_buf),
+	attr.log_level = 1,
+
+	pfd = syscall(__NR_bpf, BPF_PROG_LOAD, &attr, sizeof(attr));
+	if (pfd < 0) {
+		perror("bpf");
+		fprintf(stderr, "bpf verifier:\n%s\n", log_buf);
+		exit(1);
+	}
+
+	if (setsockopt(fd, SOL_PACKET, PACKET_FANOUT_DATA, &pfd, sizeof(pfd))) {
+		perror("fanout data ebpf");
+		exit(1);
+	}
+
+	if (close(pfd)) {
+		perror("close ebpf");
+		exit(1);
+	}
+}
+
 static char *sock_fanout_open_ring(int fd)
 {
 	struct tpacket_req req = {
@@ -115,8 +164,8 @@
 
 	ring = mmap(0, req.tp_block_size * req.tp_block_nr,
 		    PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
-	if (!ring) {
-		fprintf(stderr, "packetsock ring mmap\n");
+	if (ring == MAP_FAILED) {
+		perror("packetsock ring mmap");
 		exit(1);
 	}
 
@@ -209,6 +258,7 @@
 {
 	const int expect0[] = { 0, 0 };
 	char *rings[2];
+	uint8_t type = typeflags & 0xFF;
 	int fds[2], fds_udp[2][2], ret;
 
 	fprintf(stderr, "test: datapath 0x%hx\n", typeflags);
@@ -219,6 +269,11 @@
 		fprintf(stderr, "ERROR: failed open\n");
 		exit(1);
 	}
+	if (type == PACKET_FANOUT_CBPF)
+		sock_setfilter(fds[0], SOL_PACKET, PACKET_FANOUT_DATA);
+	else if (type == PACKET_FANOUT_EBPF)
+		sock_fanout_set_ebpf(fds[0]);
+
 	rings[0] = sock_fanout_open_ring(fds[0]);
 	rings[1] = sock_fanout_open_ring(fds[1]);
 	pair_udp_open(fds_udp[0], PORT_BASE);
@@ -227,11 +282,11 @@
 
 	/* Send data, but not enough to overflow a queue */
 	pair_udp_send(fds_udp[0], 15);
-	pair_udp_send(fds_udp[1], 5);
+	pair_udp_send_char(fds_udp[1], 5, DATA_CHAR_1);
 	ret = sock_fanout_read(fds, rings, expect1);
 
 	/* Send more data, overflow the queue */
-	pair_udp_send(fds_udp[0], 15);
+	pair_udp_send_char(fds_udp[0], 15, DATA_CHAR_1);
 	/* TODO: ensure consistent order between expect1 and expect2 */
 	ret |= sock_fanout_read(fds, rings, expect2);
 
@@ -275,6 +330,7 @@
 	const int expect_rb[2][2]	= { { 15, 5 },  { 20, 15 } };
 	const int expect_cpu0[2][2]	= { { 20, 0 },  { 20, 0 } };
 	const int expect_cpu1[2][2]	= { { 0, 20 },  { 0, 20 } };
+	const int expect_bpf[2][2]	= { { 15, 5 },  { 15, 20 } };
 	int port_off = 2, tries = 5, ret;
 
 	test_control_single();
@@ -296,6 +352,11 @@
 	ret |= test_datapath(PACKET_FANOUT_ROLLOVER,
 			     port_off, expect_rb[0], expect_rb[1]);
 
+	ret |= test_datapath(PACKET_FANOUT_CBPF,
+			     port_off, expect_bpf[0], expect_bpf[1]);
+	ret |= test_datapath(PACKET_FANOUT_EBPF,
+			     port_off, expect_bpf[0], expect_bpf[1]);
+
 	set_cpuaffinity(0);
 	ret |= test_datapath(PACKET_FANOUT_CPU, port_off,
 			     expect_cpu0[0], expect_cpu0[1]);
diff --git a/tools/testing/selftests/net/psock_lib.h b/tools/testing/selftests/net/psock_lib.h
index 37da54a..24bc7ec 100644
--- a/tools/testing/selftests/net/psock_lib.h
+++ b/tools/testing/selftests/net/psock_lib.h
@@ -30,6 +30,7 @@
 
 #define DATA_LEN			100
 #define DATA_CHAR			'a'
+#define DATA_CHAR_1			'b'
 
 #define PORT_BASE			8000
 
@@ -37,29 +38,36 @@
 # define __maybe_unused		__attribute__ ((__unused__))
 #endif
 
-static __maybe_unused void pair_udp_setfilter(int fd)
+static __maybe_unused void sock_setfilter(int fd, int lvl, int optnum)
 {
 	struct sock_filter bpf_filter[] = {
 		{ 0x80, 0, 0, 0x00000000 },  /* LD  pktlen		      */
-		{ 0x35, 0, 5, DATA_LEN   },  /* JGE DATA_LEN  [f goto nomatch]*/
+		{ 0x35, 0, 4, DATA_LEN   },  /* JGE DATA_LEN  [f goto nomatch]*/
 		{ 0x30, 0, 0, 0x00000050 },  /* LD  ip[80]		      */
-		{ 0x15, 0, 3, DATA_CHAR  },  /* JEQ DATA_CHAR [f goto nomatch]*/
-		{ 0x30, 0, 0, 0x00000051 },  /* LD  ip[81]		      */
-		{ 0x15, 0, 1, DATA_CHAR  },  /* JEQ DATA_CHAR [f goto nomatch]*/
+		{ 0x15, 1, 0, DATA_CHAR  },  /* JEQ DATA_CHAR   [t goto match]*/
+		{ 0x15, 0, 1, DATA_CHAR_1},  /* JEQ DATA_CHAR_1 [t goto match]*/
 		{ 0x06, 0, 0, 0x00000060 },  /* RET match	              */
 		{ 0x06, 0, 0, 0x00000000 },  /* RET no match		      */
 	};
 	struct sock_fprog bpf_prog;
 
+	if (lvl == SOL_PACKET && optnum == PACKET_FANOUT_DATA)
+		bpf_filter[5].code = 0x16;   /* RET A			      */
+
 	bpf_prog.filter = bpf_filter;
 	bpf_prog.len = sizeof(bpf_filter) / sizeof(struct sock_filter);
-	if (setsockopt(fd, SOL_SOCKET, SO_ATTACH_FILTER, &bpf_prog,
+	if (setsockopt(fd, lvl, optnum, &bpf_prog,
 		       sizeof(bpf_prog))) {
 		perror("setsockopt SO_ATTACH_FILTER");
 		exit(1);
 	}
 }
 
+static __maybe_unused void pair_udp_setfilter(int fd)
+{
+	sock_setfilter(fd, SOL_SOCKET, SO_ATTACH_FILTER);
+}
+
 static __maybe_unused void pair_udp_open(int fds[], uint16_t port)
 {
 	struct sockaddr_in saddr, daddr;
@@ -96,11 +104,11 @@
 	}
 }
 
-static __maybe_unused void pair_udp_send(int fds[], int num)
+static __maybe_unused void pair_udp_send_char(int fds[], int num, char payload)
 {
 	char buf[DATA_LEN], rbuf[DATA_LEN];
 
-	memset(buf, DATA_CHAR, sizeof(buf));
+	memset(buf, payload, sizeof(buf));
 	while (num--) {
 		/* Should really handle EINTR and EAGAIN */
 		if (write(fds[0], buf, sizeof(buf)) != sizeof(buf)) {
@@ -118,6 +126,11 @@
 	}
 }
 
+static __maybe_unused void pair_udp_send(int fds[], int num)
+{
+	return pair_udp_send_char(fds, num, DATA_CHAR);
+}
+
 static __maybe_unused void pair_udp_close(int fds[])
 {
 	close(fds[0]);