Implement call and call_indirect ops.
This also fixes an infinite recursion in VariadicOperands that this turned up.
PiperOrigin-RevId: 209692932
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 0debe2d..d2ea5bf 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -33,6 +33,7 @@
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
using llvm::SMLoc;
@@ -180,6 +181,8 @@
ParseResult parseTypeList(SmallVectorImpl<Type *> &elements);
// Attribute parsing.
+ Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
+ FunctionType *type);
Attribute *parseAttribute();
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
@@ -578,6 +581,33 @@
// Attribute parsing.
//===----------------------------------------------------------------------===//
+/// Given a parsed reference to a function name like @foo and a type that it
+/// corresponds to, resolve it to a concrete function object (possibly
+/// synthesizing a forward reference) or emit an error and return null on
+/// failure.
+Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
+ FunctionType *type) {
+ Identifier name = builder.getIdentifier(nameStr.drop_front());
+
+ // See if the function has already been defined in the module.
+ Function *function = getModule()->getNamedFunction(name);
+
+ // If not, get or create a forward reference to one.
+ if (!function) {
+ auto &entry = state.functionForwardRefs[name];
+ if (!entry.first) {
+ entry.first = new ExtFunction(name, type);
+ entry.second = nameLoc;
+ }
+ function = entry.first;
+ }
+
+ if (function->getType() != type)
+ return (emitError(nameLoc, "reference to function with mismatched type"),
+ nullptr);
+ return function;
+}
+
/// Attribute parsing.
///
/// attribute-value ::= bool-literal
@@ -664,7 +694,7 @@
case Token::at_identifier: {
auto nameLoc = getToken().getLoc();
- Identifier name = builder.getIdentifier(getTokenSpelling().drop_front());
+ auto nameStr = getTokenSpelling();
consumeToken(Token::at_identifier);
if (parseToken(Token::colon, "expected ':' and function type"))
@@ -673,28 +703,12 @@
Type *type = parseType();
if (!type)
return nullptr;
- auto fnType = dyn_cast<FunctionType>(type);
+ auto *fnType = dyn_cast<FunctionType>(type);
if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr);
- // See if the function has already been defined in the module.
- Function *function = getModule()->getNamedFunction(name);
-
- // If not, get or create a forward reference to one.
- if (!function) {
- auto &entry = state.functionForwardRefs[name];
- if (!entry.first) {
- entry.first = new ExtFunction(name, fnType);
- entry.second = nameLoc;
- }
- function = entry.first;
- }
-
- if (function->getType() != type)
- return (emitError(typeLoc, "reference to function with mismatched type"),
- nullptr);
-
- return builder.getFunctionAttr(function);
+ auto *function = resolveFunctionReference(nameStr, nameLoc, fnType);
+ return function ? builder.getFunctionAttr(function) : nullptr;
}
default: {
@@ -1701,6 +1715,10 @@
// High level parsing methods.
//===--------------------------------------------------------------------===//
+ bool getCurrentLocation(llvm::SMLoc *loc) override {
+ *loc = parser.getToken().getLoc();
+ return false;
+ }
bool parseComma(llvm::SMLoc *loc = nullptr) override {
if (loc)
*loc = parser.getToken().getLoc();
@@ -1753,6 +1771,19 @@
return parser.parseAttributeDict(result) == ParseFailure;
}
+ /// Parse a function name like '@foo' and return the name in a form that can
+ /// be passed to resolveFunctionName when a function type is available.
+ virtual bool parseFunctionName(StringRef &result, llvm::SMLoc &loc) {
+ loc = parser.getToken().getLoc();
+
+ if (parser.getToken().isNot(Token::at_identifier))
+ return emitError(loc, "expected function name");
+
+ result = parser.getTokenSpelling();
+ parser.consumeToken(Token::at_identifier);
+ return false;
+ }
+
bool parseOperand(OperandType &result) override {
FunctionParser::SSAUseInfo useInfo;
if (parser.parseSSAUse(useInfo))
@@ -1822,6 +1853,13 @@
return false;
}
+ /// Resolve a parse function name and a type into a function reference.
+ virtual bool resolveFunctionName(StringRef name, FunctionType *type,
+ llvm::SMLoc loc, Function *&result) {
+ result = parser.resolveFunctionReference(name, loc, type);
+ return result == nullptr;
+ }
+
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
//===--------------------------------------------------------------------===//
@@ -1830,7 +1868,7 @@
llvm::SMLoc getNameLoc() const override { return nameLoc; }
- bool resolveOperand(OperandType operand, Type *type,
+ bool resolveOperand(const OperandType &operand, Type *type,
SmallVectorImpl<SSAValue *> &result) override {
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
operand.location};
@@ -1872,6 +1910,11 @@
consumeToken();
+ // If the custom op parser crashes, produce some indication to help debugging.
+ std::string opNameStr = opName.str();
+ llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'",
+ opNameStr.c_str());
+
// Have the op implementation take a crack and parsing this.
OperationState opState(builder.getIdentifier(opName));
if (opDefinition->parseAssembly(&opAsmParser, &opState))