Implement MLValue, statement operands, operation statement operands and values. ML functions now have full support for expressing operations. Induction variables, function arguments and return values are still todo.
PiperOrigin-RevId: 206253643
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 4fd7e61..90115ec 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -29,6 +29,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/DenseMap.h"
@@ -579,8 +580,6 @@
} // end anonymous namespace
void FunctionPrinter::printOperation(const Operation *op) {
- os << " ";
-
if (op->getNumResults()) {
printValueID(op->getResult(0), /*dontPrintResultNo*/ true);
os << " = ";
@@ -717,6 +716,7 @@
os << ":\n";
for (auto &inst : block->getOperations()) {
+ os << " ";
print(&inst);
os << '\n';
}
@@ -743,7 +743,7 @@
}
void CFGFunctionPrinter::print(const BranchInst *inst) {
- os << " br bb" << getBBID(inst->getDest());
+ os << "br bb" << getBBID(inst->getDest());
if (inst->getNumOperands() != 0) {
os << '(';
@@ -759,7 +759,7 @@
}
void CFGFunctionPrinter::print(const CondBranchInst *inst) {
- os << " cond_br ";
+ os << "cond_br ";
printValueID(inst->getCondition());
os << ", bb" << getBBID(inst->getTrueDest());
@@ -788,7 +788,7 @@
}
void CFGFunctionPrinter::print(const ReturnInst *inst) {
- os << " return";
+ os << "return";
if (inst->getNumOperands() != 0)
os << ' ';
@@ -830,6 +830,8 @@
const static unsigned indentWidth = 2;
private:
+ void numberValues();
+
const MLFunction *function;
int numSpaces;
};
@@ -837,7 +839,26 @@
MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function,
const ModulePrinter &other)
- : FunctionPrinter(other), function(function), numSpaces(0) {}
+ : FunctionPrinter(other), function(function), numSpaces(0) {
+ numberValues();
+}
+
+/// Number all of the SSA values in this ML function.
+void MLFunctionPrinter::numberValues() {
+ // Visits all operation statements and numbers the first result.
+ struct NumberValuesPass : public StmtVisitor<NumberValuesPass> {
+ NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {}
+ void visitOperationStmt(OperationStmt *stmt) {
+ if (stmt->getNumResults() != 0)
+ printer->numberValueID(stmt->getResult(0));
+ }
+ MLFunctionPrinter *printer;
+ };
+
+ NumberValuesPass pass(this);
+ // TODO: it'd be cleaner to have constant visitor istead of using const_cast.
+ pass.visit(const_cast<MLFunction *>(function));
+}
void MLFunctionPrinter::print() {
os << "mlfunc ";
@@ -870,6 +891,7 @@
}
void MLFunctionPrinter::print(const OperationStmt *stmt) {
+ os.indent(numSpaces);
printOperation(stmt);
}
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index b06be4d..e598f2a 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
+#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/StringRef.h"
using namespace mlir;
@@ -120,6 +121,9 @@
: Function(name, type, Kind::MLFunc), StmtBlock(StmtBlockKind::MLFunc) {}
MLFunction::~MLFunction() {
- // TODO: When move SSA stuff is supported.
- // dropAllReferences();
+ struct DropReferencesPass : public StmtVisitor<DropReferencesPass> {
+ void visitOperationStmt(OperationStmt *stmt) { stmt->dropAllReferences(); }
+ };
+ DropReferencesPass pass;
+ pass.visit(this);
}
diff --git a/lib/IR/Operation.cpp b/lib/IR/Operation.cpp
index 92a01ba..f2f9eb4 100644
--- a/lib/IR/Operation.cpp
+++ b/lib/IR/Operation.cpp
@@ -39,10 +39,7 @@
if (auto *inst = dyn_cast<OperationInst>(this)) {
return inst->getNumOperands();
} else {
- auto *stmt = cast<OperationStmt>(this);
- (void)stmt;
- // TODO: Add operands to OperationStmt.
- return 0;
+ return cast<OperationStmt>(this)->getNumOperands();
}
}
@@ -51,9 +48,7 @@
return inst->getOperand(idx);
} else {
auto *stmt = cast<OperationStmt>(this);
- (void)stmt;
- // TODO: Add operands to OperationStmt.
- abort();
+ return stmt->getOperand(idx);
}
}
@@ -62,9 +57,7 @@
inst->setOperand(idx, cast<CFGValue>(value));
} else {
auto *stmt = cast<OperationStmt>(this);
- (void)stmt;
- // TODO: Add operands to OperationStmt.
- abort();
+ stmt->setOperand(idx, cast<MLValue>(value));
}
}
@@ -74,9 +67,7 @@
return inst->getNumResults();
} else {
auto *stmt = cast<OperationStmt>(this);
- (void)stmt;
- // TODO: Add results to OperationStmt.
- return 0;
+ return stmt->getNumResults();
}
}
@@ -86,9 +77,7 @@
return inst->getResult(idx);
} else {
auto *stmt = cast<OperationStmt>(this);
- (void)stmt;
- // TODO: Add operands to OperationStmt.
- abort();
+ return stmt->getResult(idx);
}
}
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 6b272e0..4b6ddc7 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -21,6 +21,17 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
+// StmtResult
+//===------------------------------------------------------------------===//
+
+/// Return the result number of this result.
+unsigned StmtResult::getResultNumber() const {
+ // Results are always stored consecutively, so use pointer subtraction to
+ // figure out what number this is.
+ return this - &getOwner()->getStmtResults()[0];
+}
+
+//===----------------------------------------------------------------------===//
// Statement
//===------------------------------------------------------------------===//
@@ -34,7 +45,7 @@
void Statement::destroy() {
switch (this->getKind()) {
case Kind::Operation:
- delete cast<OperationStmt>(this);
+ cast<OperationStmt>(this)->destroy();
break;
case Kind::For:
delete cast<ForStmt>(this);
@@ -113,6 +124,73 @@
}
//===----------------------------------------------------------------------===//
+// OperationStmt
+//===----------------------------------------------------------------------===//
+
+/// Create a new OperationStmt with the specific fields.
+OperationStmt *OperationStmt::create(Identifier name,
+ ArrayRef<MLValue *> operands,
+ ArrayRef<Type *> resultTypes,
+ ArrayRef<NamedAttribute> attributes,
+ MLIRContext *context) {
+ auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
+ resultTypes.size());
+ void *rawMem = malloc(byteSize);
+
+ // Initialize the OperationStmt part of the statement.
+ auto stmt = ::new (rawMem) OperationStmt(
+ name, operands.size(), resultTypes.size(), attributes, context);
+
+ // Initialize the operands and results.
+ auto stmtOperands = stmt->getStmtOperands();
+ for (unsigned i = 0, e = operands.size(); i != e; ++i)
+ new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
+
+ auto stmtResults = stmt->getStmtResults();
+ for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
+ new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
+ return stmt;
+}
+
+OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
+ unsigned numResults,
+ ArrayRef<NamedAttribute> attributes,
+ MLIRContext *context)
+ : Operation(name, /*isInstruction=*/false, attributes, context),
+ Statement(Kind::Operation), numOperands(numOperands),
+ numResults(numResults) {}
+
+OperationStmt::~OperationStmt() {
+ // Explicitly run the destructors for the operands and results.
+ for (auto &operand : getStmtOperands())
+ operand.~StmtOperand();
+
+ for (auto &result : getStmtResults())
+ result.~StmtResult();
+}
+
+void OperationStmt::destroy() {
+ this->~OperationStmt();
+ free(this);
+}
+
+/// This drops all operand uses from this statement, which is an essential
+/// step in breaking cyclic dependences between references when they are to
+/// be deleted.
+void OperationStmt::dropAllReferences() {
+ for (auto &op : getStmtOperands())
+ op.drop();
+}
+
+/// If this value is the result of an OperationStmt, return the statement
+/// that defines it.
+OperationStmt *SSAValue::getDefiningStmt() {
+ if (auto *result = dyn_cast<StmtResult>(this))
+ return result->getOwner();
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
// IfStmt
//===----------------------------------------------------------------------===//
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 362f868..11850ea 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1209,7 +1209,9 @@
/// notably for dealing with operations and SSA values.
class FunctionParser : public Parser {
public:
- FunctionParser(ParserState &state) : Parser(state) {}
+ enum class Kind { CFGFunc, MLFunc };
+
+ Kind getKind() const { return kind; }
/// After the function is finished parsing, this function checks to see if
/// there are any remaining issues.
@@ -1257,7 +1259,12 @@
Operation *parseVerboseOperation(const CreateOperationFunction &createOpFunc);
Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc);
+protected:
+ FunctionParser(ParserState &state, Kind kind) : Parser(state), kind(kind) {}
+
private:
+ /// Kind indicates if this is CFG or ML function parser.
+ Kind kind;
/// This keeps track of all of the SSA values we are tracking, indexed by
/// their name. This has one entry per result number.
llvm::StringMap<SmallVector<std::pair<SSAValue *, SMLoc>, 1>> values;
@@ -1290,7 +1297,7 @@
return inst->getResult(0);
}
-/// Given an unbound reference to an SSA value and its type, return a the value
+/// Given an unbound reference to an SSA value and its type, return the value
/// it specifies. This returns null on failure.
SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) {
auto &entries = values[useInfo.name];
@@ -1318,7 +1325,13 @@
return (emitError(useInfo.loc, "reference to invalid result number"),
nullptr);
- // Otherwise, this is a forward reference. Create a placeholder and remember
+ if (getKind() == Kind::MLFunc)
+ return (
+ emitError(useInfo.loc, "use of undefined SSA value " + useInfo.name),
+ nullptr);
+
+ // Otherwise, this is a forward reference. If we are in ML function return
+ // an error. In CFG function, create a placeholder and remember
// that we did so.
auto *result = createForwardReferencePlaceholder(useInfo.loc, type);
entries[useInfo.number].first = result;
@@ -1532,15 +1545,11 @@
// If the instruction had a name, register it.
if (!resultID.empty()) {
- // FIXME: Add result infra to handle Stmt results as well to make this
- // generic.
- if (auto *inst = dyn_cast<OperationInst>(op)) {
- if (inst->getNumResults() == 0)
- return emitError(loc, "cannot name an operation with no results");
+ if (op->getNumResults() == 0)
+ return emitError(loc, "cannot name an operation with no results");
- for (unsigned i = 0, e = inst->getNumResults(); i != e; ++i)
- addDefinition({resultID, i, loc}, inst->getResult(i));
- }
+ for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
+ addDefinition({resultID, i, loc}, op->getResult(i));
}
return ParseSuccess;
@@ -1778,7 +1787,8 @@
class CFGFunctionParser : public FunctionParser {
public:
CFGFunctionParser(ParserState &state, CFGFunction *function)
- : FunctionParser(state), function(function), builder(function) {}
+ : FunctionParser(state, Kind::CFGFunc), function(function),
+ builder(function) {}
ParseResult parseFunctionBody();
@@ -2014,7 +2024,8 @@
class MLFunctionParser : public FunctionParser {
public:
MLFunctionParser(ParserState &state, MLFunction *function)
- : FunctionParser(state), function(function), builder(function) {}
+ : FunctionParser(state, Kind::MLFunc), function(function),
+ builder(function) {}
ParseResult parseFunctionBody();
@@ -2176,7 +2187,11 @@
auto createOpFunc = [&](Identifier name, ArrayRef<SSAValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<NamedAttribute> attrs) -> Operation * {
- return builder.createOperation(name, attrs);
+ SmallVector<MLValue *, 8> stmtOperands;
+ stmtOperands.reserve(operands.size());
+ for (auto *op : operands)
+ stmtOperands.push_back(cast<MLValue>(op));
+ return builder.createOperation(name, stmtOperands, resultTypes, attrs);
};
builder.setInsertionPoint(block);
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 337f558..70367e9 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -96,7 +96,9 @@
break;
case Statement::Kind::Operation:
auto *op = cast<OperationStmt>(&stmt);
- builder.createOperation(op->getName(), op->getAttrs());
+ // TODO: clone operands and result types.
+ builder.createOperation(op->getName(), /*operands*/ {},
+ /*resultTypes*/ {}, op->getAttrs());
// TODO: loop iterator parsing not yet implemented; replace loop
// iterator uses in unrolled body appropriately.
break;