[InstSimplify] add constant folding for fdiv/frem

Also, add a helper function so we don't have to repeat this code for each binop.

llvm-svn: 299309
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 6987faf..6294d19 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -528,17 +528,26 @@
   return CommonValue;
 }
 
+static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode,
+                                       Value *&Op0, Value *&Op1,
+                                       const Query &Q) {
+  if (auto *CLHS = dyn_cast<Constant>(Op0)) {
+    if (auto *CRHS = dyn_cast<Constant>(Op1))
+      return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL);
+
+    // Canonicalize the constant to the RHS if this is a commutative operation.
+    if (Instruction::isCommutative(Opcode))
+      std::swap(Op0, Op1);
+  }
+  return nullptr;
+}
+
 /// Given operands for an Add, see if we can fold the result.
 /// If not, this returns null.
 static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
                               const Query &Q, unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::Add, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q))
+    return C;
 
   // X + undef -> undef
   if (match(Op1, m_Undef()))
@@ -674,9 +683,8 @@
 /// If not, this returns null.
 static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
                               const Query &Q, unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0))
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::Sub, CLHS, CRHS, Q.DL);
+  if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q))
+    return C;
 
   // X - undef -> undef
   // undef - X -> undef
@@ -816,13 +824,8 @@
 /// returns null.
 static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
                               const Query &Q, unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::FAdd, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q))
+    return C;
 
   // fadd X, -0 ==> X
   if (match(Op1, m_NegZero()))
@@ -855,10 +858,8 @@
 /// returns null.
 static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
                               const Query &Q, unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::FSub, CLHS, CRHS, Q.DL);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q))
+    return C;
 
   // fsub X, 0 ==> X
   if (match(Op1, m_Zero()))
@@ -889,13 +890,8 @@
 /// Given the operands for an FMul, see if we can fold the result
 static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
                                const Query &Q, unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::FMul, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
+    return C;
 
   // fmul X, 1.0 ==> X
   if (match(Op1, m_FPOne()))
@@ -912,13 +908,8 @@
 /// If not, this returns null.
 static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q,
                               unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::Mul, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q))
+    return C;
 
   // X * undef -> 0
   if (match(Op1, m_Undef()))
@@ -1060,9 +1051,8 @@
 /// If not, this returns null.
 static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
                           const Query &Q, unsigned MaxRecurse) {
-  if (Constant *C0 = dyn_cast<Constant>(Op0))
-    if (Constant *C1 = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL);
+  if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
+    return C;
 
   if (Value *V = simplifyDivRem(Op0, Op1, true))
     return V;
@@ -1162,6 +1152,9 @@
 
 static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
                                const Query &Q, unsigned) {
+  if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q))
+    return C;
+
   // undef / X -> undef    (the undef could be a snan).
   if (match(Op0, m_Undef()))
     return Op0;
@@ -1211,9 +1204,8 @@
 /// If not, this returns null.
 static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
                           const Query &Q, unsigned MaxRecurse) {
-  if (Constant *C0 = dyn_cast<Constant>(Op0))
-    if (Constant *C1 = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL);
+  if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
+    return C;
 
   if (Value *V = simplifyDivRem(Op0, Op1, false))
     return V;
@@ -1287,7 +1279,10 @@
 }
 
 static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF,
-                               const Query &, unsigned) {
+                               const Query &Q, unsigned) {
+  if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q))
+    return C;
+
   // undef % X -> undef    (the undef could be a snan).
   if (match(Op0, m_Undef()))
     return Op0;
@@ -1343,11 +1338,10 @@
 
 /// Given operands for an Shl, LShr or AShr, see if we can fold the result.
 /// If not, this returns null.
-static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1,
-                            const Query &Q, unsigned MaxRecurse) {
-  if (Constant *C0 = dyn_cast<Constant>(Op0))
-    if (Constant *C1 = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL);
+static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
+                            Value *Op1, const Query &Q, unsigned MaxRecurse) {
+  if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
+    return C;
 
   // 0 shift by X -> 0
   if (match(Op0, m_Zero()))
@@ -1394,8 +1388,8 @@
 
 /// \brief Given operands for an Shl, LShr or AShr, see if we can
 /// fold the result.  If not, this returns null.
-static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1,
-                                 bool isExact, const Query &Q,
+static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
+                                 Value *Op1, bool isExact, const Query &Q,
                                  unsigned MaxRecurse) {
   if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse))
     return V;
@@ -1644,13 +1638,8 @@
 /// If not, this returns null.
 static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q,
                               unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::And, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::And, Op0, Op1, Q))
+    return C;
 
   // X & undef -> 0
   if (match(Op1, m_Undef()))
@@ -1846,13 +1835,8 @@
 /// If not, this returns null.
 static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q,
                              unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::Or, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::Or, Op0, Op1, Q))
+    return C;
 
   // X | undef -> -1
   if (match(Op1, m_Undef()))
@@ -1979,13 +1963,8 @@
 /// If not, this returns null.
 static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q,
                               unsigned MaxRecurse) {
-  if (Constant *CLHS = dyn_cast<Constant>(Op0)) {
-    if (Constant *CRHS = dyn_cast<Constant>(Op1))
-      return ConstantFoldBinaryOpOperands(Instruction::Xor, CLHS, CRHS, Q.DL);
-
-    // Canonicalize the constant to the RHS.
-    std::swap(Op0, Op1);
-  }
+  if (Constant *C = foldOrCommuteConstant(Instruction::Xor, Op0, Op1, Q))
+    return C;
 
   // A ^ undef -> undef
   if (match(Op1, m_Undef()))