Rework the cloning infrastructure for statements to be able to take and update
an operand mapping, which simplifies it a bit. Implement cloning for IfStmt,
rename getThenClause() to getThen() which is unambiguous and less repetitive in
use cases.
PiperOrigin-RevId: 207915990
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index c2cd257..3a3b196 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -174,10 +174,10 @@
void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
recordIntegerSetReference(ifStmt->getCondition());
- for (auto &childStmt : *ifStmt->getThenClause())
+ for (auto &childStmt : *ifStmt->getThen())
visitStatement(&childStmt);
- if (ifStmt->hasElseClause())
- for (auto &childStmt : *ifStmt->getElseClause())
+ if (ifStmt->hasElse())
+ for (auto &childStmt : *ifStmt->getElse())
visitStatement(&childStmt);
}
@@ -1270,11 +1270,11 @@
os.indent(numSpaces) << "if (";
printIntegerSetReference(stmt->getCondition());
os << ") {\n";
- print(stmt->getThenClause());
+ print(stmt->getThen());
os.indent(numSpaces) << "}";
- if (stmt->hasElseClause()) {
+ if (stmt->hasElse()) {
os << " else {\n";
- print(stmt->getElseClause());
+ print(stmt->getElse());
os.indent(numSpaces) << "}";
}
}
@@ -1393,14 +1393,14 @@
void Statement::dump() const { print(llvm::errs()); }
-void StmtBlock::print(raw_ostream &os) const {
+void StmtBlock::printBlock(raw_ostream &os) const {
MLFunction *function = findFunction();
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
MLFunctionPrinter(function, modulePrinter).print(this);
}
-void StmtBlock::dump() const { print(llvm::errs()); }
+void StmtBlock::dumpBlock() const { printBlock(llvm::errs()); }
void Function::print(raw_ostream &os) const {
ModuleState state(getContext());
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index ffc210a..d20b45f 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
+#include "llvm/ADT/DenseMap.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
@@ -77,43 +78,6 @@
return nlc.numNestedLoops == 1;
}
-Statement *Statement::clone() const {
- switch (kind) {
- case Kind::Operation:
- return cast<OperationStmt>(this)->clone();
- case Kind::If:
- llvm_unreachable("cloning for if's not implemented yet");
- return cast<IfStmt>(this)->clone();
- case Kind::For:
- return cast<ForStmt>(this)->clone();
- }
-}
-
-/// Replaces all uses of oldVal with newVal.
-// TODO(bondhugula,clattner): do this more efficiently by walking those uses of
-// oldVal that fall within this statement.
-void Statement::replaceUses(MLValue *oldVal, MLValue *newVal) {
- struct ReplaceUseWalker : public StmtWalker<ReplaceUseWalker> {
- // Value to be replaced.
- MLValue *oldVal;
- // Value to be replaced with.
- MLValue *newVal;
-
- ReplaceUseWalker(MLValue *oldVal, MLValue *newVal)
- : oldVal(oldVal), newVal(newVal){};
-
- void visitOperationStmt(OperationStmt *os) {
- for (auto &operand : os->getStmtOperands()) {
- if (operand.get() == oldVal)
- operand.set(newVal);
- }
- }
- };
-
- ReplaceUseWalker ri(oldVal, newVal);
- ri.walk(this);
-}
-
//===----------------------------------------------------------------------===//
// ilist_traits for Statement
//===----------------------------------------------------------------------===//
@@ -193,22 +157,6 @@
return stmt;
}
-/// Clone an existing OperationStmt.
-OperationStmt *OperationStmt::clone() const {
- SmallVector<MLValue *, 8> operands;
- SmallVector<Type *, 8> resultTypes;
-
- // TODO(clattner): switch this to iterator logic.
- // Put together operands and results.
- for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
- operands.push_back(getStmtOperand(i).get());
-
- for (unsigned i = 0, e = getNumResults(); i != e; ++i)
- resultTypes.push_back(getStmtResult(i).getType());
-
- return create(getName(), operands, resultTypes, getAttrs(), getContext());
-}
-
OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
unsigned numResults,
ArrayRef<NamedAttribute> attributes,
@@ -256,37 +204,6 @@
StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
upperBound(upperBound), step(step) {}
-ForStmt *ForStmt::clone() const {
- auto *forStmt = new ForStmt(getLowerBound(), getUpperBound(), getStep(),
- Statement::findFunction()->getContext());
-
- // Pairs of <old op stmt result whose uses need to be replaced,
- // new result generated by the corresponding cloned op stmt>.
- SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
- for (auto &s : getStatements()) {
- auto *cloneStmt = s.clone();
- forStmt->getStatements().push_back(cloneStmt);
- if (auto *opStmt = dyn_cast<OperationStmt>(&s)) {
- auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
- for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
- oldNewResultPairs.push_back(
- std::make_pair(const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
- &cloneOpStmt->getStmtResult(i)));
- }
- }
- }
- // Replace uses of old op results' with the newly created ones.
- for (unsigned i = 0, e = oldNewResultPairs.size(); i < e; i++) {
- for (auto &stmt : *forStmt) {
- stmt.replaceUses(oldNewResultPairs[i].first, oldNewResultPairs[i].second);
- }
- }
-
- // Replace uses of old loop IV with the new one.
- forStmt->Statement::replaceUses(const_cast<ForStmt *>(this), forStmt);
- return forStmt;
-}
-
//===----------------------------------------------------------------------===//
// IfStmt
//===----------------------------------------------------------------------===//
@@ -299,7 +216,72 @@
// allocated through MLIRContext's bump pointer allocator.
}
-IfStmt *IfStmt::clone() const {
- llvm_unreachable("cloning for if's not implemented yet");
- return nullptr;
+//===----------------------------------------------------------------------===//
+// Statement Cloning
+//===----------------------------------------------------------------------===//
+
+/// Create a deep copy of this statement, remapping any operands that use
+/// values outside of the statement using the map that is provided (leaving
+/// them alone if no entry is present). Replaces references to cloned
+/// sub-statements to the corresponding statement that is copied, and adds
+/// those mappings to the map.
+Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
+ MLIRContext *context) const {
+ // If the specified value is in operandMap, return the remapped value.
+ // Otherwise return the value itself.
+ auto remapOperand = [&](const MLValue *value) -> MLValue * {
+ auto it = operandMap.find(value);
+ return it != operandMap.end() ? it->second : const_cast<MLValue *>(value);
+ };
+
+ if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
+ SmallVector<MLValue *, 8> operands;
+ operands.reserve(opStmt->getNumOperands());
+ for (auto *opValue : opStmt->getOperands())
+ operands.push_back(remapOperand(opValue));
+
+ SmallVector<Type *, 8> resultTypes;
+ resultTypes.reserve(opStmt->getNumResults());
+ for (auto *result : opStmt->getResults())
+ resultTypes.push_back(result->getType());
+ auto *newOp = OperationStmt::create(
+ opStmt->getName(), operands, resultTypes, opStmt->getAttrs(), context);
+ // Remember the mapping of any results.
+ for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
+ operandMap[opStmt->getResult(i)] = newOp->getResult(i);
+ return newOp;
+ }
+
+ if (auto *forStmt = dyn_cast<ForStmt>(this)) {
+ auto *newFor =
+ new ForStmt(forStmt->getLowerBound(), forStmt->getUpperBound(),
+ forStmt->getStep(), context);
+ // Remember the induction variable mapping.
+ operandMap[forStmt] = newFor;
+
+ // TODO: remap operands in loop bounds when they are added.
+ // Recursively clone the body of the for loop.
+ for (auto &subStmt : *forStmt)
+ newFor->push_back(subStmt.clone(operandMap, context));
+
+ return newFor;
+ }
+
+ // Otherwise, we must have an If statement.
+ auto *ifStmt = cast<IfStmt>(this);
+ auto *newIf = new IfStmt(ifStmt->getCondition());
+
+ // TODO: remap operands with remapOperand when if statements have them.
+
+ auto *resultThen = newIf->getThen();
+ for (auto &childStmt : *ifStmt->getThen())
+ resultThen->push_back(childStmt.clone(operandMap, context));
+
+ if (ifStmt->hasElse()) {
+ auto *resultElse = newIf->createElse();
+ for (auto &childStmt : *ifStmt->getElse())
+ resultElse->push_back(childStmt.clone(operandMap, context));
+ }
+
+ return newIf;
}
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index ee8c68f..dde196c 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -371,10 +371,10 @@
// If this is an if or for, recursively walk the block they contain.
if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
- if (walkBlock(*ifStmt->getThenClause()))
+ if (walkBlock(*ifStmt->getThen()))
return true;
- if (auto *elseClause = ifStmt->getElseClause())
+ if (auto *elseClause = ifStmt->getElse())
if (walkBlock(*elseClause))
return true;
}