Prepare for implementation of TensorFlow passes:
- Sketch out a TensorFlow/IR directory that will hold op definitions and common TF support logic. We will eventually have TensorFlow/TF2HLO, TensorFlow/Grappler, TensorFlow/TFLite, etc.
- Add sketches of a Switch/Merge op definition, including some missing stuff like the TwoResults trait. Add a skeleton of a pass to raise this form.
- Beef up the Pass/FunctionPass definitions slightly, moving the common code out of LoopUnroll.cpp into a new IR/Pass.cpp file.
- Switch ConvertToCFG.cpp to be a ModulePass.
- Allow _ to start bare identifiers, since this is important for TF attributes.
PiperOrigin-RevId: 206502517
diff --git a/lib/Transforms/ConvertToCFG.cpp b/lib/Transforms/ConvertToCFG.cpp
index 9487cf9..7ce52ea 100644
--- a/lib/Transforms/ConvertToCFG.cpp
+++ b/lib/Transforms/ConvertToCFG.cpp
@@ -19,11 +19,12 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Transforms/ConvertToCFG.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
+#include "mlir/IR/Pass.h"
+#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseSet.h"
using namespace mlir;
@@ -43,7 +44,7 @@
CFGFunction *cfgFunc;
CFGFuncBuilder builder;
};
-} // namespace
+} // end anonymous namespace
CFGFunction *FunctionConverter::convert(const MLFunction *mlFunc) {
builder.createBlock();
@@ -60,12 +61,14 @@
//===----------------------------------------------------------------------===//
// Module converter
//===----------------------------------------------------------------------===//
+
namespace {
// ModuleConverter class does CFG conversion for the whole module.
-class ModuleConverter {
+class ModuleConverter : public ModulePass {
public:
- explicit ModuleConverter(Module *module) : module(module) {}
- void run();
+ explicit ModuleConverter() {}
+
+ void runOnModule(Module *m) override;
private:
// Generates CFG functions for all ML functions in the module.
@@ -83,14 +86,15 @@
// Map from ML functions to generated CFG functions.
llvm::DenseMap<const MLFunction *, CFGFunction *> generatedFuncs;
- Module *module;
+ Module *module = nullptr;
};
} // end anonymous namespace
// Iterates over all functions in the module generating CFG functions
// equivalent to ML functions and replacing references to ML functions
// with references to the generated ML functions.
-void ModuleConverter::run() {
+void ModuleConverter::runOnModule(Module *m) {
+ module = m;
convertMLFunctions();
replaceReferences();
}
@@ -153,8 +157,7 @@
// Entry point method
//===----------------------------------------------------------------------===//
-void mlir::convertToCFG(Module *module) {
- ModuleConverter moduleConverter(module);
- moduleConverter.run();
- module->verify();
-}
+/// Replaces all ML functions in the module with equivalent CFG functions.
+/// Function references are appropriately patched to refer to the newly
+/// generated CFG functions.
+ModulePass *mlir::createConvertToCFGPass() { return new ModuleConverter(); }
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 9592ef7..c631bda 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -1,4 +1,4 @@
-//===- Unroll.cpp - Code to perform loop unrolling ---------------------===//
+//===- Unroll.cpp - Code to perform loop unrolling ------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -24,36 +24,25 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/Pass.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
-#include "mlir/Pass.h"
-#include "mlir/Transforms/Loop.h"
+#include "mlir/Transforms/Passes.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
namespace {
struct LoopUnroll : public MLFunctionPass {
- bool runOnMLFunction(MLFunction *f);
- bool runOnForStmt(ForStmt *forStmt);
- bool runLoopUnroll(MLFunction *f);
+ void runOnMLFunction(MLFunction *f) override;
+ void runOnForStmt(ForStmt *forStmt);
};
-} // namespace
+} // end anonymous namespace
MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
-/// Unrolls all the innermost loops of this Module.
-bool MLFunctionPass::runOnModule(Module *m) {
- bool changed = false;
- for (auto &fn : *m) {
- if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
- changed |= runOnMLFunction(mlFunc);
- }
- return changed;
-}
-
/// Unrolls all the innermost loops of this MLFunction.
-bool LoopUnroll::runOnMLFunction(MLFunction *f) {
+void LoopUnroll::runOnMLFunction(MLFunction *f) {
// Gathers all innermost loops through a post order pruned walk.
// TODO: figure out the right reusable template here to better refactor code.
class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
@@ -99,14 +88,12 @@
InnermostLoopGatherer ilg;
ilg.walkMLFunction(f);
auto &loops = ilg.loops;
- bool changed = false;
for (auto *forStmt : loops)
- changed |= runOnForStmt(forStmt);
- return changed;
+ runOnForStmt(forStmt);
}
/// Unrolls this loop completely. Returns true if the unrolling happens.
-bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
auto step = forStmt->getStep()->getValue();
@@ -139,5 +126,4 @@
}
forStmt->eraseFromBlock();
- return true;
}