rdar://12753946

Implement rule : "x * (select cond 1.0, 0.0) -> select cond x, 0.0"


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@170226 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index b95da85..964297a 100644
--- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -341,6 +341,38 @@
     }
   }
 
+  // X * cond ? 1.0 : 0.0 => cond ? X : 0.0
+  if (I.hasNoNaNs() && I.hasNoSignedZeros()) {
+    Value *V0 = I.getOperand(0);
+    Value *V1 = I.getOperand(1);
+    Value *Cond, *SLHS, *SRHS;
+    bool Match = false;
+
+    if (match(V0, m_Select(m_Value(Cond), m_Value(SLHS), m_Value(SRHS)))) {
+      Match = true;
+    } else if (match(V1, m_Select(m_Value(Cond), m_Value(SLHS), 
+                     m_Value(SRHS)))) {
+      Match = true;
+      std::swap(V0, V1);
+    }
+
+    if (Match) {
+      ConstantFP *C0 = dyn_cast<ConstantFP>(SLHS);
+      ConstantFP *C1 = dyn_cast<ConstantFP>(SRHS);
+
+      if (C0 && C1 &&
+          ((C0->isZero() && C1->isExactlyValue(1.0)) ||
+           (C1->isZero() && C0->isExactlyValue(1.0)))) {
+        Value *T;
+        if (C0->isZero())
+          T = Builder->CreateSelect(Cond, SLHS, V1);
+        else
+          T = Builder->CreateSelect(Cond, V1, SRHS);
+        return ReplaceInstUsesWith(I, T);
+      }
+    }
+  }
+
   return Changed ? &I : 0;
 }