blob: 7ce52eab241213bf4a6e7bc875e06115aa0c5749 [file] [log] [blame]
Tatiana Shpeisman6708b452018-07-24 10:15:13 -07001//===- ConvertToCFG.cpp - ML function to CFG function converstion ---------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements APIs to convert ML functions into CFG functions.
19//
20//===----------------------------------------------------------------------===//
21
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070022#include "mlir/IR/Builders.h"
23#include "mlir/IR/CFGFunction.h"
24#include "mlir/IR/MLFunction.h"
25#include "mlir/IR/Module.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070026#include "mlir/IR/Pass.h"
27#include "mlir/Transforms/Passes.h"
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070028#include "llvm/ADT/DenseSet.h"
29using namespace mlir;
30
31//===----------------------------------------------------------------------===//
32// ML function converter
33//===----------------------------------------------------------------------===//
34
35namespace {
36// Generates CFG function equivalent to the given ML function.
37class FunctionConverter {
38public:
39 FunctionConverter(CFGFunction *cfgFunc)
40 : cfgFunc(cfgFunc), builder(cfgFunc) {}
41 CFGFunction *convert(const MLFunction *mlFunc);
42
43private:
44 CFGFunction *cfgFunc;
45 CFGFuncBuilder builder;
46};
Chris Lattneree0c2ae2018-07-29 12:37:35 -070047} // end anonymous namespace
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070048
49CFGFunction *FunctionConverter::convert(const MLFunction *mlFunc) {
50 builder.createBlock();
51
52 // Creates return instruction with no operands.
53 // TODO: convert return operands.
54 builder.createReturnInst({});
55
56 // TODO: convert ML function body.
57
58 return cfgFunc;
59}
60
61//===----------------------------------------------------------------------===//
62// Module converter
63//===----------------------------------------------------------------------===//
Chris Lattneree0c2ae2018-07-29 12:37:35 -070064
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070065namespace {
66// ModuleConverter class does CFG conversion for the whole module.
Chris Lattneree0c2ae2018-07-29 12:37:35 -070067class ModuleConverter : public ModulePass {
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070068public:
Chris Lattneree0c2ae2018-07-29 12:37:35 -070069 explicit ModuleConverter() {}
70
71 void runOnModule(Module *m) override;
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070072
73private:
74 // Generates CFG functions for all ML functions in the module.
75 void convertMLFunctions();
76 // Generates CFG function for the given ML function.
77 CFGFunction *convert(const MLFunction *mlFunc);
78 // Replaces all ML function references in the module
79 // with references to the generated CFG functions.
80 void replaceReferences();
81 // Replaces function references in the given function.
82 void replaceReferences(CFGFunction *cfgFunc);
83 void replaceReferences(MLFunction *mlFunc);
84 // Removes all ML funtions from the module.
85 void removeMLFunctions();
86
87 // Map from ML functions to generated CFG functions.
88 llvm::DenseMap<const MLFunction *, CFGFunction *> generatedFuncs;
Chris Lattneree0c2ae2018-07-29 12:37:35 -070089 Module *module = nullptr;
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070090};
91} // end anonymous namespace
92
93// Iterates over all functions in the module generating CFG functions
94// equivalent to ML functions and replacing references to ML functions
95// with references to the generated ML functions.
Chris Lattneree0c2ae2018-07-29 12:37:35 -070096void ModuleConverter::runOnModule(Module *m) {
97 module = m;
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070098 convertMLFunctions();
99 replaceReferences();
100}
101
102void ModuleConverter::convertMLFunctions() {
Chris Lattnera8e47672018-07-25 14:08:16 -0700103 for (Function &fn : *module) {
104 if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700105 generatedFuncs[mlFunc] = convert(mlFunc);
106 }
107}
108
109// Creates CFG function equivalent to the given ML function.
110CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) {
111 // TODO: ensure that CFG function name is unique.
Chris Lattnera8e47672018-07-25 14:08:16 -0700112 auto *cfgFunc =
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700113 new CFGFunction(mlFunc->getName() + "_cfg", mlFunc->getType());
Chris Lattnera8e47672018-07-25 14:08:16 -0700114 module->getFunctions().push_back(cfgFunc);
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700115
116 // Generates the body of the CFG function.
117 return FunctionConverter(cfgFunc).convert(mlFunc);
118}
119
120void ModuleConverter::replaceReferences() {
Chris Lattnera8e47672018-07-25 14:08:16 -0700121 for (Function &fn : *module) {
122 switch (fn.getKind()) {
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700123 case Function::Kind::CFGFunc:
Chris Lattnera8e47672018-07-25 14:08:16 -0700124 replaceReferences(&cast<CFGFunction>(fn));
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700125 break;
126 case Function::Kind::MLFunc:
Chris Lattnera8e47672018-07-25 14:08:16 -0700127 replaceReferences(&cast<MLFunction>(fn));
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700128 break;
129 case Function::Kind::ExtFunc:
130 // nothing to do for external functions
131 break;
132 }
133 }
134}
135
136void ModuleConverter::replaceReferences(CFGFunction *func) {
137 // TODO: NOP for now since function attributes are not yet implemented.
138}
139
140void ModuleConverter::replaceReferences(MLFunction *func) {
141 // TODO: NOP for now since function attributes are not yet implemented.
142}
143
144// Removes all ML functions from the module.
145void ModuleConverter::removeMLFunctions() {
Chris Lattnera8e47672018-07-25 14:08:16 -0700146 // Delete ML functions from the module.
147 for (auto it = module->begin(), e = module->end(); it != e;) {
148 // Manipulate iterator carefully to avoid deleting a function we're pointing
149 // at.
150 Function &fn = *it++;
151 if (auto mlFunc = dyn_cast<MLFunction>(&fn))
152 mlFunc->eraseFromModule();
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700153 }
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700154}
155
156//===----------------------------------------------------------------------===//
157// Entry point method
158//===----------------------------------------------------------------------===//
159
Chris Lattneree0c2ae2018-07-29 12:37:35 -0700160/// Replaces all ML functions in the module with equivalent CFG functions.
161/// Function references are appropriately patched to refer to the newly
162/// generated CFG functions.
163ModulePass *mlir::createConvertToCFGPass() { return new ModuleConverter(); }