Implement MLValue, statement operands, operation statement operands and values. ML functions now have full support for expressing operations. Induction variables, function arguments and return values are still todo.
PiperOrigin-RevId: 206253643
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 6b272e0..4b6ddc7 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -21,6 +21,17 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
+// StmtResult
+//===------------------------------------------------------------------===//
+
+/// Return the result number of this result.
+unsigned StmtResult::getResultNumber() const {
+ // Results are always stored consecutively, so use pointer subtraction to
+ // figure out what number this is.
+ return this - &getOwner()->getStmtResults()[0];
+}
+
+//===----------------------------------------------------------------------===//
// Statement
//===------------------------------------------------------------------===//
@@ -34,7 +45,7 @@
void Statement::destroy() {
switch (this->getKind()) {
case Kind::Operation:
- delete cast<OperationStmt>(this);
+ cast<OperationStmt>(this)->destroy();
break;
case Kind::For:
delete cast<ForStmt>(this);
@@ -113,6 +124,73 @@
}
//===----------------------------------------------------------------------===//
+// OperationStmt
+//===----------------------------------------------------------------------===//
+
+/// Create a new OperationStmt with the specific fields.
+OperationStmt *OperationStmt::create(Identifier name,
+ ArrayRef<MLValue *> operands,
+ ArrayRef<Type *> resultTypes,
+ ArrayRef<NamedAttribute> attributes,
+ MLIRContext *context) {
+ auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
+ resultTypes.size());
+ void *rawMem = malloc(byteSize);
+
+ // Initialize the OperationStmt part of the statement.
+ auto stmt = ::new (rawMem) OperationStmt(
+ name, operands.size(), resultTypes.size(), attributes, context);
+
+ // Initialize the operands and results.
+ auto stmtOperands = stmt->getStmtOperands();
+ for (unsigned i = 0, e = operands.size(); i != e; ++i)
+ new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
+
+ auto stmtResults = stmt->getStmtResults();
+ for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
+ new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
+ return stmt;
+}
+
+OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
+ unsigned numResults,
+ ArrayRef<NamedAttribute> attributes,
+ MLIRContext *context)
+ : Operation(name, /*isInstruction=*/false, attributes, context),
+ Statement(Kind::Operation), numOperands(numOperands),
+ numResults(numResults) {}
+
+OperationStmt::~OperationStmt() {
+ // Explicitly run the destructors for the operands and results.
+ for (auto &operand : getStmtOperands())
+ operand.~StmtOperand();
+
+ for (auto &result : getStmtResults())
+ result.~StmtResult();
+}
+
+void OperationStmt::destroy() {
+ this->~OperationStmt();
+ free(this);
+}
+
+/// This drops all operand uses from this statement, which is an essential
+/// step in breaking cyclic dependences between references when they are to
+/// be deleted.
+void OperationStmt::dropAllReferences() {
+ for (auto &op : getStmtOperands())
+ op.drop();
+}
+
+/// If this value is the result of an OperationStmt, return the statement
+/// that defines it.
+OperationStmt *SSAValue::getDefiningStmt() {
+ if (auto *result = dyn_cast<StmtResult>(this))
+ return result->getOwner();
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
// IfStmt
//===----------------------------------------------------------------------===//