Use LLVM dynamic dispatch to disambiguate between StmtBlock subclasses.
PiperOrigin-RevId: 204614520
diff --git a/include/mlir/IR/MLFunction.h b/include/mlir/IR/MLFunction.h
index c8388ee..731a191 100644
--- a/include/mlir/IR/MLFunction.h
+++ b/include/mlir/IR/MLFunction.h
@@ -36,9 +36,13 @@
// TODO: add function arguments and return values once
// SSA values are implemented
- // Methods for support type inquiry through isa, cast, and dyn_cast
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Function *func) {
- return func->getKind() == Kind::MLFunc;
+ return func->getKind() == Function::Kind::MLFunc;
+ }
+
+ static bool classof(const StmtBlock *block) {
+ return block->getStmtBlockKind() == StmtBlockKind::MLFunc;
}
void print(raw_ostream &os) const;
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index 92ab8b6..fe58996 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -46,7 +46,7 @@
/// For statement represents an affine loop nest.
class ForStmt : public Statement, public StmtBlock {
public:
- explicit ForStmt() : Statement(Kind::For), StmtBlock(this) {}
+ explicit ForStmt() : Statement(Kind::For), StmtBlock(StmtBlockKind::For) {}
//TODO: delete nested statements or assert that they are gone.
~ForStmt() {}
@@ -56,18 +56,34 @@
static bool classof(const Statement *stmt) {
return stmt->getKind() == Kind::For;
}
+
+ static bool classof(const StmtBlock *block) {
+ return block->getStmtBlockKind() == StmtBlockKind::For;
+ }
};
/// If clause represents statements contained within then or else clause
/// of an if statement.
class IfClause : public StmtBlock {
public:
- explicit IfClause(IfStmt *stmt);
+ explicit IfClause(IfStmt *stmt)
+ : StmtBlock(StmtBlockKind::IfClause), ifStmt(stmt) {
+ assert(stmt != nullptr && "If clause must have non-null parent");
+ }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast
+ static bool classof(const StmtBlock *block) {
+ return block->getStmtBlockKind() == StmtBlockKind::IfClause;
+ }
//TODO: delete nested statements or assert that they are gone.
~IfClause() {}
- IfStmt *getIf() const;
+ /// Returns the if statement that contains this clause.
+ IfStmt *getIf() const { return ifStmt; }
+
+private:
+ IfStmt *ifStmt;
};
/// If statement restricts execution to a subset of the loop iteration space.
@@ -81,7 +97,7 @@
IfClause *getThenClause() const { return thenClause; }
IfClause *getElseClause() const { return elseClause; }
- bool hasElseClause() const {return elseClause != nullptr;}
+ bool hasElseClause() const { return elseClause != nullptr; }
IfClause *createElseClause() { return (elseClause = new IfClause(this)); }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
diff --git a/include/mlir/IR/StmtBlock.h b/include/mlir/IR/StmtBlock.h
index ed3f50c..8f4cc20 100644
--- a/include/mlir/IR/StmtBlock.h
+++ b/include/mlir/IR/StmtBlock.h
@@ -32,9 +32,17 @@
/// Statement block represents an ordered list of statements.
class StmtBlock {
public:
+ enum class StmtBlockKind {
+ MLFunc, // MLFunction
+ For, // ForStmt
+ IfClause // IfClause
+ };
+
+ StmtBlockKind getStmtBlockKind() const { return kind; }
+
/// Returns the closest surrounding statement that contains this block or
/// nullptr if this is a top-level statement block.
- Statement *getParent() const { return parent; }
+ Statement *getParentStmt() const;
/// Returns the function that this statement block is part of.
MLFunction *getFunction() const;
@@ -85,10 +93,10 @@
}
protected:
- Statement *parent;
+ StmtBlock(StmtBlockKind kind) : kind(kind) {}
- StmtBlock(Statement *parent=nullptr) : parent(parent) {}
private:
+ StmtBlockKind kind;
/// This is the list of statements in the block.
StmtListType statements;
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 453797f..72ec443 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -48,5 +48,4 @@
//===----------------------------------------------------------------------===//
MLFunction::MLFunction(StringRef name, FunctionType *type)
- : Function(name, type, Kind::MLFunc), StmtBlock() {
-}
+ : Function(name, type, Kind::MLFunc), StmtBlock(StmtBlockKind::MLFunc) {}
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 336cf7a..dd73a87 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -98,16 +98,6 @@
}
//===----------------------------------------------------------------------===//
-// IfClause
-//===----------------------------------------------------------------------===//
-
-IfClause::IfClause(IfStmt *stmt) : StmtBlock(stmt) {
- assert(stmt != nullptr && "If clause must have non-null parent");
-}
-
-IfStmt *IfClause::getIf() const { return static_cast<IfStmt *>(parent); }
-
-//===----------------------------------------------------------------------===//
// IfStmt
//===----------------------------------------------------------------------===//
diff --git a/lib/IR/StmtBlock.cpp b/lib/IR/StmtBlock.cpp
index 83e0412..16ddb37 100644
--- a/lib/IR/StmtBlock.cpp
+++ b/lib/IR/StmtBlock.cpp
@@ -15,18 +15,30 @@
// limitations under the License.
// =============================================================================
-#include "mlir/IR/MLFunction.h"
#include "mlir/IR/StmtBlock.h"
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Statements.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Statement block
//===----------------------------------------------------------------------===//
+Statement *StmtBlock::getParentStmt() const {
+ switch (kind) {
+ case StmtBlockKind::MLFunc:
+ return nullptr;
+ case StmtBlockKind::For:
+ return cast<ForStmt>(const_cast<StmtBlock *>(this));
+ case StmtBlockKind::IfClause:
+ return cast<IfClause>(this)->getIf();
+ }
+}
+
MLFunction *StmtBlock::getFunction() const {
StmtBlock *block = const_cast<StmtBlock *>(this);
- while (block->getParent() != nullptr)
- block = block->getParent()->getBlock();
+ while (block->getParentStmt() != nullptr)
+ block = block->getParentStmt()->getBlock();
return static_cast<MLFunction *>(block);
}