Prepare for implementation of TensorFlow passes:
- Sketch out a TensorFlow/IR directory that will hold op definitions and common TF support logic. We will eventually have TensorFlow/TF2HLO, TensorFlow/Grappler, TensorFlow/TFLite, etc.
- Add sketches of a Switch/Merge op definition, including some missing stuff like the TwoResults trait. Add a skeleton of a pass to raise this form.
- Beef up the Pass/FunctionPass definitions slightly, moving the common code out of LoopUnroll.cpp into a new IR/Pass.cpp file.
- Switch ConvertToCFG.cpp to be a ModulePass.
- Allow _ to start bare identifiers, since this is important for TF attributes.
PiperOrigin-RevId: 206502517
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 9592ef7..c631bda 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -1,4 +1,4 @@
-//===- Unroll.cpp - Code to perform loop unrolling ---------------------===//
+//===- Unroll.cpp - Code to perform loop unrolling ------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -24,36 +24,25 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/Pass.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
-#include "mlir/Pass.h"
-#include "mlir/Transforms/Loop.h"
+#include "mlir/Transforms/Passes.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
namespace {
struct LoopUnroll : public MLFunctionPass {
- bool runOnMLFunction(MLFunction *f);
- bool runOnForStmt(ForStmt *forStmt);
- bool runLoopUnroll(MLFunction *f);
+ void runOnMLFunction(MLFunction *f) override;
+ void runOnForStmt(ForStmt *forStmt);
};
-} // namespace
+} // end anonymous namespace
MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
-/// Unrolls all the innermost loops of this Module.
-bool MLFunctionPass::runOnModule(Module *m) {
- bool changed = false;
- for (auto &fn : *m) {
- if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
- changed |= runOnMLFunction(mlFunc);
- }
- return changed;
-}
-
/// Unrolls all the innermost loops of this MLFunction.
-bool LoopUnroll::runOnMLFunction(MLFunction *f) {
+void LoopUnroll::runOnMLFunction(MLFunction *f) {
// Gathers all innermost loops through a post order pruned walk.
// TODO: figure out the right reusable template here to better refactor code.
class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
@@ -99,14 +88,12 @@
InnermostLoopGatherer ilg;
ilg.walkMLFunction(f);
auto &loops = ilg.loops;
- bool changed = false;
for (auto *forStmt : loops)
- changed |= runOnForStmt(forStmt);
- return changed;
+ runOnForStmt(forStmt);
}
/// Unrolls this loop completely. Returns true if the unrolling happens.
-bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
auto step = forStmt->getStep()->getValue();
@@ -139,5 +126,4 @@
}
forStmt->eraseFromBlock();
- return true;
}