Added support for ComplexPattern. These are patterns that require C++ pattern
matching code that is not currently auto-generated by tblgen, e.g. X86
addressing mode. Selection routines for complex patterns can return multiple operands, e.g. X86 addressing mode returns 4.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@24634 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/utils/TableGen/DAGISelEmitter.cpp b/utils/TableGen/DAGISelEmitter.cpp
index 2642ec1..cf5abc2 100644
--- a/utils/TableGen/DAGISelEmitter.cpp
+++ b/utils/TableGen/DAGISelEmitter.cpp
@@ -477,6 +477,9 @@
   } else if (R->isSubClassOf("ValueType") || R->isSubClassOf("CondCode")) {
     // Using a VTSDNode or CondCodeSDNode.
     return MVT::Other;
+  } else if (R->isSubClassOf("ComplexPattern")) {
+    const CodeGenTarget &T = TP.getDAGISelEmitter().getTargetInfo();
+    return T.getPointerType();
   } else if (R->getName() == "node") {
     // Placeholder.
     return MVT::isUnknown;
@@ -609,7 +612,7 @@
   for (unsigned i = 0, e = getNumChildren(); i != e; ++i)
     if (!getChild(i)->canPatternMatch(Reason, ISE))
       return false;
-  
+
   // If this node is a commutative operator, check that the LHS isn't an
   // immediate.
   const SDNodeInfo &NodeInfo = ISE.getSDNodeInfo(getOperator());
@@ -833,6 +836,13 @@
   }
 }
 
+void DAGISelEmitter::ParseComplexPatterns() {
+  std::vector<Record*> AMs = Records.getAllDerivedDefinitions("ComplexPattern");
+  while (!AMs.empty()) {
+    ComplexPatterns.insert(std::make_pair(AMs.back(), AMs.back()));
+    AMs.pop_back();
+  }
+}
 
 
 /// ParsePatternFragments - Parse all of the PatFrag definitions in the .td
@@ -1204,9 +1214,10 @@
       if (InVal->isLeaf() &&
           dynamic_cast<DefInit*>(InVal->getLeafValue())) {
         Record *InRec = static_cast<DefInit*>(InVal->getLeafValue())->getDef();
-        if (CGI.OperandList[i].Rec != InRec)
+        if (CGI.OperandList[i].Rec != InRec &&
+            !InRec->isSubClassOf("ComplexPattern"))
           I->error("Operand $" + OpName +
-                 "'s register class disagrees between the operand and pattern");
+                   "'s register class disagrees between the operand and pattern");
       }
       Operands.push_back(CGI.OperandList[i].Rec);
       
@@ -1586,22 +1597,59 @@
 }
 
 
+// NodeIsComplexPattern - return true if N is a leaf node and a subclass of
+// ComplexPattern.
+static bool NodeIsComplexPattern(TreePatternNode *N)
+{
+  return (N->isLeaf() &&
+          dynamic_cast<DefInit*>(N->getLeafValue()) &&
+          static_cast<DefInit*>(N->getLeafValue())->getDef()->
+          isSubClassOf("ComplexPattern"));
+}
+
+// NodeGetComplexPattern - return the pointer to the ComplexPattern if N
+// is a leaf node and a subclass of ComplexPattern, else it returns NULL.
+static const ComplexPattern *NodeGetComplexPattern(TreePatternNode *N,
+                                                   DAGISelEmitter &ISE)
+{
+  if (N->isLeaf() &&
+      dynamic_cast<DefInit*>(N->getLeafValue()) &&
+      static_cast<DefInit*>(N->getLeafValue())->getDef()->
+      isSubClassOf("ComplexPattern")) {
+    return &ISE.getComplexPattern(static_cast<DefInit*>(N->getLeafValue())
+                                  ->getDef());
+  }
+  return NULL;
+}
+
 /// getPatternSize - Return the 'size' of this pattern.  We want to match large
 /// patterns before small ones.  This is used to determine the size of a
 /// pattern.
-static unsigned getPatternSize(TreePatternNode *P) {
+static unsigned getPatternSize(TreePatternNode *P, DAGISelEmitter &ISE) {
   assert(isExtIntegerVT(P->getExtType()) || 
          isExtFloatingPointVT(P->getExtType()) ||
          P->getExtType() == MVT::isVoid && "Not a valid pattern node to size!");
   unsigned Size = 1;  // The node itself.
-  
+
+  // FIXME: This is a hack to statically increase the priority of patterns
+  // which maps a sub-dag to a complex pattern. e.g. favors LEA over ADD.
+  // Later we can allow complexity / cost for each pattern to be (optionally)
+  // specified. To get best possible pattern match we'll need to dynamically
+  // calculate the complexity of all patterns a dag can potentially map to.
+  const ComplexPattern *AM = NodeGetComplexPattern(P, ISE);
+  if (AM)
+    Size += AM->getNumOperands();
+    
   // Count children in the count if they are also nodes.
   for (unsigned i = 0, e = P->getNumChildren(); i != e; ++i) {
     TreePatternNode *Child = P->getChild(i);
     if (!Child->isLeaf() && Child->getExtType() != MVT::Other)
-      Size += getPatternSize(Child);
-    else if (Child->isLeaf() && dynamic_cast<IntInit*>(Child->getLeafValue())) {
-      ++Size;  // Matches a ConstantSDNode.
+      Size += getPatternSize(Child, ISE);
+    else if (Child->isLeaf()) {
+      if (dynamic_cast<IntInit*>(Child->getLeafValue())) 
+        ++Size;  // Matches a ConstantSDNode.
+      else if (NodeIsComplexPattern(Child))
+        Size += getPatternSize(Child, ISE);
     }
   }
   
@@ -1624,10 +1672,13 @@
 // In particular, we want to match maximal patterns first and lowest cost within
 // a particular complexity first.
 struct PatternSortingPredicate {
+  PatternSortingPredicate(DAGISelEmitter &ise) : ISE(ise) {};
+  DAGISelEmitter &ISE;
+
   bool operator()(DAGISelEmitter::PatternToMatch *LHS,
                   DAGISelEmitter::PatternToMatch *RHS) {
-    unsigned LHSSize = getPatternSize(LHS->first);
-    unsigned RHSSize = getPatternSize(RHS->first);
+    unsigned LHSSize = getPatternSize(LHS->first, ISE);
+    unsigned RHSSize = getPatternSize(RHS->first, ISE);
     if (LHSSize > RHSSize) return true;   // LHS -> bigger -> less cost
     if (LHSSize < RHSSize) return false;
     
@@ -1649,9 +1700,10 @@
          << ")->getSignExtended() != " << II->getValue() << ")\n"
          << "        goto P" << PatternNo << "Fail;\n";
       return;
+    } else if (!NodeIsComplexPattern(N)) {
+      assert(0 && "Cannot match this as a leaf value!");
+      abort();
     }
-    assert(0 && "Cannot match this as a leaf value!");
-    abort();
   }
   
   // If this node has a name associated with it, capture it in VarMap.  If
@@ -1710,6 +1762,8 @@
         if (LeafRec->isSubClassOf("RegisterClass") ||
             LeafRec->isSubClassOf("Register")) {
           // Handle register references.  Nothing to do here.
+        } else if (LeafRec->isSubClassOf("ComplexPattern")) {
+          // Handle complex pattern. Nothing to do here.
         } else if (LeafRec->isSubClassOf("ValueType")) {
           // Make sure this is the specified value type.
           OS << "      if (cast<VTSDNode>(" << RootName << OpNo << ")->getVT() != "
@@ -1819,9 +1873,10 @@
 
 /// CodeGenPatternResult - Emit the action for a pattern.  Now that it has
 /// matched, we actually have to build a DAG!
-unsigned DAGISelEmitter::
+std::pair<unsigned, unsigned> DAGISelEmitter::
 CodeGenPatternResult(TreePatternNode *N, unsigned &Ctr,
                      std::map<std::string,std::string> &VariableMap, 
+                     unsigned PatternNo,
                      std::ostream &OS, bool &HasChain, bool InFlag,
                      bool isRoot) {
   // This is something selected from the pattern we matched.
@@ -1832,10 +1887,12 @@
            "Variable referenced but not defined and not caught earlier!");
     if (Val[0] == 'T' && Val[1] == 'm' && Val[2] == 'p') {
       // Already selected this operand, just return the tmpval.
-      return atoi(Val.c_str()+3);
+      return std::make_pair(1, atoi(Val.c_str()+3));
     }
-    
+
+    const ComplexPattern *CP;
     unsigned ResNo = Ctr++;
+    unsigned NumRes = 1;
     if (!N->isLeaf() && N->getOperator()->getName() == "imm") {
       switch (N->getType()) {
       default: assert(0 && "Unknown type for constant node!");
@@ -1850,13 +1907,27 @@
          << ResNo << "C, MVT::" << getEnumName(N->getType()) << ");\n";
     } else if (!N->isLeaf() && N->getOperator()->getName() == "tglobaladdr") {
       OS << "      SDOperand Tmp" << ResNo << " = " << Val << ";\n";
+    } else if (N->isLeaf() && (CP = NodeGetComplexPattern(N, *this))) {
+      std::string Fn = CP->getSelectFunc();
+      NumRes = CP->getNumOperands();
+      OS << "      SDOperand ";
+      for (unsigned i = 0; i < NumRes; i++) {
+        if (i != 0) OS << ", ";
+        OS << "Tmp" << i + ResNo;
+      }
+      OS << ";\n";
+      OS << "      if (!" << Fn << "(" << Val;
+      for (unsigned i = 0; i < NumRes; i++)
+        OS << " , Tmp" << i + ResNo;
+      OS << ")) goto P" << PatternNo << "Fail;\n";
+      Ctr = ResNo + NumRes;
     } else {
       OS << "      SDOperand Tmp" << ResNo << " = Select(" << Val << ");\n";
     }
     // Add Tmp<ResNo> to VariableMap, so that we don't multiply select this
     // value if used multiple times by this pattern result.
     Val = "Tmp"+utostr(ResNo);
-    return ResNo;
+    return std::make_pair(NumRes, ResNo);
   }
   
   if (N->isLeaf()) {
@@ -1868,7 +1939,7 @@
            << getQualifiedName(DI->getDef()) << ", MVT::"
            << getEnumName(N->getType())
            << ");\n";
-        return ResNo;
+        return std::make_pair(1, ResNo);
       }
     } else if (IntInit *II = dynamic_cast<IntInit*>(N->getLeafValue())) {
       unsigned ResNo = Ctr++;
@@ -1876,21 +1947,26 @@
          << II->getValue() << ", MVT::"
         << getEnumName(N->getType())
         << ");\n";
-      return ResNo;
+      return std::make_pair(1, ResNo);
     }
     
     N->dump();
     assert(0 && "Unknown leaf type!");
-    return ~0U;
+    return std::make_pair(1, ~0U);
   }
 
   Record *Op = N->getOperator();
   if (Op->isSubClassOf("Instruction")) {
     // Emit all of the operands.
     std::vector<unsigned> Ops;
-    for (unsigned i = 0, e = N->getNumChildren(); i != e; ++i)
-      Ops.push_back(CodeGenPatternResult(N->getChild(i), Ctr,
-                                         VariableMap, OS, HasChain, InFlag));
+    for (unsigned i = 0, e = N->getNumChildren(); i != e; ++i) {
+      TreePatternNode *Child = N->getChild(i);
+      std::pair<unsigned, unsigned> NOPair = 
+        CodeGenPatternResult(Child, Ctr,
+                             VariableMap, PatternNo, OS, HasChain, InFlag);
+      for (unsigned j = 0; j < NOPair.first; j++)
+        Ops.push_back(NOPair.second + j);
+    }
 
     CodeGenInstruction &II = Target.getInstruction(Op->getName());
     bool HasCtrlDep = II.hasCtrlDep;
@@ -1934,12 +2010,12 @@
         OS << "      CodeGenMap[N.getValue(0)] = Result;\n";
         OS << "      CodeGenMap[N.getValue(" << NumResults
            << ")] = Result.getValue(" << NumResults << ");\n";
-        OS << "      Chain = CodeGenMap[N].getValue(" << NumResults << ");\n";
+        OS << "      Chain = Result.getValue(" << NumResults << ");\n";
       }
       if (NumResults == 0)
         OS << "      return Chain;\n";
       else
-        OS << "      return (N.ResNo) ? Chain : CodeGenMap[N];\n";
+        OS << "      return (N.ResNo) ? Chain : Result.getValue(0);\n";
     } else {
       // If this instruction is the root, and if there is only one use of it,
       // use SelectNodeTo instead of getTargetNode to avoid an allocation.
@@ -1963,11 +2039,12 @@
       OS << ");\n";
       OS << "      }\n";
     }
-    return ResNo;
+    return std::make_pair(1, ResNo);
   } else if (Op->isSubClassOf("SDNodeXForm")) {
     assert(N->getNumChildren() == 1 && "node xform should have one child!");
     unsigned OpVal = CodeGenPatternResult(N->getChild(0), Ctr,
-                                          VariableMap, OS, HasChain, InFlag);
+                                          VariableMap, PatternNo, OS, HasChain, InFlag)
+      .second;
     
     unsigned ResNo = Ctr++;
     OS << "      SDOperand Tmp" << ResNo << " = Transform_" << Op->getName()
@@ -1976,11 +2053,11 @@
       OS << "      CodeGenMap[N] = Tmp" << ResNo << ";\n";
       OS << "      return Tmp" << ResNo << ";\n";
     }
-    return ResNo;
+    return std::make_pair(1, ResNo);
   } else {
     N->dump();
     assert(0 && "Unknown node in result pattern!");
-    return ~0U;
+    return std::make_pair(1, ~0U);
   }
 }
 
@@ -2005,10 +2082,13 @@
   if (!Pat->hasTypeSet()) {
     // Move a type over from 'other' to 'pat'.
     Pat->setType(Other->getType());
-    OS << "      if (" << Prefix << ".getValueType() != MVT::"
+    OS << "      if (" << Prefix << ".Val->getValueType(0) != MVT::"
        << getName(Pat->getType()) << ") goto P" << PatternNo << "Fail;\n";
     return true;
   } else if (Pat->isLeaf()) {
+    if (NodeIsComplexPattern(Pat))
+      OS << "      if (" << Prefix << ".Val->getValueType(0) != MVT::"
+         << getName(Pat->getType()) << ") goto P" << PatternNo << "Fail;\n";
     return false;
   }
   
@@ -2038,7 +2118,7 @@
   OS << "\n      // Emits: ";
   Pattern.second->print(OS);
   OS << "\n";
-  OS << "      // Pattern complexity = " << getPatternSize(Pattern.first)
+  OS << "      // Pattern complexity = " << getPatternSize(Pattern.first, *this)
      << "  cost = " << getResultPatternCost(Pattern.second) << "\n";
 
   // Emit the matcher, capturing named arguments in VariableMap.
@@ -2088,7 +2168,7 @@
   
   unsigned TmpNo = 0;
   CodeGenPatternResult(Pattern.second,
-                       TmpNo, VariableMap, OS, HasChain, InFlag, true /*the root*/);
+                       TmpNo, VariableMap, PatternNo, OS, HasChain, InFlag, true /*the root*/);
   delete Pat;
   
   OS << "    }\n  P" << PatternNo << "Fail:\n";
@@ -2166,23 +2246,31 @@
   // Group the patterns by their top-level opcodes.
   std::map<Record*, std::vector<PatternToMatch*>,
            CompareByRecordName> PatternsByOpcode;
-  for (unsigned i = 0, e = PatternsToMatch.size(); i != e; ++i)
-    if (!PatternsToMatch[i].first->isLeaf()) {
-      PatternsByOpcode[PatternsToMatch[i].first->getOperator()]
-         .push_back(&PatternsToMatch[i]);
+  for (unsigned i = 0, e = PatternsToMatch.size(); i != e; ++i) {
+    TreePatternNode *Node = PatternsToMatch[i].first;
+    if (!Node->isLeaf()) {
+      PatternsByOpcode[Node->getOperator()].push_back(&PatternsToMatch[i]);
     } else {
+      const ComplexPattern *CP;
       if (IntInit *II = 
-             dynamic_cast<IntInit*>(PatternsToMatch[i].first->getLeafValue())) {
+             dynamic_cast<IntInit*>(Node->getLeafValue())) {
         PatternsByOpcode[getSDNodeNamed("imm")].push_back(&PatternsToMatch[i]);
+      } else if ((CP = NodeGetComplexPattern(Node, *this))) {
+        std::vector<Record*> OpNodes = CP->getMatchingNodes();
+        for (unsigned j = 0, e = OpNodes.size(); j != e; j++) {
+          PatternsByOpcode[OpNodes[j]].insert(PatternsByOpcode[OpNodes[j]].begin(),
+                                              &PatternsToMatch[i]);
+        }
       } else {
         std::cerr << "Unrecognized opcode '";
-        PatternsToMatch[i].first->dump();
+        Node->dump();
         std::cerr << "' on tree pattern '";
         std::cerr << PatternsToMatch[i].second->getOperator()->getName();
         std::cerr << "'!\n";
         exit(1);
       }
     }
+  }
   
   // Loop over all of the case statements.
   for (std::map<Record*, std::vector<PatternToMatch*>,
@@ -2197,7 +2285,7 @@
     // the matches in order of minimal cost.  Sort the patterns so the least
     // cost one is at the start.
     std::stable_sort(Patterns.begin(), Patterns.end(),
-                     PatternSortingPredicate());
+                     PatternSortingPredicate(*this));
     
     for (unsigned i = 0, e = Patterns.size(); i != e; ++i)
       EmitCodeForPattern(*Patterns[i], OS);
@@ -2227,6 +2315,7 @@
   
   ParseNodeInfo();
   ParseNodeTransforms(OS);
+  ParseComplexPatterns();
   ParsePatternFragments(OS);
   ParseInstructions();
   ParsePatterns();