Add a DAGCombine for (ext (binop (load x), cst)).


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@133124 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b6d926e..4ac590a 100644
--- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -138,6 +138,10 @@
     SDValue PromoteExtend(SDValue Op);
     bool PromoteLoad(SDValue Op);
 
+    void ExtendSetCCUses(SmallVector<SDNode*, 4> SetCCs,
+                         SDValue Trunc, SDValue ExtLoad, DebugLoc DL,
+                         ISD::NodeType ExtType);
+
     /// combine - call the node-specific routine that knows how to fold each
     /// particular type of node. If that doesn't do anything, try the
     /// target-specific DAG combines.
@@ -3699,6 +3703,28 @@
   return true;
 }
 
+void DAGCombiner::ExtendSetCCUses(SmallVector<SDNode*, 4> SetCCs,
+                                  SDValue Trunc, SDValue ExtLoad, DebugLoc DL,
+                                  ISD::NodeType ExtType) {
+  // Extend SetCC uses if necessary.
+  for (unsigned i = 0, e = SetCCs.size(); i != e; ++i) {
+    SDNode *SetCC = SetCCs[i];
+    SmallVector<SDValue, 4> Ops;
+
+    for (unsigned j = 0; j != 2; ++j) {
+      SDValue SOp = SetCC->getOperand(j);
+      if (SOp == Trunc)
+        Ops.push_back(ExtLoad);
+      else
+        Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
+    }
+
+    Ops.push_back(SetCC->getOperand(2));
+    CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0),
+                                 &Ops[0], Ops.size()));
+  }
+}
+
 SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   EVT VT = N->getValueType(0);
@@ -3787,27 +3813,8 @@
       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, N0.getDebugLoc(),
                                   N0.getValueType(), ExtLoad);
       CombineTo(N0.getNode(), Trunc, ExtLoad.getValue(1));
-
-      // Extend SetCC uses if necessary.
-      for (unsigned i = 0, e = SetCCs.size(); i != e; ++i) {
-        SDNode *SetCC = SetCCs[i];
-        SmallVector<SDValue, 4> Ops;
-
-        for (unsigned j = 0; j != 2; ++j) {
-          SDValue SOp = SetCC->getOperand(j);
-          if (SOp == Trunc)
-            Ops.push_back(ExtLoad);
-          else
-            Ops.push_back(DAG.getNode(ISD::SIGN_EXTEND,
-                                      N->getDebugLoc(), VT, SOp));
-        }
-
-        Ops.push_back(SetCC->getOperand(2));
-        CombineTo(SetCC, DAG.getNode(ISD::SETCC, N->getDebugLoc(),
-                                     SetCC->getValueType(0),
-                                     &Ops[0], Ops.size()));
-      }
-
+      ExtendSetCCUses(SetCCs, Trunc, ExtLoad, N->getDebugLoc(),
+                      ISD::SIGN_EXTEND);
       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
     }
   }
@@ -3835,6 +3842,45 @@
     }
   }
 
+  // fold (sext (and/or/xor (load x), cst)) ->
+  //      (and/or/xor (sextload x), (sext cst))
+  if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
+       N0.getOpcode() == ISD::XOR) &&
+      isa<LoadSDNode>(N0.getOperand(0)) &&
+      N0.getOperand(1).getOpcode() == ISD::Constant &&
+      TLI.isLoadExtLegal(ISD::SEXTLOAD, N0.getValueType()) &&
+      (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
+    LoadSDNode *LN0 = cast<LoadSDNode>(N0.getOperand(0));
+    if (LN0->getExtensionType() != ISD::ZEXTLOAD) {
+      bool DoXform = true;
+      SmallVector<SDNode*, 4> SetCCs;
+      if (!N0.hasOneUse())
+        DoXform = ExtendUsesToFormExtLoad(N, N0.getOperand(0), ISD::SIGN_EXTEND,
+                                          SetCCs, TLI);
+      if (DoXform) {
+        SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, LN0->getDebugLoc(), VT,
+                                         LN0->getChain(), LN0->getBasePtr(),
+                                         LN0->getPointerInfo(),
+                                         LN0->getMemoryVT(),
+                                         LN0->isVolatile(),
+                                         LN0->isNonTemporal(),
+                                         LN0->getAlignment());
+        APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
+        Mask = Mask.sext(VT.getSizeInBits());
+        SDValue And = DAG.getNode(N0.getOpcode(), N->getDebugLoc(), VT,
+                                  ExtLoad, DAG.getConstant(Mask, VT));
+        SDValue Trunc = DAG.getNode(ISD::TRUNCATE,
+                                    N0.getOperand(0).getDebugLoc(),
+                                    N0.getOperand(0).getValueType(), ExtLoad);
+        CombineTo(N, And);
+        CombineTo(N0.getOperand(0).getNode(), Trunc, ExtLoad.getValue(1));
+        ExtendSetCCUses(SetCCs, Trunc, ExtLoad, N->getDebugLoc(),
+                        ISD::SIGN_EXTEND);
+        return SDValue(N, 0);   // Return N so it doesn't get rechecked!
+      }
+    }
+  }
+
   if (N0.getOpcode() == ISD::SETCC) {
     // sext(setcc) -> sext_in_reg(vsetcc) for vectors.
     // Only do this before legalize for now.
@@ -3993,30 +4039,51 @@
                                   N0.getValueType(), ExtLoad);
       CombineTo(N0.getNode(), Trunc, ExtLoad.getValue(1));
 
-      // Extend SetCC uses if necessary.
-      for (unsigned i = 0, e = SetCCs.size(); i != e; ++i) {
-        SDNode *SetCC = SetCCs[i];
-        SmallVector<SDValue, 4> Ops;
-
-        for (unsigned j = 0; j != 2; ++j) {
-          SDValue SOp = SetCC->getOperand(j);
-          if (SOp == Trunc)
-            Ops.push_back(ExtLoad);
-          else
-            Ops.push_back(DAG.getNode(ISD::ZERO_EXTEND,
-                                      N->getDebugLoc(), VT, SOp));
-        }
-
-        Ops.push_back(SetCC->getOperand(2));
-        CombineTo(SetCC, DAG.getNode(ISD::SETCC, N->getDebugLoc(),
-                                     SetCC->getValueType(0),
-                                     &Ops[0], Ops.size()));
-      }
-
+      ExtendSetCCUses(SetCCs, Trunc, ExtLoad, N->getDebugLoc(),
+                      ISD::ZERO_EXTEND);
       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
     }
   }
 
+  // fold (zext (and/or/xor (load x), cst)) ->
+  //      (and/or/xor (zextload x), (zext cst))
+  if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
+       N0.getOpcode() == ISD::XOR) &&
+      isa<LoadSDNode>(N0.getOperand(0)) &&
+      N0.getOperand(1).getOpcode() == ISD::Constant &&
+      TLI.isLoadExtLegal(ISD::ZEXTLOAD, N0.getValueType()) &&
+      (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
+    LoadSDNode *LN0 = cast<LoadSDNode>(N0.getOperand(0));
+    if (LN0->getExtensionType() != ISD::SEXTLOAD) {
+      bool DoXform = true;
+      SmallVector<SDNode*, 4> SetCCs;
+      if (!N0.hasOneUse())
+        DoXform = ExtendUsesToFormExtLoad(N, N0.getOperand(0), ISD::ZERO_EXTEND,
+                                          SetCCs, TLI);
+      if (DoXform) {
+        SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, LN0->getDebugLoc(), VT,
+                                         LN0->getChain(), LN0->getBasePtr(),
+                                         LN0->getPointerInfo(),
+                                         LN0->getMemoryVT(),
+                                         LN0->isVolatile(),
+                                         LN0->isNonTemporal(),
+                                         LN0->getAlignment());
+        APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
+        Mask = Mask.zext(VT.getSizeInBits());
+        SDValue And = DAG.getNode(N0.getOpcode(), N->getDebugLoc(), VT,
+                                  ExtLoad, DAG.getConstant(Mask, VT));
+        SDValue Trunc = DAG.getNode(ISD::TRUNCATE,
+                                    N0.getOperand(0).getDebugLoc(),
+                                    N0.getOperand(0).getValueType(), ExtLoad);
+        CombineTo(N, And);
+        CombineTo(N0.getOperand(0).getNode(), Trunc, ExtLoad.getValue(1));
+        ExtendSetCCUses(SetCCs, Trunc, ExtLoad, N->getDebugLoc(),
+                        ISD::ZERO_EXTEND);
+        return SDValue(N, 0);   // Return N so it doesn't get rechecked!
+      }
+    }
+  }
+
   // fold (zext (zextload x)) -> (zext (truncate (zextload x)))
   // fold (zext ( extload x)) -> (zext (truncate (zextload x)))
   if ((ISD::isZEXTLoad(N0.getNode()) || ISD::isEXTLoad(N0.getNode())) &&
@@ -4201,27 +4268,8 @@
       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, N0.getDebugLoc(),
                                   N0.getValueType(), ExtLoad);
       CombineTo(N0.getNode(), Trunc, ExtLoad.getValue(1));
-
-      // Extend SetCC uses if necessary.
-      for (unsigned i = 0, e = SetCCs.size(); i != e; ++i) {
-        SDNode *SetCC = SetCCs[i];
-        SmallVector<SDValue, 4> Ops;
-
-        for (unsigned j = 0; j != 2; ++j) {
-          SDValue SOp = SetCC->getOperand(j);
-          if (SOp == Trunc)
-            Ops.push_back(ExtLoad);
-          else
-            Ops.push_back(DAG.getNode(ISD::ANY_EXTEND,
-                                      N->getDebugLoc(), VT, SOp));
-        }
-
-        Ops.push_back(SetCC->getOperand(2));
-        CombineTo(SetCC, DAG.getNode(ISD::SETCC, N->getDebugLoc(),
-                                     SetCC->getValueType(0),
-                                     &Ops[0], Ops.size()));
-      }
-
+      ExtendSetCCUses(SetCCs, Trunc, ExtLoad, N->getDebugLoc(),
+                      ISD::ANY_EXTEND);
       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
     }
   }