[mlir] Add basic block arguments
This patch adds support for basic block arguments including parsing and printing.
In doing so noticed that `ssa-id-and-type` is undefined in the MLIR spec; suggested an implementation in the spec doc.
PiperOrigin-RevId: 205593369
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 62a9fe7..bc26bda 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1219,7 +1219,17 @@
// SSA parsing productions.
ParseResult parseSSAUse(SSAUseInfo &result);
ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results);
- SSAValue *parseSSAUseAndType();
+
+ template <typename ResultType>
+ ResultType parseSSADefOrUseAndType(
+ const std::function<ResultType(SSAUseInfo, Type *)> &action);
+
+ SSAValue *parseSSAUseAndType() {
+ return parseSSADefOrUseAndType<SSAValue *>(
+ [&](SSAUseInfo useInfo, Type *type) -> SSAValue * {
+ return resolveSSAUse(useInfo, type);
+ });
+ }
template <typename ValueTy>
ParseResult
@@ -1355,8 +1365,7 @@
/// Parse a SSA operand for an instruction or statement.
///
-/// ssa-use ::= ssa-id | ssa-constant
-/// TODO: SSA Constants.
+/// ssa-use ::= ssa-id
///
ParseResult FunctionParser::parseSSAUse(SSAUseInfo &result) {
result.name = getTokenSpelling();
@@ -1398,7 +1407,9 @@
/// Parse an SSA use with an associated type.
///
/// ssa-use-and-type ::= ssa-use `:` type
-SSAValue *FunctionParser::parseSSAUseAndType() {
+template <typename ResultType>
+ResultType FunctionParser::parseSSADefOrUseAndType(
+ const std::function<ResultType(SSAUseInfo, Type *)> &action) {
SSAUseInfo useInfo;
if (parseSSAUse(useInfo))
return nullptr;
@@ -1410,7 +1421,7 @@
if (!type)
return nullptr;
- return resolveSSAUse(useInfo, type);
+ return action(useInfo, type);
}
/// Parse a (possibly empty) list of SSA operands with types.
@@ -1570,12 +1581,39 @@
return blockAndLoc.first;
}
+ ParseResult
+ parseOptionalBasicBlockArgList(SmallVectorImpl<BBArgument *> &results,
+ BasicBlock *owner);
+
ParseResult parseBasicBlock();
OperationInst *parseCFGOperation();
TerminatorInst *parseTerminator();
};
} // end anonymous namespace
+/// Parse a (possibly empty) list of SSA operands with types as basic block
+/// arguments. Unlike parseOptionalSsaUseAndTypeList the SSA IDs are treated as
+/// defs, not uses.
+///
+/// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)*
+///
+ParseResult CFGFunctionParser::parseOptionalBasicBlockArgList(
+ SmallVectorImpl<BBArgument *> &results, BasicBlock *owner) {
+ if (getToken().is(Token::r_brace))
+ return ParseSuccess;
+
+ return parseCommaSeparatedList([&]() -> ParseResult {
+ auto type = parseSSADefOrUseAndType<Type *>(
+ [&](SSAUseInfo useInfo, Type *type) -> Type * {
+ BBArgument *arg = owner->addArgument(type);
+ if (addDefinition(useInfo, arg) == ParseFailure)
+ return nullptr;
+ return type;
+ });
+ return type ? ParseSuccess : ParseFailure;
+ });
+}
+
ParseResult CFGFunctionParser::parseFunctionBody() {
auto braceLoc = getToken().getLoc();
if (!consumeIf(Token::l_brace))
@@ -1625,20 +1663,18 @@
if (block->getFunction())
return emitError(nameLoc, "redefinition of block '" + name.str() + "'");
- // Add the block to the function.
- function->push_back(block);
-
// If an argument list is present, parse it.
if (consumeIf(Token::l_paren)) {
- SmallVector<SSAUseInfo, 8> bbArgs;
- if (parseOptionalSSAUseList(bbArgs))
+ SmallVector<BBArgument *, 8> bbArgs;
+ if (parseOptionalBasicBlockArgList(bbArgs, block))
return ParseFailure;
if (!consumeIf(Token::r_paren))
return emitError("expected ')' to end argument list");
-
- // TODO: attach it.
}
+ // Add the block to the function.
+ function->push_back(block);
+
if (!consumeIf(Token::colon))
return emitError("expected ':' after basic block name");