Implement the groundwork for predecessor/successor iterators on basic blocks.
Give BasicBlock a use/def list, making references to them in TerminatorInst's
into a type that maintains the list.
PiperOrigin-RevId: 206166388
diff --git a/include/mlir/IR/BasicBlock.h b/include/mlir/IR/BasicBlock.h
index c26203c..ce8df72 100644
--- a/include/mlir/IR/BasicBlock.h
+++ b/include/mlir/IR/BasicBlock.h
@@ -30,7 +30,8 @@
/// Basic blocks form a graph (the CFG) which can be traversed through
/// predecessor and successor edges.
class BasicBlock
- : public llvm::ilist_node_with_parent<BasicBlock, CFGFunction> {
+ : public IRObjectWithUseList,
+ public llvm::ilist_node_with_parent<BasicBlock, CFGFunction> {
public:
explicit BasicBlock();
~BasicBlock();
diff --git a/include/mlir/IR/CFGValue.h b/include/mlir/IR/CFGValue.h
index 5dd070b..8dd635c 100644
--- a/include/mlir/IR/CFGValue.h
+++ b/include/mlir/IR/CFGValue.h
@@ -39,7 +39,7 @@
};
/// The operand of a CFG Instruction contains a CFGValue.
-using InstOperand = SSAOperandImpl<CFGValue, Instruction>;
+using InstOperand = IROperandImpl<CFGValue, Instruction>;
/// CFGValue is the base class for SSA values in CFG functions.
class CFGValue : public SSAValueImpl<InstOperand, CFGValueKind> {
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index becc39a..da41825 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -30,9 +30,13 @@
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
-class OperationInst;
class BasicBlock;
class CFGFunction;
+class OperationInst;
+class TerminatorInst;
+
+/// The operand of a CFG Instruction contains a CFGValue.
+using BBDestination = IROperandImpl<BasicBlock, TerminatorInst>;
/// Instruction is the root of the operation and terminator instructions in the
/// hierarchy.
@@ -292,6 +296,21 @@
/// Remove this terminator from its BasicBlock and delete it.
void eraseFromBlock();
+ /// Return the list of destination entries that this terminator branches to.
+ MutableArrayRef<BBDestination> getDestinations();
+
+ ArrayRef<BBDestination> getDestinations() const {
+ return const_cast<TerminatorInst *>(this)->getDestinations();
+ }
+
+ unsigned getNumSuccessors() const { return getDestinations().size(); }
+
+ const BasicBlock *getSuccessor(unsigned i) const {
+ return getDestinations()[i].get();
+ }
+
+ BasicBlock *getSuccessor(unsigned i) { return getDestinations()[i].get(); }
+
protected:
TerminatorInst(Kind kind) : Instruction(kind) {}
~TerminatorInst() {}
@@ -305,7 +324,8 @@
~BranchInst() {}
/// Return the block this branch jumps to.
- BasicBlock *getDest() const { return dest; }
+ BasicBlock *getDest() const { return dest.get(); }
+ void setDest(BasicBlock *block);
unsigned getNumOperands() const { return operands.size(); }
@@ -321,16 +341,18 @@
/// Erase a specific argument from the arg list.
// TODO: void eraseArgument(int Index);
+ MutableArrayRef<BBDestination> getDestinations() { return dest; }
+ ArrayRef<BBDestination> getDestinations() const { return dest; }
+
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Instruction *inst) {
return inst->getKind() == Kind::Branch;
}
private:
- explicit BranchInst(BasicBlock *dest)
- : TerminatorInst(Kind::Branch), dest(dest) {}
+ explicit BranchInst(BasicBlock *dest);
- BasicBlock *dest;
+ BBDestination dest;
std::vector<InstOperand> operands;
};
@@ -338,6 +360,9 @@
/// condition to one of two possible successors. It may pass arguments to each
/// successor.
class CondBranchInst : public TerminatorInst {
+ // These are the indices into the dests list.
+ enum { trueIndex = 0, falseIndex = 1 };
+
public:
static CondBranchInst *create(CFGValue *condition, BasicBlock *trueDest,
BasicBlock *falseDest) {
@@ -350,10 +375,10 @@
const CFGValue *getCondition() const { return condition; }
/// Return the destination if the condition is true.
- BasicBlock *getTrueDest() const { return trueDest; }
+ BasicBlock *getTrueDest() const { return dests[trueIndex].get(); }
/// Return the destination if the condition is false.
- BasicBlock *getFalseDest() const { return falseDest; }
+ BasicBlock *getFalseDest() const { return dests[falseIndex].get(); }
// Support non-const operand iteration.
using operand_iterator = OperandIterator<CondBranchInst, CFGValue>;
@@ -476,20 +501,21 @@
/// Add a list of values to the operand list.
void addFalseOperands(ArrayRef<CFGValue *> values);
+ MutableArrayRef<BBDestination> getDestinations() { return dests; }
+ ArrayRef<BBDestination> getDestinations() const { return dests; }
+
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Instruction *inst) {
return inst->getKind() == Kind::CondBranch;
}
private:
- explicit CondBranchInst(CFGValue *condition, BasicBlock *trueDest,
- BasicBlock *falseDest)
- : TerminatorInst(Kind::CondBranch), condition(condition),
- trueDest(trueDest), falseDest(falseDest), numTrueOperands(0) {}
+ CondBranchInst(CFGValue *condition, BasicBlock *trueDest,
+ BasicBlock *falseDest);
CFGValue *condition;
- BasicBlock *trueDest;
- BasicBlock *falseDest;
+ BBDestination dests[2]; // 0 is the true dest, 1 is the false dest.
+
// Operand list. The true operands are stored first, followed by the false
// operands.
std::vector<InstOperand> operands;
diff --git a/include/mlir/IR/SSAOperand.h b/include/mlir/IR/SSAOperand.h
index 7685c31..5f47cfb 100644
--- a/include/mlir/IR/SSAOperand.h
+++ b/include/mlir/IR/SSAOperand.h
@@ -29,16 +29,16 @@
/// A reference to a value, suitable for use as an operand of an instruction,
/// statement, etc.
-class SSAOperand {
+class IROperand {
public:
- SSAOperand() {}
- SSAOperand(SSAValue *value) : value(value) { insertIntoCurrent(); }
+ IROperand() {}
+ IROperand(IRObjectWithUseList *value) : value(value) { insertIntoCurrent(); }
/// Return the current value being used by this operand.
- SSAValue *get() const { return value; }
+ IRObjectWithUseList *get() const { return value; }
/// Set the current value being used by this operand.
- void set(SSAValue *newValue) {
+ void set(IRObjectWithUseList *newValue) {
// It isn't worth optimizing for the case of switching operands on a single
// value.
removeFromCurrent();
@@ -54,16 +54,16 @@
back = nullptr;
}
- ~SSAOperand() { removeFromCurrent(); }
+ ~IROperand() { removeFromCurrent(); }
/// Return the next operand on the use-list of the value we are referring to.
/// This should generally only be used by the internal implementation details
/// of the SSA machinery.
- SSAOperand *getNextOperandUsingThisValue() { return nextUse; }
+ IROperand *getNextOperandUsingThisValue() { return nextUse; }
- /// We support a move constructor so SSAOperands can be in vectors, but this
+ /// We support a move constructor so IROperand's can be in vectors, but this
/// shouldn't be used by general clients.
- SSAOperand(SSAOperand &&other) {
+ IROperand(IROperand &&other) {
other.removeFromCurrent();
value = other.value;
other.value = nullptr;
@@ -75,17 +75,17 @@
private:
/// The value used as this operand. This can be null when in a
/// "dropAllUses" state.
- SSAValue *value = nullptr;
+ IRObjectWithUseList *value = nullptr;
/// The next operand in the use-chain.
- SSAOperand *nextUse = nullptr;
+ IROperand *nextUse = nullptr;
/// This points to the previous link in the use-chain.
- SSAOperand **back = nullptr;
+ IROperand **back = nullptr;
/// Operands are not copyable or assignable.
- SSAOperand(const SSAOperand &use) = delete;
- SSAOperand &operator=(const SSAOperand &use) = delete;
+ IROperand(const IROperand &use) = delete;
+ IROperand &operator=(const IROperand &use) = delete;
void removeFromCurrent() {
if (!back)
@@ -105,52 +105,53 @@
};
/// A reference to a value, suitable for use as an operand of an instruction,
-/// statement, etc. SSAValueTy is the root type to use for values this tracks,
+/// statement, etc. IRValueTy is the root type to use for values this tracks,
/// and SSAUserTy is the type that will contain operands.
-template <typename SSAValueTy, typename SSAOwnerTy>
-class SSAOperandImpl : public SSAOperand {
+template <typename IRValueTy, typename IROwnerTy>
+class IROperandImpl : public IROperand {
public:
- SSAOperandImpl(SSAOwnerTy *owner) : owner(owner) {}
- SSAOperandImpl(SSAOwnerTy *owner, SSAValueTy *value)
- : SSAOperand(value), owner(owner) {}
+ IROperandImpl(IROwnerTy *owner) : owner(owner) {}
+ IROperandImpl(IROwnerTy *owner, IRValueTy *value)
+ : IROperand(value), owner(owner) {}
/// Return the current value being used by this operand.
- SSAValueTy *get() const { return (SSAValueTy *)SSAOperand::get(); }
+ IRValueTy *get() const { return (IRValueTy *)IROperand::get(); }
/// Set the current value being used by this operand.
- void set(SSAValueTy *newValue) { SSAOperand::set(newValue); }
+ void set(IRValueTy *newValue) { IROperand::set(newValue); }
/// Return the user that owns this use.
- SSAOwnerTy *getOwner() { return owner; }
- const SSAOwnerTy *getOwner() const { return owner; }
+ IROwnerTy *getOwner() { return owner; }
+ const IROwnerTy *getOwner() const { return owner; }
/// Return which operand this is in the operand list of the User.
// TODO: unsigned getOperandNumber() const;
- /// We support a move constructor so SSAOperands can be in vectors, but this
+ /// We support a move constructor so IROperand's can be in vectors, but this
/// shouldn't be used by general clients.
- SSAOperandImpl(SSAOperandImpl &&other)
- : SSAOperand(std::move(other)), owner(other.owner) {}
+ IROperandImpl(IROperandImpl &&other)
+ : IROperand(std::move(other)), owner(other.owner) {}
private:
/// The owner of this operand.
- SSAOwnerTy *const owner;
+ IROwnerTy *const owner;
};
-inline auto SSAValue::use_begin() const -> use_iterator {
- return SSAValue::use_iterator(firstUse);
+inline auto IRObjectWithUseList::use_begin() const -> use_iterator {
+ return use_iterator(firstUse);
}
-inline auto SSAValue::use_end() const -> use_iterator {
- return SSAValue::use_iterator(nullptr);
+inline auto IRObjectWithUseList::use_end() const -> use_iterator {
+ return use_iterator(nullptr);
}
-inline auto SSAValue::getUses() const -> llvm::iterator_range<use_iterator> {
+inline auto IRObjectWithUseList::getUses() const
+ -> llvm::iterator_range<use_iterator> {
return {use_begin(), use_end()};
}
/// Returns true if this value has exactly one use.
-inline bool SSAValue::hasOneUse() const {
+inline bool IRObjectWithUseList::hasOneUse() const {
return firstUse && firstUse->getNextOperandUsingThisValue() == nullptr;
}
diff --git a/include/mlir/IR/SSAValue.h b/include/mlir/IR/SSAValue.h
index 0b7648c..f4fd833 100644
--- a/include/mlir/IR/SSAValue.h
+++ b/include/mlir/IR/SSAValue.h
@@ -29,9 +29,43 @@
namespace mlir {
class OperationInst;
-class SSAOperand;
+class IROperand;
template <typename OperandType, typename OwnerType> class SSAValueUseIterator;
+class IRObjectWithUseList {
+public:
+ ~IRObjectWithUseList() {
+ assert(use_empty() && "Cannot destroy a value that still has uses!");
+ }
+
+ /// Returns true if this value has no uses.
+ bool use_empty() const { return firstUse == nullptr; }
+
+ /// Returns true if this value has exactly one use.
+ inline bool hasOneUse() const;
+
+ using use_iterator = SSAValueUseIterator<IROperand, void>;
+ using use_range = llvm::iterator_range<use_iterator>;
+
+ inline use_iterator use_begin() const;
+ inline use_iterator use_end() const;
+
+ /// Returns a range of all uses, which is useful for iterating over all uses.
+ inline use_range getUses() const;
+
+ /// Replace all uses of 'this' value with the new value, updating anything in
+ /// the IR that uses 'this' to use the other value instead. When this returns
+ /// there are zero uses of 'this'.
+ void replaceAllUsesWith(IRObjectWithUseList *newValue);
+
+protected:
+ IRObjectWithUseList() {}
+
+private:
+ friend class IROperand;
+ IROperand *firstUse = nullptr;
+};
+
/// This enumerates all of the SSA value kinds in the MLIR system.
enum class SSAValueKind {
BBArgument,
@@ -45,35 +79,20 @@
/// This is the common base class for all values in the MLIR system,
/// representing a computable value that has a type and a set of users.
///
-class SSAValue {
+class SSAValue : public IRObjectWithUseList {
public:
- ~SSAValue() {
- assert(use_empty() && "Cannot destroy a value that still has uses!");
- }
+ ~SSAValue() {}
SSAValueKind getKind() const { return typeAndKind.getInt(); }
Type *getType() const { return typeAndKind.getPointer(); }
- /// Returns true if this value has no uses.
- bool use_empty() const { return firstUse == nullptr; }
-
- /// Returns true if this value has exactly one use.
- inline bool hasOneUse() const;
-
- using use_iterator = SSAValueUseIterator<SSAOperand, void>;
- using use_range = llvm::iterator_range<use_iterator>;
-
- inline use_iterator use_begin() const;
- inline use_iterator use_end() const;
-
- /// Returns a range of all uses, which is useful for iterating over all uses.
- inline use_range getUses() const;
-
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'.
- void replaceAllUsesWith(SSAValue *newValue);
+ void replaceAllUsesWith(SSAValue *newValue) {
+ IRObjectWithUseList::replaceAllUsesWith(newValue);
+ }
/// If this value is the result of an OperationInst, return the instruction
/// that defines it.
@@ -84,27 +103,24 @@
protected:
SSAValue(SSAValueKind kind, Type *type) : typeAndKind(type, kind) {}
-
private:
- friend class SSAOperand;
const llvm::PointerIntPair<Type *, 3, SSAValueKind> typeAndKind;
- SSAOperand *firstUse = nullptr;
};
/// This template unifies the implementation logic for CFGValue and StmtValue
/// while providing more type-specific APIs when walking use lists etc.
///
-/// SSAOperandTy is the concrete instance of SSAOperand to use (including
+/// IROperandTy is the concrete instance of IROperand to use (including
/// substituted template arguments) and KindTy is the enum 'kind' discriminator
/// that subclasses want to use.
///
-template <typename SSAOperandTy, typename KindTy>
+template <typename IROperandTy, typename KindTy>
class SSAValueImpl : public SSAValue {
public:
// Provide more specific implementations of the base class functionality.
KindTy getKind() const { return (KindTy)SSAValue::getKind(); }
- // TODO: using use_iterator = SSAValueUseIterator<SSAOperandTy>;
+ // TODO: using use_iterator = SSAValueUseIterator<IROperandTy>;
// TODO: using use_range = llvm::iterator_range<use_iterator>;
// TODO: inline use_iterator use_begin() const;
@@ -122,10 +138,10 @@
/// An iterator over all uses of a ValueBase.
template <typename OperandType, typename OwnerType>
class SSAValueUseIterator
- : public std::iterator<std::forward_iterator_tag, SSAOperand> {
+ : public std::iterator<std::forward_iterator_tag, IROperand> {
public:
SSAValueUseIterator() = default;
- explicit SSAValueUseIterator(SSAOperand *current) : current(current) {}
+ explicit SSAValueUseIterator(IROperand *current) : current(current) {}
OperandType *operator->() const { return current; }
OperandType &operator*() const { return current; }
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 8476b06..b06be4d 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -108,6 +108,7 @@
for (auto &bb : *this) {
for (auto &inst : bb)
inst.dropAllReferences();
+ bb.getTerminator()->dropAllReferences();
}
}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index a1cdcb3..7afaf68 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -22,7 +22,7 @@
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'.
-void SSAValue::replaceAllUsesWith(SSAValue *newValue) {
+void IRObjectWithUseList::replaceAllUsesWith(IRObjectWithUseList *newValue) {
assert(this != newValue && "cannot RAUW a value with itself");
while (!use_empty()) {
use_begin()->set(newValue);
@@ -105,6 +105,10 @@
void Instruction::dropAllReferences() {
for (auto &op : getInstOperands())
op.drop();
+
+ if (auto *term = dyn_cast<TerminatorInst>(this))
+ for (auto &dest : term->getDestinations())
+ dest.drop();
}
//===----------------------------------------------------------------------===//
@@ -209,7 +213,7 @@
}
//===----------------------------------------------------------------------===//
-// Terminators
+// TerminatorInst
//===----------------------------------------------------------------------===//
/// Remove this terminator from its BasicBlock and delete it.
@@ -219,6 +223,25 @@
destroy();
}
+/// Return the list of destination entries that this terminator branches to.
+MutableArrayRef<BBDestination> TerminatorInst::getDestinations() {
+ switch (getKind()) {
+ case Kind::Operation:
+ assert(0 && "not a terminator");
+ case Kind::Branch:
+ return cast<BranchInst>(this)->getDestinations();
+ case Kind::CondBranch:
+ return cast<CondBranchInst>(this)->getDestinations();
+ case Kind::Return:
+ // Return has no basic block successors.
+ return {};
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnInst
+//===----------------------------------------------------------------------===//
+
/// Create a new OperationInst with the specific fields.
ReturnInst *ReturnInst::create(ArrayRef<CFGValue *> operands) {
auto byteSize = totalSizeToAlloc<InstOperand>(operands.size());
@@ -245,6 +268,15 @@
operand.~InstOperand();
}
+//===----------------------------------------------------------------------===//
+// BranchInst
+//===----------------------------------------------------------------------===//
+
+BranchInst::BranchInst(BasicBlock *dest)
+ : TerminatorInst(Kind::Branch), dest(this, dest) {}
+
+void BranchInst::setDest(BasicBlock *block) { dest.set(block); }
+
/// Add one value to the operand list.
void BranchInst::addOperand(CFGValue *value) {
operands.emplace_back(InstOperand(this, value));
@@ -257,6 +289,18 @@
addOperand(value);
}
+//===----------------------------------------------------------------------===//
+// CondBranchInst
+//===----------------------------------------------------------------------===//
+
+CondBranchInst::CondBranchInst(CFGValue *condition, BasicBlock *trueDest,
+ BasicBlock *falseDest)
+ : TerminatorInst(Kind::CondBranch),
+ condition(condition), dests{{this}, {this}}, numTrueOperands(0) {
+ dests[falseIndex].set(falseDest);
+ dests[trueIndex].set(trueDest);
+}
+
/// Add one value to the true operand list.
void CondBranchInst::addTrueOperand(CFGValue *value) {
assert(getNumFalseOperands() == 0 &&