Implement return statement as RetOp operation. Add verification of the return statement placement and operands. Add parser and parsing error tests for return statements with non-zero number of operands. Add a few missing tests for ForStmt parsing errors.
Prior to this CL, return statement had no explicit representation in MLIR. Now, it is represented as ReturnOp standard operation and is pretty printed according to the return statement syntax. This way statement walkers can process ML function return operands without making special case for them.
PiperOrigin-RevId: 208092424
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 9644a00..c3f815f 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -22,7 +22,9 @@
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
+#include "mlir/Support/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
+
using namespace mlir;
static void printDimAndSymbolList(Operation::const_operand_iterator begin,
@@ -60,6 +62,10 @@
return false;
}
+//===----------------------------------------------------------------------===//
+// AddFOp
+//===----------------------------------------------------------------------===//
+
bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type *type;
@@ -86,6 +92,10 @@
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// AffineApplyOp
+//===----------------------------------------------------------------------===//
+
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder();
auto *affineIntTy = builder.getAffineIntType();
@@ -135,6 +145,10 @@
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// AllocOp
+//===----------------------------------------------------------------------===//
+
void AllocOp::print(OpAsmPrinter *p) const {
MemRefType *type = cast<MemRefType>(getMemRef()->getType());
*p << "alloc";
@@ -183,6 +197,10 @@
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
void ConstantOp::print(OpAsmPrinter *p) const {
*p << "constant " << *getValue();
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
@@ -249,6 +267,10 @@
return result;
}
+//===----------------------------------------------------------------------===//
+// DimOp
+//===----------------------------------------------------------------------===//
+
void DimOp::print(OpAsmPrinter *p) const {
*p << "dim " << *getOperand() << ", " << getIndex();
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
@@ -293,6 +315,10 @@
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
void LoadOp::print(OpAsmPrinter *p) const {
*p << "load " << *getMemRef() << '[';
p->printOperands(getIndices());
@@ -336,6 +362,52 @@
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
+ SmallVector<OpAsmParser::OperandType, 2> opInfo;
+ SmallVector<Type *, 2> types;
+ SmallVector<SSAValue *, 2> operands;
+
+ return parser->parseOperandList(opInfo, -1, OpAsmParser::Delimiter::None) ||
+ (!opInfo.empty() && parser->parseColonTypeList(types)) ||
+ parser->resolveOperands(opInfo, types, result->operands);
+}
+
+void ReturnOp::print(OpAsmPrinter *p) const {
+ *p << "return";
+ if (getNumOperands() > 0) {
+ *p << " ";
+ p->printOperands(operand_begin(), operand_end());
+ *p << " : ";
+ interleave(operand_begin(), operand_end(),
+ [&](auto *e) { p->printType(e->getType()); },
+ [&]() { *p << ", "; });
+ }
+}
+
+const char *ReturnOp::verify() const {
+ // ReturnOp must be part of an ML function.
+ if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) {
+ StmtBlock *block = stmt->getBlock();
+
+ if (!block || !isa<MLFunction>(block) ||
+ &cast<MLFunction>(block)->back() != stmt)
+ return "must be the last statement in the ML function";
+
+ // Return success. Checking that operand types match those in the function
+ // signature is performed in the ML function verifier.
+ return nullptr;
+ }
+ return "cannot occur in a CFG function.";
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
void StoreOp::print(OpAsmPrinter *p) const {
*p << "store " << *getValueToStore();
*p << ", " << *getMemRef() << '[';
@@ -391,9 +463,13 @@
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// Register operations.
+//===----------------------------------------------------------------------===//
+
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DimOp, LoadOp,
- StoreOp>(
+ StoreOp, ReturnOp>(
/*prefix=*/"");
}