Prepare for implementation of TensorFlow passes:
- Sketch out a TensorFlow/IR directory that will hold op definitions and common TF support logic. We will eventually have TensorFlow/TF2HLO, TensorFlow/Grappler, TensorFlow/TFLite, etc.
- Add sketches of a Switch/Merge op definition, including some missing stuff like the TwoResults trait. Add a skeleton of a pass to raise this form.
- Beef up the Pass/FunctionPass definitions slightly, moving the common code out of LoopUnroll.cpp into a new IR/Pass.cpp file.
- Switch ConvertToCFG.cpp to be a ModulePass.
- Allow _ to start bare identifiers, since this is important for TF attributes.
PiperOrigin-RevId: 206502517
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index 4f62b28..8ec042b 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -24,10 +24,11 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
+#include "mlir/IR/Pass.h"
#include "mlir/Parser.h"
-#include "mlir/Pass.h"
-#include "mlir/Transforms/ConvertToCFG.h"
-#include "mlir/Transforms/Loop.h"
+#include "mlir/TensorFlow/ControlFlowOps.h"
+#include "mlir/TensorFlow/Passes.h"
+#include "mlir/Transforms/Passes.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
@@ -48,6 +49,7 @@
checkParserErrors("check-parser-errors", cl::desc("Check for parser errors"),
cl::init(false));
+// TODO(clattner): replace these bool options with an enum list option.
static cl::opt<bool> convertToCFGOpt(
"convert-to-cfg",
cl::desc("Convert all ML functions in the module to CFG ones"));
@@ -56,6 +58,10 @@
cl::desc("Unroll innermost loops"),
cl::init(false));
+static cl::opt<bool> raiseTFControlFlow(
+ "tf-raise-control-flow",
+ cl::desc("Raise TensorFlow Switch/Match nodes to a CFG"));
+
enum OptResult { OptSuccess, OptFailure };
/// Open the specified output file and return it, exiting if there is any I/O or
@@ -72,6 +78,10 @@
return result;
}
+static void initializeMLIRContext(MLIRContext &ctx) {
+ TFControlFlow::registerOperations(ctx);
+}
+
/// Parses the memory buffer and, if successfully parsed, prints the parsed
/// output. Optionally, convert ML functions into CFG functions.
/// TODO: pull parsing and printing into separate functions.
@@ -82,17 +92,31 @@
// Parse the input file.
MLIRContext context;
+ initializeMLIRContext(context);
std::unique_ptr<Module> module(parseSourceFile(sourceMgr, &context));
if (!module)
return OptFailure;
// Convert ML functions into CFG functions
- if (convertToCFGOpt)
- convertToCFG(module.get());
+ if (convertToCFGOpt) {
+ auto *pass = createConvertToCFGPass();
+ pass->runOnModule(module.get());
+ delete pass;
+ module->verify();
+ }
if (unrollInnermostLoops) {
- MLFunctionPass *loopUnroll = createLoopUnrollPass();
- loopUnroll->runOnModule(module.get());
+ auto *pass = createLoopUnrollPass();
+ pass->runOnModule(module.get());
+ delete pass;
+ module->verify();
+ }
+
+ if (raiseTFControlFlow) {
+ auto *pass = createRaiseTFControlFlowPass();
+ pass->runOnModule(module.get());
+ delete pass;
+ module->verify();
}
// Print the output.
@@ -187,6 +211,7 @@
// Parse the input file.
MLIRContext context;
+ initializeMLIRContext(context);
std::unique_ptr<Module> module(
parseSourceFile(sourceMgr, &context, checker));