Use LLVM dynamic dispatch to disambiguate between StmtBlock subclasses.

PiperOrigin-RevId: 204614520
diff --git a/include/mlir/IR/MLFunction.h b/include/mlir/IR/MLFunction.h
index c8388ee..731a191 100644
--- a/include/mlir/IR/MLFunction.h
+++ b/include/mlir/IR/MLFunction.h
@@ -36,9 +36,13 @@
   // TODO: add function arguments and return values once
   // SSA values are implemented
 
-  // Methods for support type inquiry through isa, cast, and dyn_cast
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Function *func) {
-    return func->getKind() == Kind::MLFunc;
+    return func->getKind() == Function::Kind::MLFunc;
+  }
+
+  static bool classof(const StmtBlock *block) {
+    return block->getStmtBlockKind() == StmtBlockKind::MLFunc;
   }
 
   void print(raw_ostream &os) const;
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index 92ab8b6..fe58996 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -46,7 +46,7 @@
 /// For statement represents an affine loop nest.
 class ForStmt : public Statement, public StmtBlock {
 public:
-  explicit ForStmt() : Statement(Kind::For), StmtBlock(this) {}
+  explicit ForStmt() : Statement(Kind::For), StmtBlock(StmtBlockKind::For) {}
   //TODO: delete nested statements or assert that they are gone.
   ~ForStmt() {}
 
@@ -56,18 +56,34 @@
   static bool classof(const Statement *stmt) {
     return stmt->getKind() == Kind::For;
   }
+
+  static bool classof(const StmtBlock *block) {
+    return block->getStmtBlockKind() == StmtBlockKind::For;
+  }
 };
 
 /// If clause represents statements contained within then or else clause
 /// of an if statement.
 class IfClause : public StmtBlock {
 public:
-  explicit IfClause(IfStmt *stmt);
+  explicit IfClause(IfStmt *stmt)
+      : StmtBlock(StmtBlockKind::IfClause), ifStmt(stmt) {
+    assert(stmt != nullptr && "If clause must have non-null parent");
+  }
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast
+  static bool classof(const StmtBlock *block) {
+    return block->getStmtBlockKind() == StmtBlockKind::IfClause;
+  }
 
   //TODO: delete nested statements or assert that they are gone.
   ~IfClause() {}
 
-  IfStmt *getIf() const;
+  /// Returns the if statement that contains this clause.
+  IfStmt *getIf() const { return ifStmt; }
+
+private:
+  IfStmt *ifStmt;
 };
 
 /// If statement restricts execution to a subset of the loop iteration space.
@@ -81,7 +97,7 @@
 
   IfClause *getThenClause() const { return thenClause; }
   IfClause *getElseClause() const { return elseClause; }
-  bool hasElseClause() const {return elseClause != nullptr;}
+  bool hasElseClause() const { return elseClause != nullptr; }
   IfClause *createElseClause() { return (elseClause = new IfClause(this)); }
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
diff --git a/include/mlir/IR/StmtBlock.h b/include/mlir/IR/StmtBlock.h
index ed3f50c..8f4cc20 100644
--- a/include/mlir/IR/StmtBlock.h
+++ b/include/mlir/IR/StmtBlock.h
@@ -32,9 +32,17 @@
 /// Statement block represents an ordered list of statements.
 class StmtBlock {
 public:
+  enum class StmtBlockKind {
+    MLFunc,  // MLFunction
+    For,     // ForStmt
+    IfClause // IfClause
+  };
+
+  StmtBlockKind getStmtBlockKind() const { return kind; }
+
   /// Returns the closest surrounding statement that contains this block or
   /// nullptr if this is a top-level statement block.
-  Statement *getParent() const { return parent; }
+  Statement *getParentStmt() const;
 
   /// Returns the function that this statement block is part of.
   MLFunction *getFunction() const;
@@ -85,10 +93,10 @@
   }
 
 protected:
-  Statement *parent;
+  StmtBlock(StmtBlockKind kind) : kind(kind) {}
 
-  StmtBlock(Statement *parent=nullptr) : parent(parent) {}
 private:
+  StmtBlockKind kind;
   /// This is the list of statements in the block.
   StmtListType statements;