Use LLVM dynamic dispatch to disambiguate between StmtBlock subclasses.
PiperOrigin-RevId: 204614520
diff --git a/lib/IR/StmtBlock.cpp b/lib/IR/StmtBlock.cpp
index 83e0412..16ddb37 100644
--- a/lib/IR/StmtBlock.cpp
+++ b/lib/IR/StmtBlock.cpp
@@ -15,18 +15,30 @@
// limitations under the License.
// =============================================================================
-#include "mlir/IR/MLFunction.h"
#include "mlir/IR/StmtBlock.h"
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Statements.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Statement block
//===----------------------------------------------------------------------===//
+Statement *StmtBlock::getParentStmt() const {
+ switch (kind) {
+ case StmtBlockKind::MLFunc:
+ return nullptr;
+ case StmtBlockKind::For:
+ return cast<ForStmt>(const_cast<StmtBlock *>(this));
+ case StmtBlockKind::IfClause:
+ return cast<IfClause>(this)->getIf();
+ }
+}
+
MLFunction *StmtBlock::getFunction() const {
StmtBlock *block = const_cast<StmtBlock *>(this);
- while (block->getParent() != nullptr)
- block = block->getParent()->getBlock();
+ while (block->getParentStmt() != nullptr)
+ block = block->getParentStmt()->getBlock();
return static_cast<MLFunction *>(block);
}