X86/KVM: Properly update 'tsc_offset' to represent the running guest

Update 'tsc_offset' on vmentry/vmexit of L2 guests to ensure that it always
captures the TSC_OFFSET of the running guest whether it is the L1 or L2
guest.

Cc: Paolo Bonzini <pbonzini@redhat.com>
Cc: Radim Krčmář <rkrcmar@redhat.com>
Cc: kvm@vger.kernel.org
Cc: linux-kernel@vger.kernel.org
Reviewed-by: Jim Mattson <jmattson@google.com>
Suggested-by: Paolo Bonzini <pbonzini@redhat.com>
Signed-off-by: KarimAllah Ahmed <karahmed@amazon.de>
[AMD changes, fix update_ia32_tsc_adjust_msr. - Paolo]
Signed-off-by: Paolo Bonzini <pbonzini@redhat.com>
diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index 949c977..c25775f 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -1013,6 +1013,7 @@ struct kvm_x86_ops {
 
 	bool (*has_wbinvd_exit)(void);
 
+	u64 (*read_l1_tsc_offset)(struct kvm_vcpu *vcpu);
 	void (*write_tsc_offset)(struct kvm_vcpu *vcpu, u64 offset);
 
 	void (*get_exit_info)(struct kvm_vcpu *vcpu, u64 *info1, u64 *info2);
diff --git a/arch/x86/kvm/svm.c b/arch/x86/kvm/svm.c
index b3ebc8a..e77a536 100644
--- a/arch/x86/kvm/svm.c
+++ b/arch/x86/kvm/svm.c
@@ -1424,12 +1424,23 @@ static void init_sys_seg(struct vmcb_seg *seg, uint32_t type)
 	seg->base = 0;
 }
 
+static u64 svm_read_l1_tsc_offset(struct kvm_vcpu *vcpu)
+{
+	struct vcpu_svm *svm = to_svm(vcpu);
+
+	if (is_guest_mode(vcpu))
+		return svm->nested.hsave->control.tsc_offset;
+
+	return vcpu->arch.tsc_offset;
+}
+
 static void svm_write_tsc_offset(struct kvm_vcpu *vcpu, u64 offset)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
 	u64 g_tsc_offset = 0;
 
 	if (is_guest_mode(vcpu)) {
+		/* Write L1's TSC offset.  */
 		g_tsc_offset = svm->vmcb->control.tsc_offset -
 			       svm->nested.hsave->control.tsc_offset;
 		svm->nested.hsave->control.tsc_offset = offset;
@@ -3323,6 +3334,7 @@ static int nested_svm_vmexit(struct vcpu_svm *svm)
 	/* Restore the original control entries */
 	copy_vmcb_control_area(vmcb, hsave);
 
+	svm->vcpu.arch.tsc_offset = svm->vmcb->control.tsc_offset;
 	kvm_clear_exception_queue(&svm->vcpu);
 	kvm_clear_interrupt_queue(&svm->vcpu);
 
@@ -3483,10 +3495,12 @@ static void enter_svm_guest_mode(struct vcpu_svm *svm, u64 vmcb_gpa,
 	/* We don't want to see VMMCALLs from a nested guest */
 	clr_intercept(svm, INTERCEPT_VMMCALL);
 
+	svm->vcpu.arch.tsc_offset += nested_vmcb->control.tsc_offset;
+	svm->vmcb->control.tsc_offset = svm->vcpu.arch.tsc_offset;
+
 	svm->vmcb->control.virt_ext = nested_vmcb->control.virt_ext;
 	svm->vmcb->control.int_vector = nested_vmcb->control.int_vector;
 	svm->vmcb->control.int_state = nested_vmcb->control.int_state;
-	svm->vmcb->control.tsc_offset += nested_vmcb->control.tsc_offset;
 	svm->vmcb->control.event_inj = nested_vmcb->control.event_inj;
 	svm->vmcb->control.event_inj_err = nested_vmcb->control.event_inj_err;
 
@@ -7102,6 +7116,7 @@ static struct kvm_x86_ops svm_x86_ops __ro_after_init = {
 
 	.has_wbinvd_exit = svm_has_wbinvd_exit,
 
+	.read_l1_tsc_offset = svm_read_l1_tsc_offset,
 	.write_tsc_offset = svm_write_tsc_offset,
 
 	.set_tdp_cr3 = set_tdp_cr3,
diff --git a/arch/x86/kvm/vmx.c b/arch/x86/kvm/vmx.c
index a13c603..7207e6c 100644
--- a/arch/x86/kvm/vmx.c
+++ b/arch/x86/kvm/vmx.c
@@ -2874,6 +2874,17 @@ static void setup_msrs(struct vcpu_vmx *vmx)
 		vmx_update_msr_bitmap(&vmx->vcpu);
 }
 
+static u64 vmx_read_l1_tsc_offset(struct kvm_vcpu *vcpu)
+{
+	struct vmcs12 *vmcs12 = get_vmcs12(vcpu);
+
+	if (is_guest_mode(vcpu) &&
+	    (vmcs12->cpu_based_vm_exec_control & CPU_BASED_USE_TSC_OFFSETING))
+		return vcpu->arch.tsc_offset - vmcs12->tsc_offset;
+
+	return vcpu->arch.tsc_offset;
+}
+
 /*
  * reads and returns guest's timestamp counter "register"
  * guest_tsc = (host_tsc * tsc multiplier) >> 48 + tsc_offset
@@ -11175,11 +11186,8 @@ static int prepare_vmcs02(struct kvm_vcpu *vcpu, struct vmcs12 *vmcs12,
 		vmcs_write64(GUEST_IA32_PAT, vmx->vcpu.arch.pat);
 	}
 
-	if (vmcs12->cpu_based_vm_exec_control & CPU_BASED_USE_TSC_OFFSETING)
-		vmcs_write64(TSC_OFFSET,
-			vcpu->arch.tsc_offset + vmcs12->tsc_offset);
-	else
-		vmcs_write64(TSC_OFFSET, vcpu->arch.tsc_offset);
+	vmcs_write64(TSC_OFFSET, vcpu->arch.tsc_offset);
+
 	if (kvm_has_tsc_control)
 		decache_tsc_multiplier(vmx);
 
@@ -11427,6 +11435,7 @@ static int enter_vmx_non_root_mode(struct kvm_vcpu *vcpu, bool from_vmentry)
 	struct vmcs12 *vmcs12 = get_vmcs12(vcpu);
 	u32 msr_entry_idx;
 	u32 exit_qual;
+	int r;
 
 	enter_guest_mode(vcpu);
 
@@ -11436,26 +11445,21 @@ static int enter_vmx_non_root_mode(struct kvm_vcpu *vcpu, bool from_vmentry)
 	vmx_switch_vmcs(vcpu, &vmx->nested.vmcs02);
 	vmx_segment_cache_clear(vmx);
 
-	if (prepare_vmcs02(vcpu, vmcs12, from_vmentry, &exit_qual)) {
-		leave_guest_mode(vcpu);
-		vmx_switch_vmcs(vcpu, &vmx->vmcs01);
-		nested_vmx_entry_failure(vcpu, vmcs12,
-					 EXIT_REASON_INVALID_STATE, exit_qual);
-		return 1;
-	}
+	if (vmcs12->cpu_based_vm_exec_control & CPU_BASED_USE_TSC_OFFSETING)
+		vcpu->arch.tsc_offset += vmcs12->tsc_offset;
+
+	r = EXIT_REASON_INVALID_STATE;
+	if (prepare_vmcs02(vcpu, vmcs12, from_vmentry, &exit_qual))
+		goto fail;
 
 	nested_get_vmcs12_pages(vcpu, vmcs12);
 
+	r = EXIT_REASON_MSR_LOAD_FAIL;
 	msr_entry_idx = nested_vmx_load_msr(vcpu,
 					    vmcs12->vm_entry_msr_load_addr,
 					    vmcs12->vm_entry_msr_load_count);
-	if (msr_entry_idx) {
-		leave_guest_mode(vcpu);
-		vmx_switch_vmcs(vcpu, &vmx->vmcs01);
-		nested_vmx_entry_failure(vcpu, vmcs12,
-				EXIT_REASON_MSR_LOAD_FAIL, msr_entry_idx);
-		return 1;
-	}
+	if (msr_entry_idx)
+		goto fail;
 
 	/*
 	 * Note no nested_vmx_succeed or nested_vmx_fail here. At this point
@@ -11464,6 +11468,14 @@ static int enter_vmx_non_root_mode(struct kvm_vcpu *vcpu, bool from_vmentry)
 	 * the success flag) when L2 exits (see nested_vmx_vmexit()).
 	 */
 	return 0;
+
+fail:
+	if (vmcs12->cpu_based_vm_exec_control & CPU_BASED_USE_TSC_OFFSETING)
+		vcpu->arch.tsc_offset -= vmcs12->tsc_offset;
+	leave_guest_mode(vcpu);
+	vmx_switch_vmcs(vcpu, &vmx->vmcs01);
+	nested_vmx_entry_failure(vcpu, vmcs12, r, exit_qual);
+	return 1;
 }
 
 /*
@@ -12035,6 +12047,9 @@ static void nested_vmx_vmexit(struct kvm_vcpu *vcpu, u32 exit_reason,
 
 	leave_guest_mode(vcpu);
 
+	if (vmcs12->cpu_based_vm_exec_control & CPU_BASED_USE_TSC_OFFSETING)
+		vcpu->arch.tsc_offset -= vmcs12->tsc_offset;
+
 	if (likely(!vmx->fail)) {
 		if (exit_reason == -1)
 			sync_vmcs12(vcpu, vmcs12);
@@ -12725,6 +12740,7 @@ static struct kvm_x86_ops vmx_x86_ops __ro_after_init = {
 
 	.has_wbinvd_exit = cpu_has_vmx_wbinvd_exit,
 
+	.read_l1_tsc_offset = vmx_read_l1_tsc_offset,
 	.write_tsc_offset = vmx_write_tsc_offset,
 
 	.set_tdp_cr3 = vmx_set_cr3,
diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index 0334b25..3f3fba5 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -1490,7 +1490,7 @@ static void kvm_track_tsc_matching(struct kvm_vcpu *vcpu)
 
 static void update_ia32_tsc_adjust_msr(struct kvm_vcpu *vcpu, s64 offset)
 {
-	u64 curr_offset = vcpu->arch.tsc_offset;
+	u64 curr_offset = kvm_x86_ops->read_l1_tsc_offset(vcpu);
 	vcpu->arch.ia32_tsc_adjust_msr += offset - curr_offset;
 }
 
@@ -1532,7 +1532,9 @@ static u64 kvm_compute_tsc_offset(struct kvm_vcpu *vcpu, u64 target_tsc)
 
 u64 kvm_read_l1_tsc(struct kvm_vcpu *vcpu, u64 host_tsc)
 {
-	return vcpu->arch.tsc_offset + kvm_scale_tsc(vcpu, host_tsc);
+	u64 tsc_offset = kvm_x86_ops->read_l1_tsc_offset(vcpu);
+
+	return tsc_offset + kvm_scale_tsc(vcpu, host_tsc);
 }
 EXPORT_SYMBOL_GPL(kvm_read_l1_tsc);