Push location information more tightly into the IR, providing space for every
operation and statement to have a location, and make it so a location is
required to be specified whenever you make one (though a null location is still
allowed). This is to encourage compiler authors to propagate loc info
properly, allowing our failability story to work well.
This is still a WIP - it isn't clear if we want to continue abusing Attribute
for location information, or whether we should introduce a new class heirarchy
to do so. This is good step along the way, and unblocks some of the tf/xla
work that builds upon it.
PiperOrigin-RevId: 210001406
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 7da08c2..8a6248e 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -16,6 +16,7 @@
// =============================================================================
#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
@@ -58,6 +59,21 @@
}
}
+/// Return the context this operation is associated with.
+MLIRContext *Statement::getContext() const {
+ // Work a bit to avoid calling findFunction() and getting its context.
+ switch (getKind()) {
+ case Kind::Operation:
+ return cast<OperationStmt>(this)->getContext();
+ case Kind::For:
+ return cast<ForStmt>(this)->getType()->getContext();
+ case Kind::If:
+ // TODO(shpeisman): When if statement has value operands, we can get a
+ // context from their type.
+ return findFunction()->getContext();
+ }
+}
+
Statement *Statement::getParentStmt() const {
return block ? block->getParentStmt() : nullptr;
}
@@ -78,6 +94,28 @@
return nlc.numNestedLoops == 1;
}
+/// Emit a note about this statement, reporting up to any diagnostic
+/// handlers that may be listening.
+void Statement::emitNote(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Note);
+}
+
+/// Emit a warning about this statement, reporting up to any diagnostic
+/// handlers that may be listening.
+void Statement::emitWarning(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Warning);
+}
+
+/// Emit an error about fatal conditions with this statement, reporting up to
+/// any diagnostic handlers that may be listening. NOTE: This may terminate
+/// the containing application, only use when the IR is in an inconsistent
+/// state.
+void Statement::emitError(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Error);
+}
//===----------------------------------------------------------------------===//
// ilist_traits for Statement
//===----------------------------------------------------------------------===//
@@ -133,7 +171,7 @@
//===----------------------------------------------------------------------===//
/// Create a new OperationStmt with the specific fields.
-OperationStmt *OperationStmt::create(Identifier name,
+OperationStmt *OperationStmt::create(Attribute *location, Identifier name,
ArrayRef<MLValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<NamedAttribute> attributes,
@@ -144,7 +182,7 @@
// Initialize the OperationStmt part of the statement.
auto stmt = ::new (rawMem) OperationStmt(
- name, operands.size(), resultTypes.size(), attributes, context);
+ location, name, operands.size(), resultTypes.size(), attributes, context);
// Initialize the operands and results.
auto stmtOperands = stmt->getStmtOperands();
@@ -157,12 +195,12 @@
return stmt;
}
-OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
- unsigned numResults,
+OperationStmt::OperationStmt(Attribute *location, Identifier name,
+ unsigned numOperands, unsigned numResults,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context)
- : Operation(name, /*isInstruction=*/false, attributes, context),
- Statement(Kind::Operation), numOperands(numOperands),
+ : Operation(/*isInstruction=*/false, name, attributes, context),
+ Statement(Kind::Operation, location), numOperands(numOperands),
numResults(numResults) {}
OperationStmt::~OperationStmt() {
@@ -197,9 +235,10 @@
// ForStmt
//===----------------------------------------------------------------------===//
-ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
- int64_t step, MLIRContext *context)
- : Statement(Kind::For),
+ForStmt::ForStmt(Attribute *location, AffineConstantExpr *lowerBound,
+ AffineConstantExpr *upperBound, int64_t step,
+ MLIRContext *context)
+ : Statement(Kind::For, location),
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
upperBound(upperBound), step(step) {}
@@ -208,6 +247,10 @@
// IfStmt
//===----------------------------------------------------------------------===//
+IfStmt::IfStmt(Attribute *location, IntegerSet *condition)
+ : Statement(Kind::If, location), thenClause(new IfClause(this)),
+ elseClause(nullptr), condition(condition) {}
+
IfStmt::~IfStmt() {
delete thenClause;
if (elseClause)
@@ -244,8 +287,9 @@
resultTypes.reserve(opStmt->getNumResults());
for (auto *result : opStmt->getResults())
resultTypes.push_back(result->getType());
- auto *newOp = OperationStmt::create(
- opStmt->getName(), operands, resultTypes, opStmt->getAttrs(), context);
+ auto *newOp =
+ OperationStmt::create(getLoc(), opStmt->getName(), operands,
+ resultTypes, opStmt->getAttrs(), context);
// Remember the mapping of any results.
for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
operandMap[opStmt->getResult(i)] = newOp->getResult(i);
@@ -254,8 +298,8 @@
if (auto *forStmt = dyn_cast<ForStmt>(this)) {
auto *newFor =
- new ForStmt(forStmt->getLowerBound(), forStmt->getUpperBound(),
- forStmt->getStep(), context);
+ new ForStmt(getLoc(), forStmt->getLowerBound(),
+ forStmt->getUpperBound(), forStmt->getStep(), context);
// Remember the induction variable mapping.
operandMap[forStmt] = newFor;
@@ -269,7 +313,7 @@
// Otherwise, we must have an If statement.
auto *ifStmt = cast<IfStmt>(this);
- auto *newIf = new IfStmt(ifStmt->getCondition());
+ auto *newIf = new IfStmt(getLoc(), ifStmt->getCondition());
// TODO: remap operands with remapOperand when if statements have them.