Implement MLValue, statement operands, operation statement operands and values. ML functions now have full support for expressing operations. Induction variables, function arguments and return values are still todo.

PiperOrigin-RevId: 206253643
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 4fd7e61..90115ec 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -29,6 +29,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSet.h"
 #include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Support/STLExtras.h"
 #include "llvm/ADT/DenseMap.h"
@@ -579,8 +580,6 @@
 } // end anonymous namespace
 
 void FunctionPrinter::printOperation(const Operation *op) {
-  os << "  ";
-
   if (op->getNumResults()) {
     printValueID(op->getResult(0), /*dontPrintResultNo*/ true);
     os << " = ";
@@ -717,6 +716,7 @@
   os << ":\n";
 
   for (auto &inst : block->getOperations()) {
+    os << "  ";
     print(&inst);
     os << '\n';
   }
@@ -743,7 +743,7 @@
 }
 
 void CFGFunctionPrinter::print(const BranchInst *inst) {
-  os << "  br bb" << getBBID(inst->getDest());
+  os << "br bb" << getBBID(inst->getDest());
 
   if (inst->getNumOperands() != 0) {
     os << '(';
@@ -759,7 +759,7 @@
 }
 
 void CFGFunctionPrinter::print(const CondBranchInst *inst) {
-  os << "  cond_br ";
+  os << "cond_br ";
   printValueID(inst->getCondition());
 
   os << ", bb" << getBBID(inst->getTrueDest());
@@ -788,7 +788,7 @@
 }
 
 void CFGFunctionPrinter::print(const ReturnInst *inst) {
-  os << "  return";
+  os << "return";
 
   if (inst->getNumOperands() != 0)
     os << ' ';
@@ -830,6 +830,8 @@
   const static unsigned indentWidth = 2;
 
 private:
+  void numberValues();
+
   const MLFunction *function;
   int numSpaces;
 };
@@ -837,7 +839,26 @@
 
 MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function,
                                      const ModulePrinter &other)
-    : FunctionPrinter(other), function(function), numSpaces(0) {}
+    : FunctionPrinter(other), function(function), numSpaces(0) {
+  numberValues();
+}
+
+/// Number all of the SSA values in this ML function.
+void MLFunctionPrinter::numberValues() {
+  // Visits all operation statements and numbers the first result.
+  struct NumberValuesPass : public StmtVisitor<NumberValuesPass> {
+    NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {}
+    void visitOperationStmt(OperationStmt *stmt) {
+      if (stmt->getNumResults() != 0)
+        printer->numberValueID(stmt->getResult(0));
+    }
+    MLFunctionPrinter *printer;
+  };
+
+  NumberValuesPass pass(this);
+  // TODO: it'd be cleaner to have constant visitor istead of using const_cast.
+  pass.visit(const_cast<MLFunction *>(function));
+}
 
 void MLFunctionPrinter::print() {
   os << "mlfunc ";
@@ -870,6 +891,7 @@
 }
 
 void MLFunctionPrinter::print(const OperationStmt *stmt) {
+  os.indent(numSpaces);
   printOperation(stmt);
 }
 
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index b06be4d..e598f2a 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/CFGFunction.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Module.h"
+#include "mlir/IR/StmtVisitor.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/StringRef.h"
 using namespace mlir;
@@ -120,6 +121,9 @@
     : Function(name, type, Kind::MLFunc), StmtBlock(StmtBlockKind::MLFunc) {}
 
 MLFunction::~MLFunction() {
-  // TODO: When move SSA stuff is supported.
-  // dropAllReferences();
+  struct DropReferencesPass : public StmtVisitor<DropReferencesPass> {
+    void visitOperationStmt(OperationStmt *stmt) { stmt->dropAllReferences(); }
+  };
+  DropReferencesPass pass;
+  pass.visit(this);
 }
diff --git a/lib/IR/Operation.cpp b/lib/IR/Operation.cpp
index 92a01ba..f2f9eb4 100644
--- a/lib/IR/Operation.cpp
+++ b/lib/IR/Operation.cpp
@@ -39,10 +39,7 @@
   if (auto *inst = dyn_cast<OperationInst>(this)) {
     return inst->getNumOperands();
   } else {
-    auto *stmt = cast<OperationStmt>(this);
-    (void)stmt;
-    // TODO: Add operands to OperationStmt.
-    return 0;
+    return cast<OperationStmt>(this)->getNumOperands();
   }
 }
 
@@ -51,9 +48,7 @@
     return inst->getOperand(idx);
   } else {
     auto *stmt = cast<OperationStmt>(this);
-    (void)stmt;
-    // TODO: Add operands to OperationStmt.
-    abort();
+    return stmt->getOperand(idx);
   }
 }
 
@@ -62,9 +57,7 @@
     inst->setOperand(idx, cast<CFGValue>(value));
   } else {
     auto *stmt = cast<OperationStmt>(this);
-    (void)stmt;
-    // TODO: Add operands to OperationStmt.
-    abort();
+    stmt->setOperand(idx, cast<MLValue>(value));
   }
 }
 
@@ -74,9 +67,7 @@
     return inst->getNumResults();
   } else {
     auto *stmt = cast<OperationStmt>(this);
-    (void)stmt;
-    // TODO: Add results to OperationStmt.
-    return 0;
+    return stmt->getNumResults();
   }
 }
 
@@ -86,9 +77,7 @@
     return inst->getResult(idx);
   } else {
     auto *stmt = cast<OperationStmt>(this);
-    (void)stmt;
-    // TODO: Add operands to OperationStmt.
-    abort();
+    return stmt->getResult(idx);
   }
 }
 
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 6b272e0..4b6ddc7 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -21,6 +21,17 @@
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
+// StmtResult
+//===------------------------------------------------------------------===//
+
+/// Return the result number of this result.
+unsigned StmtResult::getResultNumber() const {
+  // Results are always stored consecutively, so use pointer subtraction to
+  // figure out what number this is.
+  return this - &getOwner()->getStmtResults()[0];
+}
+
+//===----------------------------------------------------------------------===//
 // Statement
 //===------------------------------------------------------------------===//
 
@@ -34,7 +45,7 @@
 void Statement::destroy() {
   switch (this->getKind()) {
   case Kind::Operation:
-    delete cast<OperationStmt>(this);
+    cast<OperationStmt>(this)->destroy();
     break;
   case Kind::For:
     delete cast<ForStmt>(this);
@@ -113,6 +124,73 @@
 }
 
 //===----------------------------------------------------------------------===//
+// OperationStmt
+//===----------------------------------------------------------------------===//
+
+/// Create a new OperationStmt with the specific fields.
+OperationStmt *OperationStmt::create(Identifier name,
+                                     ArrayRef<MLValue *> operands,
+                                     ArrayRef<Type *> resultTypes,
+                                     ArrayRef<NamedAttribute> attributes,
+                                     MLIRContext *context) {
+  auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
+                                                            resultTypes.size());
+  void *rawMem = malloc(byteSize);
+
+  // Initialize the OperationStmt part of the statement.
+  auto stmt = ::new (rawMem) OperationStmt(
+      name, operands.size(), resultTypes.size(), attributes, context);
+
+  // Initialize the operands and results.
+  auto stmtOperands = stmt->getStmtOperands();
+  for (unsigned i = 0, e = operands.size(); i != e; ++i)
+    new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
+
+  auto stmtResults = stmt->getStmtResults();
+  for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
+    new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
+  return stmt;
+}
+
+OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
+                             unsigned numResults,
+                             ArrayRef<NamedAttribute> attributes,
+                             MLIRContext *context)
+    : Operation(name, /*isInstruction=*/false, attributes, context),
+      Statement(Kind::Operation), numOperands(numOperands),
+      numResults(numResults) {}
+
+OperationStmt::~OperationStmt() {
+  // Explicitly run the destructors for the operands and results.
+  for (auto &operand : getStmtOperands())
+    operand.~StmtOperand();
+
+  for (auto &result : getStmtResults())
+    result.~StmtResult();
+}
+
+void OperationStmt::destroy() {
+  this->~OperationStmt();
+  free(this);
+}
+
+/// This drops all operand uses from this statement, which is an essential
+/// step in breaking cyclic dependences between references when they are to
+/// be deleted.
+void OperationStmt::dropAllReferences() {
+  for (auto &op : getStmtOperands())
+    op.drop();
+}
+
+/// If this value is the result of an OperationStmt, return the statement
+/// that defines it.
+OperationStmt *SSAValue::getDefiningStmt() {
+  if (auto *result = dyn_cast<StmtResult>(this))
+    return result->getOwner();
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
 // IfStmt
 //===----------------------------------------------------------------------===//