[mlir] Implement conditional branch
This looks heavyweight but most of the code is in the massive number of operand accessors!
We need to be able to iterate over all operands to the condbr (all live-outs) but also just
the true/just the false operands too.
PiperOrigin-RevId: 205897704
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index e6e4e4d..ec9ce18 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -146,6 +146,13 @@
return insertTerminator(BranchInst::create(dest));
}
+ CondBranchInst *createCondBranchInst(CFGValue *condition,
+ BasicBlock *trueDest,
+ BasicBlock *falseDest) {
+ return insertTerminator(
+ CondBranchInst::create(condition, trueDest, falseDest));
+ }
+
private:
template <typename T>
T *insertTerminator(T *term) {
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index be743e6..cbd9578 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -31,26 +31,20 @@
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
- class OperationInst;
- class BasicBlock;
- class CFGFunction;
+class OperationInst;
+class BasicBlock;
+class CFGFunction;
/// Instruction is the root of the operation and terminator instructions in the
/// hierarchy.
class Instruction {
public:
- enum class Kind {
- Operation,
- Branch,
- Return
- };
+ enum class Kind { Operation, Branch, CondBranch, Return };
Kind getKind() const { return kind; }
/// Return the BasicBlock containing this instruction.
- BasicBlock *getBlock() const {
- return block;
- }
+ BasicBlock *getBlock() const { return block; }
/// Return the CFGFunction containing this instruction.
CFGFunction *getFunction() const;
@@ -122,6 +116,7 @@
// every instruction by a word, is not necessary given the closed nature of
// instruction kinds.
~Instruction();
+
private:
Kind kind;
BasicBlock *block = nullptr;
@@ -290,7 +285,6 @@
/// represent control flow and returns.
class TerminatorInst : public Instruction {
public:
-
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Instruction *inst) {
return inst->getKind() != Kind::Operation;
@@ -312,9 +306,7 @@
~BranchInst() {}
/// Return the block this branch jumps to.
- BasicBlock *getDest() const {
- return dest;
- }
+ BasicBlock *getDest() const { return dest; }
unsigned getNumOperands() const { return operands.size(); }
@@ -385,6 +377,202 @@
std::vector<InstOperand> operands;
};
+/// The 'cond_br' instruction is a conditional branch based on a boolean
+/// condition to one of two possible successors. It may pass arguments to each
+/// successor.
+class CondBranchInst : public TerminatorInst {
+public:
+ static CondBranchInst *create(CFGValue *condition, BasicBlock *trueDest,
+ BasicBlock *falseDest) {
+ return new CondBranchInst(condition, trueDest, falseDest);
+ }
+ ~CondBranchInst() {}
+
+ /// Return the i1 condition.
+ CFGValue *getCondition() { return condition; }
+ const CFGValue *getCondition() const { return condition; }
+
+ /// Return the destination if the condition is true.
+ BasicBlock *getTrueDest() const { return trueDest; }
+
+ /// Return the destination if the condition is false.
+ BasicBlock *getFalseDest() const { return falseDest; }
+
+ // Support non-const operand iteration.
+ using operand_iterator = OperandIterator<CondBranchInst, CFGValue>;
+ // Support const operand iteration.
+ typedef OperandIterator<const CondBranchInst, const CFGValue>
+ const_operand_iterator;
+
+ //
+ // Accessors for the entire operand list. This includes operands to both true
+ // and false blocks.
+ //
+
+ CFGValue *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
+ const CFGValue *getOperand(unsigned idx) const {
+ return getInstOperand(idx).get();
+ }
+ void setOperand(unsigned idx, CFGValue *value) {
+ return getInstOperand(idx).set(value);
+ }
+
+ operand_iterator operand_begin() { return operand_iterator(this, 0); }
+ operand_iterator operand_end() {
+ return operand_iterator(this, getNumOperands());
+ }
+ llvm::iterator_range<operand_iterator> getOperands() {
+ return {operand_begin(), operand_end()};
+ }
+
+ const_operand_iterator operand_begin() const {
+ return const_operand_iterator(this, 0);
+ }
+ const_operand_iterator operand_end() const {
+ return const_operand_iterator(this, getNumOperands());
+ }
+ llvm::iterator_range<const_operand_iterator> getOperands() const {
+ return {operand_begin(), operand_end()};
+ }
+
+ ArrayRef<InstOperand> getInstOperands() const { return operands; }
+ MutableArrayRef<InstOperand> getInstOperands() { return operands; }
+
+ InstOperand &getInstOperand(unsigned idx) { return operands[idx]; }
+ const InstOperand &getInstOperand(unsigned idx) const {
+ return operands[idx];
+ }
+ unsigned getNumOperands() const { return operands.size(); }
+
+ //
+ // Accessors for operands to the 'true' destination
+ //
+
+ CFGValue *getTrueOperand(unsigned idx) {
+ return getTrueInstOperand(idx).get();
+ }
+ const CFGValue *getTrueOperand(unsigned idx) const {
+ return getTrueInstOperand(idx).get();
+ }
+ void setTrueOperand(unsigned idx, CFGValue *value) {
+ return getTrueInstOperand(idx).set(value);
+ }
+
+ operand_iterator true_operand_begin() { return operand_iterator(this, 0); }
+ operand_iterator true_operand_end() {
+ return operand_iterator(this, getNumTrueOperands());
+ }
+ llvm::iterator_range<operand_iterator> getTrueOperands() {
+ return {true_operand_begin(), true_operand_end()};
+ }
+
+ const_operand_iterator true_operand_begin() const {
+ return const_operand_iterator(this, 0);
+ }
+ const_operand_iterator true_operand_end() const {
+ return const_operand_iterator(this, getNumTrueOperands());
+ }
+ llvm::iterator_range<const_operand_iterator> getTrueOperands() const {
+ return {true_operand_begin(), true_operand_end()};
+ }
+
+ ArrayRef<InstOperand> getTrueInstOperands() const {
+ return {&operands[0], &operands[0] + getNumTrueOperands()};
+ }
+ MutableArrayRef<InstOperand> getTrueInstOperands() {
+ return {&operands[0], &operands[0] + getNumTrueOperands()};
+ }
+
+ InstOperand &getTrueInstOperand(unsigned idx) { return operands[idx]; }
+ const InstOperand &getTrueInstOperand(unsigned idx) const {
+ return operands[idx];
+ }
+ unsigned getNumTrueOperands() const { return numTrueOperands; }
+
+ /// Add one value to the true operand list.
+ void addTrueOperand(CFGValue *value);
+
+ /// Add a list of values to the operand list.
+ void addTrueOperands(ArrayRef<CFGValue *> values);
+
+ //
+ // Accessors for operands to the 'false' destination
+ //
+
+ CFGValue *getFalseOperand(unsigned idx) {
+ return getFalseInstOperand(idx).get();
+ }
+ const CFGValue *getFalseOperand(unsigned idx) const {
+ return getFalseInstOperand(idx).get();
+ }
+ void setFalseOperand(unsigned idx, CFGValue *value) {
+ return getFalseInstOperand(idx).set(value);
+ }
+
+ operand_iterator false_operand_begin() {
+ return operand_iterator(this, getNumTrueOperands());
+ }
+ operand_iterator false_operand_end() {
+ return operand_iterator(this, getNumOperands());
+ }
+ llvm::iterator_range<operand_iterator> getFalseOperands() {
+ return {false_operand_begin(), false_operand_end()};
+ }
+
+ const_operand_iterator false_operand_begin() const {
+ return const_operand_iterator(this, getNumTrueOperands());
+ }
+ const_operand_iterator false_operand_end() const {
+ return const_operand_iterator(this, getNumOperands());
+ }
+ llvm::iterator_range<const_operand_iterator> getFalseOperands() const {
+ return {false_operand_begin(), false_operand_end()};
+ }
+
+ ArrayRef<InstOperand> getFalseInstOperands() const {
+ return {&operands[0] + getNumTrueOperands(),
+ &operands[0] + getNumOperands()};
+ }
+ MutableArrayRef<InstOperand> getFalseInstOperands() {
+ return {&operands[0] + getNumTrueOperands(),
+ &operands[0] + getNumOperands()};
+ }
+
+ InstOperand &getFalseInstOperand(unsigned idx) {
+ return operands[idx + getNumTrueOperands()];
+ }
+ const InstOperand &getFalseInstOperand(unsigned idx) const {
+ return operands[idx + getNumTrueOperands()];
+ }
+ unsigned getNumFalseOperands() const {
+ return operands.size() - numTrueOperands;
+ }
+
+ /// Add one value to the false operand list.
+ void addFalseOperand(CFGValue *value);
+
+ /// Add a list of values to the operand list.
+ void addFalseOperands(ArrayRef<CFGValue *> values);
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(const Instruction *inst) {
+ return inst->getKind() == Kind::CondBranch;
+ }
+
+private:
+ explicit CondBranchInst(CFGValue *condition, BasicBlock *trueDest,
+ BasicBlock *falseDest)
+ : TerminatorInst(Kind::CondBranch), condition(condition),
+ trueDest(trueDest), falseDest(falseDest), numTrueOperands(0) {}
+
+ CFGValue *condition;
+ BasicBlock *trueDest;
+ BasicBlock *falseDest;
+ // Operand list. The true operands are stored first, followed by the false
+ // operands.
+ std::vector<InstOperand> operands;
+ unsigned numTrueOperands;
+};
/// The 'return' instruction represents the end of control flow within the
/// current function, and can return zero or more results. The result list is
@@ -471,7 +659,6 @@
} // end namespace mlir
-
//===----------------------------------------------------------------------===//
// ilist_traits for OperationInst
//===----------------------------------------------------------------------===//
@@ -489,10 +676,11 @@
void removeNodeFromList(OperationInst *inst);
void transferNodesFromList(ilist_traits<OperationInst> &otherList,
instr_iterator first, instr_iterator last);
+
private:
mlir::BasicBlock *getContainingBlock();
};
} // end namespace llvm
-#endif // MLIR_IR_INSTRUCTIONS_H
+#endif // MLIR_IR_INSTRUCTIONS_H