Implement support for branch instruction operands.
PiperOrigin-RevId: 205666777
diff --git a/include/mlir/IR/BasicBlock.h b/include/mlir/IR/BasicBlock.h
index be64e4f..d59a83b 100644
--- a/include/mlir/IR/BasicBlock.h
+++ b/include/mlir/IR/BasicBlock.h
@@ -59,9 +59,12 @@
reverse_args_iterator args_rend() const { return getArguments().rend(); }
bool args_empty() const { return arguments.empty(); }
+
+ /// Add one value to the operand list.
BBArgument *addArgument(Type *type);
- llvm::iterator_range<BBArgListType::iterator>
- addArguments(ArrayRef<Type *> types);
+
+ /// Add one argument to the argument list for each type specified in the list.
+ llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type *> types);
unsigned getNumArguments() const { return arguments.size(); }
BBArgument *getArgument(unsigned i) { return arguments[i]; }
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 1a003b6..12b8aed 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -186,7 +186,7 @@
return op;
}
- // Creates for statement. When step is not specified, it is set to 1.
+ // Creates for statement. When step is not specified, it is set to 1.
ForStmt *createFor(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound,
AffineConstantExpr *step = nullptr);
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index 6f374bd..02b4ffc 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -260,7 +260,31 @@
return dest;
}
- // TODO: need to take operands to specify BB arguments
+ unsigned getNumOperands() const { return operands.size(); }
+
+ // TODO: Add a getOperands() custom sequence that provides a value projection
+ // of the operand list.
+ CFGValue *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
+ const CFGValue *getOperand(unsigned idx) const {
+ return getInstOperand(idx).get();
+ }
+
+ ArrayRef<InstOperand> getInstOperands() const { return operands; }
+ MutableArrayRef<InstOperand> getInstOperands() { return operands; }
+
+ InstOperand &getInstOperand(unsigned idx) { return operands[idx]; }
+ const InstOperand &getInstOperand(unsigned idx) const {
+ return operands[idx];
+ }
+
+ /// Add one value to the operand list.
+ void addOperand(CFGValue *value);
+
+ /// Add a list of values to the operand list.
+ void addOperands(ArrayRef<CFGValue *> values);
+
+ /// Erase a specific argument from the arg list.
+ // TODO: void eraseArgument(int Index);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Instruction *inst) {
@@ -272,6 +296,7 @@
: TerminatorInst(Kind::Branch), dest(dest) {}
BasicBlock *dest;
+ std::vector<InstOperand> operands;
};
diff --git a/include/mlir/IR/Module.h b/include/mlir/IR/Module.h
index 8401252..b0b9c96 100644
--- a/include/mlir/IR/Module.h
+++ b/include/mlir/IR/Module.h
@@ -37,6 +37,8 @@
// FIXME: wrong representation and API.
// TODO(someone): This should switch to llvm::iplist<Function>.
+ // TODO(someone): we also need a symbol table for function names +
+ // autorenaming like LLVM does.
std::vector<Function*> functionList;
/// Perform (potentially expensive) checks of invariants, used to detect
diff --git a/include/mlir/IR/SSAOperand.h b/include/mlir/IR/SSAOperand.h
index c8ab7d5..7685c31 100644
--- a/include/mlir/IR/SSAOperand.h
+++ b/include/mlir/IR/SSAOperand.h
@@ -61,6 +61,17 @@
/// of the SSA machinery.
SSAOperand *getNextOperandUsingThisValue() { return nextUse; }
+ /// We support a move constructor so SSAOperands can be in vectors, but this
+ /// shouldn't be used by general clients.
+ SSAOperand(SSAOperand &&other) {
+ other.removeFromCurrent();
+ value = other.value;
+ other.value = nullptr;
+ nextUse = nullptr;
+ back = nullptr;
+ insertIntoCurrent();
+ }
+
private:
/// The value used as this operand. This can be null when in a
/// "dropAllUses" state.
@@ -116,6 +127,11 @@
/// Return which operand this is in the operand list of the User.
// TODO: unsigned getOperandNumber() const;
+ /// We support a move constructor so SSAOperands can be in vectors, but this
+ /// shouldn't be used by general clients.
+ SSAOperandImpl(SSAOperandImpl &&other)
+ : SSAOperand(std::move(other)), owner(other.owner) {}
+
private:
/// The owner of this operand.
SSAOwnerTy *const owner;
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index afb4e73..64b5a2c 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -706,9 +706,23 @@
void CFGFunctionPrinter::print(const OperationInst *inst) {
printOperation(inst);
}
+
void CFGFunctionPrinter::print(const BranchInst *inst) {
os << " br bb" << getBBID(inst->getDest());
+
+ if (inst->getNumOperands() != 0) {
+ os << '(';
+ // TODO: Use getOperands() when we have it.
+ interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
+ printValueID(operand.get());
+ });
+ os << ") : ";
+ interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
+ ModulePrinter::print(operand.get()->getType());
+ });
+ }
}
+
void CFGFunctionPrinter::print(const ReturnInst *inst) {
os << " return";
diff --git a/lib/IR/BasicBlock.cpp b/lib/IR/BasicBlock.cpp
index 1aad430..7cb6440 100644
--- a/lib/IR/BasicBlock.cpp
+++ b/lib/IR/BasicBlock.cpp
@@ -35,6 +35,31 @@
getFunction()->getBlocks().erase(this);
}
+//===----------------------------------------------------------------------===//
+// Argument list management.
+//===----------------------------------------------------------------------===//
+
+BBArgument *BasicBlock::addArgument(Type *type) {
+ auto *arg = new BBArgument(type, this);
+ arguments.push_back(arg);
+ return arg;
+}
+
+/// Add one argument to the argument list for each type specified in the list.
+auto BasicBlock::addArguments(ArrayRef<Type *> types)
+ -> llvm::iterator_range<args_iterator> {
+ arguments.reserve(arguments.size() + types.size());
+ auto initialSize = arguments.size();
+ for (auto *type : types) {
+ addArgument(type);
+ }
+ return {arguments.data() + initialSize, arguments.data() + arguments.size()};
+}
+
+//===----------------------------------------------------------------------===//
+// Terminator management
+//===----------------------------------------------------------------------===//
+
void BasicBlock::setTerminator(TerminatorInst *inst) {
// If we already had a terminator, abandon it.
if (terminator)
@@ -46,6 +71,10 @@
inst->block = this;
}
+//===----------------------------------------------------------------------===//
+// ilist_traits for BasicBlock
+//===----------------------------------------------------------------------===//
+
mlir::CFGFunction *
llvm::ilist_traits<::mlir::BasicBlock>::getContainingFunction() {
size_t Offset(
@@ -86,17 +115,3 @@
for (; first != last; ++first)
first->function = curParent;
}
-
-BBArgument *BasicBlock::addArgument(Type *type) {
- arguments.push_back(new BBArgument(type, this));
- return arguments.back();
-}
-
-llvm::iterator_range<BasicBlock::BBArgListType::iterator>
-BasicBlock::addArguments(ArrayRef<Type *> types) {
- auto initial_size = arguments.size();
- for (auto *type : types) {
- addArgument(type);
- }
- return {arguments.data() + initial_size, arguments.data() + arguments.size()};
-}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 428523e..9374933 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -207,3 +207,15 @@
for (auto &operand : getInstOperands())
operand.~InstOperand();
}
+
+/// Add one value to the operand list.
+void BranchInst::addOperand(CFGValue *value) {
+ operands.emplace_back(InstOperand(this, value));
+}
+
+/// Add a list of values to the operand list.
+void BranchInst::addOperands(ArrayRef<CFGValue *> values) {
+ operands.reserve(operands.size() + values.size());
+ for (auto *value : values)
+ addOperand(value);
+}
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index d1bb2ac..aaec3ce 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -100,6 +100,7 @@
bool verifyOperation(const OperationInst &inst);
bool verifyTerminator(const TerminatorInst &term);
bool verifyReturn(const ReturnInst &inst);
+ bool verifyBranch(const BranchInst &inst);
};
} // end anonymous namespace
@@ -163,6 +164,9 @@
if (auto *ret = dyn_cast<ReturnInst>(&term))
return verifyReturn(*ret);
+ if (auto *br = dyn_cast<BranchInst>(&term))
+ return verifyBranch(*br);
+
return false;
}
@@ -175,6 +179,31 @@
Twine(results.size()),
inst);
+ for (unsigned i = 0, e = results.size(); i != e; ++i)
+ if (inst.getOperand(i)->getType() != results[i])
+ return failure("type of return operand " + Twine(i) +
+ " doesn't match result function result type",
+ inst);
+
+ return false;
+}
+
+bool CFGFuncVerifier::verifyBranch(const BranchInst &inst) {
+ // Verify that the number of operands lines up with the number of BB arguments
+ // in the successor.
+ auto dest = inst.getDest();
+ if (inst.getNumOperands() != dest->getNumArguments())
+ return failure("branch has " + Twine(inst.getNumOperands()) +
+ " operands, but target block has " +
+ Twine(dest->getNumArguments()),
+ inst);
+
+ for (unsigned i = 0, e = inst.getNumOperands(); i != e; ++i)
+ if (inst.getOperand(i)->getType() != dest->getArgument(i)->getType())
+ return failure("type of branch operand " + Twine(i) +
+ " doesn't match target bb argument type",
+ inst);
+
return false;
}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index bc26bda..678b5fc 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -162,6 +162,7 @@
Type *parseMemRefType();
Type *parseFunctionType();
Type *parseType();
+ ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements);
ParseResult parseTypeList(SmallVectorImpl<Type*> &elements);
// Attribute parsing.
@@ -516,12 +517,27 @@
}
}
+/// Parse a list of types without an enclosing parenthesis. The list must have
+/// at least one member.
+///
+/// type-list-no-parens ::= type (`,` type)*
+///
+ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
+ auto parseElt = [&]() -> ParseResult {
+ auto elt = parseType();
+ elements.push_back(elt);
+ return elt ? ParseSuccess : ParseFailure;
+ };
+
+ return parseCommaSeparatedList(parseElt);
+}
+
/// Parse a "type list", which is a singular type, or a parenthesized list of
/// types.
///
/// type-list ::= type-list-parens | type
/// type-list-parens ::= `(` `)`
-/// | `(` type (`,` type)* `)`
+/// | `(` type-list-no-parens `)`
///
ParseResult Parser::parseTypeList(SmallVectorImpl<Type*> &elements) {
auto parseElt = [&]() -> ParseResult {
@@ -1706,7 +1722,7 @@
/// Parse the terminator instruction for a basic block.
///
/// terminator-stmt ::= `br` bb-id branch-use-list?
-/// branch-use-list ::= `(` ssa-use-and-type-list? `)`
+/// branch-use-list ::= `(` ssa-use-list `)` ':' type-list-no-parens
/// terminator-stmt ::=
/// `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list?
/// terminator-stmt ::= `return` ssa-use-and-type-list?
@@ -1730,7 +1746,40 @@
auto destBB = getBlockNamed(getTokenSpelling(), getToken().getLoc());
if (!consumeIf(Token::bare_identifier))
return (emitError("expected basic block name"), nullptr);
- return builder.createBranchInst(destBB);
+ auto branch = builder.createBranchInst(destBB);
+
+ // Parse the use list.
+ if (!consumeIf(Token::l_paren))
+ return branch;
+
+ SmallVector<SSAUseInfo, 4> valueIDs;
+ if (parseOptionalSSAUseList(valueIDs))
+ return nullptr;
+ if (!consumeIf(Token::r_paren))
+ return (emitError("expected ')' in branch argument list"), nullptr);
+ if (!consumeIf(Token::colon))
+ return (emitError("expected ':' in branch argument list"), nullptr);
+
+ auto typeLoc = getToken().getLoc();
+ SmallVector<Type *, 4> types;
+ if (parseTypeListNoParens(types))
+ return nullptr;
+
+ if (types.size() != valueIDs.size())
+ return (emitError(typeLoc, "expected " + Twine(valueIDs.size()) +
+ " types to match operand list"),
+ nullptr);
+
+ SmallVector<CFGValue *, 4> values;
+ values.reserve(valueIDs.size());
+ for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) {
+ if (auto *value = resolveSSAUse(valueIDs[i], types[i]))
+ values.push_back(cast<CFGValue>(value));
+ else
+ return nullptr;
+ }
+ branch->addOperands(values);
+ return branch;
}
// TODO: cond_br.
}
diff --git a/test/IR/parser-errors.mlir b/test/IR/parser-errors.mlir
index c9086fe..fe7a49c 100644
--- a/test/IR/parser-errors.mlir
+++ b/test/IR/parser-errors.mlir
@@ -243,3 +243,16 @@
bb42(%0: f32):
return
}
+
+// -----
+
+cfgfunc @br_mismatch() { // expected-error {{branch has 2 operands, but target block has 1}}
+bb0: // CHECK: bb0:
+ // CHECK: %0 = "foo"() : () -> (i1, i17)
+ %0 = "foo"() : () -> (i1, i17)
+ br bb1(%0#1, %0#0) : i17, i1
+
+bb1(%x: i17):
+ return
+}
+
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index a4d45c6..2f95957 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -66,16 +66,16 @@
// CHECK: extfunc @functions((memref<1x?x4x?x?xaffineint, (d0, d1, d2, d3, d4) [s0] -> (d0, d1, d2, d3, d4), 0>, memref<i8, (d0) -> (d0), 0>) -> (), () -> ())
extfunc @functions((memref<1x?x4x?x?xaffineint, #map0, 0>, memref<i8, #map1, 0>) -> (), ()->())
-// CHECK-LABEL: cfgfunc @simpleCFG(i32, f32) {
-cfgfunc @simpleCFG(i32, f32) {
+// CHECK-LABEL: cfgfunc @simpleCFG(i32, f32) -> i1 {
+cfgfunc @simpleCFG(i32, f32) -> i1 {
// CHECK: bb0(%0: i32, %1: f32):
bb42 (%0: i32, %f: f32):
// CHECK: %2 = "foo"() : () -> i64
%1 = "foo"() : ()->i64
// CHECK: "bar"(%2) : (i64) -> (i1, i1, i1)
%2 = "bar"(%1) : (i64) -> (i1,i1,i1)
- // CHECK: return
- return
+ // CHECK: return %3#1
+ return %2#1 : i1
// CHECK: }
}
@@ -208,4 +208,17 @@
// CHECK: %2 = "bar"(%0#0, %0#1) : (i1, i17) -> (i11, f32)
%2 = "bar"(%0#0, %0#1) : (i1, i17) -> (i11, f32)
br bb1
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: cfgfunc @bbargs() -> (i16, i8) {
+cfgfunc @bbargs() -> (i16, i8) {
+bb0: // CHECK: bb0:
+ // CHECK: %0 = "foo"() : () -> (i1, i17)
+ %0 = "foo"() : () -> (i1, i17)
+ br bb1(%0#1, %0#0) : i17, i1
+
+bb1(%x: i17, %y: i1): // CHECK: bb1(%1: i17, %2: i1):
+ // CHECK: %3 = "baz"(%1, %2, %0#1) : (i17, i1, i17) -> (i16, i8)
+ %1 = "baz"(%x, %y, %0#1) : (i17, i1, i17) -> (i16, i8)
+ return %1#0 : i16, %1#1 : i8
+}