blob: 19e6d626e2386cffb2e618f796fc017d632c9c40 [file] [log] [blame]
Mircea Trofincaf395e2020-07-13 14:12:32 -07001//===- TFUtils.cpp - tensorflow evaluation utilities ----------------------===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements utilities for interfacing with tensorflow C APIs.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Analysis/Utils/TFUtils.h"
15#include "llvm/ADT/Twine.h"
16#include "llvm/Support/Debug.h"
17#include "llvm/Support/ManagedStatic.h"
18#include "llvm/Support/raw_ostream.h"
19
Mircea Trofin4f763b22020-07-14 19:32:37 -070020#include "tensorflow/c/c_api.h"
Mircea Trofincaf395e2020-07-13 14:12:32 -070021#include "tensorflow/c/c_api_experimental.h"
22
23#include <cassert>
24
25using namespace llvm;
26
27namespace {
28
Mircea Trofin4f763b22020-07-14 19:32:37 -070029using TFGraphPtr = std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)>;
30using TFSessionOptionsPtr =
31 std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>;
32using TFStatusPtr = std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
33
Mircea Trofincaf395e2020-07-13 14:12:32 -070034struct TFInitializer {
35 TFInitializer() {
36 assert(!IsInitialized && "TFInitialized should be called only once");
37 int Argc = 1;
38 const char *Name = "";
39 const char **NamePtr = &Name;
40 TF_InitMain(Name, &Argc, const_cast<char ***>(&NamePtr));
41 IsInitialized = true;
42 }
43 bool IsInitialized = false;
44};
45
46llvm::ManagedStatic<TFInitializer> TFLibInitializer;
47
48bool ensureInitTF() { return TFLibInitializer->IsInitialized; }
49
Mircea Trofin4f763b22020-07-14 19:32:37 -070050TFGraphPtr createTFGraph() {
51 return TFGraphPtr(TF_NewGraph(), &TF_DeleteGraph);
Mircea Trofincaf395e2020-07-13 14:12:32 -070052}
53
Mircea Trofin4f763b22020-07-14 19:32:37 -070054TFStatusPtr createTFStatus() {
55 return TFStatusPtr(TF_NewStatus(), &TF_DeleteStatus);
Mircea Trofincaf395e2020-07-13 14:12:32 -070056}
57
Mircea Trofin4f763b22020-07-14 19:32:37 -070058TFSessionOptionsPtr createTFSessionOptions() {
59 return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions);
Mircea Trofincaf395e2020-07-13 14:12:32 -070060}
61} // namespace
62
Mircea Trofin4f763b22020-07-14 19:32:37 -070063namespace llvm {
64class EvaluationResultImpl {
65public:
66 EvaluationResultImpl(size_t OutputSize)
67 : OutputSize(OutputSize), Output(OutputSize){};
68
69 ~EvaluationResultImpl() {
70 for (auto *P : Output)
71 if (P)
72 TF_DeleteTensor(P);
73 }
74
75 EvaluationResultImpl(const EvaluationResultImpl &) = delete;
76 EvaluationResultImpl(EvaluationResultImpl &&Other) = delete;
77 std::vector<TF_Tensor *> &getOutput() { return Output; }
78
79private:
80 const size_t OutputSize;
81 std::vector<TF_Tensor *> Output;
82};
83
84class TFModelEvaluatorImpl {
85public:
86 TFModelEvaluatorImpl(StringRef SavedModelPath,
87 const std::vector<std::string> &InputNames,
88 const std::vector<std::string> &OutputNames,
89 const char *Tags);
90
91 bool isValid() const { return IsValid; }
92 size_t OutputSize() const { return OutputFeed.size(); }
93
94 void evaluate(TF_Tensor **Output, TF_Status *Status) {
95 TF_SessionRun(Session, nullptr, InputFeed.data(), Input.data(),
96 Input.size(), OutputFeed.data(), Output, OutputFeed.size(),
97 nullptr, 0, nullptr, Status);
98 }
99
100 void initInput(size_t Index, TF_DataType Type,
101 const std::vector<int64_t> &Dimensions);
102 const std::vector<TF_Tensor *> &getInput() const { return Input; }
103
104 ~TFModelEvaluatorImpl();
105
106private:
107 /// The objects necessary for carrying out an evaluation of the SavedModel.
108 /// They are expensive to set up, and we maintain them accross all the
109 /// evaluations of the model.
110 TF_Session *Session = nullptr;
111 TFGraphPtr Graph;
112 TFSessionOptionsPtr Options;
113
114 /// The specification of the input nodes.
115 std::vector<TF_Output> InputFeed;
116
117 /// The input tensors. They must match by index of the corresponding InputFeed
118 /// value. We set up the tensors once and just mutate theirs scalars before
119 /// each evaluation. The input tensors keep their value after an evaluation.
120 std::vector<TF_Tensor *> Input;
121
122 /// The specification of the output nodes. When evaluating, the tensors in the
123 /// output tensor vector must match by index the corresponding element in the
124 /// OutputFeed.
125 std::vector<TF_Output> OutputFeed;
126
127 void invalidate() { IsValid = false; }
128
129 bool IsValid = true;
130
131 /// Reusable utility for ensuring we can bind the requested Name to a node in
132 /// the SavedModel Graph.
133 bool checkReportAndInvalidate(const TF_Output &Output, StringRef Name);
134};
135} // namespace llvm
136
137TFModelEvaluatorImpl::TFModelEvaluatorImpl(
138 StringRef SavedModelPath, const std::vector<std::string> &InputNames,
139 const std::vector<std::string> &OutputNames, const char *Tags)
Mircea Trofincaf395e2020-07-13 14:12:32 -0700140 : Graph(createTFGraph()), Options(createTFSessionOptions()),
141 InputFeed(InputNames.size()), Input(InputNames.size()),
142 OutputFeed(OutputNames.size()) {
143 if (!ensureInitTF()) {
144 errs() << "Tensorflow should have been initialized";
145 return;
146 }
147 auto Status = createTFStatus();
148
149 Session = TF_LoadSessionFromSavedModel(Options.get(), nullptr,
150 SavedModelPath.str().c_str(), &Tags, 1,
151 Graph.get(), nullptr, Status.get());
152 if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
153 errs() << TF_Message(Status.get());
Mircea Trofin4f763b22020-07-14 19:32:37 -0700154 invalidate();
Mircea Trofincaf395e2020-07-13 14:12:32 -0700155 }
156 for (size_t I = 0; I < InputNames.size(); ++I) {
157 InputFeed[I] = {
158 TF_GraphOperationByName(Graph.get(), (InputNames[I]).c_str()), 0};
Mircea Trofin4f763b22020-07-14 19:32:37 -0700159 if (!checkReportAndInvalidate(InputFeed[I], InputNames[I]))
Mircea Trofincaf395e2020-07-13 14:12:32 -0700160 return;
161 }
162 for (size_t I = 0; I < OutputNames.size(); ++I) {
163 OutputFeed[I] = {
164 TF_GraphOperationByName(Graph.get(), (OutputNames[I]).c_str()), 0};
Mircea Trofin4f763b22020-07-14 19:32:37 -0700165 if (!checkReportAndInvalidate(OutputFeed[I], OutputNames[I]))
Mircea Trofincaf395e2020-07-13 14:12:32 -0700166 return;
167 }
168}
169
Mircea Trofin4f763b22020-07-14 19:32:37 -0700170TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,
171 const std::vector<std::string> &InputNames,
172 const std::vector<std::string> &OutputNames,
173 const char *Tags)
174 : Impl(new TFModelEvaluatorImpl(SavedModelPath, InputNames, OutputNames,
175 Tags)) {
176 if (!Impl->isValid())
177 Impl.reset();
178}
179
180TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {
Mircea Trofincaf395e2020-07-13 14:12:32 -0700181 for (auto *T : Input) {
182 TF_DeleteTensor(T);
183 }
Mircea Trofincaf395e2020-07-13 14:12:32 -0700184 if (Session == nullptr)
185 return;
186 auto Status = createTFStatus();
187 TF_DeleteSession(Session, Status.get());
188 Session = nullptr;
189 if (TF_GetCode(Status.get()) != TF_Code::TF_OK)
190 errs() << "Could not delete TF session";
191}
192
Mircea Trofin4f763b22020-07-14 19:32:37 -0700193bool TFModelEvaluatorImpl::checkReportAndInvalidate(const TF_Output &Output,
194 StringRef Name) {
195 if (Output.oper)
196 return true;
197 errs() << "Could not find TF_Output named: " + Name;
198 IsValid = false;
199 return IsValid;
200}
201
Mircea Trofincaf395e2020-07-13 14:12:32 -0700202Optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {
203 if (!isValid())
204 return None;
Mircea Trofin4f763b22020-07-14 19:32:37 -0700205 std::unique_ptr<EvaluationResultImpl> Ret =
206 std::make_unique<EvaluationResultImpl>(Impl->OutputSize());
Mircea Trofincaf395e2020-07-13 14:12:32 -0700207 auto Status = createTFStatus();
Mircea Trofin4f763b22020-07-14 19:32:37 -0700208 Impl->evaluate(Ret->getOutput().data(), Status.get());
Mircea Trofincaf395e2020-07-13 14:12:32 -0700209 if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
210 errs() << TF_Message(Status.get());
Mircea Trofin4f763b22020-07-14 19:32:37 -0700211 Impl.reset();
Mircea Trofincaf395e2020-07-13 14:12:32 -0700212 return None;
213 }
Mircea Trofin4f763b22020-07-14 19:32:37 -0700214 return EvaluationResult(std::move(Ret));
Mircea Trofincaf395e2020-07-13 14:12:32 -0700215}
216
Mircea Trofin4f763b22020-07-14 19:32:37 -0700217void TFModelEvaluatorImpl::initInput(size_t Index, TF_DataType Type,
218 const std::vector<int64_t> &Dimensions) {
Mircea Trofincaf395e2020-07-13 14:12:32 -0700219 int64_t TotalSize = TF_DataTypeSize(Type);
220 for (auto &D : Dimensions)
221 TotalSize *= D;
222
223 Input[Index] =
224 TF_AllocateTensor(Type, Dimensions.data(), Dimensions.size(), TotalSize);
225 std::memset(TF_TensorData(Input[Index]), 0, TotalSize);
Mircea Trofin4f763b22020-07-14 19:32:37 -0700226}
227
228void *TFModelEvaluator::getUntypedInput(size_t Index) {
229 return TF_TensorData(Impl->getInput()[Index]);
230}
231
232TFModelEvaluator::EvaluationResult::EvaluationResult(
233 std::unique_ptr<EvaluationResultImpl> Impl)
234 : Impl(std::move(Impl)) {}
235
236TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
237 : Impl(std::move(Other.Impl)) {}
238
239void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
240 return TF_TensorData(Impl->getOutput()[Index]);
241}
242
243void TFModelEvaluator::initInput(size_t Index, int TypeIndex,
244 const std::vector<int64_t> &Dimensions) {
245 Impl->initInput(Index, static_cast<TF_DataType>(TypeIndex), Dimensions);
246}
247
248template <> int TFModelEvaluator::getModelTypeIndex<float>() {
249 return TF_FLOAT;
250}
251
252template <> int TFModelEvaluator::getModelTypeIndex<double>() {
253 return TF_DOUBLE;
254}
255
256template <> int TFModelEvaluator::getModelTypeIndex<int8_t>() {
257 return TF_INT8;
258}
259
260template <> int TFModelEvaluator::getModelTypeIndex<uint8_t>() {
261 return TF_UINT8;
262}
263
264template <> int TFModelEvaluator::getModelTypeIndex<int16_t>() {
265 return TF_INT16;
266}
267
268template <> int TFModelEvaluator::getModelTypeIndex<uint16_t>() {
269 return TF_UINT16;
270}
271
272template <> int TFModelEvaluator::getModelTypeIndex<int32_t>() {
273 return TF_INT32;
274}
275
276template <> int TFModelEvaluator::getModelTypeIndex<uint32_t>() {
277 return TF_UINT32;
278}
279
280template <> int TFModelEvaluator::getModelTypeIndex<int64_t>() {
281 return TF_INT64;
282}
283
284template <> int TFModelEvaluator::getModelTypeIndex<uint64_t>() {
285 return TF_UINT64;
286}
287
288TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
289TFModelEvaluator::~TFModelEvaluator() {}