Implement MLValue, statement operands, operation statement operands and values. ML functions now have full support for expressing operations. Induction variables, function arguments and return values are still todo.

PiperOrigin-RevId: 206253643
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 4fd7e61..90115ec 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -29,6 +29,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSet.h"
 #include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Support/STLExtras.h"
 #include "llvm/ADT/DenseMap.h"
@@ -579,8 +580,6 @@
 } // end anonymous namespace
 
 void FunctionPrinter::printOperation(const Operation *op) {
-  os << "  ";
-
   if (op->getNumResults()) {
     printValueID(op->getResult(0), /*dontPrintResultNo*/ true);
     os << " = ";
@@ -717,6 +716,7 @@
   os << ":\n";
 
   for (auto &inst : block->getOperations()) {
+    os << "  ";
     print(&inst);
     os << '\n';
   }
@@ -743,7 +743,7 @@
 }
 
 void CFGFunctionPrinter::print(const BranchInst *inst) {
-  os << "  br bb" << getBBID(inst->getDest());
+  os << "br bb" << getBBID(inst->getDest());
 
   if (inst->getNumOperands() != 0) {
     os << '(';
@@ -759,7 +759,7 @@
 }
 
 void CFGFunctionPrinter::print(const CondBranchInst *inst) {
-  os << "  cond_br ";
+  os << "cond_br ";
   printValueID(inst->getCondition());
 
   os << ", bb" << getBBID(inst->getTrueDest());
@@ -788,7 +788,7 @@
 }
 
 void CFGFunctionPrinter::print(const ReturnInst *inst) {
-  os << "  return";
+  os << "return";
 
   if (inst->getNumOperands() != 0)
     os << ' ';
@@ -830,6 +830,8 @@
   const static unsigned indentWidth = 2;
 
 private:
+  void numberValues();
+
   const MLFunction *function;
   int numSpaces;
 };
@@ -837,7 +839,26 @@
 
 MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function,
                                      const ModulePrinter &other)
-    : FunctionPrinter(other), function(function), numSpaces(0) {}
+    : FunctionPrinter(other), function(function), numSpaces(0) {
+  numberValues();
+}
+
+/// Number all of the SSA values in this ML function.
+void MLFunctionPrinter::numberValues() {
+  // Visits all operation statements and numbers the first result.
+  struct NumberValuesPass : public StmtVisitor<NumberValuesPass> {
+    NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {}
+    void visitOperationStmt(OperationStmt *stmt) {
+      if (stmt->getNumResults() != 0)
+        printer->numberValueID(stmt->getResult(0));
+    }
+    MLFunctionPrinter *printer;
+  };
+
+  NumberValuesPass pass(this);
+  // TODO: it'd be cleaner to have constant visitor istead of using const_cast.
+  pass.visit(const_cast<MLFunction *>(function));
+}
 
 void MLFunctionPrinter::print() {
   os << "mlfunc ";
@@ -870,6 +891,7 @@
 }
 
 void MLFunctionPrinter::print(const OperationStmt *stmt) {
+  os.indent(numSpaces);
   printOperation(stmt);
 }
 
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index b06be4d..e598f2a 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/CFGFunction.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Module.h"
+#include "mlir/IR/StmtVisitor.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/StringRef.h"
 using namespace mlir;
@@ -120,6 +121,9 @@
     : Function(name, type, Kind::MLFunc), StmtBlock(StmtBlockKind::MLFunc) {}
 
 MLFunction::~MLFunction() {
-  // TODO: When move SSA stuff is supported.
-  // dropAllReferences();
+  struct DropReferencesPass : public StmtVisitor<DropReferencesPass> {
+    void visitOperationStmt(OperationStmt *stmt) { stmt->dropAllReferences(); }
+  };
+  DropReferencesPass pass;
+  pass.visit(this);
 }
diff --git a/lib/IR/Operation.cpp b/lib/IR/Operation.cpp
index 92a01ba..f2f9eb4 100644
--- a/lib/IR/Operation.cpp
+++ b/lib/IR/Operation.cpp
@@ -39,10 +39,7 @@
   if (auto *inst = dyn_cast<OperationInst>(this)) {
     return inst->getNumOperands();
   } else {
-    auto *stmt = cast<OperationStmt>(this);
-    (void)stmt;
-    // TODO: Add operands to OperationStmt.
-    return 0;
+    return cast<OperationStmt>(this)->getNumOperands();
   }
 }
 
@@ -51,9 +48,7 @@
     return inst->getOperand(idx);
   } else {
     auto *stmt = cast<OperationStmt>(this);
-    (void)stmt;
-    // TODO: Add operands to OperationStmt.
-    abort();
+    return stmt->getOperand(idx);
   }
 }
 
@@ -62,9 +57,7 @@
     inst->setOperand(idx, cast<CFGValue>(value));
   } else {
     auto *stmt = cast<OperationStmt>(this);
-    (void)stmt;
-    // TODO: Add operands to OperationStmt.
-    abort();
+    stmt->setOperand(idx, cast<MLValue>(value));
   }
 }
 
@@ -74,9 +67,7 @@
     return inst->getNumResults();
   } else {
     auto *stmt = cast<OperationStmt>(this);
-    (void)stmt;
-    // TODO: Add results to OperationStmt.
-    return 0;
+    return stmt->getNumResults();
   }
 }
 
@@ -86,9 +77,7 @@
     return inst->getResult(idx);
   } else {
     auto *stmt = cast<OperationStmt>(this);
-    (void)stmt;
-    // TODO: Add operands to OperationStmt.
-    abort();
+    return stmt->getResult(idx);
   }
 }
 
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 6b272e0..4b6ddc7 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -21,6 +21,17 @@
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
+// StmtResult
+//===------------------------------------------------------------------===//
+
+/// Return the result number of this result.
+unsigned StmtResult::getResultNumber() const {
+  // Results are always stored consecutively, so use pointer subtraction to
+  // figure out what number this is.
+  return this - &getOwner()->getStmtResults()[0];
+}
+
+//===----------------------------------------------------------------------===//
 // Statement
 //===------------------------------------------------------------------===//
 
@@ -34,7 +45,7 @@
 void Statement::destroy() {
   switch (this->getKind()) {
   case Kind::Operation:
-    delete cast<OperationStmt>(this);
+    cast<OperationStmt>(this)->destroy();
     break;
   case Kind::For:
     delete cast<ForStmt>(this);
@@ -113,6 +124,73 @@
 }
 
 //===----------------------------------------------------------------------===//
+// OperationStmt
+//===----------------------------------------------------------------------===//
+
+/// Create a new OperationStmt with the specific fields.
+OperationStmt *OperationStmt::create(Identifier name,
+                                     ArrayRef<MLValue *> operands,
+                                     ArrayRef<Type *> resultTypes,
+                                     ArrayRef<NamedAttribute> attributes,
+                                     MLIRContext *context) {
+  auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
+                                                            resultTypes.size());
+  void *rawMem = malloc(byteSize);
+
+  // Initialize the OperationStmt part of the statement.
+  auto stmt = ::new (rawMem) OperationStmt(
+      name, operands.size(), resultTypes.size(), attributes, context);
+
+  // Initialize the operands and results.
+  auto stmtOperands = stmt->getStmtOperands();
+  for (unsigned i = 0, e = operands.size(); i != e; ++i)
+    new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
+
+  auto stmtResults = stmt->getStmtResults();
+  for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
+    new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
+  return stmt;
+}
+
+OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
+                             unsigned numResults,
+                             ArrayRef<NamedAttribute> attributes,
+                             MLIRContext *context)
+    : Operation(name, /*isInstruction=*/false, attributes, context),
+      Statement(Kind::Operation), numOperands(numOperands),
+      numResults(numResults) {}
+
+OperationStmt::~OperationStmt() {
+  // Explicitly run the destructors for the operands and results.
+  for (auto &operand : getStmtOperands())
+    operand.~StmtOperand();
+
+  for (auto &result : getStmtResults())
+    result.~StmtResult();
+}
+
+void OperationStmt::destroy() {
+  this->~OperationStmt();
+  free(this);
+}
+
+/// This drops all operand uses from this statement, which is an essential
+/// step in breaking cyclic dependences between references when they are to
+/// be deleted.
+void OperationStmt::dropAllReferences() {
+  for (auto &op : getStmtOperands())
+    op.drop();
+}
+
+/// If this value is the result of an OperationStmt, return the statement
+/// that defines it.
+OperationStmt *SSAValue::getDefiningStmt() {
+  if (auto *result = dyn_cast<StmtResult>(this))
+    return result->getOwner();
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
 // IfStmt
 //===----------------------------------------------------------------------===//
 
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 362f868..11850ea 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1209,7 +1209,9 @@
 /// notably for dealing with operations and SSA values.
 class FunctionParser : public Parser {
 public:
-  FunctionParser(ParserState &state) : Parser(state) {}
+  enum class Kind { CFGFunc, MLFunc };
+
+  Kind getKind() const { return kind; }
 
   /// After the function is finished parsing, this function checks to see if
   /// there are any remaining issues.
@@ -1257,7 +1259,12 @@
   Operation *parseVerboseOperation(const CreateOperationFunction &createOpFunc);
   Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc);
 
+protected:
+  FunctionParser(ParserState &state, Kind kind) : Parser(state), kind(kind) {}
+
 private:
+  /// Kind indicates if this is CFG or ML function parser.
+  Kind kind;
   /// This keeps track of all of the SSA values we are tracking, indexed by
   /// their name.  This has one entry per result number.
   llvm::StringMap<SmallVector<std::pair<SSAValue *, SMLoc>, 1>> values;
@@ -1290,7 +1297,7 @@
   return inst->getResult(0);
 }
 
-/// Given an unbound reference to an SSA value and its type, return a the value
+/// Given an unbound reference to an SSA value and its type, return the value
 /// it specifies.  This returns null on failure.
 SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) {
   auto &entries = values[useInfo.name];
@@ -1318,7 +1325,13 @@
     return (emitError(useInfo.loc, "reference to invalid result number"),
             nullptr);
 
-  // Otherwise, this is a forward reference.  Create a placeholder and remember
+  if (getKind() == Kind::MLFunc)
+    return (
+        emitError(useInfo.loc, "use of undefined SSA value " + useInfo.name),
+        nullptr);
+
+  // Otherwise, this is a forward reference.  If we are in ML function return
+  // an error. In CFG function, create a placeholder and remember
   // that we did so.
   auto *result = createForwardReferencePlaceholder(useInfo.loc, type);
   entries[useInfo.number].first = result;
@@ -1532,15 +1545,11 @@
 
   // If the instruction had a name, register it.
   if (!resultID.empty()) {
-    // FIXME: Add result infra to handle Stmt results as well to make this
-    // generic.
-    if (auto *inst = dyn_cast<OperationInst>(op)) {
-      if (inst->getNumResults() == 0)
-        return emitError(loc, "cannot name an operation with no results");
+    if (op->getNumResults() == 0)
+      return emitError(loc, "cannot name an operation with no results");
 
-      for (unsigned i = 0, e = inst->getNumResults(); i != e; ++i)
-        addDefinition({resultID, i, loc}, inst->getResult(i));
-    }
+    for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
+      addDefinition({resultID, i, loc}, op->getResult(i));
   }
 
   return ParseSuccess;
@@ -1778,7 +1787,8 @@
 class CFGFunctionParser : public FunctionParser {
 public:
   CFGFunctionParser(ParserState &state, CFGFunction *function)
-      : FunctionParser(state), function(function), builder(function) {}
+      : FunctionParser(state, Kind::CFGFunc), function(function),
+        builder(function) {}
 
   ParseResult parseFunctionBody();
 
@@ -2014,7 +2024,8 @@
 class MLFunctionParser : public FunctionParser {
 public:
   MLFunctionParser(ParserState &state, MLFunction *function)
-      : FunctionParser(state), function(function), builder(function) {}
+      : FunctionParser(state, Kind::MLFunc), function(function),
+        builder(function) {}
 
   ParseResult parseFunctionBody();
 
@@ -2176,7 +2187,11 @@
   auto createOpFunc = [&](Identifier name, ArrayRef<SSAValue *> operands,
                           ArrayRef<Type *> resultTypes,
                           ArrayRef<NamedAttribute> attrs) -> Operation * {
-    return builder.createOperation(name, attrs);
+    SmallVector<MLValue *, 8> stmtOperands;
+    stmtOperands.reserve(operands.size());
+    for (auto *op : operands)
+      stmtOperands.push_back(cast<MLValue>(op));
+    return builder.createOperation(name, stmtOperands, resultTypes, attrs);
   };
 
   builder.setInsertionPoint(block);
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 337f558..70367e9 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -96,7 +96,9 @@
         break;
       case Statement::Kind::Operation:
         auto *op = cast<OperationStmt>(&stmt);
-        builder.createOperation(op->getName(), op->getAttrs());
+        // TODO: clone operands and result types.
+        builder.createOperation(op->getName(), /*operands*/ {},
+                                /*resultTypes*/ {}, op->getAttrs());
         // TODO: loop iterator parsing not yet implemented; replace loop
         // iterator uses in unrolled body appropriately.
         break;