Implement operands for the lower and upper bounds of the for statement.
This revamps implementation of the loop bounds in the ForStmt, using general representation that supports operands. The frequent case of constant bounds is supported
via special access methods.
This also includes:
- Operand iterators for the Statement class.
- OpPointer::is() method to query the class of the Operation.
- Support for the bound shorthand notation parsing and printing.
- Validity checks for the bound operands used as dim ids and symbols
I didn't mean this CL to be so large. It just happened this way, as one thing led to another.
PiperOrigin-RevId: 210204858
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 8a6248e..3746f71 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -15,12 +15,14 @@
// limitations under the License.
// =============================================================================
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
-#include "mlir/IR/Types.h"
#include "llvm/ADT/DenseMap.h"
+
using namespace mlir;
//===----------------------------------------------------------------------===//
@@ -66,7 +68,7 @@
case Kind::Operation:
return cast<OperationStmt>(this)->getContext();
case Kind::For:
- return cast<ForStmt>(this)->getType()->getContext();
+ return cast<ForStmt>(this)->getContext();
case Kind::If:
// TODO(shpeisman): When if statement has value operands, we can get a
// context from their type.
@@ -94,6 +96,42 @@
return nlc.numNestedLoops == 1;
}
+MLValue *Statement::getOperand(unsigned idx) {
+ return getStmtOperand(idx).get();
+}
+
+const MLValue *Statement::getOperand(unsigned idx) const {
+ return getStmtOperand(idx).get();
+}
+
+void Statement::setOperand(unsigned idx, MLValue *value) {
+ getStmtOperand(idx).set(value);
+}
+
+unsigned Statement::getNumOperands() const {
+ switch (getKind()) {
+ case Kind::Operation:
+ return cast<OperationStmt>(this)->getNumOperands();
+ case Kind::For:
+ return cast<ForStmt>(this)->getNumOperands();
+ case Kind::If:
+ // TODO: query IfStmt once it has operands.
+ return 0;
+ }
+}
+
+MutableArrayRef<StmtOperand> Statement::getStmtOperands() {
+ switch (getKind()) {
+ case Kind::Operation:
+ return cast<OperationStmt>(this)->getStmtOperands();
+ case Kind::For:
+ return cast<ForStmt>(this)->getStmtOperands();
+ case Kind::If:
+ // TODO: query IfStmt once it has operands.
+ return {};
+ }
+}
+
/// Emit a note about this statement, reporting up to any diagnostic
/// handlers that may be listening.
void Statement::emitNote(const Twine &message) const {
@@ -231,17 +269,89 @@
return findFunction()->getContext();
}
+bool OperationStmt::isReturn() const { return is<ReturnOp>(); }
+
//===----------------------------------------------------------------------===//
// ForStmt
//===----------------------------------------------------------------------===//
-ForStmt::ForStmt(Attribute *location, AffineConstantExpr *lowerBound,
- AffineConstantExpr *upperBound, int64_t step,
- MLIRContext *context)
+ForStmt *ForStmt::create(Attribute *location, ArrayRef<MLValue *> lbOperands,
+ AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
+ AffineMap *ubMap, int64_t step, MLIRContext *context) {
+ assert(lbOperands.size() == lbMap->getNumOperands() &&
+ "lower bound operand count does not match the affine map");
+ assert(ubOperands.size() == ubMap->getNumOperands() &&
+ "upper bound operand count does not match the affine map");
+
+ unsigned numOperands = lbOperands.size() + ubOperands.size();
+ ForStmt *stmt =
+ new ForStmt(location, numOperands, lbMap, ubMap, step, context);
+
+ unsigned i = 0;
+ for (unsigned e = lbOperands.size(); i != e; ++i)
+ stmt->operands.emplace_back(StmtOperand(stmt, lbOperands[i]));
+
+ for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j)
+ stmt->operands.emplace_back(StmtOperand(stmt, ubOperands[j]));
+
+ return stmt;
+}
+
+ForStmt::ForStmt(Attribute *location, unsigned numOperands, AffineMap *lbMap,
+ AffineMap *ubMap, int64_t step, MLIRContext *context)
: Statement(Kind::For, location),
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
- StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
- upperBound(upperBound), step(step) {}
+ StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {}
+
+const AffineBound ForStmt::getLowerBound() const {
+ return AffineBound(*this, 0, lbMap->getNumOperands(), lbMap);
+}
+
+const AffineBound ForStmt::getUpperBound() const {
+ return AffineBound(*this, lbMap->getNumOperands(), getNumOperands(), ubMap);
+}
+
+void ForStmt::setLowerBound(ArrayRef<MLValue *> operands, AffineMap *map) {
+ // TODO: handle the case when number of existing or new operands is non-zero.
+ assert(getNumOperands() == 0 && operands.empty());
+
+ this->lbMap = map;
+}
+
+void ForStmt::setUpperBound(ArrayRef<MLValue *> operands, AffineMap *map) {
+ // TODO: handle the case when number of existing or new operands is non-zero.
+ assert(getNumOperands() == 0 && operands.empty());
+
+ this->ubMap = map;
+}
+
+bool ForStmt::hasConstantLowerBound() const {
+ return lbMap->isSingleConstant();
+}
+
+bool ForStmt::hasConstantUpperBound() const {
+ return ubMap->isSingleConstant();
+}
+
+int64_t ForStmt::getConstantLowerBound() const {
+ return lbMap->getSingleConstantValue();
+}
+
+int64_t ForStmt::getConstantUpperBound() const {
+ return ubMap->getSingleConstantValue();
+}
+
+void ForStmt::setConstantLowerBound(int64_t value) {
+ MLIRContext *context = getContext();
+ auto *expr = AffineConstantExpr::get(value, context);
+ setLowerBound({}, AffineMap::get(0, 0, expr, {}, context));
+}
+
+void ForStmt::setConstantUpperBound(int64_t value) {
+ MLIRContext *context = getContext();
+ auto *expr = AffineConstantExpr::get(value, context);
+ setUpperBound({}, AffineMap::get(0, 0, expr, {}, context));
+}
//===----------------------------------------------------------------------===//
// IfStmt
@@ -277,12 +387,12 @@
return it != operandMap.end() ? it->second : const_cast<MLValue *>(value);
};
- if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
- SmallVector<MLValue *, 8> operands;
- operands.reserve(opStmt->getNumOperands());
- for (auto *opValue : opStmt->getOperands())
- operands.push_back(remapOperand(opValue));
+ SmallVector<MLValue *, 8> operands;
+ operands.reserve(getNumOperands());
+ for (auto *opValue : getOperands())
+ operands.push_back(remapOperand(opValue));
+ if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
SmallVector<Type *, 8> resultTypes;
resultTypes.reserve(opStmt->getNumResults());
for (auto *result : opStmt->getResults())
@@ -297,13 +407,18 @@
}
if (auto *forStmt = dyn_cast<ForStmt>(this)) {
- auto *newFor =
- new ForStmt(getLoc(), forStmt->getLowerBound(),
- forStmt->getUpperBound(), forStmt->getStep(), context);
+ auto *lbMap = forStmt->getLowerBoundMap();
+ auto *ubMap = forStmt->getUpperBoundMap();
+
+ auto *newFor = ForStmt::create(
+ getLoc(),
+ ArrayRef<MLValue *>(operands).take_front(lbMap->getNumOperands()),
+ lbMap, ArrayRef<MLValue *>(operands).take_back(ubMap->getNumOperands()),
+ ubMap, forStmt->getStep(), context);
+
// Remember the induction variable mapping.
operandMap[forStmt] = newFor;
- // TODO: remap operands in loop bounds when they are added.
// Recursively clone the body of the for loop.
for (auto &subStmt : *forStmt)
newFor->push_back(subStmt.clone(operandMap, context));