Push location information more tightly into the IR, providing space for every
operation and statement to have a location, and make it so a location is
required to be specified whenever you make one (though a null location is still
allowed). This is to encourage compiler authors to propagate loc info
properly, allowing our failability story to work well.
This is still a WIP - it isn't clear if we want to continue abusing Attribute
for location information, or whether we should introduce a new class heirarchy
to do so. This is good step along the way, and unblocks some of the tf/xla
work that builds upon it.
PiperOrigin-RevId: 210001406
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 95d73a4..48e78c7 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -174,8 +174,8 @@
/// Create operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
- OpPointer<OpTy> create(Args... args) {
- OperationState state(getContext(), OpTy::getOperationName());
+ OpPointer<OpTy> create(Attribute *location, Args... args) {
+ OperationState state(getContext(), location, OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *inst = createOperation(state);
auto result = inst->template getAs<OpTy>();
@@ -191,19 +191,20 @@
// Terminators.
- ReturnInst *createReturnInst(ArrayRef<CFGValue *> operands) {
- return insertTerminator(ReturnInst::create(operands));
+ ReturnInst *createReturnInst(Attribute *location,
+ ArrayRef<CFGValue *> operands) {
+ return insertTerminator(ReturnInst::create(location, operands));
}
- BranchInst *createBranchInst(BasicBlock *dest) {
- return insertTerminator(BranchInst::create(dest));
+ BranchInst *createBranchInst(Attribute *location, BasicBlock *dest) {
+ return insertTerminator(BranchInst::create(location, dest));
}
- CondBranchInst *createCondBranchInst(CFGValue *condition,
+ CondBranchInst *createCondBranchInst(Attribute *location, CFGValue *condition,
BasicBlock *trueDest,
BasicBlock *falseDest) {
return insertTerminator(
- CondBranchInst::create(condition, trueDest, falseDest));
+ CondBranchInst::create(location, condition, trueDest, falseDest));
}
private:
@@ -280,8 +281,8 @@
/// Create operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
- OpPointer<OpTy> create(Args... args) {
- OperationState state(getContext(), OpTy::getOperationName());
+ OpPointer<OpTy> create(Attribute *location, Args... args) {
+ OperationState state(getContext(), location, OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *stmt = createOperation(state);
auto result = stmt->template getAs<OpTy>();
@@ -302,14 +303,10 @@
}
// Creates for statement. When step is not specified, it is set to 1.
- ForStmt *createFor(AffineConstantExpr *lowerBound,
+ ForStmt *createFor(Attribute *location, AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound, int64_t step = 1);
- IfStmt *createIf(IntegerSet *condition) {
- auto *stmt = new IfStmt(condition);
- block->getStatements().insert(insertPoint, stmt);
- return stmt;
- }
+ IfStmt *createIf(Attribute *location, IntegerSet *condition);
private:
StmtBlock *block = nullptr;
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index 1eac36c..da8033a 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -49,6 +49,10 @@
/// Return the context this operation is associated with.
MLIRContext *getContext() const;
+ /// The source location the operation was defined or derived from. Note that
+ /// it is possible for this pointer to be null.
+ Attribute *getLoc() const { return location; }
+
/// Return the BasicBlock containing this instruction.
BasicBlock *getBlock() const { return block; }
@@ -119,8 +123,23 @@
/// be deleted.
void dropAllReferences();
+ /// Emit an error about fatal conditions with this operation, reporting up to
+ /// any diagnostic handlers that may be listening. NOTE: This may terminate
+ /// the containing application, only use when the IR is in an inconsistent
+ /// state.
+ void emitError(const Twine &message) const;
+
+ /// Emit a warning about this operation, reporting up to any diagnostic
+ /// handlers that may be listening.
+ void emitWarning(const Twine &message) const;
+
+ /// Emit a note about this operation, reporting up to any diagnostic
+ /// handlers that may be listening.
+ void emitNote(const Twine &message) const;
+
protected:
- Instruction(Kind kind) : kind(kind) {}
+ Instruction(Kind kind, Attribute *location)
+ : kind(kind), location(location) {}
// Instructions are deleted through the destroy() member because this class
// does not have a virtual destructor. A vtable would bloat the size of
@@ -132,6 +151,10 @@
Kind kind;
BasicBlock *block = nullptr;
+ /// This holds information about the source location the operation was defined
+ /// or derived from.
+ Attribute *location;
+
friend struct llvm::ilist_traits<OperationInst>;
friend class BasicBlock;
};
@@ -150,13 +173,17 @@
private llvm::TrailingObjects<OperationInst, InstOperand, InstResult> {
public:
/// Create a new OperationInst with the specific fields.
- static OperationInst *create(Identifier name, ArrayRef<CFGValue *> operands,
+ static OperationInst *create(Attribute *location, Identifier name,
+ ArrayRef<CFGValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context);
- /// Return the context this operation is associated with.
- MLIRContext *getContext() const { return Instruction::getContext(); }
+ using Instruction::emitError;
+ using Instruction::emitNote;
+ using Instruction::emitWarning;
+ using Instruction::getContext;
+ using Instruction::getLoc;
OperationInst *clone() const;
@@ -278,8 +305,9 @@
private:
const unsigned numOperands, numResults;
- OperationInst(Identifier name, unsigned numOperands, unsigned numResults,
- ArrayRef<NamedAttribute> attributes, MLIRContext *context);
+ OperationInst(Attribute *location, Identifier name, unsigned numOperands,
+ unsigned numResults, ArrayRef<NamedAttribute> attributes,
+ MLIRContext *context);
~OperationInst();
// This stuff is used by the TrailingObjects template.
@@ -323,7 +351,8 @@
}
protected:
- TerminatorInst(Kind kind) : Instruction(kind) {}
+ TerminatorInst(Kind kind, Attribute *location)
+ : Instruction(kind, location) {}
~TerminatorInst() {}
};
@@ -331,7 +360,9 @@
/// and may pass basic block arguments to the successor.
class BranchInst : public TerminatorInst {
public:
- static BranchInst *create(BasicBlock *dest) { return new BranchInst(dest); }
+ static BranchInst *create(Attribute *location, BasicBlock *dest) {
+ return new BranchInst(location, dest);
+ }
~BranchInst() {}
/// Return the block this branch jumps to.
@@ -361,7 +392,7 @@
}
private:
- explicit BranchInst(BasicBlock *dest);
+ explicit BranchInst(Attribute *location, BasicBlock *dest);
BasicBlockOperand dest;
std::vector<InstOperand> operands;
@@ -375,9 +406,9 @@
enum { trueIndex = 0, falseIndex = 1 };
public:
- static CondBranchInst *create(CFGValue *condition, BasicBlock *trueDest,
- BasicBlock *falseDest) {
- return new CondBranchInst(condition, trueDest, falseDest);
+ static CondBranchInst *create(Attribute *location, CFGValue *condition,
+ BasicBlock *trueDest, BasicBlock *falseDest) {
+ return new CondBranchInst(location, condition, trueDest, falseDest);
}
~CondBranchInst() {}
@@ -521,7 +552,7 @@
}
private:
- CondBranchInst(CFGValue *condition, BasicBlock *trueDest,
+ CondBranchInst(Attribute *location, CFGValue *condition, BasicBlock *trueDest,
BasicBlock *falseDest);
CFGValue *condition;
@@ -541,7 +572,7 @@
private llvm::TrailingObjects<ReturnInst, InstOperand> {
public:
/// Create a new ReturnInst with the specific fields.
- static ReturnInst *create(ArrayRef<CFGValue *> operands);
+ static ReturnInst *create(Attribute *location, ArrayRef<CFGValue *> operands);
unsigned getNumOperands() const { return numOperands; }
@@ -566,8 +597,7 @@
return numOperands;
}
- explicit ReturnInst(unsigned numOperands)
- : TerminatorInst(Kind::Return), numOperands(numOperands) {}
+ ReturnInst(Attribute *location, unsigned numOperands);
~ReturnInst();
unsigned numOperands;
diff --git a/include/mlir/IR/Operation.h b/include/mlir/IR/Operation.h
index 2e44cd5..158362e 100644
--- a/include/mlir/IR/Operation.h
+++ b/include/mlir/IR/Operation.h
@@ -46,6 +46,7 @@
/// this in a collection.
struct OperationState {
MLIRContext *const context;
+ Attribute *location;
Identifier name;
SmallVector<SSAValue *, 4> operands;
/// Types of the results of this operation.
@@ -53,16 +54,18 @@
SmallVector<NamedAttribute, 4> attributes;
public:
- OperationState(MLIRContext *context, StringRef name)
- : context(context), name(Identifier::get(name, context)) {}
+ OperationState(MLIRContext *context, Attribute *location, StringRef name)
+ : context(context), location(location),
+ name(Identifier::get(name, context)) {}
- OperationState(MLIRContext *context, Identifier name)
- : context(context), name(name) {}
+ OperationState(MLIRContext *context, Attribute *location, Identifier name)
+ : context(context), location(location), name(name) {}
- OperationState(MLIRContext *context, StringRef name,
+ OperationState(MLIRContext *context, Attribute *location, StringRef name,
ArrayRef<SSAValue *> operands, ArrayRef<Type *> types,
ArrayRef<NamedAttribute> attributes = {})
- : context(context), name(Identifier::get(name, context)),
+ : context(context), location(location),
+ name(Identifier::get(name, context)),
operands(operands.begin(), operands.end()),
types(types.begin(), types.end()),
attributes(attributes.begin(), attributes.end()) {}
@@ -89,6 +92,10 @@
/// Return the context this operation is associated with.
MLIRContext *getContext() const;
+ /// The source location the operation was defined or derived from. Note that
+ /// it is possible for this pointer to be null.
+ Attribute *getLoc() const;
+
/// Return the function this operation is defined in. This has a verbose
/// name to avoid name lookup ambiguities.
Function *getOperationFunction();
@@ -236,7 +243,7 @@
}
protected:
- Operation(Identifier name, bool isInstruction, ArrayRef<NamedAttribute> attrs,
+ Operation(bool isInstruction, Identifier name, ArrayRef<NamedAttribute> attrs,
MLIRContext *context);
~Operation();
@@ -247,6 +254,8 @@
/// This holds the name of the operation, and a bool. The bool is true if
/// this operation is an OperationInst, false if it is a OperationStmt.
llvm::PointerIntPair<Identifier, 1, bool> nameAndIsInstruction;
+
+ /// This holds general named attributes for the operation.
AttributeListStorage *attrs;
};
diff --git a/include/mlir/IR/Statement.h b/include/mlir/IR/Statement.h
index a4c7e6b..9669849 100644
--- a/include/mlir/IR/Statement.h
+++ b/include/mlir/IR/Statement.h
@@ -27,6 +27,7 @@
#include "llvm/ADT/ilist_node.h"
namespace mlir {
+class Attribute;
class MLFunction;
class StmtBlock;
class ForStmt;
@@ -46,6 +47,14 @@
};
Kind getKind() const { return kind; }
+
+ /// Return the context this operation is associated with.
+ MLIRContext *getContext() const;
+
+ /// The source location the operation was defined or derived from. Note that
+ /// it is possible for this pointer to be null.
+ Attribute *getLoc() const { return location; }
+
/// Remove this statement from its block and delete it.
void eraseFromBlock();
@@ -82,8 +91,22 @@
void print(raw_ostream &os) const;
void dump() const;
+ /// Emit an error about fatal conditions with this operation, reporting up to
+ /// any diagnostic handlers that may be listening. NOTE: This may terminate
+ /// the containing application, only use when the IR is in an inconsistent
+ /// state.
+ void emitError(const Twine &message) const;
+
+ /// Emit a warning about this operation, reporting up to any diagnostic
+ /// handlers that may be listening.
+ void emitWarning(const Twine &message) const;
+
+ /// Emit a note about this operation, reporting up to any diagnostic
+ /// handlers that may be listening.
+ void emitNote(const Twine &message) const;
+
protected:
- Statement(Kind kind) : kind(kind) {}
+ Statement(Kind kind, Attribute *location) : kind(kind), location(location) {}
// Statements are deleted through the destroy() member because this class
// does not have a virtual destructor.
~Statement();
@@ -93,6 +116,10 @@
/// The statement block that containts this statement.
StmtBlock *block = nullptr;
+ /// This holds information about the source location the operation was defined
+ /// or derived from.
+ Attribute *location;
+
// allow ilist_traits access to 'block' field.
friend struct llvm::ilist_traits<Statement>;
};
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index 1a68aab..47b2427 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -40,7 +40,8 @@
private llvm::TrailingObjects<OperationStmt, StmtOperand, StmtResult> {
public:
/// Create a new OperationStmt with the specific fields.
- static OperationStmt *create(Identifier name, ArrayRef<MLValue *> operands,
+ static OperationStmt *create(Attribute *location, Identifier name,
+ ArrayRef<MLValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context);
@@ -48,6 +49,11 @@
/// Return the context this operation is associated with.
MLIRContext *getContext() const;
+ using Statement::emitError;
+ using Statement::emitNote;
+ using Statement::emitWarning;
+ using Statement::getLoc;
+
/// Check if this statement is a return statement.
bool isReturn() const { return getName().str() == "return"; }
@@ -179,8 +185,9 @@
private:
const unsigned numOperands, numResults;
- OperationStmt(Identifier name, unsigned numOperands, unsigned numResults,
- ArrayRef<NamedAttribute> attributes, MLIRContext *context);
+ OperationStmt(Attribute *location, Identifier name, unsigned numOperands,
+ unsigned numResults, ArrayRef<NamedAttribute> attributes,
+ MLIRContext *context);
~OperationStmt();
// This stuff is used by the TrailingObjects template.
@@ -198,7 +205,7 @@
public:
// TODO: lower and upper bounds should be affine maps with
// dimension and symbol use lists.
- explicit ForStmt(AffineConstantExpr *lowerBound,
+ explicit ForStmt(Attribute *location, AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound, int64_t step,
MLIRContext *context);
@@ -275,10 +282,7 @@
/// If statement restricts execution to a subset of the loop iteration space.
class IfStmt : public Statement {
public:
- explicit IfStmt(IntegerSet *condition)
- : Statement(Kind::If), thenClause(new IfClause(this)),
- elseClause(nullptr), condition(condition) {}
-
+ explicit IfStmt(Attribute *location, IntegerSet *condition);
~IfStmt();
IfClause *getThen() const { return thenClause; }
@@ -296,6 +300,8 @@
}
private:
+ // TODO: The 'If' always has an associated 'theClause', we should be able to
+ // store the IfClause object for it inline to save an extra allocation.
IfClause *thenClause;
IfClause *elseClause;
// TODO(shpeisman): please name the ifStmt's conditional encapsulating
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 73beea6..bd78929 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -182,8 +182,8 @@
for (auto elt : state.operands)
operands.push_back(cast<CFGValue>(elt));
- auto *op = OperationInst::create(state.name, operands, state.types,
- state.attributes, context);
+ auto *op = OperationInst::create(state.location, state.name, operands,
+ state.types, state.attributes, context);
block->getOperations().insert(insertPoint, op);
return op;
}
@@ -199,16 +199,23 @@
for (auto elt : state.operands)
operands.push_back(cast<MLValue>(elt));
- auto *op = OperationStmt::create(state.name, operands, state.types,
- state.attributes, context);
+ auto *op = OperationStmt::create(state.location, state.name, operands,
+ state.types, state.attributes, context);
block->getStatements().insert(insertPoint, op);
return op;
}
-ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound,
+ForStmt *MLFuncBuilder::createFor(Attribute *location,
+ AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound,
int64_t step) {
- auto *stmt = new ForStmt(lowerBound, upperBound, step, context);
+ auto *stmt = new ForStmt(location, lowerBound, upperBound, step, context);
+ block->getStatements().insert(insertPoint, stmt);
+ return stmt;
+}
+
+IfStmt *MLFuncBuilder::createIf(Attribute *location, IntegerSet *condition) {
+ auto *stmt = new IfStmt(location, condition);
block->getStatements().insert(insertPoint, stmt);
return stmt;
}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 8f64066..7378b7a 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/Instructions.h"
#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/MLIRContext.h"
using namespace mlir;
/// Replace all uses of 'this' value with the new value, updating anything in
@@ -118,12 +119,35 @@
dest.drop();
}
+/// Emit a note about this instruction, reporting up to any diagnostic
+/// handlers that may be listening.
+void Instruction::emitNote(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Note);
+}
+
+/// Emit a warning about this operation, reporting up to any diagnostic
+/// handlers that may be listening.
+void Instruction::emitWarning(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Warning);
+}
+
+/// Emit an error about fatal conditions with this instruction, reporting up to
+/// any diagnostic handlers that may be listening. NOTE: This may terminate
+/// the containing application, only use when the IR is in an inconsistent
+/// state.
+void Instruction::emitError(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Error);
+}
+
//===----------------------------------------------------------------------===//
// OperationInst
//===----------------------------------------------------------------------===//
/// Create a new OperationInst with the specific fields.
-OperationInst *OperationInst::create(Identifier name,
+OperationInst *OperationInst::create(Attribute *location, Identifier name,
ArrayRef<CFGValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<NamedAttribute> attributes,
@@ -134,7 +158,7 @@
// Initialize the OperationInst part of the instruction.
auto inst = ::new (rawMem) OperationInst(
- name, operands.size(), resultTypes.size(), attributes, context);
+ location, name, operands.size(), resultTypes.size(), attributes, context);
// Initialize the operands and results.
auto instOperands = inst->getInstOperands();
@@ -151,23 +175,23 @@
SmallVector<CFGValue *, 8> operands;
SmallVector<Type *, 8> resultTypes;
- // TODO(clattner): switch to iterator logic.
// Put together the operands and results.
- for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
- operands.push_back(getInstOperand(i).get());
+ for (auto *operand : getOperands())
+ operands.push_back(const_cast<CFGValue *>(operand));
- for (unsigned i = 0, e = getNumResults(); i != e; ++i)
- resultTypes.push_back(getInstResult(i).getType());
+ for (auto *result : getResults())
+ resultTypes.push_back(result->getType());
- return create(getName(), operands, resultTypes, getAttrs(), getContext());
+ return create(getLoc(), getName(), operands, resultTypes, getAttrs(),
+ getContext());
}
-OperationInst::OperationInst(Identifier name, unsigned numOperands,
- unsigned numResults,
+OperationInst::OperationInst(Attribute *location, Identifier name,
+ unsigned numOperands, unsigned numResults,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context)
- : Operation(name, /*isInstruction=*/true, attributes, context),
- Instruction(Kind::Operation), numOperands(numOperands),
+ : Operation(/*isInstruction=*/true, name, attributes, context),
+ Instruction(Kind::Operation, location), numOperands(numOperands),
numResults(numResults) {}
OperationInst::~OperationInst() {
@@ -257,12 +281,13 @@
//===----------------------------------------------------------------------===//
/// Create a new OperationInst with the specific fields.
-ReturnInst *ReturnInst::create(ArrayRef<CFGValue *> operands) {
+ReturnInst *ReturnInst::create(Attribute *location,
+ ArrayRef<CFGValue *> operands) {
auto byteSize = totalSizeToAlloc<InstOperand>(operands.size());
void *rawMem = malloc(byteSize);
// Initialize the ReturnInst part of the instruction.
- auto inst = ::new (rawMem) ReturnInst(operands.size());
+ auto inst = ::new (rawMem) ReturnInst(location, operands.size());
// Initialize the operands and results.
auto instOperands = inst->getInstOperands();
@@ -271,6 +296,9 @@
return inst;
}
+ReturnInst::ReturnInst(Attribute *location, unsigned numOperands)
+ : TerminatorInst(Kind::Return, location), numOperands(numOperands) {}
+
void ReturnInst::destroy() {
this->~ReturnInst();
free(this);
@@ -286,8 +314,8 @@
// BranchInst
//===----------------------------------------------------------------------===//
-BranchInst::BranchInst(BasicBlock *dest)
- : TerminatorInst(Kind::Branch), dest(this, dest) {}
+BranchInst::BranchInst(Attribute *location, BasicBlock *dest)
+ : TerminatorInst(Kind::Branch, location), dest(this, dest) {}
void BranchInst::setDest(BasicBlock *block) { dest.set(block); }
@@ -307,9 +335,9 @@
// CondBranchInst
//===----------------------------------------------------------------------===//
-CondBranchInst::CondBranchInst(CFGValue *condition, BasicBlock *trueDest,
- BasicBlock *falseDest)
- : TerminatorInst(Kind::CondBranch),
+CondBranchInst::CondBranchInst(Attribute *location, CFGValue *condition,
+ BasicBlock *trueDest, BasicBlock *falseDest)
+ : TerminatorInst(Kind::CondBranch, location),
condition(condition), dests{{this}, {this}}, numTrueOperands(0) {
dests[falseIndex].set(falseDest);
dests[trueIndex].set(trueDest);
diff --git a/lib/IR/Operation.cpp b/lib/IR/Operation.cpp
index aed6fcc..71167bd 100644
--- a/lib/IR/Operation.cpp
+++ b/lib/IR/Operation.cpp
@@ -24,7 +24,7 @@
#include "mlir/IR/Statements.h"
using namespace mlir;
-Operation::Operation(Identifier name, bool isInstruction,
+Operation::Operation(bool isInstruction, Identifier name,
ArrayRef<NamedAttribute> attrs, MLIRContext *context)
: nameAndIsInstruction(name, isInstruction) {
this->attrs = AttributeListStorage::get(attrs, context);
@@ -44,6 +44,14 @@
return cast<OperationStmt>(this)->getContext();
}
+/// The source location the operation was defined or derived from. Note that
+/// it is possible for this pointer to be null.
+Attribute *Operation::getLoc() const {
+ if (auto *inst = dyn_cast<OperationInst>(this))
+ return inst->getLoc();
+ return cast<OperationStmt>(this)->getLoc();
+}
+
/// Return the function this operation is defined in.
Function *Operation::getOperationFunction() {
if (auto *inst = dyn_cast<OperationInst>(this))
@@ -139,14 +147,14 @@
/// Emit a note about this operation, reporting up to any diagnostic
/// handlers that may be listening.
void Operation::emitNote(const Twine &message) const {
- getContext()->emitDiagnostic(getAttr(":location"), message,
+ getContext()->emitDiagnostic(getLoc(), message,
MLIRContext::DiagnosticKind::Note);
}
/// Emit a warning about this operation, reporting up to any diagnostic
/// handlers that may be listening.
void Operation::emitWarning(const Twine &message) const {
- getContext()->emitDiagnostic(getAttr(":location"), message,
+ getContext()->emitDiagnostic(getLoc(), message,
MLIRContext::DiagnosticKind::Warning);
}
@@ -155,6 +163,6 @@
/// the containing application, only use when the IR is in an inconsistent
/// state.
void Operation::emitError(const Twine &message) const {
- getContext()->emitDiagnostic(getAttr(":location"), message,
+ getContext()->emitDiagnostic(getLoc(), message,
MLIRContext::DiagnosticKind::Error);
}
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 7da08c2..8a6248e 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -16,6 +16,7 @@
// =============================================================================
#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
@@ -58,6 +59,21 @@
}
}
+/// Return the context this operation is associated with.
+MLIRContext *Statement::getContext() const {
+ // Work a bit to avoid calling findFunction() and getting its context.
+ switch (getKind()) {
+ case Kind::Operation:
+ return cast<OperationStmt>(this)->getContext();
+ case Kind::For:
+ return cast<ForStmt>(this)->getType()->getContext();
+ case Kind::If:
+ // TODO(shpeisman): When if statement has value operands, we can get a
+ // context from their type.
+ return findFunction()->getContext();
+ }
+}
+
Statement *Statement::getParentStmt() const {
return block ? block->getParentStmt() : nullptr;
}
@@ -78,6 +94,28 @@
return nlc.numNestedLoops == 1;
}
+/// Emit a note about this statement, reporting up to any diagnostic
+/// handlers that may be listening.
+void Statement::emitNote(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Note);
+}
+
+/// Emit a warning about this statement, reporting up to any diagnostic
+/// handlers that may be listening.
+void Statement::emitWarning(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Warning);
+}
+
+/// Emit an error about fatal conditions with this statement, reporting up to
+/// any diagnostic handlers that may be listening. NOTE: This may terminate
+/// the containing application, only use when the IR is in an inconsistent
+/// state.
+void Statement::emitError(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Error);
+}
//===----------------------------------------------------------------------===//
// ilist_traits for Statement
//===----------------------------------------------------------------------===//
@@ -133,7 +171,7 @@
//===----------------------------------------------------------------------===//
/// Create a new OperationStmt with the specific fields.
-OperationStmt *OperationStmt::create(Identifier name,
+OperationStmt *OperationStmt::create(Attribute *location, Identifier name,
ArrayRef<MLValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<NamedAttribute> attributes,
@@ -144,7 +182,7 @@
// Initialize the OperationStmt part of the statement.
auto stmt = ::new (rawMem) OperationStmt(
- name, operands.size(), resultTypes.size(), attributes, context);
+ location, name, operands.size(), resultTypes.size(), attributes, context);
// Initialize the operands and results.
auto stmtOperands = stmt->getStmtOperands();
@@ -157,12 +195,12 @@
return stmt;
}
-OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
- unsigned numResults,
+OperationStmt::OperationStmt(Attribute *location, Identifier name,
+ unsigned numOperands, unsigned numResults,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context)
- : Operation(name, /*isInstruction=*/false, attributes, context),
- Statement(Kind::Operation), numOperands(numOperands),
+ : Operation(/*isInstruction=*/false, name, attributes, context),
+ Statement(Kind::Operation, location), numOperands(numOperands),
numResults(numResults) {}
OperationStmt::~OperationStmt() {
@@ -197,9 +235,10 @@
// ForStmt
//===----------------------------------------------------------------------===//
-ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
- int64_t step, MLIRContext *context)
- : Statement(Kind::For),
+ForStmt::ForStmt(Attribute *location, AffineConstantExpr *lowerBound,
+ AffineConstantExpr *upperBound, int64_t step,
+ MLIRContext *context)
+ : Statement(Kind::For, location),
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
upperBound(upperBound), step(step) {}
@@ -208,6 +247,10 @@
// IfStmt
//===----------------------------------------------------------------------===//
+IfStmt::IfStmt(Attribute *location, IntegerSet *condition)
+ : Statement(Kind::If, location), thenClause(new IfClause(this)),
+ elseClause(nullptr), condition(condition) {}
+
IfStmt::~IfStmt() {
delete thenClause;
if (elseClause)
@@ -244,8 +287,9 @@
resultTypes.reserve(opStmt->getNumResults());
for (auto *result : opStmt->getResults())
resultTypes.push_back(result->getType());
- auto *newOp = OperationStmt::create(
- opStmt->getName(), operands, resultTypes, opStmt->getAttrs(), context);
+ auto *newOp =
+ OperationStmt::create(getLoc(), opStmt->getName(), operands,
+ resultTypes, opStmt->getAttrs(), context);
// Remember the mapping of any results.
for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
operandMap[opStmt->getResult(i)] = newOp->getResult(i);
@@ -254,8 +298,8 @@
if (auto *forStmt = dyn_cast<ForStmt>(this)) {
auto *newFor =
- new ForStmt(forStmt->getLowerBound(), forStmt->getUpperBound(),
- forStmt->getStep(), context);
+ new ForStmt(getLoc(), forStmt->getLowerBound(),
+ forStmt->getUpperBound(), forStmt->getStep(), context);
// Remember the induction variable mapping.
operandMap[forStmt] = newFor;
@@ -269,7 +313,7 @@
// Otherwise, we must have an If statement.
auto *ifStmt = cast<IfStmt>(this);
- auto *newIf = new IfStmt(ifStmt->getCondition());
+ auto *newIf = new IfStmt(getLoc(), ifStmt->getCondition());
// TODO: remap operands with remapOperand when if statements have them.
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index de82e45..829769e 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -120,6 +120,10 @@
const Token &getToken() const { return state.curToken; }
StringRef getTokenSpelling() const { return state.curToken.getSpelling(); }
+ /// Encode the specified source location information into an attribute for
+ /// attachment to the IR.
+ Attribute *getEncodedSourceLocation(llvm::SMLoc loc);
+
/// Emit an error and return failure.
ParseResult emitError(const Twine &message) {
return emitError(state.curToken.getLoc(), message);
@@ -203,6 +207,20 @@
// Helper methods.
//===----------------------------------------------------------------------===//
+/// Encode the specified source location information into an attribute for
+/// attachment to the IR.
+Attribute *Parser::getEncodedSourceLocation(llvm::SMLoc loc) {
+ // TODO(clattner): Switch to an more structured form that includes
+ // file/line/column instead of just byte offset in the file. This will
+ // eliminate this block of low level code poking at the SourceMgr directly.
+ auto &sourceMgr = getSourceMgr();
+ auto fileID = sourceMgr.FindBufferContainingLoc(loc);
+
+ auto *srcBuffer = sourceMgr.getMemoryBuffer(fileID);
+ unsigned locationEncoding = loc.getPointer() - srcBuffer->getBufferStart();
+ return builder.getIntegerAttr(locationEncoding);
+}
+
ParseResult Parser::emitError(SMLoc loc, const Twine &message) {
// If we hit a parse error in response to a lexer error, then the lexer
// already reported the error.
@@ -1367,8 +1385,11 @@
// We create these placeholders as having an empty name, which we know cannot
// be created through normal user input, allowing us to distinguish them.
auto name = Identifier::get("placeholder", getContext());
- auto *inst = OperationInst::create(name, /*operands*/ {}, type, /*attrs*/ {},
- getContext());
+ auto *inst =
+ // FIXME(clattner): encode the location into the placeholder instead of
+ // into the forwardReferencePlaceholders map!
+ OperationInst::create(/*location=*/nullptr, name, /*operands=*/{}, type,
+ /*attrs=*/{}, getContext());
forwardReferencePlaceholders[inst->getResult(0)] = loc;
return inst->getResult(0);
}
@@ -1606,19 +1627,6 @@
if (!op)
return ParseFailure;
- // Apply location information to the instruction.
- // TODO(clattner): make this more principled. We shouldn't overwrite existing
- // location info, we should use a better serialized form, and we shouldn't
- // be using the :location attribute. This is also pretty inefficient.
- {
- auto &sourceMgr = getSourceMgr();
- auto fileID = sourceMgr.FindBufferContainingLoc(loc);
- auto *srcBuffer = sourceMgr.getMemoryBuffer(fileID);
- unsigned locationEncoding = loc.getPointer() - srcBuffer->getBufferStart();
- op->setAttr(builder.getIdentifier(":location"),
- builder.getIntegerAttr(locationEncoding));
- }
-
// We just parsed an operation. If it is a recognized one, verify that it
// is structurally as we expect. If not, produce an error with a reasonable
// source location.
@@ -1642,13 +1650,17 @@
Operation *FunctionParser::parseVerboseOperation(
const CreateOperationFunction &createOpFunc) {
+
+ // Get location information for the operation.
+ auto *srcLocation = getEncodedSourceLocation(getToken().getLoc());
+
auto name = getToken().getStringValue();
if (name.empty())
return (emitError("empty operation name is invalid"), nullptr);
consumeToken(Token::string);
- OperationState result(builder.getContext(), name);
+ OperationState result(builder.getContext(), srcLocation, name);
// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;
@@ -1915,8 +1927,11 @@
llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'",
opNameStr.c_str());
+ // Get location information for the operation.
+ auto *srcLocation = getEncodedSourceLocation(opLoc);
+
// Have the op implementation take a crack and parsing this.
- OperationState opState(builder.getContext(), opName);
+ OperationState opState(builder.getContext(), srcLocation, opName);
if (opDefinition->parseAssembly(&opAsmParser, &opState))
return nullptr;
@@ -2100,6 +2115,8 @@
/// terminator-stmt ::= `return` ssa-use-and-type-list?
///
TerminatorInst *CFGFunctionParser::parseTerminator() {
+ auto loc = getToken().getLoc();
+
switch (getToken().getKind()) {
default:
return (emitError("expected terminator at end of basic block"), nullptr);
@@ -2111,7 +2128,7 @@
SmallVector<CFGValue *, 8> operands;
if (parseOptionalSSAUseAndTypeList(operands, /*isParenthesized*/ false))
return nullptr;
- return builder.createReturnInst(operands);
+ return builder.createReturnInst(getEncodedSourceLocation(loc), operands);
}
case Token::kw_br: {
@@ -2120,7 +2137,8 @@
SmallVector<CFGValue *, 4> values;
if (parseBranchBlockAndUseList(destBB, values))
return nullptr;
- auto branch = builder.createBranchInst(destBB);
+ auto branch =
+ builder.createBranchInst(getEncodedSourceLocation(loc), destBB);
branch->addOperands(values);
return branch;
}
@@ -2149,7 +2167,8 @@
if (parseBranchBlockAndUseList(falseBlock, falseOperands))
return nullptr;
- auto branch = builder.createCondBranchInst(cast<CFGValue>(cond), trueBlock,
+ auto branch = builder.createCondBranchInst(getEncodedSourceLocation(loc),
+ cast<CFGValue>(cond), trueBlock,
falseBlock);
branch->addTrueOperands(trueOperands);
branch->addFalseOperands(falseOperands);
@@ -2240,7 +2259,8 @@
}
// Create for statement.
- ForStmt *forStmt = builder.createFor(lowerBound, upperBound, step);
+ ForStmt *forStmt = builder.createFor(getEncodedSourceLocation(loc),
+ lowerBound, upperBound, step);
// Create SSA value definition for the induction variable.
if (addDefinition({inductionVariableName, 0, loc}, forStmt))
@@ -2394,6 +2414,7 @@
/// | ml-if-head `else` `{` ml-stmt* `}`
///
ParseResult MLFunctionParser::parseIfStmt() {
+ auto loc = getToken().getLoc();
consumeToken(Token::kw_if);
if (parseToken(Token::l_paren, "expected '('"))
@@ -2406,7 +2427,7 @@
if (parseToken(Token::r_paren, "expected ')'"))
return ParseFailure;
- IfStmt *ifStmt = builder.createIf(condition);
+ IfStmt *ifStmt = builder.createIf(getEncodedSourceLocation(loc), condition);
IfClause *thenClause = ifStmt->getThen();
// When parsing of an if statement body fails, the IR contains
diff --git a/lib/Transforms/ConvertToCFG.cpp b/lib/Transforms/ConvertToCFG.cpp
index 1bac43f..50cf7d7 100644
--- a/lib/Transforms/ConvertToCFG.cpp
+++ b/lib/Transforms/ConvertToCFG.cpp
@@ -51,7 +51,7 @@
// Creates return instruction with no operands.
// TODO: convert return operands.
- builder.createReturnInst({});
+ builder.createReturnInst(mlFunc->getReturnStmt()->getLoc(), {});
// TODO: convert ML function body.
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index a4a11a7..ffb2947 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -196,7 +196,8 @@
// value and add an operand mapping for it.
if (!forStmt->use_empty()) {
auto *ivConst =
- funcTopBuilder.create<ConstantAffineIntOp>(i)->getResult();
+ funcTopBuilder.create<ConstantAffineIntOp>(forStmt->getLoc(), i)
+ ->getResult();
operandMapping[forStmt] = cast<MLValue>(ivConst);
}
@@ -261,7 +262,8 @@
builder.getConstantExpr(i * step));
auto *bumpMap = builder.getAffineMap(1, 0, {bumpExpr}, {});
auto *ivUnroll =
- builder.create<AffineApplyOp>(bumpMap, forStmt)->getResult(0);
+ builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
+ ->getResult(0);
operandMapping[forStmt] = cast<MLValue>(ivUnroll);
}