Implement return statement as RetOp operation. Add verification of the return statement placement and operands. Add parser and parsing error tests for return statements with non-zero number of operands. Add a few missing tests for ForStmt parsing errors.
Prior to this CL, return statement had no explicit representation in MLIR. Now, it is represented as ReturnOp standard operation and is pretty printed according to the return statement syntax. This way statement walkers can process ML function return operands without making special case for them.
PiperOrigin-RevId: 208092424
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 3a3b196..892ee5d 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -1207,7 +1207,6 @@
printFunctionSignature();
os << " {\n";
print(function);
- os << " return\n";
os << "}\n\n";
}
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 9644a00..c3f815f 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -22,7 +22,9 @@
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
+#include "mlir/Support/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
+
using namespace mlir;
static void printDimAndSymbolList(Operation::const_operand_iterator begin,
@@ -60,6 +62,10 @@
return false;
}
+//===----------------------------------------------------------------------===//
+// AddFOp
+//===----------------------------------------------------------------------===//
+
bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type *type;
@@ -86,6 +92,10 @@
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// AffineApplyOp
+//===----------------------------------------------------------------------===//
+
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder();
auto *affineIntTy = builder.getAffineIntType();
@@ -135,6 +145,10 @@
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// AllocOp
+//===----------------------------------------------------------------------===//
+
void AllocOp::print(OpAsmPrinter *p) const {
MemRefType *type = cast<MemRefType>(getMemRef()->getType());
*p << "alloc";
@@ -183,6 +197,10 @@
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
void ConstantOp::print(OpAsmPrinter *p) const {
*p << "constant " << *getValue();
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
@@ -249,6 +267,10 @@
return result;
}
+//===----------------------------------------------------------------------===//
+// DimOp
+//===----------------------------------------------------------------------===//
+
void DimOp::print(OpAsmPrinter *p) const {
*p << "dim " << *getOperand() << ", " << getIndex();
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
@@ -293,6 +315,10 @@
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
void LoadOp::print(OpAsmPrinter *p) const {
*p << "load " << *getMemRef() << '[';
p->printOperands(getIndices());
@@ -336,6 +362,52 @@
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
+ SmallVector<OpAsmParser::OperandType, 2> opInfo;
+ SmallVector<Type *, 2> types;
+ SmallVector<SSAValue *, 2> operands;
+
+ return parser->parseOperandList(opInfo, -1, OpAsmParser::Delimiter::None) ||
+ (!opInfo.empty() && parser->parseColonTypeList(types)) ||
+ parser->resolveOperands(opInfo, types, result->operands);
+}
+
+void ReturnOp::print(OpAsmPrinter *p) const {
+ *p << "return";
+ if (getNumOperands() > 0) {
+ *p << " ";
+ p->printOperands(operand_begin(), operand_end());
+ *p << " : ";
+ interleave(operand_begin(), operand_end(),
+ [&](auto *e) { p->printType(e->getType()); },
+ [&]() { *p << ", "; });
+ }
+}
+
+const char *ReturnOp::verify() const {
+ // ReturnOp must be part of an ML function.
+ if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) {
+ StmtBlock *block = stmt->getBlock();
+
+ if (!block || !isa<MLFunction>(block) ||
+ &cast<MLFunction>(block)->back() != stmt)
+ return "must be the last statement in the ML function";
+
+ // Return success. Checking that operand types match those in the function
+ // signature is performed in the ML function verifier.
+ return nullptr;
+ }
+ return "cannot occur in a CFG function.";
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
void StoreOp::print(OpAsmPrinter *p) const {
*p << "store " << *getValueToStore();
*p << ", " << *getMemRef() << '[';
@@ -391,9 +463,13 @@
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// Register operations.
+//===----------------------------------------------------------------------===//
+
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DimOp, LoadOp,
- StoreOp>(
+ StoreOp, ReturnOp>(
/*prefix=*/"");
}
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index dde196c..272395c 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -207,7 +207,7 @@
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (inst.getOperand(i)->getType() != results[i])
return failure("type of return operand " + Twine(i) +
- " doesn't match result function result type",
+ " doesn't match function result type",
inst);
return false;
@@ -313,7 +313,11 @@
llvm::PrettyStackTraceFormat fmt("MLIR Verifier: mlfunc @%s",
fn.getName().c_str());
- // TODO: check basic structural properties.
+ // TODO: check basic structural properties
+ // TODO: check that operation is not a return statement unless it's
+ // the last one in the function.
+ if (verifyReturn())
+ return true;
return verifyDominance();
}
@@ -321,6 +325,9 @@
/// Walk all of the code in this MLFunc and verify that the operands of any
/// operations are properly dominated by their definitions.
bool verifyDominance();
+
+ /// Verify that function has a return statement that matches its signature.
+ bool verifyReturn();
};
} // end anonymous namespace
@@ -390,6 +397,37 @@
return walkBlock(fn);
}
+bool MLFuncVerifier::verifyReturn() {
+ // TODO: fold return verification in the pass that verifies all statements.
+ const char missingReturnMsg[] = "ML function must end with return statement";
+ if (fn.getStatements().empty())
+ return failure(missingReturnMsg, fn);
+
+ const auto &stmt = fn.getStatements().back();
+ if (const auto *op = dyn_cast<OperationStmt>(&stmt)) {
+ if (!op->isReturn())
+ return failure(missingReturnMsg, fn);
+
+ // The operand number and types must match the function signature.
+ // TODO: move this verification in ReturnOp::verify() if printing
+ // of the error messages below can be made to work there.
+ const auto &results = fn.getType()->getResults();
+ if (op->getNumOperands() != results.size())
+ return failure("return has " + Twine(op->getNumOperands()) +
+ " operands, but enclosing function returns " +
+ Twine(results.size()),
+ *op);
+
+ for (unsigned i = 0, e = results.size(); i != e; ++i)
+ if (op->getOperand(i)->getType() != results[i])
+ return failure("type of return operand " + Twine(i) +
+ " doesn't match function result type",
+ *op);
+ return false;
+ }
+ return failure(missingReturnMsg, fn);
+}
+
//===----------------------------------------------------------------------===//
// Entrypoints
//===----------------------------------------------------------------------===//
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 45ac4e6..f263804 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -2117,20 +2117,9 @@
ParseResult MLFunctionParser::parseFunctionBody() {
auto braceLoc = getToken().getLoc();
- // Parse statements in this function
- if (parseToken(Token::l_brace, "expected '{' in ML function") ||
- parseStatements(function)) {
- return ParseFailure;
- }
-
- // TODO: store return operands in the IR.
- SmallVector<SSAUseInfo, 4> dummyUseInfo;
-
- if (parseToken(Token::kw_return,
- "ML function must end with return statement") ||
- parseOptionalSSAUseList(dummyUseInfo) ||
- parseToken(Token::r_brace, "expected '}' to end mlfunc"))
+ // Parse statements in this function.
+ if (parseStmtBlock(function))
return ParseFailure;
getModule()->getFunctions().push_back(function);
@@ -2154,7 +2143,7 @@
StringRef inductionVariableName = getTokenSpelling();
consumeToken(Token::percent_identifier);
- if (parseToken(Token::equal, "expected ="))
+ if (parseToken(Token::equal, "expected '='"))
return ParseFailure;
// Parse loop bounds
@@ -2387,7 +2376,10 @@
builder.setInsertionPointToEnd(block);
- while (getToken().isNot(Token::kw_return, Token::r_brace)) {
+ // Parse statements till we see '}' or 'return'.
+ // Return statement is parsed separately to emit a more intuitive error
+ // when '}' is missing after the return statement.
+ while (getToken().isNot(Token::r_brace, Token::kw_return)) {
switch (getToken().getKind()) {
default:
if (parseOperation(createOpFunc))
@@ -2404,6 +2396,11 @@
} // end switch
}
+ // Parse the return statement.
+ if (getToken().is(Token::kw_return))
+ if (parseOperation(createOpFunc))
+ return ParseFailure;
+
return ParseSuccess;
}
@@ -2413,8 +2410,7 @@
ParseResult MLFunctionParser::parseStmtBlock(StmtBlock *block) {
if (parseToken(Token::l_brace, "expected '{' before statement list") ||
parseStatements(block) ||
- parseToken(Token::r_brace,
- "expected '}' at the end of the statement block"))
+ parseToken(Token::r_brace, "expected '}' after statement list"))
return ParseFailure;
return ParseSuccess;