Use LLVM dynamic dispatch to disambiguate between StmtBlock subclasses.

PiperOrigin-RevId: 204614520
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 453797f..72ec443 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -48,5 +48,4 @@
 //===----------------------------------------------------------------------===//
 
 MLFunction::MLFunction(StringRef name, FunctionType *type)
-  : Function(name, type, Kind::MLFunc), StmtBlock() {
-}
+    : Function(name, type, Kind::MLFunc), StmtBlock(StmtBlockKind::MLFunc) {}
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 336cf7a..dd73a87 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -98,16 +98,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// IfClause
-//===----------------------------------------------------------------------===//
-
-IfClause::IfClause(IfStmt *stmt) : StmtBlock(stmt) {
-  assert(stmt != nullptr && "If clause must have non-null parent");
-}
-
-IfStmt *IfClause::getIf() const { return static_cast<IfStmt *>(parent); }
-
-//===----------------------------------------------------------------------===//
 // IfStmt
 //===----------------------------------------------------------------------===//
 
diff --git a/lib/IR/StmtBlock.cpp b/lib/IR/StmtBlock.cpp
index 83e0412..16ddb37 100644
--- a/lib/IR/StmtBlock.cpp
+++ b/lib/IR/StmtBlock.cpp
@@ -15,18 +15,30 @@
 // limitations under the License.
 // =============================================================================
 
-#include "mlir/IR/MLFunction.h"
 #include "mlir/IR/StmtBlock.h"
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Statements.h"
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
 // Statement block
 //===----------------------------------------------------------------------===//
 
+Statement *StmtBlock::getParentStmt() const {
+  switch (kind) {
+  case StmtBlockKind::MLFunc:
+    return nullptr;
+  case StmtBlockKind::For:
+    return cast<ForStmt>(const_cast<StmtBlock *>(this));
+  case StmtBlockKind::IfClause:
+    return cast<IfClause>(this)->getIf();
+  }
+}
+
 MLFunction *StmtBlock::getFunction() const {
   StmtBlock *block = const_cast<StmtBlock *>(this);
 
-  while (block->getParent() != nullptr)
-    block = block->getParent()->getBlock();
+  while (block->getParentStmt() != nullptr)
+    block = block->getParentStmt()->getBlock();
   return static_cast<MLFunction *>(block);
 }