net: filter: x86: split bpf_jit_compile()

Split bpf_jit_compile() into two functions to improve readability
of for(pass++) loop. The change follows similar style of JIT compilers
for arm, powerpc, s390

The body of new do_jit() was not reformatted to reduce noise
in this patch, since the following patch replaces most of it.

Tested with BPF testsuite.

Signed-off-by: Alexei Starovoitov <ast@plumgrid.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index dc01773..c5fa7c9 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -178,41 +178,26 @@
 	return header;
 }
 
-void bpf_jit_compile(struct sk_filter *fp)
+struct jit_context {
+	unsigned int cleanup_addr; /* epilogue code offset */
+	int pc_ret0; /* bpf index of first RET #0 instruction (if any) */
+	u8 seen;
+};
+
+static int do_jit(struct sk_filter *bpf_prog, int *addrs, u8 *image,
+		  int oldproglen, struct jit_context *ctx)
 {
+	const struct sock_filter *filter = bpf_prog->insns;
+	int flen = bpf_prog->len;
 	u8 temp[64];
 	u8 *prog;
-	unsigned int proglen, oldproglen = 0;
-	int ilen, i;
+	int ilen, i, proglen;
 	int t_offset, f_offset;
-	u8 t_op, f_op, seen = 0, pass;
-	u8 *image = NULL;
-	struct bpf_binary_header *header = NULL;
+	u8 t_op, f_op, seen = 0;
 	u8 *func;
-	int pc_ret0 = -1; /* bpf index of first RET #0 instruction (if any) */
-	unsigned int cleanup_addr; /* epilogue code offset */
-	unsigned int *addrs;
-	const struct sock_filter *filter = fp->insns;
-	int flen = fp->len;
+	unsigned int cleanup_addr = ctx->cleanup_addr;
+	u8 seen_or_pass0 = ctx->seen;
 
-	if (!bpf_jit_enable)
-		return;
-
-	addrs = kmalloc(flen * sizeof(*addrs), GFP_KERNEL);
-	if (addrs == NULL)
-		return;
-
-	/* Before first pass, make a rough estimation of addrs[]
-	 * each bpf instruction is translated to less than 64 bytes
-	 */
-	for (proglen = 0, i = 0; i < flen; i++) {
-		proglen += 64;
-		addrs[i] = proglen;
-	}
-	cleanup_addr = proglen; /* epilogue address */
-
-	for (pass = 0; pass < 10; pass++) {
-		u8 seen_or_pass0 = (pass == 0) ? (SEEN_XREG | SEEN_DATAREF | SEEN_MEM) : seen;
 		/* no prologue/epilogue for trivial filters (RET something) */
 		proglen = 0;
 		prog = temp;
@@ -325,12 +310,12 @@
 			case BPF_S_ALU_DIV_X: /* A /= X; */
 				seen |= SEEN_XREG;
 				EMIT2(0x85, 0xdb);	/* test %ebx,%ebx */
-				if (pc_ret0 > 0) {
+				if (ctx->pc_ret0 > 0) {
 					/* addrs[pc_ret0 - 1] is start address of target
 					 * (addrs[i] - 4) is the address following this jmp
 					 * ("xor %edx,%edx; div %ebx" being 4 bytes long)
 					 */
-					EMIT_COND_JMP(X86_JE, addrs[pc_ret0 - 1] -
+					EMIT_COND_JMP(X86_JE, addrs[ctx->pc_ret0 - 1] -
 								(addrs[i] - 4));
 				} else {
 					EMIT_COND_JMP(X86_JNE, 2 + 5);
@@ -342,12 +327,12 @@
 			case BPF_S_ALU_MOD_X: /* A %= X; */
 				seen |= SEEN_XREG;
 				EMIT2(0x85, 0xdb);	/* test %ebx,%ebx */
-				if (pc_ret0 > 0) {
+				if (ctx->pc_ret0 > 0) {
 					/* addrs[pc_ret0 - 1] is start address of target
 					 * (addrs[i] - 6) is the address following this jmp
 					 * ("xor %edx,%edx; div %ebx;mov %edx,%eax" being 6 bytes long)
 					 */
-					EMIT_COND_JMP(X86_JE, addrs[pc_ret0 - 1] -
+					EMIT_COND_JMP(X86_JE, addrs[ctx->pc_ret0 - 1] -
 								(addrs[i] - 6));
 				} else {
 					EMIT_COND_JMP(X86_JNE, 2 + 5);
@@ -441,8 +426,8 @@
 				break;
 			case BPF_S_RET_K:
 				if (!K) {
-					if (pc_ret0 == -1)
-						pc_ret0 = i;
+					if (ctx->pc_ret0 == -1)
+						ctx->pc_ret0 = i;
 					CLEAR_A();
 				} else {
 					EMIT1_off32(0xb8, K);	/* mov $imm32,%eax */
@@ -603,7 +588,7 @@
 				int off = pkt_type_offset();
 
 				if (off < 0)
-					goto out;
+					return -EINVAL;
 				if (is_imm8(off)) {
 					/* movzbl off8(%rdi),%eax */
 					EMIT4(0x0f, 0xb6, 0x47, off);
@@ -725,36 +710,79 @@
 				}
 				EMIT_COND_JMP(f_op, f_offset);
 				break;
-			default:
-				/* hmm, too complex filter, give up with jit compiler */
-				goto out;
-			}
-			ilen = prog - temp;
-			if (image) {
-				if (unlikely(proglen + ilen > oldproglen)) {
-					pr_err("bpb_jit_compile fatal error\n");
-					kfree(addrs);
-					module_free(NULL, header);
-					return;
-				}
-				memcpy(image + proglen, temp, ilen);
-			}
-			proglen += ilen;
-			addrs[i] = proglen;
-			prog = temp;
+		default:
+			/* hmm, too complex filter, give up with jit compiler */
+			return -EINVAL;
 		}
-		/* last bpf instruction is always a RET :
-		 * use it to give the cleanup instruction(s) addr
-		 */
-		cleanup_addr = proglen - 1; /* ret */
-		if (seen_or_pass0)
-			cleanup_addr -= 1; /* leaveq */
-		if (seen_or_pass0 & SEEN_XREG)
-			cleanup_addr -= 4; /* mov  -8(%rbp),%rbx */
+		ilen = prog - temp;
+		if (image) {
+			if (unlikely(proglen + ilen > oldproglen)) {
+				pr_err("bpb_jit_compile fatal error\n");
+				return -EFAULT;
+			}
+			memcpy(image + proglen, temp, ilen);
+		}
+		proglen += ilen;
+		addrs[i] = proglen;
+		prog = temp;
+	}
+	/* last bpf instruction is always a RET :
+	 * use it to give the cleanup instruction(s) addr
+	 */
+	ctx->cleanup_addr = proglen - 1; /* ret */
+	if (seen_or_pass0)
+		ctx->cleanup_addr -= 1; /* leaveq */
+	if (seen_or_pass0 & SEEN_XREG)
+		ctx->cleanup_addr -= 4; /* mov  -8(%rbp),%rbx */
 
+	ctx->seen = seen;
+
+	return proglen;
+}
+
+void bpf_jit_compile(struct sk_filter *prog)
+{
+	struct bpf_binary_header *header = NULL;
+	int proglen, oldproglen = 0;
+	struct jit_context ctx = {};
+	u8 *image = NULL;
+	int *addrs;
+	int pass;
+	int i;
+
+	if (!bpf_jit_enable)
+		return;
+
+	if (!prog || !prog->len)
+		return;
+
+	addrs = kmalloc(prog->len * sizeof(*addrs), GFP_KERNEL);
+	if (!addrs)
+		return;
+
+	/* Before first pass, make a rough estimation of addrs[]
+	 * each bpf instruction is translated to less than 64 bytes
+	 */
+	for (proglen = 0, i = 0; i < prog->len; i++) {
+		proglen += 64;
+		addrs[i] = proglen;
+	}
+	ctx.cleanup_addr = proglen;
+	ctx.seen = SEEN_XREG | SEEN_DATAREF | SEEN_MEM;
+	ctx.pc_ret0 = -1;
+
+	for (pass = 0; pass < 10; pass++) {
+		proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
+		if (proglen <= 0) {
+			image = NULL;
+			if (header)
+				module_free(NULL, header);
+			goto out;
+		}
 		if (image) {
 			if (proglen != oldproglen)
-				pr_err("bpb_jit_compile proglen=%u != oldproglen=%u\n", proglen, oldproglen);
+				pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
+				       proglen, oldproglen);
 			break;
 		}
 		if (proglen == oldproglen) {
@@ -766,17 +794,16 @@
 	}
 
 	if (bpf_jit_enable > 1)
-		bpf_jit_dump(flen, proglen, pass, image);
+		bpf_jit_dump(prog->len, proglen, 0, image);
 
 	if (image) {
 		bpf_flush_icache(header, image + proglen);
 		set_memory_ro((unsigned long)header, header->pages);
-		fp->bpf_func = (void *)image;
-		fp->jited = 1;
+		prog->bpf_func = (void *)image;
+		prog->jited = 1;
 	}
 out:
 	kfree(addrs);
-	return;
 }
 
 static void bpf_jit_free_deferred(struct work_struct *work)