blob: 9487cf905943b9f39512d27e79ce74e17783c9fc [file] [log] [blame]
Tatiana Shpeisman6708b452018-07-24 10:15:13 -07001//===- ConvertToCFG.cpp - ML function to CFG function converstion ---------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements APIs to convert ML functions into CFG functions.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Transforms/ConvertToCFG.h"
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/CFGFunction.h"
25#include "mlir/IR/MLFunction.h"
26#include "mlir/IR/Module.h"
27#include "llvm/ADT/DenseSet.h"
28using namespace mlir;
29
30//===----------------------------------------------------------------------===//
31// ML function converter
32//===----------------------------------------------------------------------===//
33
34namespace {
35// Generates CFG function equivalent to the given ML function.
36class FunctionConverter {
37public:
38 FunctionConverter(CFGFunction *cfgFunc)
39 : cfgFunc(cfgFunc), builder(cfgFunc) {}
40 CFGFunction *convert(const MLFunction *mlFunc);
41
42private:
43 CFGFunction *cfgFunc;
44 CFGFuncBuilder builder;
45};
46} // namespace
47
48CFGFunction *FunctionConverter::convert(const MLFunction *mlFunc) {
49 builder.createBlock();
50
51 // Creates return instruction with no operands.
52 // TODO: convert return operands.
53 builder.createReturnInst({});
54
55 // TODO: convert ML function body.
56
57 return cfgFunc;
58}
59
60//===----------------------------------------------------------------------===//
61// Module converter
62//===----------------------------------------------------------------------===//
63namespace {
64// ModuleConverter class does CFG conversion for the whole module.
65class ModuleConverter {
66public:
67 explicit ModuleConverter(Module *module) : module(module) {}
68 void run();
69
70private:
71 // Generates CFG functions for all ML functions in the module.
72 void convertMLFunctions();
73 // Generates CFG function for the given ML function.
74 CFGFunction *convert(const MLFunction *mlFunc);
75 // Replaces all ML function references in the module
76 // with references to the generated CFG functions.
77 void replaceReferences();
78 // Replaces function references in the given function.
79 void replaceReferences(CFGFunction *cfgFunc);
80 void replaceReferences(MLFunction *mlFunc);
81 // Removes all ML funtions from the module.
82 void removeMLFunctions();
83
84 // Map from ML functions to generated CFG functions.
85 llvm::DenseMap<const MLFunction *, CFGFunction *> generatedFuncs;
86 Module *module;
87};
88} // end anonymous namespace
89
90// Iterates over all functions in the module generating CFG functions
91// equivalent to ML functions and replacing references to ML functions
92// with references to the generated ML functions.
93void ModuleConverter::run() {
94 convertMLFunctions();
95 replaceReferences();
96}
97
98void ModuleConverter::convertMLFunctions() {
Chris Lattnera8e47672018-07-25 14:08:16 -070099 for (Function &fn : *module) {
100 if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700101 generatedFuncs[mlFunc] = convert(mlFunc);
102 }
103}
104
105// Creates CFG function equivalent to the given ML function.
106CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) {
107 // TODO: ensure that CFG function name is unique.
Chris Lattnera8e47672018-07-25 14:08:16 -0700108 auto *cfgFunc =
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700109 new CFGFunction(mlFunc->getName() + "_cfg", mlFunc->getType());
Chris Lattnera8e47672018-07-25 14:08:16 -0700110 module->getFunctions().push_back(cfgFunc);
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700111
112 // Generates the body of the CFG function.
113 return FunctionConverter(cfgFunc).convert(mlFunc);
114}
115
116void ModuleConverter::replaceReferences() {
Chris Lattnera8e47672018-07-25 14:08:16 -0700117 for (Function &fn : *module) {
118 switch (fn.getKind()) {
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700119 case Function::Kind::CFGFunc:
Chris Lattnera8e47672018-07-25 14:08:16 -0700120 replaceReferences(&cast<CFGFunction>(fn));
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700121 break;
122 case Function::Kind::MLFunc:
Chris Lattnera8e47672018-07-25 14:08:16 -0700123 replaceReferences(&cast<MLFunction>(fn));
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700124 break;
125 case Function::Kind::ExtFunc:
126 // nothing to do for external functions
127 break;
128 }
129 }
130}
131
132void ModuleConverter::replaceReferences(CFGFunction *func) {
133 // TODO: NOP for now since function attributes are not yet implemented.
134}
135
136void ModuleConverter::replaceReferences(MLFunction *func) {
137 // TODO: NOP for now since function attributes are not yet implemented.
138}
139
140// Removes all ML functions from the module.
141void ModuleConverter::removeMLFunctions() {
Chris Lattnera8e47672018-07-25 14:08:16 -0700142 // Delete ML functions from the module.
143 for (auto it = module->begin(), e = module->end(); it != e;) {
144 // Manipulate iterator carefully to avoid deleting a function we're pointing
145 // at.
146 Function &fn = *it++;
147 if (auto mlFunc = dyn_cast<MLFunction>(&fn))
148 mlFunc->eraseFromModule();
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700149 }
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700150}
151
152//===----------------------------------------------------------------------===//
153// Entry point method
154//===----------------------------------------------------------------------===//
155
156void mlir::convertToCFG(Module *module) {
157 ModuleConverter moduleConverter(module);
158 moduleConverter.run();
159 module->verify();
160}