Push location information more tightly into the IR, providing space for every
operation and statement to have a location, and make it so a location is
required to be specified whenever you make one (though a null location is still
allowed).  This is to encourage compiler authors to propagate loc info
properly, allowing our failability story to work well.

This is still a WIP - it isn't clear if we want to continue abusing Attribute
for location information, or whether we should introduce a new class heirarchy
to do so.  This is good step along the way, and unblocks some of the tf/xla
work that builds upon it.

PiperOrigin-RevId: 210001406
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 7da08c2..8a6248e 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -16,6 +16,7 @@
 // =============================================================================
 
 #include "mlir/IR/MLFunction.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/StmtVisitor.h"
 #include "mlir/IR/Types.h"
@@ -58,6 +59,21 @@
   }
 }
 
+/// Return the context this operation is associated with.
+MLIRContext *Statement::getContext() const {
+  // Work a bit to avoid calling findFunction() and getting its context.
+  switch (getKind()) {
+  case Kind::Operation:
+    return cast<OperationStmt>(this)->getContext();
+  case Kind::For:
+    return cast<ForStmt>(this)->getType()->getContext();
+  case Kind::If:
+    // TODO(shpeisman): When if statement has value operands, we can get a
+    // context from their type.
+    return findFunction()->getContext();
+  }
+}
+
 Statement *Statement::getParentStmt() const {
   return block ? block->getParentStmt() : nullptr;
 }
@@ -78,6 +94,28 @@
   return nlc.numNestedLoops == 1;
 }
 
+/// Emit a note about this statement, reporting up to any diagnostic
+/// handlers that may be listening.
+void Statement::emitNote(const Twine &message) const {
+  getContext()->emitDiagnostic(getLoc(), message,
+                               MLIRContext::DiagnosticKind::Note);
+}
+
+/// Emit a warning about this statement, reporting up to any diagnostic
+/// handlers that may be listening.
+void Statement::emitWarning(const Twine &message) const {
+  getContext()->emitDiagnostic(getLoc(), message,
+                               MLIRContext::DiagnosticKind::Warning);
+}
+
+/// Emit an error about fatal conditions with this statement, reporting up to
+/// any diagnostic handlers that may be listening.  NOTE: This may terminate
+/// the containing application, only use when the IR is in an inconsistent
+/// state.
+void Statement::emitError(const Twine &message) const {
+  getContext()->emitDiagnostic(getLoc(), message,
+                               MLIRContext::DiagnosticKind::Error);
+}
 //===----------------------------------------------------------------------===//
 // ilist_traits for Statement
 //===----------------------------------------------------------------------===//
@@ -133,7 +171,7 @@
 //===----------------------------------------------------------------------===//
 
 /// Create a new OperationStmt with the specific fields.
-OperationStmt *OperationStmt::create(Identifier name,
+OperationStmt *OperationStmt::create(Attribute *location, Identifier name,
                                      ArrayRef<MLValue *> operands,
                                      ArrayRef<Type *> resultTypes,
                                      ArrayRef<NamedAttribute> attributes,
@@ -144,7 +182,7 @@
 
   // Initialize the OperationStmt part of the statement.
   auto stmt = ::new (rawMem) OperationStmt(
-      name, operands.size(), resultTypes.size(), attributes, context);
+      location, name, operands.size(), resultTypes.size(), attributes, context);
 
   // Initialize the operands and results.
   auto stmtOperands = stmt->getStmtOperands();
@@ -157,12 +195,12 @@
   return stmt;
 }
 
-OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
-                             unsigned numResults,
+OperationStmt::OperationStmt(Attribute *location, Identifier name,
+                             unsigned numOperands, unsigned numResults,
                              ArrayRef<NamedAttribute> attributes,
                              MLIRContext *context)
-    : Operation(name, /*isInstruction=*/false, attributes, context),
-      Statement(Kind::Operation), numOperands(numOperands),
+    : Operation(/*isInstruction=*/false, name, attributes, context),
+      Statement(Kind::Operation, location), numOperands(numOperands),
       numResults(numResults) {}
 
 OperationStmt::~OperationStmt() {
@@ -197,9 +235,10 @@
 // ForStmt
 //===----------------------------------------------------------------------===//
 
-ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
-                 int64_t step, MLIRContext *context)
-    : Statement(Kind::For),
+ForStmt::ForStmt(Attribute *location, AffineConstantExpr *lowerBound,
+                 AffineConstantExpr *upperBound, int64_t step,
+                 MLIRContext *context)
+    : Statement(Kind::For, location),
       MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
       StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
       upperBound(upperBound), step(step) {}
@@ -208,6 +247,10 @@
 // IfStmt
 //===----------------------------------------------------------------------===//
 
+IfStmt::IfStmt(Attribute *location, IntegerSet *condition)
+    : Statement(Kind::If, location), thenClause(new IfClause(this)),
+      elseClause(nullptr), condition(condition) {}
+
 IfStmt::~IfStmt() {
   delete thenClause;
   if (elseClause)
@@ -244,8 +287,9 @@
     resultTypes.reserve(opStmt->getNumResults());
     for (auto *result : opStmt->getResults())
       resultTypes.push_back(result->getType());
-    auto *newOp = OperationStmt::create(
-        opStmt->getName(), operands, resultTypes, opStmt->getAttrs(), context);
+    auto *newOp =
+        OperationStmt::create(getLoc(), opStmt->getName(), operands,
+                              resultTypes, opStmt->getAttrs(), context);
     // Remember the mapping of any results.
     for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
       operandMap[opStmt->getResult(i)] = newOp->getResult(i);
@@ -254,8 +298,8 @@
 
   if (auto *forStmt = dyn_cast<ForStmt>(this)) {
     auto *newFor =
-        new ForStmt(forStmt->getLowerBound(), forStmt->getUpperBound(),
-                    forStmt->getStep(), context);
+        new ForStmt(getLoc(), forStmt->getLowerBound(),
+                    forStmt->getUpperBound(), forStmt->getStep(), context);
     // Remember the induction variable mapping.
     operandMap[forStmt] = newFor;
 
@@ -269,7 +313,7 @@
 
   // Otherwise, we must have an If statement.
   auto *ifStmt = cast<IfStmt>(this);
-  auto *newIf = new IfStmt(ifStmt->getCondition());
+  auto *newIf = new IfStmt(getLoc(), ifStmt->getCondition());
 
   // TODO: remap operands with remapOperand when if statements have them.