Implement support for branch instruction operands.
PiperOrigin-RevId: 205666777
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index afb4e73..64b5a2c 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -706,9 +706,23 @@
void CFGFunctionPrinter::print(const OperationInst *inst) {
printOperation(inst);
}
+
void CFGFunctionPrinter::print(const BranchInst *inst) {
os << " br bb" << getBBID(inst->getDest());
+
+ if (inst->getNumOperands() != 0) {
+ os << '(';
+ // TODO: Use getOperands() when we have it.
+ interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
+ printValueID(operand.get());
+ });
+ os << ") : ";
+ interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
+ ModulePrinter::print(operand.get()->getType());
+ });
+ }
}
+
void CFGFunctionPrinter::print(const ReturnInst *inst) {
os << " return";
diff --git a/lib/IR/BasicBlock.cpp b/lib/IR/BasicBlock.cpp
index 1aad430..7cb6440 100644
--- a/lib/IR/BasicBlock.cpp
+++ b/lib/IR/BasicBlock.cpp
@@ -35,6 +35,31 @@
getFunction()->getBlocks().erase(this);
}
+//===----------------------------------------------------------------------===//
+// Argument list management.
+//===----------------------------------------------------------------------===//
+
+BBArgument *BasicBlock::addArgument(Type *type) {
+ auto *arg = new BBArgument(type, this);
+ arguments.push_back(arg);
+ return arg;
+}
+
+/// Add one argument to the argument list for each type specified in the list.
+auto BasicBlock::addArguments(ArrayRef<Type *> types)
+ -> llvm::iterator_range<args_iterator> {
+ arguments.reserve(arguments.size() + types.size());
+ auto initialSize = arguments.size();
+ for (auto *type : types) {
+ addArgument(type);
+ }
+ return {arguments.data() + initialSize, arguments.data() + arguments.size()};
+}
+
+//===----------------------------------------------------------------------===//
+// Terminator management
+//===----------------------------------------------------------------------===//
+
void BasicBlock::setTerminator(TerminatorInst *inst) {
// If we already had a terminator, abandon it.
if (terminator)
@@ -46,6 +71,10 @@
inst->block = this;
}
+//===----------------------------------------------------------------------===//
+// ilist_traits for BasicBlock
+//===----------------------------------------------------------------------===//
+
mlir::CFGFunction *
llvm::ilist_traits<::mlir::BasicBlock>::getContainingFunction() {
size_t Offset(
@@ -86,17 +115,3 @@
for (; first != last; ++first)
first->function = curParent;
}
-
-BBArgument *BasicBlock::addArgument(Type *type) {
- arguments.push_back(new BBArgument(type, this));
- return arguments.back();
-}
-
-llvm::iterator_range<BasicBlock::BBArgListType::iterator>
-BasicBlock::addArguments(ArrayRef<Type *> types) {
- auto initial_size = arguments.size();
- for (auto *type : types) {
- addArgument(type);
- }
- return {arguments.data() + initial_size, arguments.data() + arguments.size()};
-}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 428523e..9374933 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -207,3 +207,15 @@
for (auto &operand : getInstOperands())
operand.~InstOperand();
}
+
+/// Add one value to the operand list.
+void BranchInst::addOperand(CFGValue *value) {
+ operands.emplace_back(InstOperand(this, value));
+}
+
+/// Add a list of values to the operand list.
+void BranchInst::addOperands(ArrayRef<CFGValue *> values) {
+ operands.reserve(operands.size() + values.size());
+ for (auto *value : values)
+ addOperand(value);
+}
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index d1bb2ac..aaec3ce 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -100,6 +100,7 @@
bool verifyOperation(const OperationInst &inst);
bool verifyTerminator(const TerminatorInst &term);
bool verifyReturn(const ReturnInst &inst);
+ bool verifyBranch(const BranchInst &inst);
};
} // end anonymous namespace
@@ -163,6 +164,9 @@
if (auto *ret = dyn_cast<ReturnInst>(&term))
return verifyReturn(*ret);
+ if (auto *br = dyn_cast<BranchInst>(&term))
+ return verifyBranch(*br);
+
return false;
}
@@ -175,6 +179,31 @@
Twine(results.size()),
inst);
+ for (unsigned i = 0, e = results.size(); i != e; ++i)
+ if (inst.getOperand(i)->getType() != results[i])
+ return failure("type of return operand " + Twine(i) +
+ " doesn't match result function result type",
+ inst);
+
+ return false;
+}
+
+bool CFGFuncVerifier::verifyBranch(const BranchInst &inst) {
+ // Verify that the number of operands lines up with the number of BB arguments
+ // in the successor.
+ auto dest = inst.getDest();
+ if (inst.getNumOperands() != dest->getNumArguments())
+ return failure("branch has " + Twine(inst.getNumOperands()) +
+ " operands, but target block has " +
+ Twine(dest->getNumArguments()),
+ inst);
+
+ for (unsigned i = 0, e = inst.getNumOperands(); i != e; ++i)
+ if (inst.getOperand(i)->getType() != dest->getArgument(i)->getType())
+ return failure("type of branch operand " + Twine(i) +
+ " doesn't match target bb argument type",
+ inst);
+
return false;
}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index bc26bda..678b5fc 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -162,6 +162,7 @@
Type *parseMemRefType();
Type *parseFunctionType();
Type *parseType();
+ ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements);
ParseResult parseTypeList(SmallVectorImpl<Type*> &elements);
// Attribute parsing.
@@ -516,12 +517,27 @@
}
}
+/// Parse a list of types without an enclosing parenthesis. The list must have
+/// at least one member.
+///
+/// type-list-no-parens ::= type (`,` type)*
+///
+ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
+ auto parseElt = [&]() -> ParseResult {
+ auto elt = parseType();
+ elements.push_back(elt);
+ return elt ? ParseSuccess : ParseFailure;
+ };
+
+ return parseCommaSeparatedList(parseElt);
+}
+
/// Parse a "type list", which is a singular type, or a parenthesized list of
/// types.
///
/// type-list ::= type-list-parens | type
/// type-list-parens ::= `(` `)`
-/// | `(` type (`,` type)* `)`
+/// | `(` type-list-no-parens `)`
///
ParseResult Parser::parseTypeList(SmallVectorImpl<Type*> &elements) {
auto parseElt = [&]() -> ParseResult {
@@ -1706,7 +1722,7 @@
/// Parse the terminator instruction for a basic block.
///
/// terminator-stmt ::= `br` bb-id branch-use-list?
-/// branch-use-list ::= `(` ssa-use-and-type-list? `)`
+/// branch-use-list ::= `(` ssa-use-list `)` ':' type-list-no-parens
/// terminator-stmt ::=
/// `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list?
/// terminator-stmt ::= `return` ssa-use-and-type-list?
@@ -1730,7 +1746,40 @@
auto destBB = getBlockNamed(getTokenSpelling(), getToken().getLoc());
if (!consumeIf(Token::bare_identifier))
return (emitError("expected basic block name"), nullptr);
- return builder.createBranchInst(destBB);
+ auto branch = builder.createBranchInst(destBB);
+
+ // Parse the use list.
+ if (!consumeIf(Token::l_paren))
+ return branch;
+
+ SmallVector<SSAUseInfo, 4> valueIDs;
+ if (parseOptionalSSAUseList(valueIDs))
+ return nullptr;
+ if (!consumeIf(Token::r_paren))
+ return (emitError("expected ')' in branch argument list"), nullptr);
+ if (!consumeIf(Token::colon))
+ return (emitError("expected ':' in branch argument list"), nullptr);
+
+ auto typeLoc = getToken().getLoc();
+ SmallVector<Type *, 4> types;
+ if (parseTypeListNoParens(types))
+ return nullptr;
+
+ if (types.size() != valueIDs.size())
+ return (emitError(typeLoc, "expected " + Twine(valueIDs.size()) +
+ " types to match operand list"),
+ nullptr);
+
+ SmallVector<CFGValue *, 4> values;
+ values.reserve(valueIDs.size());
+ for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) {
+ if (auto *value = resolveSSAUse(valueIDs[i], types[i]))
+ values.push_back(cast<CFGValue>(value));
+ else
+ return nullptr;
+ }
+ branch->addOperands(values);
+ return branch;
}
// TODO: cond_br.
}