KVM: MMU: Make cmpxchg_gpte aware of nesting too

This patch makes the cmpxchg_gpte() function aware of the
difference between l1-gfns and l2-gfns when nested
virtualization is in use.  This fixes a potential
data-corruption problem in the l1-guest and makes the code
work correct (at least as correct as the hardware which is
emulated in this code) again.

Cc: stable@kernel.org
Signed-off-by: Joerg Roedel <joerg.roedel@amd.com>
Signed-off-by: Avi Kivity <avi@redhat.com>
diff --git a/arch/x86/kvm/paging_tmpl.h b/arch/x86/kvm/paging_tmpl.h
index 74f8567..1b68990 100644
--- a/arch/x86/kvm/paging_tmpl.h
+++ b/arch/x86/kvm/paging_tmpl.h
@@ -78,15 +78,21 @@
 	return (gpte & PT_LVL_ADDR_MASK(lvl)) >> PAGE_SHIFT;
 }
 
-static bool FNAME(cmpxchg_gpte)(struct kvm *kvm,
+static int FNAME(cmpxchg_gpte)(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
 			 gfn_t table_gfn, unsigned index,
 			 pt_element_t orig_pte, pt_element_t new_pte)
 {
 	pt_element_t ret;
 	pt_element_t *table;
 	struct page *page;
+	gpa_t gpa;
 
-	page = gfn_to_page(kvm, table_gfn);
+	gpa = mmu->translate_gpa(vcpu, table_gfn << PAGE_SHIFT,
+				 PFERR_USER_MASK|PFERR_WRITE_MASK);
+	if (gpa == UNMAPPED_GVA)
+		return -EFAULT;
+
+	page = gfn_to_page(vcpu->kvm, gpa_to_gfn(gpa));
 
 	table = kmap_atomic(page, KM_USER0);
 	ret = CMPXCHG(&table[index], orig_pte, new_pte);
@@ -192,11 +198,17 @@
 #endif
 
 		if (!eperm && !rsvd_fault && !(pte & PT_ACCESSED_MASK)) {
+			int ret;
 			trace_kvm_mmu_set_accessed_bit(table_gfn, index,
 						       sizeof(pte));
-			if (FNAME(cmpxchg_gpte)(vcpu->kvm, table_gfn,
-			    index, pte, pte|PT_ACCESSED_MASK))
+			ret = FNAME(cmpxchg_gpte)(vcpu, mmu, table_gfn,
+					index, pte, pte|PT_ACCESSED_MASK);
+			if (ret < 0) {
+				present = false;
+				break;
+			} else if (ret)
 				goto walk;
+
 			mark_page_dirty(vcpu->kvm, table_gfn);
 			pte |= PT_ACCESSED_MASK;
 		}
@@ -245,13 +257,17 @@
 		goto error;
 
 	if (write_fault && !is_dirty_gpte(pte)) {
-		bool ret;
+		int ret;
 
 		trace_kvm_mmu_set_dirty_bit(table_gfn, index, sizeof(pte));
-		ret = FNAME(cmpxchg_gpte)(vcpu->kvm, table_gfn, index, pte,
+		ret = FNAME(cmpxchg_gpte)(vcpu, mmu, table_gfn, index, pte,
 			    pte|PT_DIRTY_MASK);
-		if (ret)
+		if (ret < 0) {
+			present = false;
+			goto error;
+		} else if (ret)
 			goto walk;
+
 		mark_page_dirty(vcpu->kvm, table_gfn);
 		pte |= PT_DIRTY_MASK;
 		walker->ptes[walker->level - 1] = pte;