blob: 0a98e2bf3eb7038df0386bef52d5334cd95ea9a9 [file] [log] [blame]
Tatiana Shpeisman6708b452018-07-24 10:15:13 -07001//===- ConvertToCFG.cpp - ML function to CFG function converstion ---------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements APIs to convert ML functions into CFG functions.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Transforms/ConvertToCFG.h"
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/CFGFunction.h"
25#include "mlir/IR/MLFunction.h"
26#include "mlir/IR/Module.h"
27#include "llvm/ADT/DenseSet.h"
28using namespace mlir;
29
30//===----------------------------------------------------------------------===//
31// ML function converter
32//===----------------------------------------------------------------------===//
33
34namespace {
35// Generates CFG function equivalent to the given ML function.
36class FunctionConverter {
37public:
38 FunctionConverter(CFGFunction *cfgFunc)
39 : cfgFunc(cfgFunc), builder(cfgFunc) {}
40 CFGFunction *convert(const MLFunction *mlFunc);
41
42private:
43 CFGFunction *cfgFunc;
44 CFGFuncBuilder builder;
45};
46} // namespace
47
48CFGFunction *FunctionConverter::convert(const MLFunction *mlFunc) {
49 builder.createBlock();
50
51 // Creates return instruction with no operands.
52 // TODO: convert return operands.
53 builder.createReturnInst({});
54
55 // TODO: convert ML function body.
56
57 return cfgFunc;
58}
59
60//===----------------------------------------------------------------------===//
61// Module converter
62//===----------------------------------------------------------------------===//
63namespace {
64// ModuleConverter class does CFG conversion for the whole module.
65class ModuleConverter {
66public:
67 explicit ModuleConverter(Module *module) : module(module) {}
68 void run();
69
70private:
71 // Generates CFG functions for all ML functions in the module.
72 void convertMLFunctions();
73 // Generates CFG function for the given ML function.
74 CFGFunction *convert(const MLFunction *mlFunc);
75 // Replaces all ML function references in the module
76 // with references to the generated CFG functions.
77 void replaceReferences();
78 // Replaces function references in the given function.
79 void replaceReferences(CFGFunction *cfgFunc);
80 void replaceReferences(MLFunction *mlFunc);
81 // Removes all ML funtions from the module.
82 void removeMLFunctions();
83
84 // Map from ML functions to generated CFG functions.
85 llvm::DenseMap<const MLFunction *, CFGFunction *> generatedFuncs;
86 Module *module;
87};
88} // end anonymous namespace
89
90// Iterates over all functions in the module generating CFG functions
91// equivalent to ML functions and replacing references to ML functions
92// with references to the generated ML functions.
93void ModuleConverter::run() {
94 convertMLFunctions();
95 replaceReferences();
96}
97
98void ModuleConverter::convertMLFunctions() {
99 for (Function *fn : module->functionList) {
100 if (auto mlFunc = dyn_cast<MLFunction>(fn))
101 generatedFuncs[mlFunc] = convert(mlFunc);
102 }
103}
104
105// Creates CFG function equivalent to the given ML function.
106CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) {
107 // TODO: ensure that CFG function name is unique.
108 CFGFunction *cfgFunc =
109 new CFGFunction(mlFunc->getName() + "_cfg", mlFunc->getType());
110 module->functionList.push_back(cfgFunc);
111
112 // Generates the body of the CFG function.
113 return FunctionConverter(cfgFunc).convert(mlFunc);
114}
115
116void ModuleConverter::replaceReferences() {
117 for (Function *fn : module->functionList) {
118 switch (fn->getKind()) {
119 case Function::Kind::CFGFunc:
120 replaceReferences(cast<CFGFunction>(fn));
121 break;
122 case Function::Kind::MLFunc:
123 replaceReferences(cast<MLFunction>(fn));
124 break;
125 case Function::Kind::ExtFunc:
126 // nothing to do for external functions
127 break;
128 }
129 }
130}
131
132void ModuleConverter::replaceReferences(CFGFunction *func) {
133 // TODO: NOP for now since function attributes are not yet implemented.
134}
135
136void ModuleConverter::replaceReferences(MLFunction *func) {
137 // TODO: NOP for now since function attributes are not yet implemented.
138}
139
140// Removes all ML functions from the module.
141void ModuleConverter::removeMLFunctions() {
142 std::vector<Function *> &fnList = module->functionList;
143
144 // Delete ML functions and its data.
145 for (auto &fn : fnList) {
146 if (auto mlFunc = dyn_cast<MLFunction>(fn)) {
147 delete mlFunc;
148 fn = nullptr;
149 }
150 }
151
152 // Remove ML functions from the function list.
153 fnList.erase(std::remove_if(fnList.begin(), fnList.end(),
154 [](Function *fn) { return !fn; }),
155 fnList.end());
156}
157
158//===----------------------------------------------------------------------===//
159// Entry point method
160//===----------------------------------------------------------------------===//
161
162void mlir::convertToCFG(Module *module) {
163 ModuleConverter moduleConverter(module);
164 moduleConverter.run();
165 module->verify();
166}