Implement support for branch instruction operands.
PiperOrigin-RevId: 205666777
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index bc26bda..678b5fc 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -162,6 +162,7 @@
Type *parseMemRefType();
Type *parseFunctionType();
Type *parseType();
+ ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements);
ParseResult parseTypeList(SmallVectorImpl<Type*> &elements);
// Attribute parsing.
@@ -516,12 +517,27 @@
}
}
+/// Parse a list of types without an enclosing parenthesis. The list must have
+/// at least one member.
+///
+/// type-list-no-parens ::= type (`,` type)*
+///
+ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
+ auto parseElt = [&]() -> ParseResult {
+ auto elt = parseType();
+ elements.push_back(elt);
+ return elt ? ParseSuccess : ParseFailure;
+ };
+
+ return parseCommaSeparatedList(parseElt);
+}
+
/// Parse a "type list", which is a singular type, or a parenthesized list of
/// types.
///
/// type-list ::= type-list-parens | type
/// type-list-parens ::= `(` `)`
-/// | `(` type (`,` type)* `)`
+/// | `(` type-list-no-parens `)`
///
ParseResult Parser::parseTypeList(SmallVectorImpl<Type*> &elements) {
auto parseElt = [&]() -> ParseResult {
@@ -1706,7 +1722,7 @@
/// Parse the terminator instruction for a basic block.
///
/// terminator-stmt ::= `br` bb-id branch-use-list?
-/// branch-use-list ::= `(` ssa-use-and-type-list? `)`
+/// branch-use-list ::= `(` ssa-use-list `)` ':' type-list-no-parens
/// terminator-stmt ::=
/// `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list?
/// terminator-stmt ::= `return` ssa-use-and-type-list?
@@ -1730,7 +1746,40 @@
auto destBB = getBlockNamed(getTokenSpelling(), getToken().getLoc());
if (!consumeIf(Token::bare_identifier))
return (emitError("expected basic block name"), nullptr);
- return builder.createBranchInst(destBB);
+ auto branch = builder.createBranchInst(destBB);
+
+ // Parse the use list.
+ if (!consumeIf(Token::l_paren))
+ return branch;
+
+ SmallVector<SSAUseInfo, 4> valueIDs;
+ if (parseOptionalSSAUseList(valueIDs))
+ return nullptr;
+ if (!consumeIf(Token::r_paren))
+ return (emitError("expected ')' in branch argument list"), nullptr);
+ if (!consumeIf(Token::colon))
+ return (emitError("expected ':' in branch argument list"), nullptr);
+
+ auto typeLoc = getToken().getLoc();
+ SmallVector<Type *, 4> types;
+ if (parseTypeListNoParens(types))
+ return nullptr;
+
+ if (types.size() != valueIDs.size())
+ return (emitError(typeLoc, "expected " + Twine(valueIDs.size()) +
+ " types to match operand list"),
+ nullptr);
+
+ SmallVector<CFGValue *, 4> values;
+ values.reserve(valueIDs.size());
+ for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) {
+ if (auto *value = resolveSSAUse(valueIDs[i], types[i]))
+ values.push_back(cast<CFGValue>(value));
+ else
+ return nullptr;
+ }
+ branch->addOperands(values);
+ return branch;
}
// TODO: cond_br.
}