Implement call and call_indirect ops.
This also fixes an infinite recursion in VariadicOperands that this turned up.
PiperOrigin-RevId: 209692932
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 6bcaad4..cd90e3e 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -247,6 +247,7 @@
}
void print(const Module *module);
+ void printFunctionReference(const Function *func);
void printAttribute(const Attribute *attr);
void printType(const Type *type);
void print(const Function *fn);
@@ -387,6 +388,10 @@
}
}
+void ModulePrinter::printFunctionReference(const Function *func) {
+ os << '@' << func->getName();
+}
+
void ModulePrinter::printAttribute(const Attribute *attr) {
switch (attr->getKind()) {
case Attribute::Kind::Bool:
@@ -420,7 +425,8 @@
if (!function) {
os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
} else {
- os << '@' << function->getName() << " : ";
+ printFunctionReference(function);
+ os << " : ";
printType(function->getType());
}
break;
@@ -768,6 +774,9 @@
void printAffineExpr(const AffineExpr *expr) {
return ModulePrinter::printAffineExpr(expr);
}
+ void printFunctionReference(const Function *func) {
+ return ModulePrinter::printFunctionReference(func);
+ }
void printOperand(const SSAValue *value) { printValueID(value); }
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index b1dce25..73beea6 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -111,7 +111,7 @@
return TypeAttr::get(type, context);
}
-FunctionAttr *Builder::getFunctionAttr(Function *value) {
+FunctionAttr *Builder::getFunctionAttr(const Function *value) {
return FunctionAttr::get(value, context);
}
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 68839d8..e59b59c 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -255,7 +255,7 @@
using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists;
- DenseMap<Function *, FunctionAttr *> functionAttrs;
+ DenseMap<const Function *, FunctionAttr *> functionAttrs;
public:
MLIRContextImpl() : identifiers(allocator) {
@@ -648,16 +648,20 @@
return result;
}
-FunctionAttr *FunctionAttr::get(Function *value, MLIRContext *context) {
+FunctionAttr *FunctionAttr::get(const Function *value, MLIRContext *context) {
+ assert(value && "Cannot get FunctionAttr for a null function");
+
auto *&result = context->getImpl().functionAttrs[value];
if (result)
return result;
result = context->getImpl().allocator.Allocate<FunctionAttr>();
- new (result) FunctionAttr(value);
+ new (result) FunctionAttr(const_cast<Function *>(value));
return result;
}
+FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
+
/// This function is used by the internals of the Function class to null out
/// attributes refering to functions that are about to be deleted.
void FunctionAttr::dropFunctionReference(Function *value) {
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index cec23e0..fb7dfe7 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -129,7 +129,7 @@
// Check that affine map attribute was specified.
auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
if (!affineMapAttr)
- return "requires an affine map.";
+ return "requires an affine map";
// Check input and output dimensions match.
auto *map = affineMapAttr->getValue();
@@ -198,7 +198,152 @@
}
//===----------------------------------------------------------------------===//
-// ConstantOp
+// CallOp
+//===----------------------------------------------------------------------===//
+
+OperationState CallOp::build(Builder *builder, Function *callee,
+ ArrayRef<SSAValue *> operands) {
+ OperationState result(builder->getIdentifier("call"));
+ result.operands.append(operands.begin(), operands.end());
+ result.attributes.push_back(
+ {builder->getIdentifier("callee"), builder->getFunctionAttr(callee)});
+ result.types.append(callee->getType()->getResults().begin(),
+ callee->getType()->getResults().end());
+ return result;
+}
+
+bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
+ StringRef calleeName;
+ llvm::SMLoc calleeLoc;
+ FunctionType *calleeType = nullptr;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ Function *callee = nullptr;
+ if (parser->parseFunctionName(calleeName, calleeLoc) ||
+ parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
+ OpAsmParser::Delimiter::Paren) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(calleeType) ||
+ parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
+ parser->addTypesToList(calleeType->getResults(), result->types) ||
+ parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
+ result->operands))
+ return true;
+
+ auto &builder = parser->getBuilder();
+ result->attributes.push_back(
+ {builder.getIdentifier("callee"), builder.getFunctionAttr(callee)});
+
+ return false;
+}
+
+void CallOp::print(OpAsmPrinter *p) const {
+ *p << "call ";
+ p->printFunctionReference(getCallee());
+ *p << '(';
+ p->printOperands(getOperands());
+ *p << ')';
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
+ *p << " : " << *getCallee()->getType();
+}
+
+const char *CallOp::verify() const {
+ // Check that the callee attribute was specified.
+ auto *fnAttr = getAttrOfType<FunctionAttr>("callee");
+ if (!fnAttr)
+ return "requires a 'callee' function attribute";
+
+ // Verify that the operand and result types match the callee.
+ auto *fnType = fnAttr->getValue()->getType();
+ if (fnType->getNumInputs() != getNumOperands())
+ return "incorrect number of operands for callee";
+
+ for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
+ if (getOperand(i)->getType() != fnType->getInput(i))
+ return "operand type mismatch";
+ }
+
+ if (fnType->getNumResults() != getNumResults())
+ return "incorrect number of results for callee";
+
+ for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
+ if (getResult(i)->getType() != fnType->getResult(i))
+ return "result type mismatch";
+ }
+
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// CallIndirectOp
+//===----------------------------------------------------------------------===//
+
+OperationState CallIndirectOp::build(Builder *builder, SSAValue *callee,
+ ArrayRef<SSAValue *> operands) {
+ auto *fnType = cast<FunctionType>(callee->getType());
+
+ OperationState result(builder->getIdentifier("call_indirect"));
+ result.operands.push_back(callee);
+ result.operands.append(operands.begin(), operands.end());
+ result.types.append(fnType->getResults().begin(), fnType->getResults().end());
+ return result;
+}
+
+bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
+ FunctionType *calleeType = nullptr;
+ OpAsmParser::OperandType callee;
+ llvm::SMLoc operandsLoc;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ return parser->parseOperand(callee) ||
+ parser->getCurrentLocation(&operandsLoc) ||
+ parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
+ OpAsmParser::Delimiter::Paren) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(calleeType) ||
+ parser->resolveOperand(callee, calleeType, result->operands) ||
+ parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
+ result->operands) ||
+ parser->addTypesToList(calleeType->getResults(), result->types);
+}
+
+void CallIndirectOp::print(OpAsmPrinter *p) const {
+ *p << "call_indirect ";
+ p->printOperand(getCallee());
+ *p << '(';
+ auto operandRange = getOperands();
+ p->printOperands(++operandRange.begin(), operandRange.end());
+ *p << ')';
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
+ *p << " : " << *getCallee()->getType();
+}
+
+const char *CallIndirectOp::verify() const {
+ // The callee must be a function.
+ auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
+ if (!fnType)
+ return "callee must have function type";
+
+ // Verify that the operand and result types match the callee.
+ if (fnType->getNumInputs() != getNumOperands() - 1)
+ return "incorrect number of operands for callee";
+
+ for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
+ if (getOperand(i + 1)->getType() != fnType->getInput(i))
+ return "operand type mismatch";
+ }
+
+ if (fnType->getNumResults() != getNumResults())
+ return "incorrect number of results for callee";
+
+ for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
+ if (getResult(i)->getType() != fnType->getResult(i))
+ return "result type mismatch";
+ }
+
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Constant*Op
//===----------------------------------------------------------------------===//
void ConstantOp::print(OpAsmPrinter *p) const {
@@ -444,10 +589,10 @@
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type *, 2> types;
-
- return parser->parseOperandList(opInfo, -1, OpAsmParser::Delimiter::None) ||
+ llvm::SMLoc loc;
+ return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
- parser->resolveOperands(opInfo, types, result->operands);
+ parser->resolveOperands(opInfo, types, loc, result->operands);
}
void ReturnOp::print(OpAsmPrinter *p) const {
@@ -541,7 +686,7 @@
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
- opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DeallocOp,
- DimOp, LoadOp, ReturnOp, StoreOp>(
+ opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
+ ConstantOp, DeallocOp, DimOp, LoadOp, ReturnOp, StoreOp>(
/*prefix=*/"");
}