Implement operands for the lower and upper bounds of the for statement.
This revamps implementation of the loop bounds in the ForStmt, using general representation that supports operands. The frequent case of constant bounds is supported
via special access methods.
This also includes:
- Operand iterators for the Statement class.
- OpPointer::is() method to query the class of the Operation.
- Support for the bound shorthand notation parsing and printing.
- Validity checks for the bound operands used as dim ids and symbols
I didn't mean this CL to be so large. It just happened this way, as one thing led to another.
PiperOrigin-RevId: 210204858
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index cd90e3e..2cc20ac 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -101,6 +101,18 @@
}
}
+ // Return true if this map could be printed using the shorthand form.
+ static bool hasShorthandForm(const AffineMap *boundMap) {
+ if (boundMap->isSingleConstant())
+ return true;
+
+ // Check if the affine map is single dim id or single symbol identity -
+ // (i)->(i) or ()[s]->(i)
+ return boundMap->getNumOperands() == 1 && boundMap->getNumResults() == 1 &&
+ (isa<AffineDimExpr>(boundMap->getResult(0)) ||
+ isa<AffineSymbolExpr>(boundMap->getResult(0)));
+ }
+
// Visit functions.
void visitFunction(const Function *fn);
void visitExtFunction(const ExtFunction *fn);
@@ -183,6 +195,14 @@
}
void ModuleState::visitForStmt(const ForStmt *forStmt) {
+ AffineMap *lbMap = forStmt->getLowerBoundMap();
+ if (!hasShorthandForm(lbMap))
+ recordAffineMapReference(lbMap);
+
+ AffineMap *ubMap = forStmt->getUpperBoundMap();
+ if (!hasShorthandForm(ubMap))
+ recordAffineMapReference(ubMap);
+
for (auto &childStmt : *forStmt)
visitStatement(&childStmt);
}
@@ -1216,20 +1236,24 @@
const MLFunction *getFunction() const { return function; }
- // Prints ML function
+ // Prints ML function.
void print();
- // Prints ML function signature
+ // Prints ML function signature.
void printFunctionSignature();
- // Methods to print ML function statements
+ // Methods to print ML function statements.
void print(const Statement *stmt);
void print(const OperationStmt *stmt);
void print(const ForStmt *stmt);
void print(const IfStmt *stmt);
void print(const StmtBlock *block);
- // Number of spaces used for indenting nested statements
+ // Print loop bounds.
+ void printDimAndSymbolList(ArrayRef<StmtOperand> ops, unsigned numDims);
+ void printBound(AffineBound bound, const char *prefix);
+
+ // Number of spaces used for indenting nested statements.
const static unsigned indentWidth = 2;
private:
@@ -1249,7 +1273,7 @@
/// Number all of the SSA values in this ML function.
void MLFunctionPrinter::numberValues() {
- // Numbers ML function arguments
+ // Numbers ML function arguments.
for (auto *arg : function->getArguments())
numberValueID(arg);
@@ -1323,8 +1347,11 @@
void MLFunctionPrinter::print(const ForStmt *stmt) {
os.indent(numSpaces) << "for ";
printOperand(stmt);
- os << " = " << *stmt->getLowerBound();
- os << " to " << *stmt->getUpperBound();
+ os << " = ";
+ printBound(stmt->getLowerBound(), "max");
+ os << " to ";
+ printBound(stmt->getUpperBound(), "min");
+
if (stmt->getStep() != 1)
os << " step " << stmt->getStep();
@@ -1333,6 +1360,51 @@
os.indent(numSpaces) << "}";
}
+void MLFunctionPrinter::printDimAndSymbolList(ArrayRef<StmtOperand> ops,
+ unsigned numDims) {
+ auto printComma = [&]() { os << ", "; };
+ os << '(';
+ interleave(ops.begin(), ops.begin() + numDims,
+ [&](const StmtOperand &v) { printOperand(v.get()); }, printComma);
+ os << ')';
+
+ if (numDims < ops.size()) {
+ os << '[';
+ interleave(ops.begin() + numDims, ops.end(),
+ [&](const StmtOperand &v) { printOperand(v.get()); },
+ printComma);
+ os << ']';
+ }
+}
+
+void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
+ AffineMap *map = bound.getMap();
+
+ // Check if this bound should be printed using short-hand notation.
+ if (map->getNumResults() == 1) {
+ AffineExpr *expr = map->getResult(0);
+
+ // Print constant bound.
+ if (auto *constExpr = dyn_cast<AffineConstantExpr>(expr)) {
+ os << constExpr->getValue();
+ return;
+ }
+
+ // Print bound that consists of a single SSA id.
+ if (isa<AffineDimExpr>(expr) || isa<AffineSymbolExpr>(expr)) {
+ printOperand(bound.getOperand(0));
+ return;
+ }
+ } else {
+ // Map has multiple results. Print 'min' or 'max' prefix.
+ os << prefix << ' ';
+ }
+
+ // Print the map and the operands.
+ printAffineMapReference(map);
+ printDimAndSymbolList(bound.getStmtOperands(), map->getNumDims());
+}
+
void MLFunctionPrinter::print(const IfStmt *stmt) {
os.indent(numSpaces) << "if (";
printIntegerSetReference(stmt->getCondition());