Fix asserts in AMDGCN fmed3 folding by handling more cases of NaN

Better NaN handling for AMDGCN fmed3.

All operands are checked for NaN now. The checks
were moved before the canonicalization to provide
a better mapping from fclamp. Changed the behaviour
of fmed3(x,y,NaN) to return max(x,y) instead of
min(x,y) in light of this. Updated tests as a result
and added some new cases to cover the fix.

Patch by Alan Baker

llvm-svn: 336375
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 92533ff..78c2d31 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3424,6 +3424,24 @@
     Value *Src1 = II->getArgOperand(1);
     Value *Src2 = II->getArgOperand(2);
 
+    // Checking for NaN before canonicalization provides better fidelity when
+    // mapping other operations onto fmed3 since the order of operands is
+    // unchanged.
+    CallInst *NewCall = nullptr;
+    if (match(Src0, m_NaN()) || isa<UndefValue>(Src0)) {
+      NewCall = Builder.CreateMinNum(Src1, Src2);
+    } else if (match(Src1, m_NaN()) || isa<UndefValue>(Src1)) {
+      NewCall = Builder.CreateMinNum(Src0, Src2);
+    } else if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) {
+      NewCall = Builder.CreateMaxNum(Src0, Src1);
+    }
+
+    if (NewCall) {
+      NewCall->copyFastMathFlags(II);
+      NewCall->takeName(II);
+      return replaceInstUsesWith(*II, NewCall);
+    }
+
     bool Swap = false;
     // Canonicalize constants to RHS operands.
     //
@@ -3450,13 +3468,6 @@
       return II;
     }
 
-    if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) {
-      CallInst *NewCall = Builder.CreateMinNum(Src0, Src1);
-      NewCall->copyFastMathFlags(II);
-      NewCall->takeName(II);
-      return replaceInstUsesWith(*II, NewCall);
-    }
-
     if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) {
       if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) {
         if (const ConstantFP *C2 = dyn_cast<ConstantFP>(Src2)) {
diff --git a/llvm/test/Transforms/InstCombine/AMDGPU/amdgcn-intrinsics.ll b/llvm/test/Transforms/InstCombine/AMDGPU/amdgcn-intrinsics.ll
index 5ef3f5d..1fad1d8 100644
--- a/llvm/test/Transforms/InstCombine/AMDGPU/amdgcn-intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/AMDGPU/amdgcn-intrinsics.ll
@@ -1229,7 +1229,7 @@
 }
 
 ; CHECK-LABEL: @fmed3_x_y_undef_f32(
-; CHECK: call float @llvm.minnum.f32(float %x, float %y)
+; CHECK: call float @llvm.maxnum.f32(float %x, float %y)
 define float @fmed3_x_y_undef_f32(float %x, float %y) {
   %med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float undef)
   ret float %med3
@@ -1250,7 +1250,7 @@
 }
 
 ; CHECK-LABEL: @fmed3_x_y_qnan0_f32(
-; CHECK: call float @llvm.minnum.f32(float %x, float %y)
+; CHECK: call float @llvm.maxnum.f32(float %x, float %y)
 define float @fmed3_x_y_qnan0_f32(float %x, float %y) {
   %med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float 0x7FF8000000000000)
   ret float %med3
@@ -1265,7 +1265,7 @@
 
 ; This can return any of the qnans.
 ; CHECK-LABEL: @fmed3_qnan0_qnan1_qnan2_f32(
-; CHECK: ret float 0x7FF8002000000000
+; CHECK: ret float 0x7FF8030000000000
 define float @fmed3_qnan0_qnan1_qnan2_f32(float %x, float %y) {
   %med3 = call float @llvm.amdgcn.fmed3.f32(float 0x7FF8000100000000, float 0x7FF8002000000000, float 0x7FF8030000000000)
   ret float %med3
@@ -1334,6 +1334,48 @@
   ret float %med3
 }
 
+; CHECK-LABEL: @fmed3_nan_0_1_f32(
+; CHECK: ret float 0.0
+define float @fmed3_nan_0_1_f32() {
+  %med3 = call float @llvm.amdgcn.fmed3.f32(float 0x7FF8001000000000, float 0.0, float 1.0)
+  ret float %med3
+}
+
+; CHECK-LABEL: @fmed3_0_nan_1_f32(
+; CHECK: ret float 0.0
+define float @fmed3_0_nan_1_f32() {
+  %med = call float @llvm.amdgcn.fmed3.f32(float 0.0, float 0x7FF8001000000000, float 1.0)
+  ret float %med
+}
+
+; CHECK-LABEL: @fmed3_0_1_nan_f32(
+; CHECK: ret float 1.0
+define float @fmed3_0_1_nan_f32() {
+  %med = call float @llvm.amdgcn.fmed3.f32(float 0.0, float 1.0, float 0x7FF8001000000000)
+  ret float %med
+}
+
+; CHECK-LABEL: @fmed3_undef_0_1_f32(
+; CHECK: ret float 0.0
+define float @fmed3_undef_0_1_f32() {
+  %med3 = call float @llvm.amdgcn.fmed3.f32(float undef, float 0.0, float 1.0)
+  ret float %med3
+}
+
+; CHECK-LABEL: @fmed3_0_undef_1_f32(
+; CHECK: ret float 0.0
+define float @fmed3_0_undef_1_f32() {
+  %med = call float @llvm.amdgcn.fmed3.f32(float 0.0, float undef, float 1.0)
+  ret float %med
+}
+
+; CHECK-LABEL: @fmed3_0_1_undef_f32(
+; CHECK: ret float 1.0
+define float @fmed3_0_1_undef_f32() {
+  %med = call float @llvm.amdgcn.fmed3.f32(float 0.0, float 1.0, float undef)
+  ret float %med
+}
+
 ; --------------------------------------------------------------------
 ; llvm.amdgcn.icmp
 ; --------------------------------------------------------------------