blob: 077e068d8630f255182ff0d4c037e2a9ad3dff41 [file] [log] [blame]
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -07001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
David Gross3ced3cf2017-09-13 10:45:21 -070017#define LOG_TAG "ExecutionBuilder"
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -070018
David Gross3ced3cf2017-09-13 10:45:21 -070019#include "ExecutionBuilder.h"
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -070020
David Gross83e24dc2017-09-10 14:31:58 -070021#include "CompilationBuilder.h"
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -070022#include "CpuExecutor.h"
23#include "HalInterfaces.h"
24#include "Manager.h"
25#include "ModelBuilder.h"
26
Michael Butler689d8922017-09-01 10:58:46 -070027#include <mutex>
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -070028#include <thread>
Michael Butler689d8922017-09-01 10:58:46 -070029#include <vector>
30
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -070031namespace android {
32namespace nn {
33
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -070034int ModelArgumentInfo::setFromPointer(const Operand& operand,
35 const ANeuralNetworksOperandType* type, void* data,
36 uint32_t length) {
37 int n = updateDimensionInfo(operand, type);
38 if (n != ANEURALNETWORKS_NO_ERROR) {
39 return n;
40 }
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -070041 if (data == nullptr) {
42 if (length) {
43 LOG(ERROR) << "Setting argument as having no value but non-zero length passed.";
44 return ANEURALNETWORKS_BAD_DATA;
45 }
46 state = ModelArgumentInfo::HAS_NO_VALUE;
47 } else {
48 state = ModelArgumentInfo::POINTER;
49 }
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -070050 buffer = data;
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -070051 locationAndLength = {.poolIndex = 0, .offset = 0, .length = length};
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -070052 return ANEURALNETWORKS_NO_ERROR;
53}
54
55int ModelArgumentInfo::setFromMemory(const Operand& operand, const ANeuralNetworksOperandType* type,
56 uint32_t poolIndex, uint32_t offset, uint32_t length) {
57 int n = updateDimensionInfo(operand, type);
58 if (n != ANEURALNETWORKS_NO_ERROR) {
59 return n;
60 }
61 state = ModelArgumentInfo::MEMORY;
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -070062 locationAndLength = {.poolIndex = poolIndex, .offset = offset, .length = length};
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -070063 buffer = nullptr;
64 return ANEURALNETWORKS_NO_ERROR;
65}
66
David Gross8fb14e92017-10-04 15:16:02 -070067int ModelArgumentInfo::setFromTemporaryMemory(const Operand& operand,
68 uint32_t poolIndex, uint32_t offset) {
David Gross4d83c522017-10-03 23:20:24 -070069 dimensions = operand.dimensions;
David Gross96811e22017-10-02 14:40:09 -070070 state = ModelArgumentInfo::MEMORY;
David Gross4d83c522017-10-03 23:20:24 -070071 locationAndLength =
David Gross8fb14e92017-10-04 15:16:02 -070072 {.poolIndex = poolIndex, .offset = offset, .length = sizeOfData(operand)};
David Gross96811e22017-10-02 14:40:09 -070073 buffer = nullptr;
74 return ANEURALNETWORKS_NO_ERROR;
75}
76
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -070077int ModelArgumentInfo::updateDimensionInfo(const Operand& operand,
78 const ANeuralNetworksOperandType* newType) {
79 if (newType == nullptr) {
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -070080 dimensions = hidl_vec<uint32_t>();
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -070081 } else {
Jean-Luc Brouilletd2d0c032017-09-12 12:03:41 -070082 uint32_t count = newType->dimensionCount;
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -070083 if (static_cast<OperandType>(newType->type) != operand.type ||
84 count != operand.dimensions.size()) {
David Gross3ced3cf2017-09-13 10:45:21 -070085 LOG(ERROR) << "ANeuralNetworksExecution_setInput/Output incompatible types";
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -070086 return ANEURALNETWORKS_BAD_DATA;
87 }
88 for (uint32_t i = 0; i < count; i++) {
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -070089 dimensions[i] = newType->dimensions[i];
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -070090 }
91 }
92 return ANEURALNETWORKS_NO_ERROR;
93}
94
David Gross3ced3cf2017-09-13 10:45:21 -070095ExecutionBuilder::ExecutionBuilder(const CompilationBuilder* compilation) :
David Gross83e24dc2017-09-10 14:31:58 -070096 mModel(compilation->mModel),
David Gross1f438152017-09-28 09:18:51 -070097 mPlan(&compilation->mPlan),
David Gross83e24dc2017-09-10 14:31:58 -070098 mInputs(mModel->inputCount()),
Jean-Luc Brouillet1da8fed2017-10-11 22:34:04 -070099 mOutputs(mModel->outputCount()) {
Miao Wang820215d2017-10-04 19:45:45 -0700100 VLOG(EXECUTION) << "ExecutionBuilder::ExecutionBuilder";
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700101}
102
David Gross3ced3cf2017-09-13 10:45:21 -0700103int ExecutionBuilder::setInput(uint32_t index, const ANeuralNetworksOperandType* type,
Jean-Luc Brouillete127e492017-09-27 23:59:20 -0700104 const void* buffer, size_t length) {
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700105 uint32_t count = static_cast<uint32_t>(mInputs.size());
106 if (index >= count) {
David Gross3ced3cf2017-09-13 10:45:21 -0700107 LOG(ERROR) << "ANeuralNetworksExecution_setInput bad index " << index << " " << count;
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700108 return ANEURALNETWORKS_BAD_DATA;
109 }
Jean-Luc Brouillete127e492017-09-27 23:59:20 -0700110 if (type != nullptr) {
111 int n = validateOperandType(*type, "ANeuralNetworksExecution_setInput", false);
112 if (n != ANEURALNETWORKS_NO_ERROR) {
113 return n;
114 }
115 }
116 if (length > 0xFFFFFFFF) {
117 LOG(ERROR) << "ANeuralNetworksExecution_setInput input exceeds max length " << length;
118 return ANEURALNETWORKS_BAD_DATA;
119 }
120 uint32_t l = static_cast<uint32_t>(length);
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -0700121 return mInputs[index].setFromPointer(mModel->getInputOperand(index), type,
Jean-Luc Brouillete127e492017-09-27 23:59:20 -0700122 const_cast<void*>(buffer), l);
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700123}
124
David Gross3ced3cf2017-09-13 10:45:21 -0700125int ExecutionBuilder::setInputFromMemory(uint32_t index, const ANeuralNetworksOperandType* type,
Jean-Luc Brouillete127e492017-09-27 23:59:20 -0700126 const Memory* memory, size_t offset, size_t length) {
David Gross96811e22017-10-02 14:40:09 -0700127 // Should be similar to StepExecutor::setInputOrOutputFromTemporaryMemory()
128
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700129 uint32_t count = static_cast<uint32_t>(mInputs.size());
130 if (index >= count) {
David Gross3ced3cf2017-09-13 10:45:21 -0700131 LOG(ERROR) << "ANeuralNetworksExecution_setInputFromMemory bad index " << index << " "
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700132 << count;
133 return ANEURALNETWORKS_BAD_DATA;
134 }
Miao Wang105807d2017-09-05 14:41:05 -0700135 if (!memory->validateSize(offset, length)) {
136 return ANEURALNETWORKS_BAD_DATA;
137 }
Jean-Luc Brouillete127e492017-09-27 23:59:20 -0700138 // TODO validate the rest
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -0700139 uint32_t poolIndex = mMemories.add(memory);
140 return mInputs[index].setFromMemory(mModel->getInputOperand(index), type, poolIndex, offset,
141 length);
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700142}
143
David Gross3ced3cf2017-09-13 10:45:21 -0700144int ExecutionBuilder::setOutput(uint32_t index, const ANeuralNetworksOperandType* type, void* buffer,
Jean-Luc Brouillete127e492017-09-27 23:59:20 -0700145 size_t length) {
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700146 uint32_t count = static_cast<uint32_t>(mOutputs.size());
147 if (index >= count) {
David Gross3ced3cf2017-09-13 10:45:21 -0700148 LOG(ERROR) << "ANeuralNetworksExecution_setOutput bad index " << index << " " << count;
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700149 return ANEURALNETWORKS_BAD_DATA;
150 }
Jean-Luc Brouillete127e492017-09-27 23:59:20 -0700151 if (type != nullptr) {
152 int n = validateOperandType(*type, "ANeuralNetworksExecution_setOutput", false);
153 if (n != ANEURALNETWORKS_NO_ERROR) {
154 return n;
155 }
156 }
157 if (length > 0xFFFFFFFF) {
158 LOG(ERROR) << "ANeuralNetworksExecution_setOutput input exceeds max length " << length;
159 return ANEURALNETWORKS_BAD_DATA;
160 }
161 uint32_t l = static_cast<uint32_t>(length);
162 return mOutputs[index].setFromPointer(mModel->getOutputOperand(index), type, buffer, l);
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700163}
164
David Gross3ced3cf2017-09-13 10:45:21 -0700165int ExecutionBuilder::setOutputFromMemory(uint32_t index, const ANeuralNetworksOperandType* type,
Jean-Luc Brouillete127e492017-09-27 23:59:20 -0700166 const Memory* memory, size_t offset, size_t length) {
David Gross96811e22017-10-02 14:40:09 -0700167 // Should be similar to StepExecutor::setInputOrOutputFromTemporaryMemory()
168
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700169 uint32_t count = static_cast<uint32_t>(mOutputs.size());
170 if (index >= count) {
David Gross3ced3cf2017-09-13 10:45:21 -0700171 LOG(ERROR) << "ANeuralNetworksExecution_setOutputFromMemory bad index " << index << " "
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700172 << count;
173 return ANEURALNETWORKS_BAD_DATA;
174 }
Miao Wang105807d2017-09-05 14:41:05 -0700175 if (!memory->validateSize(offset, length)) {
176 return ANEURALNETWORKS_BAD_DATA;
177 }
Jean-Luc Brouillete127e492017-09-27 23:59:20 -0700178 // TODO validate the rest
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -0700179 uint32_t poolIndex = mMemories.add(memory);
180 return mOutputs[index].setFromMemory(mModel->getOutputOperand(index), type, poolIndex, offset,
181 length);
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700182}
183
David Gross5e8feed2017-10-04 23:05:05 -0700184// Attempt synchronous execution of full model on CPU.
185// Ensure that executionCallback->notify() is called.
186static void cpuFallbackFull(const ExecutionBuilder* executionBuilder,
187 const sp<ExecutionCallback>& executionCallback) {
Miao Wang820215d2017-10-04 19:45:45 -0700188 VLOG(EXECUTION) << "cpuFallbackFull";
David Gross5e8feed2017-10-04 23:05:05 -0700189 StepExecutor executor(executionBuilder, executionBuilder->getModel(),
190 nullptr /* no IDevice, so CPU */,
191 nullptr /* no IPreparedModel */);
192 executor.mapInputsAndOutputsTrivially();
193 sp<ExecutionCallback> fallbackCallback;
194 if (executor.startCompute(&fallbackCallback) != ANEURALNETWORKS_NO_ERROR) {
195 executionCallback->notify(ErrorStatus::GENERAL_FAILURE);
196 return;
197 }
198 fallbackCallback->wait();
199 executionCallback->notify(fallbackCallback->getStatus());
200}
201
202// Attempt synchronous execution on CPU.
203// (1) First, attempt to execute this step on CPU. If successful,
204// return true. (Do not call executionCallback->notify().)
205// (2) If unsuccessful, attempt to execute the full model on CPU,
206// ensure that executionCallback->notify() is called, and return
207// false.
208static bool cpuFallbackPartial(const ExecutionBuilder* executionBuilder,
209 const ExecutionPlan* plan,
210 std::shared_ptr<ExecutionPlan::Controller> controller,
211 const sp<ExecutionCallback>& executionCallback) {
Miao Wang820215d2017-10-04 19:45:45 -0700212 VLOG(EXECUTION) << "cpuFallbackPartial";
David Gross5e8feed2017-10-04 23:05:05 -0700213 std::shared_ptr<StepExecutor> executor;
214 int n = plan->fallback(controller, &executor);
215 if (n != ANEURALNETWORKS_NO_ERROR || executor->isCpu()) {
216 cpuFallbackFull(executionBuilder, executionCallback);
217 return false;
218 }
219 sp<ExecutionCallback> fallbackCallback;
220 if (executor->startComputeOnCpu(&fallbackCallback) != ANEURALNETWORKS_NO_ERROR) {
221 cpuFallbackFull(executionBuilder, executionCallback);
222 return false;
223 }
224 fallbackCallback->wait();
225 if (fallbackCallback->getStatus() != ErrorStatus::NONE) {
226 cpuFallbackFull(executionBuilder, executionCallback);
227 return false;
228 }
229 return true;
230}
231
232static void asyncStartComputePartitioned(const ExecutionBuilder* executionBuilder,
233 const ExecutionPlan* plan,
David Grosse3178822017-10-04 16:05:58 -0700234 std::shared_ptr<ExecutionPlan::Controller> controller,
David Gross5e8feed2017-10-04 23:05:05 -0700235 bool allowFallback,
236 const sp<ExecutionCallback>& executionCallback) {
Miao Wang820215d2017-10-04 19:45:45 -0700237 VLOG(EXECUTION) << "ExecutionBuilder::startCompute (from plan, iteratively)";
David Grosse3178822017-10-04 16:05:58 -0700238 while (true) {
239 std::shared_ptr<StepExecutor> executor;
Miao Wang820215d2017-10-04 19:45:45 -0700240 VLOG(EXECUTION) << "looking for next StepExecutor";
David Grosse3178822017-10-04 16:05:58 -0700241 int n = plan->next(controller, &executor);
David Gross5e8feed2017-10-04 23:05:05 -0700242 if (n != ANEURALNETWORKS_NO_ERROR) {
243 if (allowFallback) {
244 cpuFallbackFull(executionBuilder, executionCallback);
245 } else {
246 executionCallback->notify(ErrorStatus::GENERAL_FAILURE);
247 }
248 return;
249 }
250 if (executor == nullptr) {
251 executionCallback->notify(ErrorStatus::NONE);
David Grosse3178822017-10-04 16:05:58 -0700252 return;
253 }
254
255 sp<ExecutionCallback> stepCallback;
256 n = executor->startCompute(&stepCallback);
257 if (n != ANEURALNETWORKS_NO_ERROR) {
David Gross5e8feed2017-10-04 23:05:05 -0700258 if (allowFallback) {
259 if (cpuFallbackPartial(executionBuilder, plan, controller, executionCallback)) {
260 // Successfully executed one step on CPU.
261 continue;
262 } else {
263 // Either successfully executed entire plan on
264 // CPU, or tried and failed to do so.
265 return;
266 }
267 } else {
268 executionCallback->notify(ErrorStatus::GENERAL_FAILURE);
269 return;
270 }
David Grosse3178822017-10-04 16:05:58 -0700271 }
272 stepCallback->wait();
273 ErrorStatus status = stepCallback->getStatus();
274 if (status != ErrorStatus::NONE) {
David Gross5e8feed2017-10-04 23:05:05 -0700275 if (allowFallback) {
276 if (cpuFallbackPartial(executionBuilder, plan, controller, executionCallback)) {
277 // Successfully executed one step on CPU.
278 continue;
279 } else {
280 // Either successfully executed entire plan on
281 // CPU, or tried and failed to do so.
282 return;
283 }
284 } else {
285 executionCallback->notify(status);
286 return;
287 }
David Grosse3178822017-10-04 16:05:58 -0700288 }
289 }
290}
291
Michael Butler033b8a62017-09-22 18:21:59 -0700292int ExecutionBuilder::startCompute(sp<ExecutionCallback>* synchronizationCallback) {
293 *synchronizationCallback = nullptr;
David Gross425b2592017-09-13 19:33:14 -0700294
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700295 // TODO validate that we have full types for all inputs and outputs,
296 // that the graph is not cyclic,
Yang Nif1817c62017-08-22 16:18:50 -0700297
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700298 for (auto& p : mInputs) {
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -0700299 if (p.state == ModelArgumentInfo::UNSPECIFIED) {
David Gross3ced3cf2017-09-13 10:45:21 -0700300 LOG(ERROR) << "ANeuralNetworksExecution_startCompute not all inputs specified";
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700301 return ANEURALNETWORKS_BAD_DATA;
302 }
303 }
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700304 for (auto& p : mOutputs) {
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -0700305 if (p.state == ModelArgumentInfo::UNSPECIFIED) {
David Gross3ced3cf2017-09-13 10:45:21 -0700306 LOG(ERROR) << "ANeuralNetworksExecution_startCompute not all outputs specified";
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700307 return ANEURALNETWORKS_BAD_DATA;
308 }
309 }
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700310
David Gross5e8feed2017-10-04 23:05:05 -0700311#ifndef DISABLE_PARTITIONED_EXECUTION
312 {
313 // TODO: Remove the non-plan-based path once we've fully integrated ExecutionPlan
314 // with the compilation and execution phases of the NN API? Or retain that path
315 // as a fallback in the case of partitioning failure?
316 //
317 // TODO: Entire plan-based-path should run in an asynchronous thread --
318 // take the asynchronous thread logic out of startComputeOnCpu() and use
319 // it to wrap the plan-based-path.
320 const uint32_t partitioning = DeviceManager::get()->getPartitioning();
321 if (partitioning > 0) {
322 const bool allowFallback = DeviceManager::partitioningAllowsFallback(partitioning);
323 std::shared_ptr<ExecutionPlan::Controller> controller = mPlan->makeController(this);
324 if (controller == nullptr) {
325 if (!allowFallback) {
326 return ANEURALNETWORKS_OP_FAILED;
327 }
328 } else {
329 // TODO: use a thread pool
330
331 // Prepare the callback for asynchronous execution.
332 // sp<ExecutionCallback> object is returned when the
333 // execution has been successfully launched, otherwise a
334 // nullptr is returned. The executionCallback is
335 // abstracted in the NN API as an "event".
336 sp<ExecutionCallback> executionCallback = new ExecutionCallback();
337 std::thread thread(asyncStartComputePartitioned, this, mPlan, controller,
338 allowFallback,
339 executionCallback);
340 executionCallback->bind_thread(std::move(thread));
341 *synchronizationCallback = executionCallback;
342 return ANEURALNETWORKS_NO_ERROR;
David Grossa2a03632017-10-03 12:49:47 -0700343 }
David Gross1f438152017-09-28 09:18:51 -0700344 }
345 }
David Gross5e8feed2017-10-04 23:05:05 -0700346#else
347 {
348 // Find a driver that can handle all the operations.
349 // TODO: Does not handle CPU fallback (which is tricky because
350 // StepExecutor::startCompute() is designed as
351 // asynchronous).
352 // TODO: Does not actually behave asynchronously (because
353 // StepExecutor::startCompute() isn't actually asynchronous
354 // on a device as opposed to a CPU).
355 Model hidlModel;
356 mModel->setHidlModel(&hidlModel);
357 const std::vector<std::shared_ptr<Device>>& devices = DeviceManager::get()->getDrivers();
358 for (const auto& device : devices) {
359 hidl_vec<bool> supports;
Miao Wang820215d2017-10-04 19:45:45 -0700360 VLOG(EXECUTION) << "Checking " << device->getName();
David Gross5e8feed2017-10-04 23:05:05 -0700361 device->getSupportedOperations(hidlModel, &supports);
362 if (std::find(supports.begin(), supports.end(), false) == supports.end()) {
Miao Wang820215d2017-10-04 19:45:45 -0700363 VLOG(EXECUTION) << "ExecutionBuilder::startCompute (without plan) on " << device->getName();
David Gross5e8feed2017-10-04 23:05:05 -0700364 StepExecutor executor(this, mModel, device->getInterface(),
365 nullptr /* no IPreparedModel, so compile */);
366 executor.mapInputsAndOutputsTrivially();
367 return executor.startCompute(synchronizationCallback);
368 }
Jean-Luc Brouilletef22aa52017-09-15 22:53:49 -0700369 }
370 }
David Gross5e8feed2017-10-04 23:05:05 -0700371#endif // DISABLE_PARTITIONED_EXECUTION
372
373 // Run on the CPU.
Miao Wang820215d2017-10-04 19:45:45 -0700374 VLOG(EXECUTION) << "ExecutionBuilder::startCompute (without plan) on CPU";
David Grossb2604912017-10-01 15:26:33 -0700375 StepExecutor executor(this, mModel,
376 nullptr /* no IDevice, so CPU */,
377 nullptr /* no IPreparedModel */);
378 executor.mapInputsAndOutputsTrivially();
Michael Butler033b8a62017-09-22 18:21:59 -0700379 return executor.startCompute(synchronizationCallback);
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700380}
381
382// Figures out how to place each of the input or outputs in a buffer. This just does the layout,
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700383// it does not copy data. Aligns each input a bit.
David Grossb2604912017-10-01 15:26:33 -0700384int StepExecutor::allocatePointerArgumentsToPool(std::vector<ModelArgumentInfo>* args,
385 Memory* memory) {
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700386 uint32_t nextPoolIndex = mMemories.size();
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700387 int64_t total = 0;
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700388 for (auto& info : *args) {
389 if (info.state == ModelArgumentInfo::POINTER) {
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -0700390 DataLocation& loc = info.locationAndLength;
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700391 // TODO Good enough alignment?
392 total += alignBytesNeeded(static_cast<uint32_t>(total), loc.length);
393 loc.poolIndex = nextPoolIndex;
394 loc.offset = static_cast<uint32_t>(total);
395 total += loc.length;
396 }
397 };
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700398 if (total > 0xFFFFFFFF) {
David Gross3ced3cf2017-09-13 10:45:21 -0700399 LOG(ERROR) << "ANeuralNetworksExecution_startCompute Size of all inputs or outputs exceeds "
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700400 "2^32.";
401 return ANEURALNETWORKS_BAD_DATA;
402 }
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700403 hidl_memory hidlMemory;
404 if (total > 0) {
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -0700405 memory->create(total); // TODO check error
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700406 mMemories.add(memory);
407 }
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700408 return ANEURALNETWORKS_NO_ERROR;
409}
410
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -0700411static void setRequestArgumentArray(const std::vector<ModelArgumentInfo>& argumentInfos,
Jean-Luc Brouilleta4e2ee82017-09-09 19:27:25 -0700412 hidl_vec<RequestArgument>* ioInfos) {
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700413 size_t count = argumentInfos.size();
414 ioInfos->resize(count);
415 for (size_t i = 0; i < count; i++) {
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -0700416 const auto& info = argumentInfos[i];
417 (*ioInfos)[i] = { .hasNoValue = info.state == ModelArgumentInfo::HAS_NO_VALUE,
418 .location = info.locationAndLength,
419 .dimensions = info.dimensions,
420 };
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700421 }
422}
423
David Grossb2604912017-10-01 15:26:33 -0700424StepExecutor::StepExecutor(const ExecutionBuilder* executionBuilder,
425 const ModelBuilder* model,
426 sp<IDevice> driver, sp<IPreparedModel> preparedModel) :
427 mExecutionBuilder(executionBuilder), mModel(model),
David Gross891b10f2017-10-01 20:48:10 -0700428 mDriver(driver), mPreparedModel(preparedModel),
429 mInputs(model->inputCount()), mOutputs(model->outputCount()) {}
David Grossb2604912017-10-01 15:26:33 -0700430
431void StepExecutor::mapInputsAndOutputsTrivially() {
432 mInputs = mExecutionBuilder->mInputs;
433 mOutputs = mExecutionBuilder->mOutputs;
434 mMemories = mExecutionBuilder->mMemories;
435}
436
David Gross891b10f2017-10-01 20:48:10 -0700437void StepExecutor::mapInputOrOutput(const ModelArgumentInfo& builderInputOrOutput,
438 ModelArgumentInfo* executorInputOrOutput) {
439 *executorInputOrOutput = builderInputOrOutput;
440 switch (executorInputOrOutput->state) {
441 default:
442 nnAssert(!"unexpected ModelArgumentInfo::state");
443 case ModelArgumentInfo::POINTER:
444 case ModelArgumentInfo::UNSPECIFIED:
445 break;
446 case ModelArgumentInfo::MEMORY: {
447 const uint32_t builderPoolIndex =
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -0700448 builderInputOrOutput.locationAndLength.poolIndex;
David Gross891b10f2017-10-01 20:48:10 -0700449 const Memory* memory = mExecutionBuilder->mMemories[builderPoolIndex];
450 const uint32_t executorPoolIndex = mMemories.add(memory);
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -0700451 executorInputOrOutput->locationAndLength.poolIndex =
David Gross891b10f2017-10-01 20:48:10 -0700452 executorPoolIndex;
453 break;
454 }
455 }
456}
457
David Gross96811e22017-10-02 14:40:09 -0700458int StepExecutor::setInputOrOutputFromTemporaryMemory(const Operand& inputOrOutputOperand,
David Gross8fb14e92017-10-04 15:16:02 -0700459 const Memory* memory, uint32_t offset,
David Gross96811e22017-10-02 14:40:09 -0700460 ModelArgumentInfo* inputOrOutputInfo) {
461 // Should be similar to
462 // ExecutionBuilder::setInputFromMemory()
463 // ExecutionBuilder::setOutputFromMemory()
464
465 uint32_t poolIndex = mMemories.add(memory);
David Gross8fb14e92017-10-04 15:16:02 -0700466 return inputOrOutputInfo->setFromTemporaryMemory(inputOrOutputOperand, poolIndex, offset);
David Gross96811e22017-10-02 14:40:09 -0700467}
468
Michael Butler033b8a62017-09-22 18:21:59 -0700469int StepExecutor::startCompute(sp<ExecutionCallback>* synchronizationCallback) {
David Grossb2604912017-10-01 15:26:33 -0700470 if (mDriver == nullptr) {
Michael Butler033b8a62017-09-22 18:21:59 -0700471 return startComputeOnCpu(synchronizationCallback);
David Grossb2604912017-10-01 15:26:33 -0700472 } else {
Michael Butler033b8a62017-09-22 18:21:59 -0700473 return startComputeOnDevice(synchronizationCallback);
David Grossb2604912017-10-01 15:26:33 -0700474 }
475}
476
Michael Butler033b8a62017-09-22 18:21:59 -0700477int StepExecutor::startComputeOnDevice(sp<ExecutionCallback>* synchronizationCallback) {
David Grossb2604912017-10-01 15:26:33 -0700478 nnAssert(mDriver != nullptr);
479
Michael Butler033b8a62017-09-22 18:21:59 -0700480 *synchronizationCallback = nullptr;
David Gross425b2592017-09-13 19:33:14 -0700481
David Grossb2604912017-10-01 15:26:33 -0700482 // TODO: Remove the mPreparedModel == nullptr case once we've fully integrated
David Gross1f438152017-09-28 09:18:51 -0700483 // ExecutionPlan with the compilation and execution phases of the NN API
David Grossb2604912017-10-01 15:26:33 -0700484 if (mPreparedModel == nullptr) {
David Gross1f438152017-09-28 09:18:51 -0700485 Model model;
486 mModel->setHidlModel(&model);
Michael Butler5f916fc2017-09-11 21:10:36 -0700487
David Gross1f438152017-09-28 09:18:51 -0700488 // TODO Dangerous! In async, the model will outlive it here. Safe for now
Michael Butler033b8a62017-09-22 18:21:59 -0700489 sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
490 Return<ErrorStatus> prepareLaunchStatus =
491 mDriver->prepareModel(model, preparedModelCallback);
492 if (!prepareLaunchStatus.isOk() || prepareLaunchStatus != ErrorStatus::NONE) {
493 return ANEURALNETWORKS_OP_FAILED;
494 }
Michael Butler5f916fc2017-09-11 21:10:36 -0700495
Michael Butler033b8a62017-09-22 18:21:59 -0700496 // Immediately synchronize with callback object for now
David Gross1f438152017-09-28 09:18:51 -0700497 // TODO: change to asynchronous later
Michael Butler033b8a62017-09-22 18:21:59 -0700498 preparedModelCallback->wait();
499 ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
500 mPreparedModel = preparedModelCallback->getPreparedModel();
501 if (prepareReturnStatus != ErrorStatus::NONE || mPreparedModel == nullptr) {
David Gross1f438152017-09-28 09:18:51 -0700502 return ANEURALNETWORKS_OP_FAILED;
503 }
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700504 }
505
David Grosse413eef2017-09-29 11:43:32 -0700506 // We separate the input & output pools so that we reduce the copying done if we
507 // do an eventual remoting (hidl_memory->update()). We could also use it to set
508 // protection on read only memory but that's not currently done.
509 Memory inputPointerArguments;
510 Memory outputPointerArguments;
511
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700512 // Layout the input and output data
David Grosse413eef2017-09-29 11:43:32 -0700513 int n = allocatePointerArgumentsToPool(&mInputs, &inputPointerArguments);
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700514 if (n != ANEURALNETWORKS_NO_ERROR) {
515 return n;
516 }
David Grosse413eef2017-09-29 11:43:32 -0700517 n = allocatePointerArgumentsToPool(&mOutputs, &outputPointerArguments);
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700518 if (n != ANEURALNETWORKS_NO_ERROR) {
519 return n;
520 }
521
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700522 // Copy the input data that was specified via a pointer.
David Grosse413eef2017-09-29 11:43:32 -0700523 // inputPointerArguments.update();
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700524 for (auto& info : mInputs) {
525 if (info.state == ModelArgumentInfo::POINTER) {
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -0700526 DataLocation& loc = info.locationAndLength;
Jean-Luc Brouillet2150f1d2017-09-01 13:29:08 -0700527 uint8_t* data = nullptr;
David Grosse413eef2017-09-29 11:43:32 -0700528 int n = inputPointerArguments.getPointer(&data);
Jean-Luc Brouillet2150f1d2017-09-01 13:29:08 -0700529 if (n != ANEURALNETWORKS_NO_ERROR) {
530 return n;
531 }
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700532 memcpy(data + loc.offset, info.buffer, loc.length);
533 }
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700534 }
David Grosse413eef2017-09-29 11:43:32 -0700535 // TODO: Add inputPointerArguments.commit() and .update() at all the right places
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700536
537 Request request;
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -0700538 setRequestArgumentArray(mInputs, &request.inputs);
539 setRequestArgumentArray(mOutputs, &request.outputs);
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700540 uint32_t count = mMemories.size();
541 request.pools.resize(count);
542 for (uint32_t i = 0; i < count; i++) {
543 request.pools[i] = mMemories[i]->getHidlMemory();
544 }
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700545
Michael Butler033b8a62017-09-22 18:21:59 -0700546 // Prepare the callback for asynchronous execution. sp<ExecutionCallback>
547 // object is returned when the execution has been successfully launched,
548 // otherwise a nullptr is returned. The executionCallback is abstracted in
549 // the NN API as an "event".
Michael Butler689d8922017-09-01 10:58:46 -0700550 //
Michael Butler033b8a62017-09-22 18:21:59 -0700551 // The sp is used for ref-counting purposes. Without it, the HIDL service
552 // could attempt to communicate with a dead callback object.
553 //
554 // TODO: Explain the "dead callback" problem further, either here or
Michael Butler689d8922017-09-01 10:58:46 -0700555 // in the design document.
Michael Butler033b8a62017-09-22 18:21:59 -0700556 sp<ExecutionCallback> executionCallback = new ExecutionCallback();
Michael Butler689d8922017-09-01 10:58:46 -0700557
Miao Wang820215d2017-10-04 19:45:45 -0700558 VLOG(EXECUTION) << "Before mPreparedModel->execute() " << toString(request);
David Gross3ced3cf2017-09-13 10:45:21 -0700559 // Execute.
Michael Butler033b8a62017-09-22 18:21:59 -0700560 // TODO: What happens to the Callback if the service dies abnormally
561 // -- won't that keep the Callback live forever, because the service
Michael Butler689d8922017-09-01 10:58:46 -0700562 // never has the opportunity to bump the reference count down? Or
563 // maybe the HIDL infrastructure handles this magically? At worst,
Michael Butler033b8a62017-09-22 18:21:59 -0700564 // it seems like this is a small memory leak, if the Callback stays
Michael Butler689d8922017-09-01 10:58:46 -0700565 // alive forever.
Michael Butler033b8a62017-09-22 18:21:59 -0700566 if (mPreparedModel->execute(request, executionCallback) != ErrorStatus::NONE) {
Miao Wang820215d2017-10-04 19:45:45 -0700567 VLOG(EXECUTION) << "**Execute failed**";
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700568 return ANEURALNETWORKS_OP_FAILED;
569 }
570
Michael Butler689d8922017-09-01 10:58:46 -0700571 // TODO: Remove this synchronization point when the block of code below is
572 // removed.
Michael Butler033b8a62017-09-22 18:21:59 -0700573 executionCallback->wait();
574 Return<ErrorStatus> executionStatus = executionCallback->getStatus();
575 if (!executionStatus.isOk() || executionStatus != ErrorStatus::NONE) {
Miao Wang820215d2017-10-04 19:45:45 -0700576 VLOG(EXECUTION) << "**Execute async failed**";
Michael Butler689d8922017-09-01 10:58:46 -0700577 return ANEURALNETWORKS_OP_FAILED;
578 }
579
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700580 // Copy the output data from shared memory to the output buffers.
Michael Butler689d8922017-09-01 10:58:46 -0700581 // TODO: Move this block of code somewhere else. It should not be in the
582 // startCompute function.
583 // TODO: outputMemory->update(); outputMemory->commit()
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700584 for (auto& info : mOutputs) {
585 if (info.state == ModelArgumentInfo::POINTER) {
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -0700586 DataLocation& loc = info.locationAndLength;
Jean-Luc Brouillet2150f1d2017-09-01 13:29:08 -0700587 uint8_t* data = nullptr;
David Grosse413eef2017-09-29 11:43:32 -0700588 int n = outputPointerArguments.getPointer(&data);
Jean-Luc Brouillet2150f1d2017-09-01 13:29:08 -0700589 if (n != ANEURALNETWORKS_NO_ERROR) {
590 return n;
591 }
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700592 memcpy(info.buffer, data + loc.offset, loc.length);
593 }
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700594 }
Miao Wang820215d2017-10-04 19:45:45 -0700595 VLOG(EXECUTION) << "StepExecutor::startComputeOnDevice completed";
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700596
Michael Butler033b8a62017-09-22 18:21:59 -0700597 *synchronizationCallback = executionCallback;
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700598 return ANEURALNETWORKS_NO_ERROR;
599}
600
Michael Butler689d8922017-09-01 10:58:46 -0700601static void asyncStartComputeOnCpu(const Model& model, const Request& request,
Jean-Luc Brouillet1da8fed2017-10-11 22:34:04 -0700602 const std::vector<RunTimePoolInfo>& modelPoolInfos,
603 const std::vector<RunTimePoolInfo>& requestPoolInfos,
Michael Butler033b8a62017-09-22 18:21:59 -0700604 const sp<IExecutionCallback>& executionCallback) {
Michael Butler689d8922017-09-01 10:58:46 -0700605 CpuExecutor executor;
Jean-Luc Brouillet1da8fed2017-10-11 22:34:04 -0700606 int err = executor.run(model, request, modelPoolInfos, requestPoolInfos);
Michael Butler5f916fc2017-09-11 21:10:36 -0700607 ErrorStatus status = err == ANEURALNETWORKS_NO_ERROR ?
608 ErrorStatus::NONE : ErrorStatus::GENERAL_FAILURE;
Michael Butler033b8a62017-09-22 18:21:59 -0700609 executionCallback->notify(status);
Michael Butler689d8922017-09-01 10:58:46 -0700610}
611
Michael Butler033b8a62017-09-22 18:21:59 -0700612int StepExecutor::startComputeOnCpu(sp<ExecutionCallback>* synchronizationCallback) {
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700613 // TODO: use a thread pool
Michael Butler689d8922017-09-01 10:58:46 -0700614
David Gross1f438152017-09-28 09:18:51 -0700615 Model model;
616 mModel->setHidlModel(&model);
617
Michael Butler033b8a62017-09-22 18:21:59 -0700618 // Prepare the callback for asynchronous execution. sp<ExecutionCallback>
619 // object is returned when the execution has been successfully launched,
620 // otherwise a nullptr is returned. The executionCallback is abstracted in
621 // the NN API as an "event".
622 sp<ExecutionCallback> executionCallback = new ExecutionCallback();
623 *synchronizationCallback = nullptr;
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700624
Jean-Luc Brouillet1da8fed2017-10-11 22:34:04 -0700625 std::vector<RunTimePoolInfo> modelPoolInfos;
626 if (!setRunTimePoolInfosFromHidlMemories(&modelPoolInfos, model.pools)) {
627 return ANEURALNETWORKS_UNMAPPABLE;
628 }
629
630 std::vector<RunTimePoolInfo> requestPoolInfos;
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700631 uint32_t count = mMemories.size();
Jean-Luc Brouillet1da8fed2017-10-11 22:34:04 -0700632 requestPoolInfos.resize(count);
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700633 for (uint32_t i = 0; i < count; i++) {
634 const Memory* mem = mMemories[i];
Jean-Luc Brouillet1da8fed2017-10-11 22:34:04 -0700635 if (!requestPoolInfos[i].set(mem->getHidlMemory())) {
636 return ANEURALNETWORKS_UNMAPPABLE;
637 }
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700638 }
639 // Create as many pools as there are input / output.
Jean-Luc Brouillet1da8fed2017-10-11 22:34:04 -0700640 auto fixPointerArguments = [&requestPoolInfos](std::vector<ModelArgumentInfo>& argumentInfos) {
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700641 for (ModelArgumentInfo& argumentInfo : argumentInfos) {
642 if (argumentInfo.state == ModelArgumentInfo::POINTER) {
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -0700643 RunTimePoolInfo runTimeInfo = {
644 .buffer = static_cast<uint8_t*>(argumentInfo.buffer)};
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -0700645 argumentInfo.locationAndLength.poolIndex =
Jean-Luc Brouillet1da8fed2017-10-11 22:34:04 -0700646 static_cast<uint32_t>(requestPoolInfos.size());
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -0700647 argumentInfo.locationAndLength.offset = 0;
Jean-Luc Brouillet1da8fed2017-10-11 22:34:04 -0700648 requestPoolInfos.push_back(runTimeInfo);
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700649 }
650 }
651 };
652 fixPointerArguments(mInputs);
653 fixPointerArguments(mOutputs);
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700654
Jean-Luc Brouillet8b99bb12017-08-20 18:16:36 -0700655 Request request;
Jean-Luc Brouillet62cc2752017-09-27 19:18:51 -0700656 setRequestArgumentArray(mInputs, &request.inputs);
657 setRequestArgumentArray(mOutputs, &request.outputs);
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700658
Michael Butler689d8922017-09-01 10:58:46 -0700659 // TODO: should model be moved with a std::cref?
660 std::thread thread(asyncStartComputeOnCpu, model, std::move(request),
Jean-Luc Brouillet1da8fed2017-10-11 22:34:04 -0700661 std::move(modelPoolInfos), std::move(requestPoolInfos),
662 executionCallback);
Michael Butler033b8a62017-09-22 18:21:59 -0700663 executionCallback->bind_thread(std::move(thread));
Michael Butler689d8922017-09-01 10:58:46 -0700664
Michael Butler033b8a62017-09-22 18:21:59 -0700665 *synchronizationCallback = executionCallback;
Michael Butler689d8922017-09-01 10:58:46 -0700666 return ANEURALNETWORKS_NO_ERROR;
Jean-Luc Brouillet707dbd22017-07-25 00:17:50 -0700667}
668
Jean-Luc Brouillet389f26c2017-09-02 23:05:37 -0700669} // namespace nn
670} // namespace android