Implement simple loop-invariant-code-motion based on dialect interfaces.
PiperOrigin-RevId: 275004258
diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
index 7ea9c15..1dc7deb 100644
--- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp
+++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
@@ -29,11 +29,30 @@
#include "mlir/IR/Value.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
+#include "mlir/Transforms/SideEffectsInterface.h"
using namespace mlir;
using namespace mlir::loop;
//===----------------------------------------------------------------------===//
+// LoopOpsDialect Interfaces
+//===----------------------------------------------------------------------===//
+namespace {
+
+struct LoopSideEffectsInterface : public SideEffectsDialectInterface {
+ using SideEffectsDialectInterface::SideEffectsDialectInterface;
+
+ SideEffecting isSideEffecting(Operation *op) const override {
+ if (isa<IfOp>(op) || isa<ForOp>(op)) {
+ return Recursive;
+ }
+ return SideEffectsDialectInterface::isSideEffecting(op);
+ };
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
// LoopOpsDialect
//===----------------------------------------------------------------------===//
@@ -43,6 +62,7 @@
#define GET_OP_LIST
#include "mlir/Dialect/LoopOps/LoopOps.cpp.inc"
>();
+ addInterfaces<LoopSideEffectsInterface>();
}
//===----------------------------------------------------------------------===//
@@ -112,6 +132,18 @@
return success();
}
+Region &ForOp::getLoopBody() { return region(); }
+
+bool ForOp::isDefinedOutsideOfLoop(Value *value) {
+ return !region().isAncestor(value->getParentRegion());
+}
+
+LogicalResult ForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
+ for (auto *op : ops)
+ op->moveBefore(this->getOperation());
+ return success();
+}
+
ForOp mlir::loop::getForInductionVarOwner(Value *val) {
auto *ivArg = dyn_cast<BlockArgument>(val);
if (!ivArg)