Use LLVM dynamic dispatch to disambiguate between StmtBlock subclasses.

PiperOrigin-RevId: 204614520
diff --git a/include/mlir/IR/MLFunction.h b/include/mlir/IR/MLFunction.h
index c8388ee..731a191 100644
--- a/include/mlir/IR/MLFunction.h
+++ b/include/mlir/IR/MLFunction.h
@@ -36,9 +36,13 @@
   // TODO: add function arguments and return values once
   // SSA values are implemented
 
-  // Methods for support type inquiry through isa, cast, and dyn_cast
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Function *func) {
-    return func->getKind() == Kind::MLFunc;
+    return func->getKind() == Function::Kind::MLFunc;
+  }
+
+  static bool classof(const StmtBlock *block) {
+    return block->getStmtBlockKind() == StmtBlockKind::MLFunc;
   }
 
   void print(raw_ostream &os) const;
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index 92ab8b6..fe58996 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -46,7 +46,7 @@
 /// For statement represents an affine loop nest.
 class ForStmt : public Statement, public StmtBlock {
 public:
-  explicit ForStmt() : Statement(Kind::For), StmtBlock(this) {}
+  explicit ForStmt() : Statement(Kind::For), StmtBlock(StmtBlockKind::For) {}
   //TODO: delete nested statements or assert that they are gone.
   ~ForStmt() {}
 
@@ -56,18 +56,34 @@
   static bool classof(const Statement *stmt) {
     return stmt->getKind() == Kind::For;
   }
+
+  static bool classof(const StmtBlock *block) {
+    return block->getStmtBlockKind() == StmtBlockKind::For;
+  }
 };
 
 /// If clause represents statements contained within then or else clause
 /// of an if statement.
 class IfClause : public StmtBlock {
 public:
-  explicit IfClause(IfStmt *stmt);
+  explicit IfClause(IfStmt *stmt)
+      : StmtBlock(StmtBlockKind::IfClause), ifStmt(stmt) {
+    assert(stmt != nullptr && "If clause must have non-null parent");
+  }
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast
+  static bool classof(const StmtBlock *block) {
+    return block->getStmtBlockKind() == StmtBlockKind::IfClause;
+  }
 
   //TODO: delete nested statements or assert that they are gone.
   ~IfClause() {}
 
-  IfStmt *getIf() const;
+  /// Returns the if statement that contains this clause.
+  IfStmt *getIf() const { return ifStmt; }
+
+private:
+  IfStmt *ifStmt;
 };
 
 /// If statement restricts execution to a subset of the loop iteration space.
@@ -81,7 +97,7 @@
 
   IfClause *getThenClause() const { return thenClause; }
   IfClause *getElseClause() const { return elseClause; }
-  bool hasElseClause() const {return elseClause != nullptr;}
+  bool hasElseClause() const { return elseClause != nullptr; }
   IfClause *createElseClause() { return (elseClause = new IfClause(this)); }
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
diff --git a/include/mlir/IR/StmtBlock.h b/include/mlir/IR/StmtBlock.h
index ed3f50c..8f4cc20 100644
--- a/include/mlir/IR/StmtBlock.h
+++ b/include/mlir/IR/StmtBlock.h
@@ -32,9 +32,17 @@
 /// Statement block represents an ordered list of statements.
 class StmtBlock {
 public:
+  enum class StmtBlockKind {
+    MLFunc,  // MLFunction
+    For,     // ForStmt
+    IfClause // IfClause
+  };
+
+  StmtBlockKind getStmtBlockKind() const { return kind; }
+
   /// Returns the closest surrounding statement that contains this block or
   /// nullptr if this is a top-level statement block.
-  Statement *getParent() const { return parent; }
+  Statement *getParentStmt() const;
 
   /// Returns the function that this statement block is part of.
   MLFunction *getFunction() const;
@@ -85,10 +93,10 @@
   }
 
 protected:
-  Statement *parent;
+  StmtBlock(StmtBlockKind kind) : kind(kind) {}
 
-  StmtBlock(Statement *parent=nullptr) : parent(parent) {}
 private:
+  StmtBlockKind kind;
   /// This is the list of statements in the block.
   StmtListType statements;
 
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 453797f..72ec443 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -48,5 +48,4 @@
 //===----------------------------------------------------------------------===//
 
 MLFunction::MLFunction(StringRef name, FunctionType *type)
-  : Function(name, type, Kind::MLFunc), StmtBlock() {
-}
+    : Function(name, type, Kind::MLFunc), StmtBlock(StmtBlockKind::MLFunc) {}
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 336cf7a..dd73a87 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -98,16 +98,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// IfClause
-//===----------------------------------------------------------------------===//
-
-IfClause::IfClause(IfStmt *stmt) : StmtBlock(stmt) {
-  assert(stmt != nullptr && "If clause must have non-null parent");
-}
-
-IfStmt *IfClause::getIf() const { return static_cast<IfStmt *>(parent); }
-
-//===----------------------------------------------------------------------===//
 // IfStmt
 //===----------------------------------------------------------------------===//
 
diff --git a/lib/IR/StmtBlock.cpp b/lib/IR/StmtBlock.cpp
index 83e0412..16ddb37 100644
--- a/lib/IR/StmtBlock.cpp
+++ b/lib/IR/StmtBlock.cpp
@@ -15,18 +15,30 @@
 // limitations under the License.
 // =============================================================================
 
-#include "mlir/IR/MLFunction.h"
 #include "mlir/IR/StmtBlock.h"
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Statements.h"
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
 // Statement block
 //===----------------------------------------------------------------------===//
 
+Statement *StmtBlock::getParentStmt() const {
+  switch (kind) {
+  case StmtBlockKind::MLFunc:
+    return nullptr;
+  case StmtBlockKind::For:
+    return cast<ForStmt>(const_cast<StmtBlock *>(this));
+  case StmtBlockKind::IfClause:
+    return cast<IfClause>(this)->getIf();
+  }
+}
+
 MLFunction *StmtBlock::getFunction() const {
   StmtBlock *block = const_cast<StmtBlock *>(this);
 
-  while (block->getParent() != nullptr)
-    block = block->getParent()->getBlock();
+  while (block->getParentStmt() != nullptr)
+    block = block->getParentStmt()->getBlock();
   return static_cast<MLFunction *>(block);
 }