AMDGPU: Avoid overwriting saved PC

Summary:
An outstanding load with same destination sgpr as call could cause PC to be
updated with junk value on return.

Reviewers: arsenm, rampitec

Reviewed By: arsenm

Subscribers: kzhuravl, jvesely, wdng, nhaehnle, yaxunl, dstuttard, tpr, t-tye, hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D69474
diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
index dcb04e4..e84e948 100644
--- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
@@ -939,19 +939,33 @@
     }
 
     if (MI.isCall() && callWaitsOnFunctionEntry(MI)) {
-      // Don't bother waiting on anything except the call address. The function
-      // is going to insert a wait on everything in its prolog. This still needs
-      // to be careful if the call target is a load (e.g. a GOT load).
+      // The function is going to insert a wait on everything in its prolog.
+      // This still needs to be careful if the call target is a load (e.g. a GOT
+      // load). We also need to check WAW depenancy with saved PC.
       Wait = AMDGPU::Waitcnt();
 
       int CallAddrOpIdx =
           AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::src0);
-      RegInterval Interval = ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI,
-                                                          CallAddrOpIdx, false);
-      for (signed RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
+      RegInterval CallAddrOpInterval = ScoreBrackets.getRegInterval(
+          &MI, TII, MRI, TRI, CallAddrOpIdx, false);
+
+      for (signed RegNo = CallAddrOpInterval.first;
+           RegNo < CallAddrOpInterval.second; ++RegNo)
         ScoreBrackets.determineWait(
             LGKM_CNT, ScoreBrackets.getRegScore(RegNo, LGKM_CNT), Wait);
+
+      int RtnAddrOpIdx =
+            AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::dst);
+      if (RtnAddrOpIdx != -1) {
+        RegInterval RtnAddrOpInterval = ScoreBrackets.getRegInterval(
+            &MI, TII, MRI, TRI, RtnAddrOpIdx, false);
+
+        for (signed RegNo = RtnAddrOpInterval.first;
+             RegNo < RtnAddrOpInterval.second; ++RegNo)
+          ScoreBrackets.determineWait(
+              LGKM_CNT, ScoreBrackets.getRegScore(RegNo, LGKM_CNT), Wait);
       }
+
     } else {
       // FIXME: Should not be relying on memoperands.
       // Look at the source operands of every instruction to see if