Rework the cloning infrastructure for statements to be able to take and update
an operand mapping, which simplifies it a bit.  Implement cloning for IfStmt,
rename getThenClause() to getThen() which is unambiguous and less repetitive in
use cases.

PiperOrigin-RevId: 207915990
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index c2cd257..3a3b196 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -174,10 +174,10 @@
 
 void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
   recordIntegerSetReference(ifStmt->getCondition());
-  for (auto &childStmt : *ifStmt->getThenClause())
+  for (auto &childStmt : *ifStmt->getThen())
     visitStatement(&childStmt);
-  if (ifStmt->hasElseClause())
-    for (auto &childStmt : *ifStmt->getElseClause())
+  if (ifStmt->hasElse())
+    for (auto &childStmt : *ifStmt->getElse())
       visitStatement(&childStmt);
 }
 
@@ -1270,11 +1270,11 @@
   os.indent(numSpaces) << "if (";
   printIntegerSetReference(stmt->getCondition());
   os << ") {\n";
-  print(stmt->getThenClause());
+  print(stmt->getThen());
   os.indent(numSpaces) << "}";
-  if (stmt->hasElseClause()) {
+  if (stmt->hasElse()) {
     os << " else {\n";
-    print(stmt->getElseClause());
+    print(stmt->getElse());
     os.indent(numSpaces) << "}";
   }
 }
@@ -1393,14 +1393,14 @@
 
 void Statement::dump() const { print(llvm::errs()); }
 
-void StmtBlock::print(raw_ostream &os) const {
+void StmtBlock::printBlock(raw_ostream &os) const {
   MLFunction *function = findFunction();
   ModuleState state(function->getContext());
   ModulePrinter modulePrinter(os, state);
   MLFunctionPrinter(function, modulePrinter).print(this);
 }
 
-void StmtBlock::dump() const { print(llvm::errs()); }
+void StmtBlock::dumpBlock() const { printBlock(llvm::errs()); }
 
 void Function::print(raw_ostream &os) const {
   ModuleState state(getContext());
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index ffc210a..d20b45f 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/StmtVisitor.h"
 #include "mlir/IR/Types.h"
+#include "llvm/ADT/DenseMap.h"
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
@@ -77,43 +78,6 @@
   return nlc.numNestedLoops == 1;
 }
 
-Statement *Statement::clone() const {
-  switch (kind) {
-  case Kind::Operation:
-    return cast<OperationStmt>(this)->clone();
-  case Kind::If:
-    llvm_unreachable("cloning for if's not implemented yet");
-    return cast<IfStmt>(this)->clone();
-  case Kind::For:
-    return cast<ForStmt>(this)->clone();
-  }
-}
-
-/// Replaces all uses of oldVal with newVal.
-// TODO(bondhugula,clattner): do this more efficiently by walking those uses of
-// oldVal that fall within this statement.
-void Statement::replaceUses(MLValue *oldVal, MLValue *newVal) {
-  struct ReplaceUseWalker : public StmtWalker<ReplaceUseWalker> {
-    // Value to be replaced.
-    MLValue *oldVal;
-    // Value to be replaced with.
-    MLValue *newVal;
-
-    ReplaceUseWalker(MLValue *oldVal, MLValue *newVal)
-        : oldVal(oldVal), newVal(newVal){};
-
-    void visitOperationStmt(OperationStmt *os) {
-      for (auto &operand : os->getStmtOperands()) {
-        if (operand.get() == oldVal)
-          operand.set(newVal);
-      }
-    }
-  };
-
-  ReplaceUseWalker ri(oldVal, newVal);
-  ri.walk(this);
-}
-
 //===----------------------------------------------------------------------===//
 // ilist_traits for Statement
 //===----------------------------------------------------------------------===//
@@ -193,22 +157,6 @@
   return stmt;
 }
 
-/// Clone an existing OperationStmt.
-OperationStmt *OperationStmt::clone() const {
-  SmallVector<MLValue *, 8> operands;
-  SmallVector<Type *, 8> resultTypes;
-
-  // TODO(clattner): switch this to iterator logic.
-  // Put together operands and results.
-  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
-    operands.push_back(getStmtOperand(i).get());
-
-  for (unsigned i = 0, e = getNumResults(); i != e; ++i)
-    resultTypes.push_back(getStmtResult(i).getType());
-
-  return create(getName(), operands, resultTypes, getAttrs(), getContext());
-}
-
 OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
                              unsigned numResults,
                              ArrayRef<NamedAttribute> attributes,
@@ -256,37 +204,6 @@
       StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
       upperBound(upperBound), step(step) {}
 
-ForStmt *ForStmt::clone() const {
-  auto *forStmt = new ForStmt(getLowerBound(), getUpperBound(), getStep(),
-                              Statement::findFunction()->getContext());
-
-  // Pairs of <old op stmt result whose uses need to be replaced,
-  // new result generated by the corresponding cloned op stmt>.
-  SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
-  for (auto &s : getStatements()) {
-    auto *cloneStmt = s.clone();
-    forStmt->getStatements().push_back(cloneStmt);
-    if (auto *opStmt = dyn_cast<OperationStmt>(&s)) {
-      auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
-      for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
-        oldNewResultPairs.push_back(
-            std::make_pair(const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
-                           &cloneOpStmt->getStmtResult(i)));
-      }
-    }
-  }
-  // Replace uses of old op results' with the newly created ones.
-  for (unsigned i = 0, e = oldNewResultPairs.size(); i < e; i++) {
-    for (auto &stmt : *forStmt) {
-      stmt.replaceUses(oldNewResultPairs[i].first, oldNewResultPairs[i].second);
-    }
-  }
-
-  // Replace uses of old loop IV with the new one.
-  forStmt->Statement::replaceUses(const_cast<ForStmt *>(this), forStmt);
-  return forStmt;
-}
-
 //===----------------------------------------------------------------------===//
 // IfStmt
 //===----------------------------------------------------------------------===//
@@ -299,7 +216,72 @@
   // allocated through MLIRContext's bump pointer allocator.
 }
 
-IfStmt *IfStmt::clone() const {
-  llvm_unreachable("cloning for if's not implemented yet");
-  return nullptr;
+//===----------------------------------------------------------------------===//
+// Statement Cloning
+//===----------------------------------------------------------------------===//
+
+/// Create a deep copy of this statement, remapping any operands that use
+/// values outside of the statement using the map that is provided (leaving
+/// them alone if no entry is present).  Replaces references to cloned
+/// sub-statements to the corresponding statement that is copied, and adds
+/// those mappings to the map.
+Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
+                            MLIRContext *context) const {
+  // If the specified value is in operandMap, return the remapped value.
+  // Otherwise return the value itself.
+  auto remapOperand = [&](const MLValue *value) -> MLValue * {
+    auto it = operandMap.find(value);
+    return it != operandMap.end() ? it->second : const_cast<MLValue *>(value);
+  };
+
+  if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
+    SmallVector<MLValue *, 8> operands;
+    operands.reserve(opStmt->getNumOperands());
+    for (auto *opValue : opStmt->getOperands())
+      operands.push_back(remapOperand(opValue));
+
+    SmallVector<Type *, 8> resultTypes;
+    resultTypes.reserve(opStmt->getNumResults());
+    for (auto *result : opStmt->getResults())
+      resultTypes.push_back(result->getType());
+    auto *newOp = OperationStmt::create(
+        opStmt->getName(), operands, resultTypes, opStmt->getAttrs(), context);
+    // Remember the mapping of any results.
+    for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
+      operandMap[opStmt->getResult(i)] = newOp->getResult(i);
+    return newOp;
+  }
+
+  if (auto *forStmt = dyn_cast<ForStmt>(this)) {
+    auto *newFor =
+        new ForStmt(forStmt->getLowerBound(), forStmt->getUpperBound(),
+                    forStmt->getStep(), context);
+    // Remember the induction variable mapping.
+    operandMap[forStmt] = newFor;
+
+    // TODO: remap operands in loop bounds when they are added.
+    // Recursively clone the body of the for loop.
+    for (auto &subStmt : *forStmt)
+      newFor->push_back(subStmt.clone(operandMap, context));
+
+    return newFor;
+  }
+
+  // Otherwise, we must have an If statement.
+  auto *ifStmt = cast<IfStmt>(this);
+  auto *newIf = new IfStmt(ifStmt->getCondition());
+
+  // TODO: remap operands with remapOperand when if statements have them.
+
+  auto *resultThen = newIf->getThen();
+  for (auto &childStmt : *ifStmt->getThen())
+    resultThen->push_back(childStmt.clone(operandMap, context));
+
+  if (ifStmt->hasElse()) {
+    auto *resultElse = newIf->createElse();
+    for (auto &childStmt : *ifStmt->getElse())
+      resultElse->push_back(childStmt.clone(operandMap, context));
+  }
+
+  return newIf;
 }
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index ee8c68f..dde196c 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -371,10 +371,10 @@
 
       // If this is an if or for, recursively walk the block they contain.
       if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
-        if (walkBlock(*ifStmt->getThenClause()))
+        if (walkBlock(*ifStmt->getThen()))
           return true;
 
-        if (auto *elseClause = ifStmt->getElseClause())
+        if (auto *elseClause = ifStmt->getElse())
           if (walkBlock(*elseClause))
             return true;
       }