Add support for operands to the return instructions, enhance verifier to report errors through the diagnostics system when invoked by the parser. It doesn't have perfect location info, but it is close enough to be testable.
PiperOrigin-RevId: 205534392
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 1da2312..1a003b6 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -134,10 +134,12 @@
// Terminators.
- ReturnInst *createReturnInst() { return insertTerminator(new ReturnInst()); }
+ ReturnInst *createReturnInst(ArrayRef<CFGValue *> operands) {
+ return insertTerminator(ReturnInst::create(operands));
+ }
BranchInst *createBranchInst(BasicBlock *dest) {
- return insertTerminator(new BranchInst(dest));
+ return insertTerminator(BranchInst::create(dest));
}
private:
diff --git a/include/mlir/IR/Function.h b/include/mlir/IR/Function.h
index 7c00cd0..e413172 100644
--- a/include/mlir/IR/Function.h
+++ b/include/mlir/IR/Function.h
@@ -46,8 +46,9 @@
MLIRContext *getContext() const;
/// Perform (potentially expensive) checks of invariants, used to detect
- /// compiler bugs. This aborts on failure.
- void verify() const;
+ /// compiler bugs. On error, this fills in the string and return true,
+ /// or aborts if the string was not provided.
+ bool verify(std::string *errorResult = nullptr) const;
void print(raw_ostream &os) const;
void dump() const;
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index 52047d3..8d4068b 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -187,9 +187,7 @@
/// and may pass basic block arguments to the successor.
class BranchInst : public TerminatorInst {
public:
- explicit BranchInst(BasicBlock *dest)
- : TerminatorInst(Kind::Branch), dest(dest) {
- }
+ static BranchInst *create(BasicBlock *dest) { return new BranchInst(dest); }
~BranchInst() {}
/// Return the block this branch jumps to.
@@ -205,6 +203,9 @@
}
private:
+ explicit BranchInst(BasicBlock *dest)
+ : TerminatorInst(Kind::Branch), dest(dest) {}
+
BasicBlock *dest;
};
@@ -212,17 +213,53 @@
/// The 'return' instruction represents the end of control flow within the
/// current function, and can return zero or more results. The result list is
/// required to align with the result list of the containing function's type.
-class ReturnInst : public TerminatorInst {
+class ReturnInst final
+ : public TerminatorInst,
+ private llvm::TrailingObjects<ReturnInst, InstOperand> {
public:
- explicit ReturnInst() : TerminatorInst(Kind::Return) {}
- ~ReturnInst() {}
+ /// Create a new OperationInst with the specific fields.
+ static ReturnInst *create(ArrayRef<CFGValue *> operands);
- // TODO: Needs to take an operand list.
+ unsigned getNumOperands() const { return numOperands; }
+
+ // TODO: Add a getOperands() custom sequence that provides a value projection
+ // of the operand list.
+ CFGValue *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
+ const CFGValue *getOperand(unsigned idx) const {
+ return getInstOperand(idx).get();
+ }
+
+ ArrayRef<InstOperand> getInstOperands() const {
+ return {getTrailingObjects<InstOperand>(), numOperands};
+ }
+ MutableArrayRef<InstOperand> getInstOperands() {
+ return {getTrailingObjects<InstOperand>(), numOperands};
+ }
+
+ InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; }
+ const InstOperand &getInstOperand(unsigned idx) const {
+ return getInstOperands()[idx];
+ }
+
+ void destroy();
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Instruction *inst) {
return inst->getKind() == Kind::Return;
}
+
+private:
+ // This stuff is used by the TrailingObjects template.
+ friend llvm::TrailingObjects<ReturnInst, InstOperand>;
+ size_t numTrailingObjects(OverloadToken<InstOperand>) const {
+ return numOperands;
+ }
+
+ explicit ReturnInst(unsigned numOperands)
+ : TerminatorInst(Kind::Return), numOperands(numOperands) {}
+ ~ReturnInst();
+
+ unsigned numOperands;
};
} // end namespace mlir
diff --git a/include/mlir/IR/Module.h b/include/mlir/IR/Module.h
index ccc832a..8401252 100644
--- a/include/mlir/IR/Module.h
+++ b/include/mlir/IR/Module.h
@@ -40,8 +40,9 @@
std::vector<Function*> functionList;
/// Perform (potentially expensive) checks of invariants, used to detect
- /// compiler bugs. This aborts on failure.
- void verify() const;
+ /// compiler bugs. On error, this fills in the string and return true,
+ /// or aborts if the string was not provided.
+ bool verify(std::string *errorResult = nullptr) const;
void print(raw_ostream &os) const;
void dump() const;
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index ff6ce3e..d811223 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -715,7 +715,19 @@
void CFGFunctionPrinter::print(const BranchInst *inst) {
os << " br bb" << getBBID(inst->getDest());
}
-void CFGFunctionPrinter::print(const ReturnInst *inst) { os << " return"; }
+void CFGFunctionPrinter::print(const ReturnInst *inst) {
+ os << " return";
+
+ if (inst->getNumOperands() != 0)
+ os << ' ';
+
+ // TODO: Use getOperands() when we have it.
+ interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
+ printValueID(operand.get());
+ os << " : ";
+ ModulePrinter::print(operand.get()->getType());
+ });
+}
void ModulePrinter::print(const CFGFunction *fn) {
CFGFunctionPrinter(fn, *this).print();
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index b925cef..428523e 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -56,7 +56,7 @@
delete cast<BranchInst>(this);
break;
case Kind::Return:
- delete cast<ReturnInst>(this);
+ cast<ReturnInst>(this)->destroy();
break;
}
}
@@ -182,4 +182,28 @@
destroy();
}
+/// Create a new OperationInst with the specific fields.
+ReturnInst *ReturnInst::create(ArrayRef<CFGValue *> operands) {
+ auto byteSize = totalSizeToAlloc<InstOperand>(operands.size());
+ void *rawMem = malloc(byteSize);
+ // Initialize the ReturnInst part of the instruction.
+ auto inst = ::new (rawMem) ReturnInst(operands.size());
+
+ // Initialize the operands and results.
+ auto instOperands = inst->getInstOperands();
+ for (unsigned i = 0, e = operands.size(); i != e; ++i)
+ new (&instOperands[i]) InstOperand(inst, operands[i]);
+ return inst;
+}
+
+void ReturnInst::destroy() {
+ this->~ReturnInst();
+ free(this);
+}
+
+ReturnInst::~ReturnInst() {
+ // Explicitly run the destructors for the operands.
+ for (auto &operand : getInstOperands())
+ operand.~InstOperand();
+}
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index ffc87aa..3c3bb16 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -41,73 +41,131 @@
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
-template <typename T>
-static void failure(const Twine &message, const T &value) {
- // Print the error message and flush the stream in case printing the value
- // causes a crash.
- llvm::errs() << "MLIR verification failure: " << message << "\n";
- llvm::errs().flush();
- value.dump();
-}
+namespace {
+/// Base class for the verifiers in this file. It is a pervasive truth that
+/// this file treats "true" as an error that needs to be recovered from, and
+/// "false" as success.
+///
+class Verifier {
+public:
+ template <typename T>
+ static void failure(const Twine &message, const T &value, raw_ostream &os) {
+ // Print the error message and flush the stream in case printing the value
+ // causes a crash.
+ os << "MLIR verification failure: " + message + "\n";
+ os.flush();
+ value.print(os);
+ }
+
+ template <typename T>
+ bool failure(const Twine &message, const T &value) {
+ // If the caller isn't trying to collect failure information, just print
+ // the result and abort.
+ if (!errorResult) {
+ failure(message, value, llvm::errs());
+ abort();
+ }
+
+ // Otherwise, emit the error into the string and return true.
+ llvm::raw_string_ostream os(*errorResult);
+ failure(message, value, os);
+ os.flush();
+ return true;
+ }
+
+protected:
+ explicit Verifier(std::string *errorResult) : errorResult(errorResult) {}
+
+private:
+ std::string *errorResult;
+};
+} // end anonymous namespace
//===----------------------------------------------------------------------===//
// CFG Functions
//===----------------------------------------------------------------------===//
namespace {
-class CFGFuncVerifier {
+class CFGFuncVerifier : public Verifier {
public:
const CFGFunction &fn;
OperationSet &operationSet;
- CFGFuncVerifier(const CFGFunction &fn)
- : fn(fn), operationSet(OperationSet::get(fn.getContext())) {}
+ CFGFuncVerifier(const CFGFunction &fn, std::string *errorResult)
+ : Verifier(errorResult), fn(fn),
+ operationSet(OperationSet::get(fn.getContext())) {}
- void verify();
- void verifyBlock(const BasicBlock &block);
- void verifyTerminator(const TerminatorInst &term);
- void verifyOperation(const OperationInst &inst);
+ bool verify();
+ bool verifyBlock(const BasicBlock &block);
+ bool verifyOperation(const OperationInst &inst);
+ bool verifyTerminator(const TerminatorInst &term);
+ bool verifyReturn(const ReturnInst &inst);
};
} // end anonymous namespace
-void CFGFuncVerifier::verify() {
+bool CFGFuncVerifier::verify() {
// TODO: Lots to be done here, including verifying dominance information when
// we have uses and defs.
for (auto &block : fn) {
- verifyBlock(block);
+ if (verifyBlock(block))
+ return true;
}
+ return false;
}
-void CFGFuncVerifier::verifyBlock(const BasicBlock &block) {
+bool CFGFuncVerifier::verifyBlock(const BasicBlock &block) {
if (!block.getTerminator())
- failure("basic block with no terminator", block);
- verifyTerminator(*block.getTerminator());
+ return failure("basic block with no terminator", block);
+
+ if (verifyTerminator(*block.getTerminator()))
+ return true;
for (auto &inst : block) {
- verifyOperation(inst);
+ if (verifyOperation(inst))
+ return true;
}
+ return false;
}
-void CFGFuncVerifier::verifyTerminator(const TerminatorInst &term) {
+bool CFGFuncVerifier::verifyTerminator(const TerminatorInst &term) {
if (term.getFunction() != &fn)
- failure("terminator in the wrong function", term);
+ return failure("terminator in the wrong function", term);
// TODO: Check that operands are structurally ok.
// TODO: Check that successors are in the right function.
+
+ if (auto *ret = dyn_cast<ReturnInst>(&term))
+ return verifyReturn(*ret);
+
+ return false;
}
-void CFGFuncVerifier::verifyOperation(const OperationInst &inst) {
+bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) {
+ // Verify that the return operands match the results of the function.
+ auto results = fn.getType()->getResults();
+ if (inst.getNumOperands() != results.size())
+ return failure("return has " + Twine(inst.getNumOperands()) +
+ " operands, but enclosing function returns " +
+ Twine(results.size()),
+ inst);
+
+ return false;
+}
+
+bool CFGFuncVerifier::verifyOperation(const OperationInst &inst) {
if (inst.getFunction() != &fn)
- failure("operation in the wrong function", inst);
+ return failure("operation in the wrong function", inst);
// TODO: Check that operands are structurally ok.
// See if we can get operation info for this.
if (auto *opInfo = inst.getAbstractOperation(fn.getContext())) {
if (auto errorMessage = opInfo->verifyInvariants(&inst))
- failure(errorMessage, inst);
+ return failure(errorMessage, inst);
}
+
+ return false;
}
//===----------------------------------------------------------------------===//
@@ -115,14 +173,16 @@
//===----------------------------------------------------------------------===//
namespace {
-class MLFuncVerifier {
+class MLFuncVerifier : public Verifier {
public:
const MLFunction &fn;
- MLFuncVerifier(const MLFunction &fn) : fn(fn) {}
+ MLFuncVerifier(const MLFunction &fn, std::string *errorResult)
+ : Verifier(errorResult), fn(fn) {}
- void verify() {
+ bool verify() {
// TODO.
+ return false;
}
};
} // end anonymous namespace
@@ -132,24 +192,33 @@
//===----------------------------------------------------------------------===//
/// Perform (potentially expensive) checks of invariants, used to detect
-/// compiler bugs. This aborts on failure.
-void Function::verify() const {
+/// compiler bugs. On error, this fills in the string and return true,
+/// or aborts if the string was not provided.
+bool Function::verify(std::string *errorResult) const {
switch (getKind()) {
case Kind::ExtFunc:
// No body, nothing can be wrong here.
- break;
+ return false;
case Kind::CFGFunc:
- return CFGFuncVerifier(*cast<CFGFunction>(this)).verify();
+ return CFGFuncVerifier(*cast<CFGFunction>(this), errorResult).verify();
case Kind::MLFunc:
- return MLFuncVerifier(*cast<MLFunction>(this)).verify();
+ return MLFuncVerifier(*cast<MLFunction>(this), errorResult).verify();
}
}
/// Perform (potentially expensive) checks of invariants, used to detect
-/// compiler bugs. This aborts on failure.
-void Module::verify() const {
+/// compiler bugs. On error, this fills in the string and return true,
+/// or aborts if the string was not provided.
+bool Module::verify(std::string *errorResult) const {
+
/// Check that each function is correct.
for (auto fn : functionList) {
- fn->verify();
+ if (fn->verify(errorResult))
+ return true;
}
+
+ // Make sure the error string is empty on success.
+ if (errorResult)
+ errorResult->clear();
+ return false;
}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 1fd3432..62a9fe7 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -137,10 +137,16 @@
return true;
}
- ParseResult parseCommaSeparatedList(
- Token::Kind rightToken,
- const std::function<ParseResult()> &parseElement,
- bool allowEmptyList = true);
+ /// Parse a comma-separated list of elements up until the specified end token.
+ ParseResult
+ parseCommaSeparatedListUntil(Token::Kind rightToken,
+ const std::function<ParseResult()> &parseElement,
+ bool allowEmptyList = true);
+
+ /// Parse a comma separated list of elements that must have at least one entry
+ /// in it.
+ ParseResult
+ parseCommaSeparatedList(const std::function<ParseResult()> &parseElement);
// We have two forms of parsing methods - those that return a non-null
// pointer on success, and those that return a ParseResult to indicate whether
@@ -188,24 +194,10 @@
return ParseFailure;
}
-/// Parse a comma-separated list of elements, terminated with an arbitrary
-/// token. This allows empty lists if allowEmptyList is true.
-///
-/// abstract-list ::= rightToken // if allowEmptyList == true
-/// abstract-list ::= element (',' element)* rightToken
-///
-ParseResult Parser::
-parseCommaSeparatedList(Token::Kind rightToken,
- const std::function<ParseResult()> &parseElement,
- bool allowEmptyList) {
- // Handle the empty case.
- if (getToken().is(rightToken)) {
- if (!allowEmptyList)
- return emitError("expected list element");
- consumeToken(rightToken);
- return ParseSuccess;
- }
-
+/// Parse a comma separated list of elements that must have at least one entry
+/// in it.
+ParseResult Parser::parseCommaSeparatedList(
+ const std::function<ParseResult()> &parseElement) {
// Non-empty case starts with an element.
if (parseElement())
return ParseFailure;
@@ -215,6 +207,28 @@
if (parseElement())
return ParseFailure;
}
+ return ParseSuccess;
+}
+
+/// Parse a comma-separated list of elements, terminated with an arbitrary
+/// token. This allows empty lists if allowEmptyList is true.
+///
+/// abstract-list ::= rightToken // if allowEmptyList == true
+/// abstract-list ::= element (',' element)* rightToken
+///
+ParseResult Parser::parseCommaSeparatedListUntil(
+ Token::Kind rightToken, const std::function<ParseResult()> &parseElement,
+ bool allowEmptyList) {
+ // Handle the empty case.
+ if (getToken().is(rightToken)) {
+ if (!allowEmptyList)
+ return emitError("expected list element");
+ consumeToken(rightToken);
+ return ParseSuccess;
+ }
+
+ if (parseCommaSeparatedList(parseElement))
+ return ParseFailure;
// Consume the end character.
if (!consumeIf(rightToken))
@@ -447,8 +461,8 @@
};
// Parse comma separated list of affine maps, followed by memory space.
- if (parseCommaSeparatedList(Token::greater, parseElt,
- /*allowEmptyList=*/false)) {
+ if (parseCommaSeparatedListUntil(Token::greater, parseElt,
+ /*allowEmptyList=*/false)) {
return nullptr;
}
// Check that MemRef type specifies at least one affine map in composition.
@@ -520,7 +534,7 @@
if (!consumeIf(Token::l_paren))
return parseElt();
- if (parseCommaSeparatedList(Token::r_paren, parseElt))
+ if (parseCommaSeparatedListUntil(Token::r_paren, parseElt))
return ParseFailure;
return ParseSuccess;
@@ -585,7 +599,7 @@
return elements.back() ? ParseSuccess : ParseFailure;
};
- if (parseCommaSeparatedList(Token::r_bracket, parseElt))
+ if (parseCommaSeparatedListUntil(Token::r_bracket, parseElt))
return nullptr;
return builder.getArrayAttr(elements);
}
@@ -628,7 +642,7 @@
return ParseSuccess;
};
- if (parseCommaSeparatedList(Token::r_brace, parseElt))
+ if (parseCommaSeparatedListUntil(Token::r_brace, parseElt))
return ParseFailure;
return ParseSuccess;
@@ -717,7 +731,8 @@
/// for non-conforming expressions.
AffineExpr *AffineMapParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
AffineExpr *lhs,
- AffineExpr *rhs, SMLoc opLoc) {
+ AffineExpr *rhs,
+ SMLoc opLoc) {
// TODO: make the error location info accurate.
switch (op) {
case Mul:
@@ -1066,7 +1081,7 @@
return emitError("expected '['");
auto parseElt = [&]() -> ParseResult { return parseDimOrSymbolId(false); };
- return parseCommaSeparatedList(Token::r_bracket, parseElt);
+ return parseCommaSeparatedListUntil(Token::r_bracket, parseElt);
}
/// Parse the list of dimensional identifiers to an affine map.
@@ -1075,7 +1090,7 @@
return emitError("expected '(' at start of dimensional identifiers list");
auto parseElt = [&]() -> ParseResult { return parseDimOrSymbolId(true); };
- return parseCommaSeparatedList(Token::r_paren, parseElt);
+ return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
}
/// Parse an affine map definition.
@@ -1114,7 +1129,7 @@
// Parse a multi-dimensional affine expression (a comma-separated list of 1-d
// affine expressions); the list cannot be empty.
// Grammar: multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
- if (parseCommaSeparatedList(Token::r_paren, parseElt, false))
+ if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, false))
return nullptr;
// Parse optional range sizes.
@@ -1137,7 +1152,7 @@
};
setSymbolicParsing(true);
- if (parseCommaSeparatedList(Token::r_paren, parseRangeSize, false))
+ if (parseCommaSeparatedListUntil(Token::r_paren, parseRangeSize, false))
return nullptr;
if (exprs.size() > rangeSizes.size())
return (emitError(loc, "fewer range sizes than range expressions"),
@@ -1182,7 +1197,7 @@
/// After the function is finished parsing, this function checks to see if
/// there are any remaining issues.
- ParseResult finalizeFunction();
+ ParseResult finalizeFunction(Function *func, SMLoc loc);
/// This represents a use of an SSA value in the program. The first two
/// entries in the tuple are the name and result number of a reference. The
@@ -1203,12 +1218,12 @@
// SSA parsing productions.
ParseResult parseSSAUse(SSAUseInfo &result);
- ParseResult parseOptionalSSAUseList(Token::Kind endToken,
- SmallVectorImpl<SSAUseInfo> &results);
+ ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results);
SSAValue *parseSSAUseAndType();
+
+ template <typename ValueTy>
ParseResult
- parseOptionalSSAUseAndTypeList(Token::Kind endToken,
- SmallVectorImpl<SSAValue *> &results);
+ parseOptionalSSAUseAndTypeList(SmallVectorImpl<ValueTy *> &results);
// Operations
ParseResult parseOperation(const CreateOperationFunction &createOpFunc);
@@ -1315,7 +1330,7 @@
/// After the function is finished parsing, this function checks to see if
/// there are any remaining issues.
-ParseResult FunctionParser::finalizeFunction() {
+ParseResult FunctionParser::finalizeFunction(Function *func, SMLoc loc) {
// Check for any forward references that are left. If we find any, error out.
if (!forwardReferencePlaceholders.empty()) {
SmallVector<std::pair<const char *, SSAValue *>, 4> errors;
@@ -1330,6 +1345,11 @@
return ParseFailure;
}
+ // Run the verifier on this function. If an error is detected, report it.
+ std::string errorString;
+ if (func->verify(&errorString))
+ return emitError(loc, errorString);
+
return ParseSuccess;
}
@@ -1363,9 +1383,10 @@
/// ssa-use-list-opt ::= ssa-use-list?
///
ParseResult
-FunctionParser::parseOptionalSSAUseList(Token::Kind endToken,
- SmallVectorImpl<SSAUseInfo> &results) {
- return parseCommaSeparatedList(endToken, [&]() -> ParseResult {
+FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
+ if (!getToken().is(Token::percent_identifier))
+ return ParseSuccess;
+ return parseCommaSeparatedList([&]() -> ParseResult {
SSAUseInfo result;
if (parseSSAUse(result))
return ParseFailure;
@@ -1396,11 +1417,15 @@
///
/// ssa-use-and-type-list ::= ssa-use-and-type (`,` ssa-use-and-type)*
///
+template <typename ValueTy>
ParseResult FunctionParser::parseOptionalSSAUseAndTypeList(
- Token::Kind endToken, SmallVectorImpl<SSAValue *> &results) {
- return parseCommaSeparatedList(endToken, [&]() -> ParseResult {
+ SmallVectorImpl<ValueTy *> &results) {
+ if (getToken().isNot(Token::percent_identifier))
+ return ParseSuccess;
+
+ return parseCommaSeparatedList([&]() -> ParseResult {
if (auto *value = parseSSAUseAndType()) {
- results.push_back(value);
+ results.push_back(cast<ValueTy>(value));
return ParseSuccess;
}
return ParseFailure;
@@ -1442,7 +1467,11 @@
// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;
- parseOptionalSSAUseList(Token::r_paren, operandInfos);
+ if (parseOptionalSSAUseList(operandInfos))
+ return ParseFailure;
+
+ if (!consumeIf(Token::r_paren))
+ return emitError("expected ')' to end operand list");
SmallVector<NamedAttribute, 4> attributes;
if (getToken().is(Token::l_brace)) {
@@ -1548,6 +1577,7 @@
} // end anonymous namespace
ParseResult CFGFunctionParser::parseFunctionBody() {
+ auto braceLoc = getToken().getLoc();
if (!consumeIf(Token::l_brace))
return emitError("expected '{' in CFG function");
@@ -1572,7 +1602,7 @@
getModule()->functionList.push_back(function);
- return finalizeFunction();
+ return finalizeFunction(function, braceLoc);
}
/// Basic block declaration.
@@ -1600,9 +1630,11 @@
// If an argument list is present, parse it.
if (consumeIf(Token::l_paren)) {
- SmallVector<SSAValue *, 8> bbArgs;
- if (parseOptionalSSAUseAndTypeList(Token::r_paren, bbArgs))
+ SmallVector<SSAUseInfo, 8> bbArgs;
+ if (parseOptionalSSAUseList(bbArgs))
return ParseFailure;
+ if (!consumeIf(Token::r_paren))
+ return emitError("expected ')' to end argument list");
// TODO: attach it.
}
@@ -1648,9 +1680,14 @@
default:
return (emitError("expected terminator at end of basic block"), nullptr);
- case Token::kw_return:
+ case Token::kw_return: {
consumeToken(Token::kw_return);
- return builder.createReturnInst();
+ SmallVector<CFGValue *, 8> results;
+ if (parseOptionalSSAUseAndTypeList(results))
+ return nullptr;
+
+ return builder.createReturnInst(results);
+ }
case Token::kw_br: {
consumeToken(Token::kw_br);
@@ -1693,6 +1730,7 @@
} // end anonymous namespace
ParseResult MLFunctionParser::parseFunctionBody() {
+ auto braceLoc = getToken().getLoc();
if (!consumeIf(Token::l_brace))
return emitError("expected '{' in ML function");
@@ -1705,12 +1743,15 @@
// TODO: store return operands in the IR.
SmallVector<SSAUseInfo, 4> dummyUseInfo;
- if (parseOptionalSSAUseList(Token::r_brace, dummyUseInfo))
+ if (parseOptionalSSAUseList(dummyUseInfo))
return ParseFailure;
+ if (!consumeIf(Token::r_brace))
+ return emitError("expected '}' to end mlfunc");
+
getModule()->functionList.push_back(function);
- return finalizeFunction();
+ return finalizeFunction(function, braceLoc);
}
/// For statement.
@@ -1959,7 +2000,7 @@
if (!consumeIf(Token::l_paren))
llvm_unreachable("expected '('");
- return parseCommaSeparatedList(Token::r_paren, parseElt);
+ return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
}
/// Parse a function signature, starting with a name and including the parameter
diff --git a/test/IR/parser-errors.mlir b/test/IR/parser-errors.mlir
index 1d29c1b..6177136 100644
--- a/test/IR/parser-errors.mlir
+++ b/test/IR/parser-errors.mlir
@@ -193,7 +193,7 @@
mlfunc @missing_rbrace() {
return %a
-mlfunc @d {return} // expected-error {{expected ',' or '}'}}
+mlfunc @d {return} // expected-error {{expected '}' to end mlfunc}}
// -----
@@ -202,3 +202,9 @@
// -----
+cfgfunc @resulterror() -> i32 { // expected-error {{return has 0 operands, but enclosing function returns 1}}
+bb42:
+ return
+}
+
+// -----
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index d10b1b6..98844dc 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -79,8 +79,8 @@
// CHECK: }
}
-// CHECK-LABEL: cfgfunc @multiblock() -> i32 {
-cfgfunc @multiblock() -> i32 {
+// CHECK-LABEL: cfgfunc @multiblock() {
+cfgfunc @multiblock() {
bb0: // CHECK: bb0:
return // CHECK: return
bb1: // CHECK: bb1:
@@ -179,17 +179,19 @@
return
}
-// CHECK-LABEL: cfgfunc @ssa_values() {
-cfgfunc @ssa_values() {
+// CHECK-LABEL: cfgfunc @ssa_values() -> (i16, i8) {
+cfgfunc @ssa_values() -> (i16, i8) {
bb0: // CHECK: bb0:
// CHECK: %0 = "foo"() : () -> (i1, i17)
%0 = "foo"() : () -> (i1, i17)
br bb2
bb1: // CHECK: bb1:
- // CHECK: %1 = "baz"(%2#1, %2#0, %0#1) : (f32, i11, i17) -> i16
- %1 = "baz"(%2#1, %2#0, %0#1) : (f32, i11, i17) -> i16
- return
+ // CHECK: %1 = "baz"(%2#1, %2#0, %0#1) : (f32, i11, i17) -> (i16, i8)
+ %1 = "baz"(%2#1, %2#0, %0#1) : (f32, i11, i17) -> (i16, i8)
+
+ // CHECK: return %1#0 : i16, %1#1 : i8
+ return %1#0 : i16, %1#1 : i8
bb2: // CHECK: bb2:
// CHECK: %2 = "bar"(%0#0, %0#1) : (i1, i17) -> (i11, f32)