x86: fix mmiotrace 8-bit register decoding

When SIL, DIL, BPL or SPL registers were used in MMIO, the datum
was extracted from AH, BH, CH, or DH, which are incorrect.

Signed-off-by: Pekka Paalanen <pq@iki.fi>
Cc: "Vegard Nossum" <vegard.nossum@gmail.com>
Cc: "Steven Rostedt" <srostedt@redhat.com>
Cc: proski@gnu.org
Cc: "Pekka Enberg"
	<penberg@cs.helsinki.fi>
Signed-off-by: Ingo Molnar <mingo@elte.hu>
diff --git a/arch/x86/mm/pf_in.c b/arch/x86/mm/pf_in.c
index efa1911..df3d5c8 100644
--- a/arch/x86/mm/pf_in.c
+++ b/arch/x86/mm/pf_in.c
@@ -79,25 +79,34 @@
 static unsigned int mw64[] = { 0x89, 0x8B };
 #endif /* not __i386__ */
 
-static int skip_prefix(unsigned char *addr, int *shorted, int *enlarged,
-								int *rexr)
+struct prefix_bits {
+	unsigned shorted:1;
+	unsigned enlarged:1;
+	unsigned rexr:1;
+	unsigned rex:1;
+};
+
+static int skip_prefix(unsigned char *addr, struct prefix_bits *prf)
 {
 	int i;
 	unsigned char *p = addr;
-	*shorted = 0;
-	*enlarged = 0;
-	*rexr = 0;
+	prf->shorted = 0;
+	prf->enlarged = 0;
+	prf->rexr = 0;
+	prf->rex = 0;
 
 restart:
 	for (i = 0; i < ARRAY_SIZE(prefix_codes); i++) {
 		if (*p == prefix_codes[i]) {
 			if (*p == 0x66)
-				*shorted = 1;
+				prf->shorted = 1;
 #ifdef __amd64__
 			if ((*p & 0xf8) == 0x48)
-				*enlarged = 1;
+				prf->enlarged = 1;
 			if ((*p & 0xf4) == 0x44)
-				*rexr = 1;
+				prf->rexr = 1;
+			if ((*p & 0xf0) == 0x40)
+				prf->rex = 1;
 #endif
 			p++;
 			goto restart;
@@ -135,12 +144,12 @@
 {
 	unsigned int opcode;
 	unsigned char *p;
-	int shorted, enlarged, rexr;
+	struct prefix_bits prf;
 	int i;
 	enum reason_type rv = OTHERS;
 
 	p = (unsigned char *)ins_addr;
-	p += skip_prefix(p, &shorted, &enlarged, &rexr);
+	p += skip_prefix(p, &prf);
 	p += get_opcode(p, &opcode);
 
 	CHECK_OP_TYPE(opcode, reg_rop, REG_READ);
@@ -156,10 +165,11 @@
 {
 	unsigned int opcode;
 	unsigned char *p;
-	int i, shorted, enlarged, rexr;
+	struct prefix_bits prf;
+	int i;
 
 	p = (unsigned char *)ins_addr;
-	p += skip_prefix(p, &shorted, &enlarged, &rexr);
+	p += skip_prefix(p, &prf);
 	p += get_opcode(p, &opcode);
 
 	for (i = 0; i < ARRAY_SIZE(rw8); i++)
@@ -168,7 +178,7 @@
 
 	for (i = 0; i < ARRAY_SIZE(rw32); i++)
 		if (rw32[i] == opcode)
-			return (shorted ? 2 : (enlarged ? 8 : 4));
+			return prf.shorted ? 2 : (prf.enlarged ? 8 : 4);
 
 	printk(KERN_ERR "mmiotrace: Unknown opcode 0x%02x\n", opcode);
 	return 0;
@@ -178,10 +188,11 @@
 {
 	unsigned int opcode;
 	unsigned char *p;
-	int i, shorted, enlarged, rexr;
+	struct prefix_bits prf;
+	int i;
 
 	p = (unsigned char *)ins_addr;
-	p += skip_prefix(p, &shorted, &enlarged, &rexr);
+	p += skip_prefix(p, &prf);
 	p += get_opcode(p, &opcode);
 
 	for (i = 0; i < ARRAY_SIZE(mw8); i++)
@@ -194,11 +205,11 @@
 
 	for (i = 0; i < ARRAY_SIZE(mw32); i++)
 		if (mw32[i] == opcode)
-			return shorted ? 2 : 4;
+			return prf.shorted ? 2 : 4;
 
 	for (i = 0; i < ARRAY_SIZE(mw64); i++)
 		if (mw64[i] == opcode)
-			return shorted ? 2 : (enlarged ? 8 : 4);
+			return prf.shorted ? 2 : (prf.enlarged ? 8 : 4);
 
 	printk(KERN_ERR "mmiotrace: Unknown opcode 0x%02x\n", opcode);
 	return 0;
@@ -238,7 +249,7 @@
 #endif
 };
 
-static unsigned char *get_reg_w8(int no, struct pt_regs *regs)
+static unsigned char *get_reg_w8(int no, int rex, struct pt_regs *regs)
 {
 	unsigned char *rv = NULL;
 
@@ -255,18 +266,6 @@
 	case arg_DL:
 		rv = (unsigned char *)&regs->dx;
 		break;
-	case arg_AH:
-		rv = 1 + (unsigned char *)&regs->ax;
-		break;
-	case arg_BH:
-		rv = 1 + (unsigned char *)&regs->bx;
-		break;
-	case arg_CH:
-		rv = 1 + (unsigned char *)&regs->cx;
-		break;
-	case arg_DH:
-		rv = 1 + (unsigned char *)&regs->dx;
-		break;
 #ifdef __amd64__
 	case arg_R8:
 		rv = (unsigned char *)&regs->r8;
@@ -294,9 +293,55 @@
 		break;
 #endif
 	default:
-		printk(KERN_ERR "mmiotrace: Error reg no# %d\n", no);
 		break;
 	}
+
+	if (rv)
+		return rv;
+
+	if (rex) {
+		/*
+		 * If REX prefix exists, access low bytes of SI etc.
+		 * instead of AH etc.
+		 */
+		switch (no) {
+		case arg_SI:
+			rv = (unsigned char *)&regs->si;
+			break;
+		case arg_DI:
+			rv = (unsigned char *)&regs->di;
+			break;
+		case arg_BP:
+			rv = (unsigned char *)&regs->bp;
+			break;
+		case arg_SP:
+			rv = (unsigned char *)&regs->sp;
+			break;
+		default:
+			break;
+		}
+	} else {
+		switch (no) {
+		case arg_AH:
+			rv = 1 + (unsigned char *)&regs->ax;
+			break;
+		case arg_BH:
+			rv = 1 + (unsigned char *)&regs->bx;
+			break;
+		case arg_CH:
+			rv = 1 + (unsigned char *)&regs->cx;
+			break;
+		case arg_DH:
+			rv = 1 + (unsigned char *)&regs->dx;
+			break;
+		default:
+			break;
+		}
+	}
+
+	if (!rv)
+		printk(KERN_ERR "mmiotrace: Error reg no# %d\n", no);
+
 	return rv;
 }
 
@@ -368,11 +413,12 @@
 	unsigned char mod_rm;
 	int reg;
 	unsigned char *p;
-	int i, shorted, enlarged, rexr;
+	struct prefix_bits prf;
+	int i;
 	unsigned long rv;
 
 	p = (unsigned char *)ins_addr;
-	p += skip_prefix(p, &shorted, &enlarged, &rexr);
+	p += skip_prefix(p, &prf);
 	p += get_opcode(p, &opcode);
 	for (i = 0; i < ARRAY_SIZE(reg_rop); i++)
 		if (reg_rop[i] == opcode) {
@@ -392,10 +438,10 @@
 
 do_work:
 	mod_rm = *p;
-	reg = ((mod_rm >> 3) & 0x7) | (rexr << 3);
+	reg = ((mod_rm >> 3) & 0x7) | (prf.rexr << 3);
 	switch (get_ins_reg_width(ins_addr)) {
 	case 1:
-		return *get_reg_w8(reg, regs);
+		return *get_reg_w8(reg, prf.rex, regs);
 
 	case 2:
 		return *(unsigned short *)get_reg_w32(reg, regs);
@@ -422,11 +468,12 @@
 	unsigned char mod_rm;
 	unsigned char mod;
 	unsigned char *p;
-	int i, shorted, enlarged, rexr;
+	struct prefix_bits prf;
+	int i;
 	unsigned long rv;
 
 	p = (unsigned char *)ins_addr;
-	p += skip_prefix(p, &shorted, &enlarged, &rexr);
+	p += skip_prefix(p, &prf);
 	p += get_opcode(p, &opcode);
 	for (i = 0; i < ARRAY_SIZE(imm_wop); i++)
 		if (imm_wop[i] == opcode) {