Allow "let AddedCost = n in" to increase pattern complexity.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@27834 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/utils/TableGen/DAGISelEmitter.cpp b/utils/TableGen/DAGISelEmitter.cpp
index 7a6097f..8d5186f 100644
--- a/utils/TableGen/DAGISelEmitter.cpp
+++ b/utils/TableGen/DAGISelEmitter.cpp
@@ -1504,7 +1504,8 @@
TreePatternNode *DstPattern = TheInst.getResultPattern();
PatternsToMatch.
push_back(PatternToMatch(Instr->getValueAsListInit("Predicates"),
- SrcPattern, DstPattern));
+ SrcPattern, DstPattern,
+ Instr->getValueAsInt("AddedCost")));
}
}
@@ -1580,7 +1581,8 @@
PatternsToMatch.
push_back(PatternToMatch(Patterns[i]->getValueAsListInit("Predicates"),
Pattern->getOnlyTree(),
- Temp.getOnlyTree()));
+ Temp.getOnlyTree(),
+ Patterns[i]->getValueAsInt("AddedCost")));
}
}
@@ -1823,7 +1825,8 @@
// Otherwise, add it to the list of patterns we have.
PatternsToMatch.
push_back(PatternToMatch(PatternsToMatch[i].getPredicates(),
- Variant, PatternsToMatch[i].getDstPattern()));
+ Variant, PatternsToMatch[i].getDstPattern(),
+ PatternsToMatch[i].getAddedCost()));
}
DEBUG(std::cerr << "\n");
@@ -1933,6 +1936,8 @@
PatternToMatch *RHS) {
unsigned LHSSize = getPatternSize(LHS->getSrcPattern(), ISE);
unsigned RHSSize = getPatternSize(RHS->getSrcPattern(), ISE);
+ LHSSize += LHS->getAddedCost();
+ RHSSize += RHS->getAddedCost();
if (LHSSize > RHSSize) return true; // LHS -> bigger -> less cost
if (LHSSize < RHSSize) return false;
@@ -2003,6 +2008,8 @@
// Predicates.
ListInit *Predicates;
+ // Pattern cost.
+ unsigned Cost;
// Instruction selector pattern.
TreePatternNode *Pattern;
// Matched instruction.
@@ -2939,8 +2946,10 @@
OS << "\n" << std::string(Indent, ' ') << "// Emits: ";
Pattern.getDstPattern()->print(OS);
OS << "\n";
+ unsigned AddedCost = Pattern.getAddedCost();
OS << std::string(Indent, ' ') << "// Pattern complexity = "
- << getPatternSize(Pattern.getSrcPattern(), *this) << " cost = "
+ << getPatternSize(Pattern.getSrcPattern(), *this) + AddedCost
+ << " cost = "
<< getResultPatternCost(Pattern.getDstPattern(), *this) << "\n";
}
if (!FirstCodeLine.first) {
@@ -2960,8 +2969,10 @@
OS << "\n" << std::string(Indent, ' ') << "// Emits: ";
Pattern.getDstPattern()->print(OS);
OS << "\n";
+ unsigned AddedCost = Pattern.getAddedCost();
OS << std::string(Indent, ' ') << "// Pattern complexity = "
- << getPatternSize(Pattern.getSrcPattern(), *this) << " cost = "
+ << getPatternSize(Pattern.getSrcPattern(), *this) + AddedCost
+ << " cost = "
<< getResultPatternCost(Pattern.getDstPattern(), *this) << "\n";
}
EmitPatterns(Other, Indent, OS);