Added support to specify predicates.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@24715 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/utils/TableGen/DAGISelEmitter.cpp b/utils/TableGen/DAGISelEmitter.cpp
index eca9ad1..07d35d4 100644
--- a/utils/TableGen/DAGISelEmitter.cpp
+++ b/utils/TableGen/DAGISelEmitter.cpp
@@ -1300,11 +1300,13 @@
     if (!SrcPattern->canPatternMatch(Reason, *this))
       I->error("Instruction can never match: " + Reason);
     
+    Record *Instr = II->first;
     TreePatternNode *DstPattern = TheInst.getResultPattern();
-    PatternsToMatch.push_back(std::make_pair(SrcPattern, DstPattern));
+    PatternsToMatch.
+      push_back(PatternToMatch(Instr->getValueAsListInit("Predicates"),
+                               SrcPattern, DstPattern));
 
     if (PatternHasCtrlDep(Pattern, *this)) {
-      Record *Instr = II->first;
       CodeGenInstruction &InstInfo = Target.getInstruction(Instr->getName());
       InstInfo.hasCtrlDep = true;
     }
@@ -1356,8 +1358,10 @@
     if (!Pattern->getOnlyTree()->canPatternMatch(Reason, *this))
       Pattern->error("Pattern can never match: " + Reason);
     
-    PatternsToMatch.push_back(std::make_pair(Pattern->getOnlyTree(),
-                                             Result->getOnlyTree()));
+    PatternsToMatch.
+      push_back(PatternToMatch(Patterns[i]->getValueAsListInit("Predicates"),
+                               Pattern->getOnlyTree(),
+                               Result->getOnlyTree()));
   }
 }
 
@@ -1565,7 +1569,7 @@
   //
   for (unsigned i = 0, e = PatternsToMatch.size(); i != e; ++i) {
     std::vector<TreePatternNode*> Variants;
-    GenerateVariantsOf(PatternsToMatch[i].first, Variants, *this);
+    GenerateVariantsOf(PatternsToMatch[i].getSrcPattern(), Variants, *this);
 
     assert(!Variants.empty() && "Must create at least original variant!");
     Variants.erase(Variants.begin());  // Remove the original pattern.
@@ -1574,7 +1578,7 @@
       continue;
 
     DEBUG(std::cerr << "FOUND VARIANTS OF: ";
-          PatternsToMatch[i].first->dump();
+          PatternsToMatch[i].getSrcPattern()->dump();
           std::cerr << "\n");
 
     for (unsigned v = 0, e = Variants.size(); v != e; ++v) {
@@ -1588,7 +1592,7 @@
       bool AlreadyExists = false;
       for (unsigned p = 0, e = PatternsToMatch.size(); p != e; ++p) {
         // Check to see if this variant already exists.
-        if (Variant->isIsomorphicTo(PatternsToMatch[p].first)) {
+        if (Variant->isIsomorphicTo(PatternsToMatch[p].getSrcPattern())) {
           DEBUG(std::cerr << "  *** ALREADY EXISTS, ignoring variant.\n");
           AlreadyExists = true;
           break;
@@ -1598,8 +1602,9 @@
       if (AlreadyExists) continue;
 
       // Otherwise, add it to the list of patterns we have.
-      PatternsToMatch.push_back(std::make_pair(Variant, 
-                                               PatternsToMatch[i].second));
+      PatternsToMatch.
+        push_back(PatternToMatch(PatternsToMatch[i].getPredicates(),
+                                 Variant, PatternsToMatch[i].getDstPattern()));
     }
 
     DEBUG(std::cerr << "\n");
@@ -1685,15 +1690,16 @@
   PatternSortingPredicate(DAGISelEmitter &ise) : ISE(ise) {};
   DAGISelEmitter &ISE;
 
-  bool operator()(DAGISelEmitter::PatternToMatch *LHS,
-                  DAGISelEmitter::PatternToMatch *RHS) {
-    unsigned LHSSize = getPatternSize(LHS->first, ISE);
-    unsigned RHSSize = getPatternSize(RHS->first, ISE);
+  bool operator()(PatternToMatch *LHS,
+                  PatternToMatch *RHS) {
+    unsigned LHSSize = getPatternSize(LHS->getSrcPattern(), ISE);
+    unsigned RHSSize = getPatternSize(RHS->getSrcPattern(), ISE);
     if (LHSSize > RHSSize) return true;   // LHS -> bigger -> less cost
     if (LHSSize < RHSSize) return false;
     
     // If the patterns have equal complexity, compare generated instruction cost
-    return getResultPatternCost(LHS->second) <getResultPatternCost(RHS->second);
+    return getResultPatternCost(LHS->getDstPattern()) <
+      getResultPatternCost(RHS->getDstPattern());
   }
 };
 
@@ -1725,8 +1731,12 @@
 private:
   DAGISelEmitter &ISE;
 
-  // LHS of the pattern being matched
-  TreePatternNode *LHS;
+  // Predicates.
+  ListInit *Predicates;
+  // Instruction selector pattern.
+  TreePatternNode *Pattern;
+  // Matched instruction.
+  TreePatternNode *Instruction;
   unsigned PatternNo;
   std::ostream &OS;
   // Node to name mapping
@@ -1738,16 +1748,39 @@
   unsigned TmpNo;
 
 public:
-  PatternCodeEmitter(DAGISelEmitter &ise, TreePatternNode *lhs,
+  PatternCodeEmitter(DAGISelEmitter &ise, ListInit *preds,
+                     TreePatternNode *pattern, TreePatternNode *instr,
                      unsigned PatNum, std::ostream &os) :
-    ISE(ise), LHS(lhs), PatternNo(PatNum), OS(os),
-    FoundChain(false), InFlag(false), TmpNo(0) {};
+    ISE(ise), Predicates(preds), Pattern(pattern), Instruction(instr),
+    PatternNo(PatNum), OS(os), FoundChain(false), InFlag(false), TmpNo(0) {};
 
   /// EmitMatchCode - Emit a matcher for N, going to the label for PatternNo
   /// if the match fails. At this point, we already know that the opcode for N
   /// matches, and the SDNode for the result has the RootName specified name.
   void EmitMatchCode(TreePatternNode *N, const std::string &RootName,
                      bool isRoot = false) {
+
+    // Emit instruction predicates. Each predicate is just a string for now.
+    if (isRoot) {
+      for (unsigned i = 0, e = Predicates->getSize(); i != e; ++i) {
+        if (DefInit *Pred = dynamic_cast<DefInit*>(Predicates->getElement(i))) {
+          Record *Def = Pred->getDef();
+          if (Def->isSubClassOf("Predicate")) {
+            if (i == 0)
+              OS << "      if (";
+            else
+              OS << " && ";
+            OS << "(" << Def->getValueAsString("CondString") << ")";
+            if (i == e-1)
+              OS << ") goto P" << PatternNo << "Fail;\n";
+          } else {
+            Def->dump();
+            assert(0 && "Unknown predicate type!");
+          }
+        }
+      }
+    }
+
     if (N->isLeaf()) {
       if (IntInit *II = dynamic_cast<IntInit*>(N->getLeafValue())) {
         OS << "      if (cast<ConstantSDNode>(" << RootName
@@ -1998,7 +2031,7 @@
       // Emit all the chain and CopyToReg stuff.
       if (II.hasCtrlDep)
         OS << "      Chain = Select(Chain);\n";
-      EmitCopyToRegs(LHS, "N", II.hasCtrlDep);
+      EmitCopyToRegs(Pattern, "N", II.hasCtrlDep);
 
       const DAGInstruction &Inst = ISE.getInstruction(Op);
       unsigned NumResults = Inst.getNumResults();    
@@ -2035,7 +2068,7 @@
           OS << "      CodeGenMap[N.getValue(0)] = Result;\n";
         }
         OS << "      Chain ";
-        if (NodeHasChain(LHS, ISE))
+        if (NodeHasChain(Pattern, ISE))
           OS << "= CodeGenMap[N.getValue(" << NumResults << ")] ";
         for (unsigned j = 0, e = FoldedChains.size(); j < e; j++)
           OS << "= CodeGenMap[" << FoldedChains[j] << ".getValue("
@@ -2071,9 +2104,7 @@
       return std::make_pair(1, ResNo);
     } else if (Op->isSubClassOf("SDNodeXForm")) {
       assert(N->getNumChildren() == 1 && "node xform should have one child!");
-      unsigned OpVal = EmitResultCode(N->getChild(0))
-        .second;
-    
+      unsigned OpVal = EmitResultCode(N->getChild(0)).second;
       unsigned ResNo = TmpNo++;
       OS << "      SDOperand Tmp" << ResNo << " = Transform_" << Op->getName()
          << "(Tmp" << OpVal << ".Val);\n";
@@ -2161,17 +2192,21 @@
   static unsigned PatternCount = 0;
   unsigned PatternNo = PatternCount++;
   OS << "    { // Pattern #" << PatternNo << ": ";
-  Pattern.first->print(OS);
+  Pattern.getSrcPattern()->print(OS);
   OS << "\n      // Emits: ";
-  Pattern.second->print(OS);
+  Pattern.getDstPattern()->print(OS);
   OS << "\n";
-  OS << "      // Pattern complexity = " << getPatternSize(Pattern.first, *this)
-     << "  cost = " << getResultPatternCost(Pattern.second) << "\n";
+  OS << "      // Pattern complexity = "
+     << getPatternSize(Pattern.getSrcPattern(), *this)
+     << "  cost = "
+     << getResultPatternCost(Pattern.getDstPattern()) << "\n";
 
-  PatternCodeEmitter Emitter(*this, Pattern.first, PatternNo, OS);
+  PatternCodeEmitter Emitter(*this, Pattern.getPredicates(),
+                             Pattern.getSrcPattern(), Pattern.getDstPattern(),
+                             PatternNo, OS);
 
   // Emit the matcher, capturing named arguments in VariableMap.
-  Emitter.EmitMatchCode(Pattern.first, "N", true /*the root*/);
+  Emitter.EmitMatchCode(Pattern.getSrcPattern(), "N", true /*the root*/);
 
   // TP - Get *SOME* tree pattern, we don't care which.
   TreePattern &TP = *PatternFragments.begin()->second;
@@ -2188,7 +2223,7 @@
   // apply the type to the tree, then rerun type inference.  Iterate until all
   // types are resolved.
   //
-  TreePatternNode *Pat = Pattern.first->clone();
+  TreePatternNode *Pat = Pattern.getSrcPattern()->clone();
   RemoveAllTypes(Pat);
   
   do {
@@ -2206,9 +2241,9 @@
     // Insert a check for an unresolved type and add it to the tree.  If we find
     // an unresolved type to add a check for, this returns true and we iterate,
     // otherwise we are done.
-  } while (Emitter.InsertOneTypeCheck(Pat, Pattern.first, "N"));
+  } while (Emitter.InsertOneTypeCheck(Pat, Pattern.getSrcPattern(), "N"));
 
-  Emitter.EmitResultCode(Pattern.second, true /*the root*/);
+  Emitter.EmitResultCode(Pattern.getDstPattern(), true /*the root*/);
 
   delete Pat;
   
@@ -2286,7 +2321,7 @@
   std::map<Record*, std::vector<PatternToMatch*>,
            CompareByRecordName> PatternsByOpcode;
   for (unsigned i = 0, e = PatternsToMatch.size(); i != e; ++i) {
-    TreePatternNode *Node = PatternsToMatch[i].first;
+    TreePatternNode *Node = PatternsToMatch[i].getSrcPattern();
     if (!Node->isLeaf()) {
       PatternsByOpcode[Node->getOperator()].push_back(&PatternsToMatch[i]);
     } else {
@@ -2304,7 +2339,7 @@
         std::cerr << "Unrecognized opcode '";
         Node->dump();
         std::cerr << "' on tree pattern '";
-        std::cerr << PatternsToMatch[i].second->getOperator()->getName();
+        std::cerr << PatternsToMatch[i].getDstPattern()->getOperator()->getName();
         std::cerr << "'!\n";
         exit(1);
       }
@@ -2366,8 +2401,8 @@
   
   DEBUG(std::cerr << "\n\nALL PATTERNS TO MATCH:\n\n";
         for (unsigned i = 0, e = PatternsToMatch.size(); i != e; ++i) {
-          std::cerr << "PATTERN: ";  PatternsToMatch[i].first->dump();
-          std::cerr << "\nRESULT:  ";PatternsToMatch[i].second->dump();
+          std::cerr << "PATTERN: ";  PatternsToMatch[i].getSrcPattern()->dump();
+          std::cerr << "\nRESULT:  ";PatternsToMatch[i].getDstPattern()->dump();
           std::cerr << "\n";
         });