Prepare for implementation of TensorFlow passes:
- Sketch out a TensorFlow/IR directory that will hold op definitions and common TF support logic. We will eventually have TensorFlow/TF2HLO, TensorFlow/Grappler, TensorFlow/TFLite, etc.
- Add sketches of a Switch/Merge op definition, including some missing stuff like the TwoResults trait. Add a skeleton of a pass to raise this form.
- Beef up the Pass/FunctionPass definitions slightly, moving the common code out of LoopUnroll.cpp into a new IR/Pass.cpp file.
- Switch ConvertToCFG.cpp to be a ModulePass.
- Allow _ to start bare identifiers, since this is important for TF attributes.
PiperOrigin-RevId: 206502517
diff --git a/include/mlir/IR/OpDefinition.h b/include/mlir/IR/OpDefinition.h
index 3ec2c5e..481c36d 100644
--- a/include/mlir/IR/OpDefinition.h
+++ b/include/mlir/IR/OpDefinition.h
@@ -361,6 +361,26 @@
}
};
+/// This class provides the API for ops that are known to have exactly two
+/// results.
+template <typename ConcreteType>
+class TwoResults : public TraitImpl<ConcreteType, TwoResults> {
+public:
+ const SSAValue *getResult(unsigned i) const {
+ return this->getOperation()->getResult(i);
+ }
+
+ SSAValue *getResult(unsigned i) { return this->getOperation()->getResult(i); }
+
+ Type *getType(unsigned i) const { return getResult(i)->getType(); }
+
+ static const char *verifyTrait(const Operation *op) {
+ if (op->getNumResults() != 2)
+ return "requires two results";
+ return nullptr;
+ }
+};
+
/// This class provides the API for ops which have an unknown number of
/// results.
template <typename ConcreteType>
diff --git a/include/mlir/IR/Pass.h b/include/mlir/IR/Pass.h
new file mode 100644
index 0000000..7ca9b08
--- /dev/null
+++ b/include/mlir/IR/Pass.h
@@ -0,0 +1,57 @@
+//===- mlir/Pass.h - Base classes for compiler passes -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_PASS_H
+#define MLIR_PASS_H
+
+namespace mlir {
+class CFGFunction;
+class MLFunction;
+class Module;
+
+class Pass {
+public:
+ virtual ~Pass() = default;
+};
+
+class ModulePass : public Pass {
+public:
+ virtual void runOnModule(Module *m) = 0;
+};
+
+class FunctionPass : public Pass {
+public:
+ virtual void runOnCFGFunction(CFGFunction *f) = 0;
+ virtual void runOnMLFunction(MLFunction *f) = 0;
+ virtual void runOnModule(Module *m);
+};
+
+class CFGFunctionPass : public FunctionPass {
+public:
+ virtual void runOnMLFunction(MLFunction *f) override {}
+ virtual void runOnCFGFunction(CFGFunction *f) override = 0;
+};
+
+class MLFunctionPass : public FunctionPass {
+public:
+ virtual void runOnCFGFunction(CFGFunction *f) override {}
+ virtual void runOnMLFunction(MLFunction *f) override = 0;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_PASS_H
diff --git a/include/mlir/Pass.h b/include/mlir/Pass.h
deleted file mode 100644
index f93061b..0000000
--- a/include/mlir/Pass.h
+++ /dev/null
@@ -1,47 +0,0 @@
-//===- mlir/Pass.h - Base class for passes ----------------------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file defines a base class that indicates that a specified class is a
-// transformation pass implementation.
-//
-//===----------------------------------------------------------------------===//
-#ifndef MLIR_PASS_H
-#define MLIR_PASS_H
-
-#include "mlir/IR/MLFunction.h"
-#include "mlir/IR/Module.h"
-
-namespace mlir {
-
-class Pass {
-protected:
- virtual ~Pass() = default;
-};
-
-class FunctionPass : public Pass {};
-
-class CFGFunctionPass : public FunctionPass {};
-
-class MLFunctionPass : public FunctionPass {
-public:
- virtual bool runOnMLFunction(MLFunction *f) = 0;
- virtual bool runOnModule(Module *m);
-};
-
-} // end namespace mlir
-
-#endif // MLIR_PASS_H
diff --git a/include/mlir/Transforms/ConvertToCFG.h b/include/mlir/Transforms/ConvertToCFG.h
deleted file mode 100644
index a1e9be9..0000000
--- a/include/mlir/Transforms/ConvertToCFG.h
+++ /dev/null
@@ -1,34 +0,0 @@
-//===- ConvertToCFG.h - Convert ML functions to CFG ones --------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file defines APIs to convert ML functions into CFG functions.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TRANSFORMS_CONVERTTOCFG_H
-#define MLIR_TRANSFORMS_CONVERTTOCFG_H
-
-namespace mlir {
-class Module;
-
-/// Replaces all ML functions in the module with equivalent CFG functions.
-/// Function references are appropriately patched to refer only
-/// to CFG functions.
-void convertToCFG(Module *module);
-
-} // namespace mlir
-#endif // MLIR_TRANSFORMS_CONVERTTOCFG_H
diff --git a/include/mlir/Transforms/Loop.h b/include/mlir/Transforms/Passes.h
similarity index 68%
rename from include/mlir/Transforms/Loop.h
rename to include/mlir/Transforms/Passes.h
index cff6609..7352ad0 100644
--- a/include/mlir/Transforms/Loop.h
+++ b/include/mlir/Transforms/Passes.h
@@ -1,4 +1,4 @@
-//===- Loop.h - Loop Transformations ----------------------------*- C++ -*-===//
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -15,21 +15,27 @@
// limitations under the License.
// =============================================================================
//
-// This header file defines prototypes that expose passes in the loop
+// This header file defines prototypes that expose pass constructors in the loop
// transformation library.
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_TRANSFORMS_LOOP_H
-#define MLIR_TRANSFORMS_LOOP_H
+#ifndef MLIR_TRANSFORMS_PASSES_H
+#define MLIR_TRANSFORMS_PASSES_H
namespace mlir {
class MLFunctionPass;
+class ModulePass;
/// A loop unrolling pass.
MLFunctionPass *createLoopUnrollPass();
+/// Replaces all ML functions in the module with equivalent CFG functions.
+/// Function references are appropriately patched to refer to the newly
+/// generated CFG functions.
+ModulePass *createConvertToCFGPass();
+
} // end namespace mlir
#endif // MLIR_TRANSFORMS_LOOP_H
diff --git a/lib/IR/Pass.cpp b/lib/IR/Pass.cpp
new file mode 100644
index 0000000..97bf986
--- /dev/null
+++ b/lib/IR/Pass.cpp
@@ -0,0 +1,38 @@
+//===- Pass.cpp - Pass infrastructure implementation ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements loop unrolling.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Pass.h"
+#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Module.h"
+
+using namespace mlir;
+
+/// Function passes walk a module and look at each function with their
+/// corresponding hooks.
+void FunctionPass::runOnModule(Module *m) {
+ for (auto &fn : *m) {
+ if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
+ runOnMLFunction(mlFunc);
+ if (auto *cfgFunc = dyn_cast<CFGFunction>(&fn))
+ runOnCFGFunction(cfgFunc);
+ }
+}
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index ce99b19..2aa0c57 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -57,6 +57,10 @@
// Unknown character, emit an error.
return emitError(tokStart, "unexpected character");
+ case '_':
+ // Handle bare identifiers.
+ return lexBareIdentifierOrKeyword(tokStart);
+
case 0:
// This may either be a nul character in the source file or may be the EOF
// marker that llvm::MemoryBuffer guarantees will be there.
@@ -151,7 +155,7 @@
/// Lex a bare identifier or keyword that starts with a letter.
///
-/// bare-id ::= letter (letter|digit|[_$])*
+/// bare-id ::= (letter|[_]) (letter|digit|[_$])*
/// integer-type ::= `i[1-9][0-9]*`
///
Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
diff --git a/lib/Transforms/ConvertToCFG.cpp b/lib/Transforms/ConvertToCFG.cpp
index 9487cf9..7ce52ea 100644
--- a/lib/Transforms/ConvertToCFG.cpp
+++ b/lib/Transforms/ConvertToCFG.cpp
@@ -19,11 +19,12 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Transforms/ConvertToCFG.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
+#include "mlir/IR/Pass.h"
+#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseSet.h"
using namespace mlir;
@@ -43,7 +44,7 @@
CFGFunction *cfgFunc;
CFGFuncBuilder builder;
};
-} // namespace
+} // end anonymous namespace
CFGFunction *FunctionConverter::convert(const MLFunction *mlFunc) {
builder.createBlock();
@@ -60,12 +61,14 @@
//===----------------------------------------------------------------------===//
// Module converter
//===----------------------------------------------------------------------===//
+
namespace {
// ModuleConverter class does CFG conversion for the whole module.
-class ModuleConverter {
+class ModuleConverter : public ModulePass {
public:
- explicit ModuleConverter(Module *module) : module(module) {}
- void run();
+ explicit ModuleConverter() {}
+
+ void runOnModule(Module *m) override;
private:
// Generates CFG functions for all ML functions in the module.
@@ -83,14 +86,15 @@
// Map from ML functions to generated CFG functions.
llvm::DenseMap<const MLFunction *, CFGFunction *> generatedFuncs;
- Module *module;
+ Module *module = nullptr;
};
} // end anonymous namespace
// Iterates over all functions in the module generating CFG functions
// equivalent to ML functions and replacing references to ML functions
// with references to the generated ML functions.
-void ModuleConverter::run() {
+void ModuleConverter::runOnModule(Module *m) {
+ module = m;
convertMLFunctions();
replaceReferences();
}
@@ -153,8 +157,7 @@
// Entry point method
//===----------------------------------------------------------------------===//
-void mlir::convertToCFG(Module *module) {
- ModuleConverter moduleConverter(module);
- moduleConverter.run();
- module->verify();
-}
+/// Replaces all ML functions in the module with equivalent CFG functions.
+/// Function references are appropriately patched to refer to the newly
+/// generated CFG functions.
+ModulePass *mlir::createConvertToCFGPass() { return new ModuleConverter(); }
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 9592ef7..c631bda 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -1,4 +1,4 @@
-//===- Unroll.cpp - Code to perform loop unrolling ---------------------===//
+//===- Unroll.cpp - Code to perform loop unrolling ------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -24,36 +24,25 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/Pass.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
-#include "mlir/Pass.h"
-#include "mlir/Transforms/Loop.h"
+#include "mlir/Transforms/Passes.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
namespace {
struct LoopUnroll : public MLFunctionPass {
- bool runOnMLFunction(MLFunction *f);
- bool runOnForStmt(ForStmt *forStmt);
- bool runLoopUnroll(MLFunction *f);
+ void runOnMLFunction(MLFunction *f) override;
+ void runOnForStmt(ForStmt *forStmt);
};
-} // namespace
+} // end anonymous namespace
MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
-/// Unrolls all the innermost loops of this Module.
-bool MLFunctionPass::runOnModule(Module *m) {
- bool changed = false;
- for (auto &fn : *m) {
- if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
- changed |= runOnMLFunction(mlFunc);
- }
- return changed;
-}
-
/// Unrolls all the innermost loops of this MLFunction.
-bool LoopUnroll::runOnMLFunction(MLFunction *f) {
+void LoopUnroll::runOnMLFunction(MLFunction *f) {
// Gathers all innermost loops through a post order pruned walk.
// TODO: figure out the right reusable template here to better refactor code.
class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
@@ -99,14 +88,12 @@
InnermostLoopGatherer ilg;
ilg.walkMLFunction(f);
auto &loops = ilg.loops;
- bool changed = false;
for (auto *forStmt : loops)
- changed |= runOnForStmt(forStmt);
- return changed;
+ runOnForStmt(forStmt);
}
/// Unrolls this loop completely. Returns true if the unrolling happens.
-bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
auto step = forStmt->getStep()->getValue();
@@ -139,5 +126,4 @@
}
forStmt->eraseFromBlock();
- return true;
}
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index 4f62b28..8ec042b 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -24,10 +24,11 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
+#include "mlir/IR/Pass.h"
#include "mlir/Parser.h"
-#include "mlir/Pass.h"
-#include "mlir/Transforms/ConvertToCFG.h"
-#include "mlir/Transforms/Loop.h"
+#include "mlir/TensorFlow/ControlFlowOps.h"
+#include "mlir/TensorFlow/Passes.h"
+#include "mlir/Transforms/Passes.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
@@ -48,6 +49,7 @@
checkParserErrors("check-parser-errors", cl::desc("Check for parser errors"),
cl::init(false));
+// TODO(clattner): replace these bool options with an enum list option.
static cl::opt<bool> convertToCFGOpt(
"convert-to-cfg",
cl::desc("Convert all ML functions in the module to CFG ones"));
@@ -56,6 +58,10 @@
cl::desc("Unroll innermost loops"),
cl::init(false));
+static cl::opt<bool> raiseTFControlFlow(
+ "tf-raise-control-flow",
+ cl::desc("Raise TensorFlow Switch/Match nodes to a CFG"));
+
enum OptResult { OptSuccess, OptFailure };
/// Open the specified output file and return it, exiting if there is any I/O or
@@ -72,6 +78,10 @@
return result;
}
+static void initializeMLIRContext(MLIRContext &ctx) {
+ TFControlFlow::registerOperations(ctx);
+}
+
/// Parses the memory buffer and, if successfully parsed, prints the parsed
/// output. Optionally, convert ML functions into CFG functions.
/// TODO: pull parsing and printing into separate functions.
@@ -82,17 +92,31 @@
// Parse the input file.
MLIRContext context;
+ initializeMLIRContext(context);
std::unique_ptr<Module> module(parseSourceFile(sourceMgr, &context));
if (!module)
return OptFailure;
// Convert ML functions into CFG functions
- if (convertToCFGOpt)
- convertToCFG(module.get());
+ if (convertToCFGOpt) {
+ auto *pass = createConvertToCFGPass();
+ pass->runOnModule(module.get());
+ delete pass;
+ module->verify();
+ }
if (unrollInnermostLoops) {
- MLFunctionPass *loopUnroll = createLoopUnrollPass();
- loopUnroll->runOnModule(module.get());
+ auto *pass = createLoopUnrollPass();
+ pass->runOnModule(module.get());
+ delete pass;
+ module->verify();
+ }
+
+ if (raiseTFControlFlow) {
+ auto *pass = createRaiseTFControlFlowPass();
+ pass->runOnModule(module.get());
+ delete pass;
+ module->verify();
}
// Print the output.
@@ -187,6 +211,7 @@
// Parse the input file.
MLIRContext context;
+ initializeMLIRContext(context);
std::unique_ptr<Module> module(
parseSourceFile(sourceMgr, &context, checker));