KVM: x86 emulator: Decode memory operands directly into a 'struct operand'

Since modrm operand can be either register or memory, decoding it into
a 'struct operand', which can represent both, is simpler.

Signed-off-by: Avi Kivity <avi@redhat.com>
diff --git a/arch/x86/include/asm/kvm_emulate.h b/arch/x86/include/asm/kvm_emulate.h
index e425444..1e4a72c 100644
--- a/arch/x86/include/asm/kvm_emulate.h
+++ b/arch/x86/include/asm/kvm_emulate.h
@@ -203,9 +203,6 @@
 	u8 modrm_rm;
 	u8 modrm_seg;
 	bool rip_relative;
-	unsigned long modrm_ea;
-	void *modrm_ptr;
-	unsigned long modrm_val;
 	struct fetch_cache fetch;
 	struct read_cache io_read;
 	struct read_cache mem_read;
diff --git a/arch/x86/kvm/emulate.c b/arch/x86/kvm/emulate.c
index eda6941..955d480 100644
--- a/arch/x86/kvm/emulate.c
+++ b/arch/x86/kvm/emulate.c
@@ -581,12 +581,14 @@
 }
 
 static int decode_modrm(struct x86_emulate_ctxt *ctxt,
-			struct x86_emulate_ops *ops)
+			struct x86_emulate_ops *ops,
+			struct operand *op)
 {
 	struct decode_cache *c = &ctxt->decode;
 	u8 sib;
 	int index_reg = 0, base_reg = 0, scale;
 	int rc = X86EMUL_CONTINUE;
+	ulong modrm_ea = 0;
 
 	if (c->rex_prefix) {
 		c->modrm_reg = (c->rex_prefix & 4) << 1;	/* REX.R */
@@ -598,16 +600,19 @@
 	c->modrm_mod |= (c->modrm & 0xc0) >> 6;
 	c->modrm_reg |= (c->modrm & 0x38) >> 3;
 	c->modrm_rm |= (c->modrm & 0x07);
-	c->modrm_ea = 0;
 	c->modrm_seg = VCPU_SREG_DS;
 
 	if (c->modrm_mod == 3) {
-		c->modrm_ptr = decode_register(c->modrm_rm,
+		op->type = OP_REG;
+		op->bytes = (c->d & ByteOp) ? 1 : c->op_bytes;
+		op->addr.reg = decode_register(c->modrm_rm,
 					       c->regs, c->d & ByteOp);
-		c->modrm_val = *(unsigned long *)c->modrm_ptr;
+		fetch_register_operand(op);
 		return rc;
 	}
 
+	op->type = OP_MEM;
+
 	if (c->ad_bytes == 2) {
 		unsigned bx = c->regs[VCPU_REGS_RBX];
 		unsigned bp = c->regs[VCPU_REGS_RBP];
@@ -618,46 +623,46 @@
 		switch (c->modrm_mod) {
 		case 0:
 			if (c->modrm_rm == 6)
-				c->modrm_ea += insn_fetch(u16, 2, c->eip);
+				modrm_ea += insn_fetch(u16, 2, c->eip);
 			break;
 		case 1:
-			c->modrm_ea += insn_fetch(s8, 1, c->eip);
+			modrm_ea += insn_fetch(s8, 1, c->eip);
 			break;
 		case 2:
-			c->modrm_ea += insn_fetch(u16, 2, c->eip);
+			modrm_ea += insn_fetch(u16, 2, c->eip);
 			break;
 		}
 		switch (c->modrm_rm) {
 		case 0:
-			c->modrm_ea += bx + si;
+			modrm_ea += bx + si;
 			break;
 		case 1:
-			c->modrm_ea += bx + di;
+			modrm_ea += bx + di;
 			break;
 		case 2:
-			c->modrm_ea += bp + si;
+			modrm_ea += bp + si;
 			break;
 		case 3:
-			c->modrm_ea += bp + di;
+			modrm_ea += bp + di;
 			break;
 		case 4:
-			c->modrm_ea += si;
+			modrm_ea += si;
 			break;
 		case 5:
-			c->modrm_ea += di;
+			modrm_ea += di;
 			break;
 		case 6:
 			if (c->modrm_mod != 0)
-				c->modrm_ea += bp;
+				modrm_ea += bp;
 			break;
 		case 7:
-			c->modrm_ea += bx;
+			modrm_ea += bx;
 			break;
 		}
 		if (c->modrm_rm == 2 || c->modrm_rm == 3 ||
 		    (c->modrm_rm == 6 && c->modrm_mod != 0))
 			c->modrm_seg = VCPU_SREG_SS;
-		c->modrm_ea = (u16)c->modrm_ea;
+		modrm_ea = (u16)modrm_ea;
 	} else {
 		/* 32/64-bit ModR/M decode. */
 		if ((c->modrm_rm & 7) == 4) {
@@ -667,48 +672,51 @@
 			scale = sib >> 6;
 
 			if ((base_reg & 7) == 5 && c->modrm_mod == 0)
-				c->modrm_ea += insn_fetch(s32, 4, c->eip);
+				modrm_ea += insn_fetch(s32, 4, c->eip);
 			else
-				c->modrm_ea += c->regs[base_reg];
+				modrm_ea += c->regs[base_reg];
 			if (index_reg != 4)
-				c->modrm_ea += c->regs[index_reg] << scale;
+				modrm_ea += c->regs[index_reg] << scale;
 		} else if ((c->modrm_rm & 7) == 5 && c->modrm_mod == 0) {
 			if (ctxt->mode == X86EMUL_MODE_PROT64)
 				c->rip_relative = 1;
 		} else
-			c->modrm_ea += c->regs[c->modrm_rm];
+			modrm_ea += c->regs[c->modrm_rm];
 		switch (c->modrm_mod) {
 		case 0:
 			if (c->modrm_rm == 5)
-				c->modrm_ea += insn_fetch(s32, 4, c->eip);
+				modrm_ea += insn_fetch(s32, 4, c->eip);
 			break;
 		case 1:
-			c->modrm_ea += insn_fetch(s8, 1, c->eip);
+			modrm_ea += insn_fetch(s8, 1, c->eip);
 			break;
 		case 2:
-			c->modrm_ea += insn_fetch(s32, 4, c->eip);
+			modrm_ea += insn_fetch(s32, 4, c->eip);
 			break;
 		}
 	}
+	op->addr.mem = modrm_ea;
 done:
 	return rc;
 }
 
 static int decode_abs(struct x86_emulate_ctxt *ctxt,
-		      struct x86_emulate_ops *ops)
+		      struct x86_emulate_ops *ops,
+		      struct operand *op)
 {
 	struct decode_cache *c = &ctxt->decode;
 	int rc = X86EMUL_CONTINUE;
 
+	op->type = OP_MEM;
 	switch (c->ad_bytes) {
 	case 2:
-		c->modrm_ea = insn_fetch(u16, 2, c->eip);
+		op->addr.mem = insn_fetch(u16, 2, c->eip);
 		break;
 	case 4:
-		c->modrm_ea = insn_fetch(u32, 4, c->eip);
+		op->addr.mem = insn_fetch(u32, 4, c->eip);
 		break;
 	case 8:
-		c->modrm_ea = insn_fetch(u64, 8, c->eip);
+		op->addr.mem = insn_fetch(u64, 8, c->eip);
 		break;
 	}
 done:
@@ -2280,6 +2288,7 @@
 	int mode = ctxt->mode;
 	int def_op_bytes, def_ad_bytes, dual, goffset;
 	struct opcode opcode, *g_mod012, *g_mod3;
+	struct operand memop = { .type = OP_NONE };
 
 	/* we cannot decode insn before we complete previous rep insn */
 	WARN_ON(ctxt->restart);
@@ -2418,25 +2427,25 @@
 
 	/* ModRM and SIB bytes. */
 	if (c->d & ModRM) {
-		rc = decode_modrm(ctxt, ops);
+		rc = decode_modrm(ctxt, ops, &memop);
 		if (!c->has_seg_override)
 			set_seg_override(c, c->modrm_seg);
 	} else if (c->d & MemAbs)
-		rc = decode_abs(ctxt, ops);
+		rc = decode_abs(ctxt, ops, &memop);
 	if (rc != X86EMUL_CONTINUE)
 		goto done;
 
 	if (!c->has_seg_override)
 		set_seg_override(c, VCPU_SREG_DS);
 
-	if (!(!c->twobyte && c->b == 0x8d))
-		c->modrm_ea += seg_override_base(ctxt, ops, c);
+	if (memop.type == OP_MEM && !(!c->twobyte && c->b == 0x8d))
+		memop.addr.mem += seg_override_base(ctxt, ops, c);
 
-	if (c->ad_bytes != 8)
-		c->modrm_ea = (u32)c->modrm_ea;
+	if (memop.type == OP_MEM && c->ad_bytes != 8)
+		memop.addr.mem = (u32)memop.addr.mem;
 
-	if (c->rip_relative)
-		c->modrm_ea += c->eip;
+	if (memop.type == OP_MEM && c->rip_relative)
+		memop.addr.mem += c->eip;
 
 	/*
 	 * Decode and fetch the source operand: register, memory
@@ -2449,31 +2458,16 @@
 		decode_register_operand(&c->src, c, 0);
 		break;
 	case SrcMem16:
-		c->src.bytes = 2;
+		memop.bytes = 2;
 		goto srcmem_common;
 	case SrcMem32:
-		c->src.bytes = 4;
+		memop.bytes = 4;
 		goto srcmem_common;
 	case SrcMem:
-		c->src.bytes = (c->d & ByteOp) ? 1 :
+		memop.bytes = (c->d & ByteOp) ? 1 :
 							   c->op_bytes;
-		/* Don't fetch the address for invlpg: it could be unmapped. */
-		if (c->d & NoAccess)
-			break;
 	srcmem_common:
-		/*
-		 * For instructions with a ModR/M byte, switch to register
-		 * access if Mod = 3.
-		 */
-		if ((c->d & ModRM) && c->modrm_mod == 3) {
-			c->src.type = OP_REG;
-			c->src.val = c->modrm_val;
-			c->src.addr.reg = c->modrm_ptr;
-			break;
-		}
-		c->src.type = OP_MEM;
-		c->src.addr.mem = c->modrm_ea;
-		c->src.val = 0;
+		c->src = memop;
 		break;
 	case SrcImm:
 	case SrcImmU:
@@ -2543,9 +2537,8 @@
 		insn_fetch_arr(c->src.valptr, c->src.bytes, c->eip);
 		break;
 	case SrcMemFAddr:
-		c->src.type = OP_MEM;
-		c->src.addr.mem = c->modrm_ea;
-		c->src.bytes = c->op_bytes + 2;
+		memop.bytes = c->op_bytes + 2;
+		goto srcmem_common;
 		break;
 	}
 
@@ -2583,26 +2576,18 @@
 		break;
 	case DstMem:
 	case DstMem64:
-		if ((c->d & ModRM) && c->modrm_mod == 3) {
-			c->dst.bytes = (c->d & ByteOp) ? 1 : c->op_bytes;
-			c->dst.type = OP_REG;
-			c->dst.val = c->dst.orig_val = c->modrm_val;
-			c->dst.addr.reg = c->modrm_ptr;
-			break;
-		}
-		c->dst.type = OP_MEM;
-		c->dst.addr.mem = c->modrm_ea;
+		c->dst = memop;
 		if ((c->d & DstMask) == DstMem64)
 			c->dst.bytes = 8;
 		else
 			c->dst.bytes = (c->d & ByteOp) ? 1 : c->op_bytes;
-		c->dst.val = 0;
-		if (c->d & BitOp) {
+		if (c->dst.type == OP_MEM && (c->d & BitOp)) {
 			unsigned long mask = ~(c->dst.bytes * 8 - 1);
 
 			c->dst.addr.mem = c->dst.addr.mem +
 						   (c->src.val & mask) / 8;
 		}
+		c->dst.orig_val = c->dst.val;
 		break;
 	case DstAcc:
 		c->dst.type = OP_REG;
@@ -2682,11 +2667,15 @@
 	}
 
 	if (c->src.type == OP_MEM) {
+		if (c->d & NoAccess)
+			goto no_fetch;
 		rc = read_emulated(ctxt, ops, c->src.addr.mem,
 					c->src.valptr, c->src.bytes);
 		if (rc != X86EMUL_CONTINUE)
 			goto done;
 		c->src.orig_val64 = c->src.val64;
+	no_fetch:
+		;
 	}
 
 	if (c->src2.type == OP_MEM) {