Finish support for function attributes, and improve lots of things:
 - Have the parser rewrite forward references to their resolved values at the
   end of parsing.
 - Implement verifier support for detecting malformed function attrs.
 - Add efficient query for (in general, recursive) attributes to tell if they
   contain a function.

As part of this, improve other general infrastructure:
 - Implement support for verifying OperationStmt's in ml functions, refactoring
   and generalizing support for operations in the verifier.
 - Refactor location handling code in mlir-opt to have the non-error expecting
   form of mlir-opt invocations to report error locations precisely.
 - Fix parser to detect verifier failures and report them through errorReporter
   instead of printing the error and crashing.

This regresses the location info for verifier errors in the parser that were
previously ascribed to the function.  This will get resolved in future patches
by adding support for function attributes, which we can use to manage location
information.

PiperOrigin-RevId: 209600980
diff --git a/include/mlir/IR/Attributes.h b/include/mlir/IR/Attributes.h
index 8ccc73e..b7d11ec 100644
--- a/include/mlir/IR/Attributes.h
+++ b/include/mlir/IR/Attributes.h
@@ -27,6 +27,8 @@
 class MLIRContext;
 class Type;
 
+/// Attributes are known-constant values of operations and functions.
+///
 /// Instances of the Attribute class are immutable, uniqued, immortal, and owned
 /// by MLIRContext.  As such, they are passed around by raw non-const pointer.
 class Attribute {
@@ -47,18 +49,25 @@
     return kind;
   }
 
+  /// Return true if this field is, or contains, a function attribute.
+  bool isOrContainsFunction() const { return isOrContainsFunctionCache; }
+
   /// Print the attribute.
   void print(raw_ostream &os) const;
   void dump() const;
 
 protected:
-  explicit Attribute(Kind kind) : kind(kind) {}
+  explicit Attribute(Kind kind, bool isOrContainsFunction)
+      : kind(kind), isOrContainsFunctionCache(isOrContainsFunction) {}
   ~Attribute() {}
 
 private:
   /// Classification of the subclass, used for type checking.
   Kind kind : 8;
 
+  /// This field is true if this is, or contains, a function attribute.
+  bool isOrContainsFunctionCache : 1;
+
   Attribute(const Attribute&) = delete;
   void operator=(const Attribute&) = delete;
 };
@@ -81,7 +90,8 @@
     return attr->getKind() == Kind::Bool;
   }
 private:
-  BoolAttr(bool value) : Attribute(Kind::Bool), value(value) {}
+  BoolAttr(bool value)
+      : Attribute(Kind::Bool, /*isOrContainsFunction=*/false), value(value) {}
   ~BoolAttr() = delete;
   bool value;
 };
@@ -99,7 +109,9 @@
     return attr->getKind() == Kind::Integer;
   }
 private:
-  IntegerAttr(int64_t value) : Attribute(Kind::Integer), value(value) {}
+  IntegerAttr(int64_t value)
+      : Attribute(Kind::Integer, /*isOrContainsFunction=*/false), value(value) {
+  }
   ~IntegerAttr() = delete;
   int64_t value;
 };
@@ -117,7 +129,8 @@
     return attr->getKind() == Kind::Float;
   }
 private:
-  FloatAttr(double value) : Attribute(Kind::Float), value(value) {}
+  FloatAttr(double value)
+      : Attribute(Kind::Float, /*isOrContainsFunction=*/false), value(value) {}
   ~FloatAttr() = delete;
   double value;
 };
@@ -135,11 +148,14 @@
     return attr->getKind() == Kind::String;
   }
 private:
-  StringAttr(StringRef value) : Attribute(Kind::String), value(value) {}
+  StringAttr(StringRef value)
+      : Attribute(Kind::String, /*isOrContainsFunction=*/false), value(value) {}
   ~StringAttr() = delete;
   StringRef value;
 };
 
+/// Array attributes are lists of other attributes.  They are not necessarily
+/// type homogenous given that attributes don't, in general, carry types.
 class ArrayAttr : public Attribute {
 public:
   static ArrayAttr *get(ArrayRef<Attribute*> value, MLIRContext *context);
@@ -153,7 +169,8 @@
     return attr->getKind() == Kind::Array;
   }
 private:
-  ArrayAttr(ArrayRef<Attribute*> value) : Attribute(Kind::Array), value(value){}
+  ArrayAttr(ArrayRef<Attribute *> value, bool isOrContainsFunction)
+      : Attribute(Kind::Array, isOrContainsFunction), value(value) {}
   ~ArrayAttr() = delete;
   ArrayRef<Attribute*> value;
 };
@@ -171,7 +188,9 @@
     return attr->getKind() == Kind::AffineMap;
   }
 private:
-  AffineMapAttr(AffineMap *value) : Attribute(Kind::AffineMap), value(value) {}
+  AffineMapAttr(AffineMap *value)
+      : Attribute(Kind::AffineMap, /*isOrContainsFunction=*/false),
+        value(value) {}
   ~AffineMapAttr() = delete;
   AffineMap *value;
 };
@@ -188,7 +207,8 @@
   }
 
 private:
-  TypeAttr(Type *value) : Attribute(Kind::Type), value(value) {}
+  TypeAttr(Type *value)
+      : Attribute(Kind::Type, /*isOrContainsFunction=*/false), value(value) {}
   ~TypeAttr() = delete;
   Type *value;
 };
@@ -216,7 +236,9 @@
   static void dropFunctionReference(Function *value);
 
 private:
-  FunctionAttr(Function *value) : Attribute(Kind::Function), value(value) {}
+  FunctionAttr(Function *value)
+      : Attribute(Kind::Function, /*isOrContainsFunction=*/true), value(value) {
+  }
   ~FunctionAttr() = delete;
   Function *value;
 };
diff --git a/include/mlir/IR/MLIRContext.h b/include/mlir/IR/MLIRContext.h
index fc35d97..d782be6 100644
--- a/include/mlir/IR/MLIRContext.h
+++ b/include/mlir/IR/MLIRContext.h
@@ -65,6 +65,9 @@
   /// message and a boolean that indicates whether this is an error or warning.
   void registerDiagnosticHandler(const DiagnosticHandlerTy &handler);
 
+  /// Return the current diagnostic handler, or null if none is present.
+  DiagnosticHandlerTy getDiagnosticHandler() const;
+
   /// This emits an diagnostic using the registered issue handle if present, or
   /// with the default behavior if not.  The MLIR compiler should not generally
   /// interact with this, it should use methods on Operation instead.
diff --git a/include/mlir/IR/Operation.h b/include/mlir/IR/Operation.h
index 51f7ee1..fed0d4e 100644
--- a/include/mlir/IR/Operation.h
+++ b/include/mlir/IR/Operation.h
@@ -31,6 +31,7 @@
 template <typename OpType> class OpPointer;
 template <typename ObjectType, typename ElementType> class OperandIterator;
 template <typename ObjectType, typename ElementType> class ResultIterator;
+class Function;
 class SSAValue;
 class Type;
 
@@ -69,6 +70,14 @@
   /// Return the context this operation is associated with.
   MLIRContext *getContext() const;
 
+  /// Return the function this operation is defined in.  This has a verbose
+  /// name to avoid name lookup ambiguities.
+  Function *getOperationFunction();
+
+  const Function *getOperationFunction() const {
+    return const_cast<Operation *>(this)->getOperationFunction();
+  }
+
   /// The name of an operation is the key identifier for it.
   Identifier getName() const { return nameAndIsInstruction.getPointer(); }
 
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index a960223..68839d8 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -285,6 +285,11 @@
   getImpl().diagnosticHandler = handler;
 }
 
+/// Return the current diagnostic handler, or null if none is present.
+auto MLIRContext::getDiagnosticHandler() const -> DiagnosticHandlerTy {
+  return getImpl().diagnosticHandler;
+}
+
 /// This emits a diagnostic using the registered issue handle if present, or
 /// with the default behavior if not.  The MLIR compiler should not generally
 /// interact with this, it should use methods on Operation instead.
@@ -608,8 +613,16 @@
   // Copy the elements into the bump pointer.
   value = impl.copyInto(value);
 
+  // Check to see if any of the elements have a function attr.
+  bool hasFunctionAttr = false;
+  for (auto *elt : value)
+    if (elt->isOrContainsFunction()) {
+      hasFunctionAttr = true;
+      break;
+    }
+
   // Initialize the memory using placement new.
-  new (result) ArrayAttr(value);
+  new (result) ArrayAttr(value, hasFunctionAttr);
 
   // Cache and return it.
   return *existing.first = result;
diff --git a/lib/IR/Operation.cpp b/lib/IR/Operation.cpp
index af937bc..aed6fcc 100644
--- a/lib/IR/Operation.cpp
+++ b/lib/IR/Operation.cpp
@@ -17,7 +17,9 @@
 
 #include "mlir/IR/Operation.h"
 #include "AttributeListStorage.h"
+#include "mlir/IR/CFGFunction.h"
 #include "mlir/IR/Instructions.h"
+#include "mlir/IR/MLFunction.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Statements.h"
 using namespace mlir;
@@ -42,22 +44,26 @@
   return cast<OperationStmt>(this)->getContext();
 }
 
+/// Return the function this operation is defined in.
+Function *Operation::getOperationFunction() {
+  if (auto *inst = dyn_cast<OperationInst>(this))
+    return inst->getFunction();
+  return cast<OperationStmt>(this)->findFunction();
+}
+
 /// Return the number of operands this operation has.
 unsigned Operation::getNumOperands() const {
-  if (auto *inst = dyn_cast<OperationInst>(this)) {
+  if (auto *inst = dyn_cast<OperationInst>(this))
     return inst->getNumOperands();
-  } else {
-    return cast<OperationStmt>(this)->getNumOperands();
-  }
+
+  return cast<OperationStmt>(this)->getNumOperands();
 }
 
 SSAValue *Operation::getOperand(unsigned idx) {
-  if (auto *inst = dyn_cast<OperationInst>(this)) {
+  if (auto *inst = dyn_cast<OperationInst>(this))
     return inst->getOperand(idx);
-  } else {
-    auto *stmt = cast<OperationStmt>(this);
-    return stmt->getOperand(idx);
-  }
+
+  return cast<OperationStmt>(this)->getOperand(idx);
 }
 
 void Operation::setOperand(unsigned idx, SSAValue *value) {
@@ -71,22 +77,18 @@
 
 /// Return the number of results this operation has.
 unsigned Operation::getNumResults() const {
-  if (auto *inst = dyn_cast<OperationInst>(this)) {
+  if (auto *inst = dyn_cast<OperationInst>(this))
     return inst->getNumResults();
-  } else {
-    auto *stmt = cast<OperationStmt>(this);
-    return stmt->getNumResults();
-  }
+
+  return cast<OperationStmt>(this)->getNumResults();
 }
 
 /// Return the indicated result.
 SSAValue *Operation::getResult(unsigned idx) {
-  if (auto *inst = dyn_cast<OperationInst>(this)) {
+  if (auto *inst = dyn_cast<OperationInst>(this))
     return inst->getResult(idx);
-  } else {
-    auto *stmt = cast<OperationStmt>(this);
-    return stmt->getResult(idx);
-  }
+
+  return cast<OperationStmt>(this)->getResult(idx);
 }
 
 ArrayRef<NamedAttribute> Operation::getAttrs() const {
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index 272395c..b74e113 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -33,11 +33,13 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/CFGFunction.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OperationSet.h"
 #include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
 #include "llvm/ADT/ScopedHashTable.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/PrettyStackTrace.h"
@@ -81,31 +83,93 @@
     return true;
   }
 
+  bool verifyOperation(const Operation &op);
+  bool verifyAttribute(Attribute *attr, const Operation &op);
+
 protected:
-  explicit Verifier(std::string *errorResult) : errorResult(errorResult) {}
+  explicit Verifier(std::string *errorResult, const Function &fn)
+      : errorResult(errorResult), fn(fn),
+        operationSet(OperationSet::get(fn.getContext())) {}
 
 private:
+  /// If the verifier is returning errors back to a client, this is the error to
+  /// fill in.
   std::string *errorResult;
+
+  /// The function being checked.
+  const Function &fn;
+
+  /// The operation set installed in the current MLIR context.
+  OperationSet &operationSet;
 };
 } // end anonymous namespace
 
+// Check that function attributes are all well formed.
+bool Verifier::verifyAttribute(Attribute *attr, const Operation &op) {
+  if (!attr->isOrContainsFunction())
+    return false;
+
+  // If we have a function attribute, check that it is non-null and in the
+  // same module as the operation that refers to it.
+  if (auto *fnAttr = dyn_cast<FunctionAttr>(attr)) {
+    if (!fnAttr->getValue())
+      return opFailure("attribute refers to deallocated function!", op);
+
+    if (fnAttr->getValue()->getModule() != fn.getModule())
+      return opFailure("attribute refers to function '" +
+                           Twine(fnAttr->getValue()->getName()) +
+                           "' defined in another module!",
+                       op);
+    return false;
+  }
+
+  // Otherwise, we must have an array attribute, remap the elements.
+  for (auto *elt : cast<ArrayAttr>(attr)->getValue()) {
+    if (verifyAttribute(elt, op))
+      return true;
+  }
+
+  return false;
+}
+
+/// Check the invariants of the specified operation instruction or statement.
+bool Verifier::verifyOperation(const Operation &op) {
+  if (op.getOperationFunction() != &fn)
+    return opFailure("operation in the wrong function", op);
+
+  // TODO: Check that operands are non-nil and structurally ok.
+
+  // Verify all attributes are ok.  We need to check Function attributes, since
+  // they are actually mutable (the function they refer to can be deleted), and
+  // we have to check array attributes that can refer to them.
+  for (auto attr : op.getAttrs()) {
+    if (verifyAttribute(attr.second, op))
+      return true;
+  }
+
+  // If we can get operation info for this, check the custom hook.
+  if (auto *opInfo = op.getAbstractOperation()) {
+    if (auto *errorMessage = opInfo->verifyInvariants(&op))
+      return opFailure(Twine("'") + op.getName().str() + "' op " + errorMessage,
+                       op);
+  }
+
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // CFG Functions
 //===----------------------------------------------------------------------===//
 
 namespace {
-class CFGFuncVerifier : public Verifier {
-public:
+struct CFGFuncVerifier : public Verifier {
   const CFGFunction &fn;
-  OperationSet &operationSet;
 
   CFGFuncVerifier(const CFGFunction &fn, std::string *errorResult)
-      : Verifier(errorResult), fn(fn),
-        operationSet(OperationSet::get(fn.getContext())) {}
+      : Verifier(errorResult, fn), fn(fn) {}
 
   bool verify();
   bool verifyBlock(const BasicBlock &block);
-  bool verifyOperation(const OperationInst &inst);
   bool verifyTerminator(const TerminatorInst &term);
   bool verifyReturn(const ReturnInst &inst);
   bool verifyBranch(const BranchInst &inst);
@@ -281,39 +345,31 @@
   return false;
 }
 
-bool CFGFuncVerifier::verifyOperation(const OperationInst &inst) {
-  if (inst.getFunction() != &fn)
-    return opFailure("operation in the wrong function", inst);
-
-  // TODO: Check that operands are structurally ok.
-
-  // See if we can get operation info for this.
-  if (auto *opInfo = inst.getAbstractOperation()) {
-    if (auto errorMessage = opInfo->verifyInvariants(&inst))
-      return opFailure(
-          Twine("'") + inst.getName().str() + "' op " + errorMessage, inst);
-  }
-
-  return false;
-}
-
 //===----------------------------------------------------------------------===//
 // ML Functions
 //===----------------------------------------------------------------------===//
 
 namespace {
-class MLFuncVerifier : public Verifier {
-public:
+struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> {
   const MLFunction &fn;
+  bool hadError = false;
 
   MLFuncVerifier(const MLFunction &fn, std::string *errorResult)
-      : Verifier(errorResult), fn(fn) {}
+      : Verifier(errorResult, fn), fn(fn) {}
+
+  void visitOperationStmt(OperationStmt *opStmt) {
+    hadError |= verifyOperation(*opStmt);
+  }
 
   bool verify() {
     llvm::PrettyStackTraceFormat fmt("MLIR Verifier: mlfunc @%s",
                                      fn.getName().c_str());
 
-    // TODO: check basic structural properties
+    // Check basic structural properties.
+    walk(const_cast<MLFunction *>(&fn));
+    if (hadError)
+      return true;
+
     // TODO: check that operation is not a return statement unless it's
     // the last one in the function.
     if (verifyReturn())
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 5310daa..174d6ca 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -26,10 +26,11 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/MLFunction.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSet.h"
-#include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/SourceMgr.h"
@@ -997,16 +998,10 @@
 ///
 ///   affine-expr ::= integer-literal
 AffineExpr *AffineParser::parseIntegerExpr() {
-  // No need to handle negative numbers separately here. They are naturally
-  // handled via the unary negation operator, although (FIXME) MININT_64 still
-  // not correctly handled.
-  if (getToken().isNot(Token::integer))
-    return (emitError("expected integer"), nullptr);
-
   auto val = getToken().getUInt64IntegerValue();
-  if (!val.hasValue() || (int64_t)val.getValue() < 0) {
+  if (!val.hasValue() || (int64_t)val.getValue() < 0)
     return (emitError("constant too large for affineint"), nullptr);
-  }
+
   consumeToken(Token::integer);
   return builder.getConstantExpr((int64_t)val.getValue());
 }
@@ -1454,11 +1449,6 @@
     return ParseFailure;
   }
 
-  // Run the verifier on this function.  If an error is detected, report it.
-  std::string errorString;
-  if (func->verify(&errorString))
-    return emitError(loc, errorString);
-
   return ParseSuccess;
 }
 
@@ -2220,9 +2210,6 @@
 
   // Reset insertion point to the current block.
   builder.setInsertionPointToEnd(forStmt->getBlock());
-
-  // TODO: remove definition of the induction variable.
-
   return ParseSuccess;
 }
 
@@ -2700,11 +2687,65 @@
   return parser.parseFunctionBody();
 }
 
+/// Given an attribute that could refer to a function attribute in the remapping
+/// table, walk it and rewrite it to use the mapped function.  If it doesn't
+/// refer to anything in the table, then it is returned unmodified.
+static Attribute *
+remapFunctionAttrs(Attribute *input,
+                   DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable,
+                   MLIRContext *context) {
+  // Most attributes are trivially unrelated to function attributes, skip them
+  // rapidly.
+  if (!input->isOrContainsFunction())
+    return input;
+
+  // If we have a function attribute, remap it.
+  if (auto *fnAttr = dyn_cast<FunctionAttr>(input)) {
+    auto it = remappingTable.find(fnAttr);
+    return it != remappingTable.end() ? it->second : input;
+  }
+
+  // Otherwise, we must have an array attribute, remap the elements.
+  auto *arrayAttr = cast<ArrayAttr>(input);
+  SmallVector<Attribute *, 8> remappedElts;
+  bool anyChange = false;
+  for (auto *elt : arrayAttr->getValue()) {
+    auto *newElt = remapFunctionAttrs(elt, remappingTable, context);
+    remappedElts.push_back(newElt);
+    anyChange |= (elt != newElt);
+  }
+
+  if (!anyChange)
+    return input;
+
+  return ArrayAttr::get(remappedElts, context);
+}
+
+/// Remap function attributes to resolve forward references to their actual
+/// definition.
+static void remapFunctionAttrsInOperation(
+    Operation *op, DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable) {
+  for (auto attr : op->getAttrs()) {
+    // Do the remapping, if we got the same thing back, then it must contain
+    // functions that aren't getting remapped.
+    auto *newVal =
+        remapFunctionAttrs(attr.second, remappingTable, op->getContext());
+    if (newVal == attr.second)
+      continue;
+
+    // Otherwise, replace the existing attribute with the new one.  It is safe
+    // to mutate the attribute list while we walk it because underlying
+    // attribute lists are uniqued and immortal.
+    op->setAttr(attr.first, newVal);
+  }
+}
+
 /// Finish the end of module parsing - when the result is valid, do final
 /// checking.
 ParseResult ModuleParser::finalizeModule() {
 
-  // Resolve all forward references.
+  // Resolve all forward references, building a remapping table of attributes.
+  DenseMap<FunctionAttr *, FunctionAttr *> remappingTable;
   for (auto forwardRef : getState().functionForwardRefs) {
     auto name = forwardRef.first;
 
@@ -2714,10 +2755,47 @@
       return emitError(forwardRef.second.second,
                        "reference to undefined function '" + name.str() + "'");
 
-    // TODO(clattner): actually go through and update references in the module
-    // to the new function.
+    remappingTable[builder.getFunctionAttr(forwardRef.second.first)] =
+        builder.getFunctionAttr(resolvedFunction);
   }
 
+  // If there was nothing to remap, then we're done.
+  if (remappingTable.empty())
+    return ParseSuccess;
+
+  // Otherwise, walk the entire module replacing uses of one attribute set with
+  // the correct ones.
+  for (auto &fn : *getModule()) {
+    if (auto *cfgFn = dyn_cast<CFGFunction>(&fn)) {
+      for (auto &bb : *cfgFn) {
+        for (auto &inst : bb) {
+          remapFunctionAttrsInOperation(&inst, remappingTable);
+        }
+      }
+    }
+
+    // Otherwise, look at MLFunctions.  We ignore ExtFunctions.
+    auto *mlFn = dyn_cast<MLFunction>(&fn);
+    if (!mlFn)
+      continue;
+
+    struct MLFnWalker : public StmtWalker<MLFnWalker> {
+      MLFnWalker(DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable)
+          : remappingTable(remappingTable) {}
+      void visitOperationStmt(OperationStmt *opStmt) {
+        remapFunctionAttrsInOperation(opStmt, remappingTable);
+      }
+
+      DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable;
+    };
+
+    MLFnWalker(remappingTable).walk(mlFn);
+  }
+
+  // Now that all references to the forward definition placeholders are
+  // resolved, we can deallocate the placeholders.
+  for (auto forwardRef : getState().functionForwardRefs)
+    forwardRef.second.first->destroy();
   return ParseSuccess;
 }
 
@@ -2778,16 +2856,60 @@
 /// MLIR module if it was valid.  If not, it emits diagnostics and returns null.
 Module *mlir::parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context,
                               SMDiagnosticHandlerTy errorReporter) {
+  if (!errorReporter)
+    errorReporter = defaultErrorReporter;
+
+  // We are going to replace the context's handler and redirect it to use the
+  // error reporter.  Save the existing handler and reinstate it when we're
+  // done.
+  auto existingContextHandler = context->getDiagnosticHandler();
+
+  // Install a new handler that uses the error reporter.
+  context->registerDiagnosticHandler([&](Attribute *location, StringRef message,
+                                         MLIRContext::DiagnosticKind kind) {
+    auto offset = cast<IntegerAttr>(location)->getValue();
+    auto *mainBuffer = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
+    auto ptr = mainBuffer->getBufferStart() + offset;
+    SourceMgr::DiagKind diagKind;
+    switch (kind) {
+    case MLIRContext::DiagnosticKind::Error:
+      diagKind = SourceMgr::DK_Error;
+      break;
+    case MLIRContext::DiagnosticKind::Warning:
+      diagKind = SourceMgr::DK_Warning;
+      break;
+    case MLIRContext::DiagnosticKind::Note:
+      diagKind = SourceMgr::DK_Note;
+      break;
+    }
+    errorReporter(
+        sourceMgr.GetMessage(SMLoc::getFromPointer(ptr), diagKind, message));
+  });
+
   // This is the result module we are parsing into.
   std::unique_ptr<Module> module(new Module(context));
 
-  ParserState state(sourceMgr, module.get(),
-                    errorReporter ? errorReporter : defaultErrorReporter);
-  if (ModuleParser(state).parseModule())
+  ParserState state(sourceMgr, module.get(), errorReporter);
+  if (ModuleParser(state).parseModule()) {
+    context->registerDiagnosticHandler(existingContextHandler);
     return nullptr;
+  }
 
   // Make sure the parse module has no other structural problems detected by the
   // verifier.
-  module->verify();
+  std::string errorResult;
+  module->verify(&errorResult);
+
+  // We don't have location information for general verifier errors, so emit the
+  // error on the first line.
+  if (!errorResult.empty()) {
+    auto *mainBuffer = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
+    errorReporter(sourceMgr.GetMessage(
+        SMLoc::getFromPointer(mainBuffer->getBufferStart()),
+        SourceMgr::DK_Error, errorResult));
+    return nullptr;
+  }
+
+  context->registerDiagnosticHandler(existingContextHandler);
   return module.release();
 }
diff --git a/test/IR/invalid-ops.mlir b/test/IR/invalid-ops.mlir
index ad821c0..17f805a 100644
--- a/test/IR/invalid-ops.mlir
+++ b/test/IR/invalid-ops.mlir
@@ -98,3 +98,10 @@
   %4 = store %3, %0[%1, %2] : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1> // expected-error {{cannot name an operation with no results}}
   return
 }
+
+// -----
+
+mlfunc @mlfunc_constant() {
+  %x = "constant"(){value: "xyz"} : () -> i32 // expected-error {{'constant' op requires 'value' to be an integer for an integer result type}}
+  return
+}
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index 4b8a22b..dc983c6 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -104,13 +104,13 @@
 // -----
 
 mlfunc @empty() {
-// expected-error@-1 {{ML function must end with return statement}}
+  //expected-error@-3 {{ML function must end with return statement}}
 }
 
 // -----
 
 mlfunc @no_return() {
-// expected-error@-1 {{ML function must end with return statement}}
+  // expected-error@-3 {{ML function must end with return statement}}
   "foo"() : () -> ()
 }
 
@@ -275,14 +275,14 @@
 
 // -----
 
-cfgfunc @resulterror() -> i32 {  // expected-error {{return has 0 operands, but enclosing function returns 1}}
+cfgfunc @resulterror() -> i32 {
 bb42:
-  return
+  return    // expected-error@-4{{return has 0 operands, but enclosing function returns 1}}
 }
 
 // -----
 
-mlfunc @mlfunc_resulterror() -> i32 {  // expected-error {{return has 0 operands, but enclosing function returns 1}}
+mlfunc @mlfunc_resulterror() -> i32 {  // expected-error@-2 {{return has 0 operands, but enclosing function returns 1}}
   return
 }
 
@@ -297,14 +297,14 @@
 
 // -----
 
-cfgfunc @bbargMismatch(i32, f32) { // expected-error {{first block of cfgfunc must have 2 arguments to match function signature}}
+cfgfunc @bbargMismatch(i32, f32) { // expected-error @-2 {{first block of cfgfunc must have 2 arguments to match function signature}}
 bb42(%0: f32):
   return
 }
 
 // -----
 
-cfgfunc @br_mismatch() {  // expected-error {{branch has 2 operands, but target block has 1}}
+cfgfunc @br_mismatch() {  // expected-error @-2 {{branch has 2 operands, but target block has 1}}
 bb0:
   %0 = "foo"() : () -> (i1, i17)
   br bb1(%0#1, %0#0 : i17, i1)
@@ -370,7 +370,7 @@
 
 // -----
 
-mlfunc @duplicate_induction_var() {  // expected-error {{}}
+mlfunc @dominance_failure() {
   for %i = 1 to 10 {
   }
   "xxx"(%i) : (affineint)->()   // expected-error {{operand #0 does not dominate this use}}
@@ -380,7 +380,7 @@
 // -----
 
 mlfunc @return_type_mismatch() -> i32 {
-  // expected-error@-1 {{type of return operand 0 doesn't match function result type}}
+  // expected-error@-3 {{type of return operand 0 doesn't match function result type}}
   %0 = "foo"() : ()->f32
   return %0 : f32
 }
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index 29fc9c4..2c6b76b 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -169,11 +169,11 @@
   for (auto &subbuffer : sourceBuffers) {
     SourceMgr sourceMgr;
     // Tell sourceMgr about this buffer, which is what the parser will pick up.
-    auto bufferId = sourceMgr.AddNewSourceBuffer(
-        MemoryBuffer::getMemBufferCopy(subbuffer), SMLoc());
+    sourceMgr.AddNewSourceBuffer(MemoryBuffer::getMemBufferCopy(subbuffer),
+                                 SMLoc());
 
     // Extract the expected errors.
-    llvm::Regex expected("expected-error(@[+-][0-9]+)? *{{(.*)}}");
+    llvm::Regex expected("expected-error *(@[+-][0-9]+)? *{{(.*)}}");
     SmallVector<ExpectedError, 2> expectedErrors;
     SmallVector<StringRef, 100> lines;
     subbuffer.split(lines, '\n');
@@ -221,30 +221,6 @@
     // Parse the input file.
     MLIRContext context;
     initializeMLIRContext(&context);
-
-    // TODO: refactor into initializeMLIRContext so the normal parser pass
-    // gets to use this.
-    context.registerDiagnosticHandler([&](Attribute *location,
-                                          StringRef message,
-                                          MLIRContext::DiagnosticKind kind) {
-      auto offset = cast<IntegerAttr>(location)->getValue();
-      auto ptr = sourceMgr.getMemoryBuffer(bufferId)->getBufferStart() + offset;
-      SourceMgr::DiagKind diagKind;
-      switch (kind) {
-      case MLIRContext::DiagnosticKind::Error:
-        diagKind = SourceMgr::DK_Error;
-        break;
-      case MLIRContext::DiagnosticKind::Warning:
-        diagKind = SourceMgr::DK_Warning;
-        break;
-      case MLIRContext::DiagnosticKind::Note:
-        diagKind = SourceMgr::DK_Note;
-        break;
-      }
-      checker(
-          sourceMgr.GetMessage(SMLoc::getFromPointer(ptr), diagKind, message));
-    });
-
     std::unique_ptr<Module> module(
         parseSourceFile(sourceMgr, &context, checker));