[X86] Allow X86ISD::Wrapper to be folded into the base of gather/scatter address

If the base of our gather corresponds to something contained in X86ISD::Wrapper we should be able to fold it into the address.

This patch refactors some of the address matching to more fully use the X86ISelAddressMode struct and the getAddressOperands helper. A new helper function matchVectorAddress is added to call matchWrapper or fall back to matchAddressBase.

We should also be able to support constant offsets from a wrapper, but I'll look into that in a future patch. We may even be able to completely reuse matchAddress here, but I wanted to start simple and work up to it.

Differential Revision: https://reviews.llvm.org/D39927

llvm-svn: 318057
diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index 9e1121a..71ae97d 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -194,6 +194,7 @@
     bool matchLoadInAddress(LoadSDNode *N, X86ISelAddressMode &AM);
     bool matchWrapper(SDValue N, X86ISelAddressMode &AM);
     bool matchAddress(SDValue N, X86ISelAddressMode &AM);
+    bool matchVectorAddress(SDValue N, X86ISelAddressMode &AM);
     bool matchAdd(SDValue N, X86ISelAddressMode &AM, unsigned Depth);
     bool matchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
                                  unsigned Depth);
@@ -1502,22 +1503,34 @@
   return false;
 }
 
+/// Helper for selectVectorAddr. Handles things that can be folded into a
+/// gather scatter address. The index register and scale should have already
+/// been handled.
+bool X86DAGToDAGISel::matchVectorAddress(SDValue N, X86ISelAddressMode &AM) {
+  // TODO: Support other operations.
+  switch (N.getOpcode()) {
+  case X86ISD::Wrapper:
+    if (!matchWrapper(N, AM))
+      return false;
+    break;
+  }
+
+  return matchAddressBase(N, AM);
+}
+
 bool X86DAGToDAGISel::selectVectorAddr(SDNode *Parent, SDValue N, SDValue &Base,
                                        SDValue &Scale, SDValue &Index,
                                        SDValue &Disp, SDValue &Segment) {
-  unsigned ScalarSize;
+  X86ISelAddressMode AM;
   if (auto Mgs = dyn_cast<MaskedGatherScatterSDNode>(Parent)) {
-    Base = Mgs->getBasePtr();
-    Index = Mgs->getIndex();
-    ScalarSize = Mgs->getValue().getScalarValueSizeInBits();
+    AM.IndexReg = Mgs->getIndex();
+    AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8;
   } else {
     auto X86Gather = cast<X86MaskedGatherSDNode>(Parent);
-    Base = X86Gather->getBasePtr();
-    Index = X86Gather->getIndex();
-    ScalarSize = X86Gather->getValue().getScalarValueSizeInBits();
+    AM.IndexReg = X86Gather->getIndex();
+    AM.Scale = X86Gather->getValue().getScalarValueSizeInBits() / 8;
   }
 
-  X86ISelAddressMode AM;
   unsigned AddrSpace = cast<MemSDNode>(Parent)->getPointerInfo().getAddrSpace();
   // AddrSpace 256 -> GS, 257 -> FS, 258 -> SS.
   if (AddrSpace == 256)
@@ -1527,21 +1540,23 @@
   if (AddrSpace == 258)
     AM.Segment = CurDAG->getRegister(X86::SS, MVT::i16);
 
-  SDLoc DL(N);
-  Scale = getI8Imm(ScalarSize/8, DL);
-
   // If Base is 0, the whole address is in index and the Scale is 1
-  if (isa<ConstantSDNode>(Base)) {
-    assert(cast<ConstantSDNode>(Base)->isNullValue() &&
+  if (isa<ConstantSDNode>(N)) {
+    assert(cast<ConstantSDNode>(N)->isNullValue() &&
            "Unexpected base in gather/scatter");
-    Scale = getI8Imm(1, DL);
-    Base = CurDAG->getRegister(0, MVT::i32);
+    AM.Scale = 1;
   }
-  if (AM.Segment.getNode())
-    Segment = AM.Segment;
-  else
-    Segment = CurDAG->getRegister(0, MVT::i32);
-  Disp = CurDAG->getTargetConstant(0, DL, MVT::i32);
+  // Otherwise, try to match into the base and displacement fields.
+  else if (matchVectorAddress(N, AM))
+    return false;
+
+  MVT VT = N.getSimpleValueType();
+  if (AM.BaseType == X86ISelAddressMode::RegBase) {
+    if (!AM.Base_Reg.getNode())
+      AM.Base_Reg = CurDAG->getRegister(0, VT);
+  }
+
+  getAddressOperands(AM, SDLoc(N), Base, Scale, Index, Disp, Segment);
   return true;
 }