blob: 7139899b79dea86056f824d2496650b4d415c4c6 [file] [log] [blame]
Michael Butler60296322019-01-17 17:54:51 -08001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Michael Butler89e99ba2019-01-24 02:36:37 -080017#define LOG_TAG "ExecutionBurstServer"
18
Michael Butler60296322019-01-17 17:54:51 -080019#include "ExecutionBurstServer.h"
20
21#include <android-base/logging.h>
Michael Butler3260db92019-04-26 17:51:23 -070022
Michael Butler4ef48f12019-05-02 14:09:17 -070023#include <cstring>
Michael Butlerc932ebb2019-04-11 14:24:06 -070024#include <limits>
Michael Butler3260db92019-04-26 17:51:23 -070025#include <map>
26
Michael Butler3db6fe52019-01-29 11:20:30 -080027#include "Tracing.h"
Michael Butler60296322019-01-17 17:54:51 -080028
Michael Butler3db6fe52019-01-29 11:20:30 -080029namespace android::nn {
Michael Butler238fe722019-03-21 12:17:27 -070030namespace {
Michael Butler60296322019-01-17 17:54:51 -080031
Michael Butlerc932ebb2019-04-11 14:24:06 -070032constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
33 std::numeric_limits<uint64_t>::max()};
34
Michael Butler238fe722019-03-21 12:17:27 -070035// DefaultBurstExecutorWithCache adapts an IPreparedModel so that it can be
36// used as an IBurstExecutorWithCache. Specifically, the cache simply stores the
37// hidl_memory object, and the execution forwards calls to the provided
38// IPreparedModel's "executeSynchronously" method. With this class, hidl_memory
39// must be mapped and unmapped for each execution.
40class DefaultBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
41 public:
42 DefaultBurstExecutorWithCache(IPreparedModel* preparedModel) : mpPreparedModel(preparedModel) {}
Michael Butler60296322019-01-17 17:54:51 -080043
Michael Butler238fe722019-03-21 12:17:27 -070044 bool isCacheEntryPresent(int32_t slot) const override {
Michael Butler3260db92019-04-26 17:51:23 -070045 const auto it = mMemoryCache.find(slot);
46 if (it == mMemoryCache.end()) {
47 return false;
48 }
49 return it->second.valid();
Michael Butler238fe722019-03-21 12:17:27 -070050 }
Michael Butler47c988f62019-03-14 17:34:48 -070051
Michael Butler238fe722019-03-21 12:17:27 -070052 void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
Michael Butler238fe722019-03-21 12:17:27 -070053 mMemoryCache[slot] = memory;
54 }
Michael Butler60296322019-01-17 17:54:51 -080055
Michael Butler3260db92019-04-26 17:51:23 -070056 void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
Michael Butler238fe722019-03-21 12:17:27 -070057
58 std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
59 const Request& request, const std::vector<int32_t>& slots,
60 MeasureTiming measure) override {
61 // convert slots to pools
62 hidl_vec<hidl_memory> pools(slots.size());
Michael Butler3260db92019-04-26 17:51:23 -070063 std::transform(slots.begin(), slots.end(), pools.begin(),
64 [this](int32_t slot) { return mMemoryCache[slot]; });
Michael Butler238fe722019-03-21 12:17:27 -070065
66 // create full request
67 Request fullRequest = request;
68 fullRequest.pools = std::move(pools);
69
70 // setup execution
71 ErrorStatus returnedStatus = ErrorStatus::GENERAL_FAILURE;
72 hidl_vec<OutputShape> returnedOutputShapes;
73 Timing returnedTiming;
74 auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](
75 ErrorStatus status, const hidl_vec<OutputShape>& outputShapes,
76 const Timing& timing) {
77 returnedStatus = status;
78 returnedOutputShapes = outputShapes;
79 returnedTiming = timing;
Michael Butler47c988f62019-03-14 17:34:48 -070080 };
Michael Butler60296322019-01-17 17:54:51 -080081
Michael Butler238fe722019-03-21 12:17:27 -070082 // execute
83 const Return<void> ret = mpPreparedModel->executeSynchronously(fullRequest, measure, cb);
84 if (!ret.isOk() || returnedStatus != ErrorStatus::NONE) {
85 LOG(ERROR) << "IPreparedModelAdapter::execute -- Error executing";
86 return {ErrorStatus::GENERAL_FAILURE, {}, {}};
Michael Butler89e99ba2019-01-24 02:36:37 -080087 }
Michael Butler60296322019-01-17 17:54:51 -080088
Michael Butler238fe722019-03-21 12:17:27 -070089 return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming);
Michael Butler60296322019-01-17 17:54:51 -080090 }
91
Michael Butler238fe722019-03-21 12:17:27 -070092 private:
93 IPreparedModel* const mpPreparedModel;
Michael Butler3260db92019-04-26 17:51:23 -070094 std::map<int32_t, hidl_memory> mMemoryCache;
Michael Butler238fe722019-03-21 12:17:27 -070095};
Michael Butler47c988f62019-03-14 17:34:48 -070096
Michael Butler238fe722019-03-21 12:17:27 -070097} // anonymous namespace
Michael Butler60296322019-01-17 17:54:51 -080098
Michael Butler60296322019-01-17 17:54:51 -080099// serialize result
Michael Butlerc932ebb2019-04-11 14:24:06 -0700100std::vector<FmqResultDatum> serialize(ErrorStatus errorStatus,
101 const std::vector<OutputShape>& outputShapes, Timing timing) {
Michael Butler60296322019-01-17 17:54:51 -0800102 // count how many elements need to be sent for a request
103 size_t count = 2 + outputShapes.size();
104 for (const auto& outputShape : outputShapes) {
105 count += outputShape.dimensions.size();
106 }
107
108 // create buffer to temporarily store elements
109 std::vector<FmqResultDatum> data;
110 data.reserve(count);
111
112 // package packetInfo
113 {
114 FmqResultDatum datum;
115 datum.packetInformation({/*.packetSize=*/static_cast<uint32_t>(count),
116 /*.errorStatus=*/errorStatus,
117 /*.numberOfOperands=*/static_cast<uint32_t>(outputShapes.size())});
118 data.push_back(datum);
119 }
120
121 // package output shape data
122 for (const auto& operand : outputShapes) {
123 // package operand information
124 FmqResultDatum datum;
125 datum.operandInformation(
126 {/*.isSufficient=*/operand.isSufficient,
127 /*.numberOfDimensions=*/static_cast<uint32_t>(operand.dimensions.size())});
128 data.push_back(datum);
129
130 // package operand dimensions
131 for (uint32_t dimension : operand.dimensions) {
132 FmqResultDatum datum;
133 datum.operandDimensionValue(dimension);
134 data.push_back(datum);
135 }
136 }
137
138 // package executionTiming
139 {
140 FmqResultDatum datum;
141 datum.executionTiming(timing);
142 data.push_back(datum);
143 }
144
145 // return result
146 return data;
147}
148
Michael Butlerc932ebb2019-04-11 14:24:06 -0700149// deserialize request
150std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> deserialize(
151 const std::vector<FmqRequestDatum>& data) {
152 using discriminator = FmqRequestDatum::hidl_discriminator;
Michael Butler60296322019-01-17 17:54:51 -0800153
Michael Butlerc932ebb2019-04-11 14:24:06 -0700154 size_t index = 0;
155
156 // validate packet information
Michael Butler3260db92019-04-26 17:51:23 -0700157 if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
Michael Butlerc932ebb2019-04-11 14:24:06 -0700158 LOG(ERROR) << "FMQ Request packet ill-formed";
159 return std::nullopt;
160 }
161
162 // unpackage packet information
163 const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
164 index++;
165 const uint32_t packetSize = packetInfo.packetSize;
166 const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
167 const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
168 const uint32_t numberOfPools = packetInfo.numberOfPools;
169
Michael Butler3260db92019-04-26 17:51:23 -0700170 // verify packet size
171 if (data.size() != packetSize) {
172 LOG(ERROR) << "FMQ Request packet ill-formed";
173 return std::nullopt;
174 }
175
Michael Butlerc932ebb2019-04-11 14:24:06 -0700176 // unpackage input operands
177 std::vector<RequestArgument> inputs;
178 inputs.reserve(numberOfInputOperands);
179 for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
180 // validate input operand information
181 if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
182 LOG(ERROR) << "FMQ Request packet ill-formed";
183 return std::nullopt;
Michael Butler60296322019-01-17 17:54:51 -0800184 }
185
Michael Butlerc932ebb2019-04-11 14:24:06 -0700186 // unpackage operand information
187 const FmqRequestDatum::OperandInformation& operandInfo =
188 data[index].inputOperandInformation();
189 index++;
190 const bool hasNoValue = operandInfo.hasNoValue;
191 const DataLocation location = operandInfo.location;
192 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
Michael Butler3db6fe52019-01-29 11:20:30 -0800193
Michael Butlerc932ebb2019-04-11 14:24:06 -0700194 // unpackage operand dimensions
195 std::vector<uint32_t> dimensions;
196 dimensions.reserve(numberOfDimensions);
197 for (size_t i = 0; i < numberOfDimensions; ++i) {
198 // validate dimension
199 if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
200 LOG(ERROR) << "FMQ Request packet ill-formed";
201 return std::nullopt;
202 }
203
204 // unpackage dimension
205 const uint32_t dimension = data[index].inputOperandDimensionValue();
206 index++;
207
208 // store result
209 dimensions.push_back(dimension);
210 }
211
212 // store result
213 inputs.push_back(
214 {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
215 }
216
217 // unpackage output operands
218 std::vector<RequestArgument> outputs;
219 outputs.reserve(numberOfOutputOperands);
220 for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
221 // validate output operand information
222 if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
223 LOG(ERROR) << "FMQ Request packet ill-formed";
224 return std::nullopt;
225 }
226
227 // unpackage operand information
228 const FmqRequestDatum::OperandInformation& operandInfo =
229 data[index].outputOperandInformation();
230 index++;
231 const bool hasNoValue = operandInfo.hasNoValue;
232 const DataLocation location = operandInfo.location;
233 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
234
235 // unpackage operand dimensions
236 std::vector<uint32_t> dimensions;
237 dimensions.reserve(numberOfDimensions);
238 for (size_t i = 0; i < numberOfDimensions; ++i) {
239 // validate dimension
240 if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
241 LOG(ERROR) << "FMQ Request packet ill-formed";
242 return std::nullopt;
243 }
244
245 // unpackage dimension
246 const uint32_t dimension = data[index].outputOperandDimensionValue();
247 index++;
248
249 // store result
250 dimensions.push_back(dimension);
251 }
252
253 // store result
254 outputs.push_back(
255 {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
256 }
257
258 // unpackage pools
259 std::vector<int32_t> slots;
260 slots.reserve(numberOfPools);
261 for (size_t pool = 0; pool < numberOfPools; ++pool) {
262 // validate input operand information
263 if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
264 LOG(ERROR) << "FMQ Request packet ill-formed";
265 return std::nullopt;
266 }
267
268 // unpackage operand information
269 const int32_t poolId = data[index].poolIdentifier();
270 index++;
271
272 // store result
273 slots.push_back(poolId);
274 }
275
276 // validate measureTiming
277 if (data[index].getDiscriminator() != discriminator::measureTiming) {
278 LOG(ERROR) << "FMQ Request packet ill-formed";
279 return std::nullopt;
280 }
281
282 // unpackage measureTiming
283 const MeasureTiming measure = data[index].measureTiming();
284 index++;
285
286 // validate packet information
287 if (index != packetSize) {
288 LOG(ERROR) << "FMQ Result packet ill-formed";
289 return std::nullopt;
290 }
291
292 // return request
293 Request request = {/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}};
294 return std::make_tuple(std::move(request), std::move(slots), measure);
295}
296
297// RequestChannelReceiver methods
298
299std::unique_ptr<RequestChannelReceiver> RequestChannelReceiver::create(
300 const FmqRequestDescriptor& requestChannel) {
301 std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
302 std::make_unique<FmqRequestChannel>(requestChannel);
303 if (!fmqRequestChannel->isValid()) {
304 LOG(ERROR) << "Unable to create RequestChannelReceiver";
305 return nullptr;
306 }
307 const bool blocking = fmqRequestChannel->getEventFlagWord() != nullptr;
308 return std::make_unique<RequestChannelReceiver>(std::move(fmqRequestChannel), blocking);
309}
310
311RequestChannelReceiver::RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
312 bool blocking)
313 : mFmqRequestChannel(std::move(fmqRequestChannel)), mBlocking(blocking) {}
314
315std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>>
316RequestChannelReceiver::getBlocking() {
317 const auto packet = getPacketBlocking();
318 if (!packet) {
319 return std::nullopt;
320 }
321
322 return deserialize(*packet);
323}
324
325void RequestChannelReceiver::invalidate() {
326 mTeardown = true;
327
328 // force unblock
329 // ExecutionBurstServer is by default waiting on a request packet. If the
330 // client process destroys its burst object, the server will still be
331 // waiting on the futex (assuming mBlocking is true). This force unblock
332 // wakes up any thread waiting on the futex.
333 if (mBlocking) {
334 // TODO: look for a different/better way to signal/notify the futex to
335 // wake up any thread waiting on it
336 FmqRequestDatum datum;
337 datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
338 /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
339 mFmqRequestChannel->writeBlocking(&datum, 1);
340 }
341}
342
343std::optional<std::vector<FmqRequestDatum>> RequestChannelReceiver::getPacketBlocking() {
344 using discriminator = FmqRequestDatum::hidl_discriminator;
345
346 if (mTeardown) {
347 return std::nullopt;
348 }
349
350 // wait for request packet and read first element of request packet
Michael Butlerc932ebb2019-04-11 14:24:06 -0700351 FmqRequestDatum datum;
352 bool success = false;
353 if (mBlocking) {
354 success = mFmqRequestChannel->readBlocking(&datum, 1);
355 } else {
356 while ((success = !mTeardown.load(std::memory_order_relaxed)) &&
357 !mFmqRequestChannel->read(&datum, 1)) {
358 }
359 }
360
Michael Butlerc932ebb2019-04-11 14:24:06 -0700361 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
362
Michael Butlerc932ebb2019-04-11 14:24:06 -0700363 // retrieve remaining elements
364 // NOTE: all of the data is already available at this point, so there's no
365 // need to do a blocking wait to wait for more data. This is known because
366 // in FMQ, all writes are published (made available) atomically. Currently,
367 // the producer always publishes the entire packet in one function call, so
368 // if the first element of the packet is available, the remaining elements
369 // are also available.
Michael Butler3260db92019-04-26 17:51:23 -0700370 const size_t count = mFmqRequestChannel->availableToRead();
371 std::vector<FmqRequestDatum> packet(count + 1);
Michael Butler4ef48f12019-05-02 14:09:17 -0700372 std::memcpy(&packet.front(), &datum, sizeof(datum));
Michael Butler3260db92019-04-26 17:51:23 -0700373 success &= mFmqRequestChannel->read(packet.data() + 1, count);
Michael Butlerc932ebb2019-04-11 14:24:06 -0700374
Michael Butler3260db92019-04-26 17:51:23 -0700375 // terminate loop
376 if (mTeardown) {
377 return std::nullopt;
378 }
379
380 // ensure packet was successfully received
Michael Butlerc932ebb2019-04-11 14:24:06 -0700381 if (!success) {
Michael Butler3260db92019-04-26 17:51:23 -0700382 LOG(ERROR) << "Error receiving packet";
383 return std::nullopt;
Michael Butlerc932ebb2019-04-11 14:24:06 -0700384 }
385
Michael Butler4ef48f12019-05-02 14:09:17 -0700386 return std::make_optional(std::move(packet));
Michael Butlerc932ebb2019-04-11 14:24:06 -0700387}
388
389// ResultChannelSender methods
390
391std::unique_ptr<ResultChannelSender> ResultChannelSender::create(
392 const FmqResultDescriptor& resultChannel) {
393 std::unique_ptr<FmqResultChannel> fmqResultChannel =
394 std::make_unique<FmqResultChannel>(resultChannel);
395 if (!fmqResultChannel->isValid()) {
396 LOG(ERROR) << "Unable to create RequestChannelSender";
397 return nullptr;
398 }
399 const bool blocking = fmqResultChannel->getEventFlagWord() != nullptr;
400 return std::make_unique<ResultChannelSender>(std::move(fmqResultChannel), blocking);
401}
402
403ResultChannelSender::ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel,
404 bool blocking)
405 : mFmqResultChannel(std::move(fmqResultChannel)), mBlocking(blocking) {}
406
407bool ResultChannelSender::send(ErrorStatus errorStatus,
408 const std::vector<OutputShape>& outputShapes, Timing timing) {
409 const std::vector<FmqResultDatum> serialized = serialize(errorStatus, outputShapes, timing);
410 return sendPacket(serialized);
411}
412
413bool ResultChannelSender::sendPacket(const std::vector<FmqResultDatum>& packet) {
Michael Butler3260db92019-04-26 17:51:23 -0700414 if (packet.size() > mFmqResultChannel->availableToWrite()) {
415 LOG(ERROR)
416 << "ResultChannelSender::sendPacket -- packet size exceeds size available in FMQ";
417 const std::vector<FmqResultDatum> errorPacket =
418 serialize(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
419 return mFmqResultChannel->writeBlocking(errorPacket.data(), errorPacket.size());
420 }
421
Michael Butlerc932ebb2019-04-11 14:24:06 -0700422 if (mBlocking) {
423 return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
424 } else {
425 return mFmqResultChannel->write(packet.data(), packet.size());
426 }
427}
428
429// ExecutionBurstServer methods
430
431sp<ExecutionBurstServer> ExecutionBurstServer::create(
432 const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
433 const MQDescriptorSync<FmqResultDatum>& resultChannel,
434 std::shared_ptr<IBurstExecutorWithCache> executorWithCache) {
435 // check inputs
436 if (callback == nullptr || executorWithCache == nullptr) {
437 LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
438 return nullptr;
439 }
440
441 // create FMQ objects
442 std::unique_ptr<RequestChannelReceiver> requestChannelReceiver =
443 RequestChannelReceiver::create(requestChannel);
444 std::unique_ptr<ResultChannelSender> resultChannelSender =
445 ResultChannelSender::create(resultChannel);
446
447 // check FMQ objects
448 if (!requestChannelReceiver || !resultChannelSender) {
449 LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
450 return nullptr;
451 }
452
453 // make and return context
454 return new ExecutionBurstServer(callback, std::move(requestChannelReceiver),
455 std::move(resultChannelSender), std::move(executorWithCache));
456}
457
458sp<ExecutionBurstServer> ExecutionBurstServer::create(
459 const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
460 const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
461 // check relevant input
462 if (preparedModel == nullptr) {
463 LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
464 return nullptr;
465 }
466
467 // adapt IPreparedModel to have caching
468 const std::shared_ptr<DefaultBurstExecutorWithCache> preparedModelAdapter =
469 std::make_shared<DefaultBurstExecutorWithCache>(preparedModel);
470
471 // make and return context
472 return ExecutionBurstServer::create(callback, requestChannel, resultChannel,
473 preparedModelAdapter);
474}
475
476ExecutionBurstServer::ExecutionBurstServer(
477 const sp<IBurstCallback>& callback, std::unique_ptr<RequestChannelReceiver> requestChannel,
478 std::unique_ptr<ResultChannelSender> resultChannel,
479 std::shared_ptr<IBurstExecutorWithCache> executorWithCache)
480 : mCallback(callback),
481 mRequestChannelReceiver(std::move(requestChannel)),
482 mResultChannelSender(std::move(resultChannel)),
483 mExecutorWithCache(std::move(executorWithCache)) {
484 // TODO: highly document the threading behavior of this class
485 mWorker = std::thread([this] { task(); });
486}
487
488ExecutionBurstServer::~ExecutionBurstServer() {
489 // set teardown flag
490 mTeardown = true;
491 mRequestChannelReceiver->invalidate();
492
493 // wait for task thread to end
494 mWorker.join();
495}
496
497Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
498 mExecutorWithCache->removeCacheEntry(slot);
499 return Void();
500}
501
502void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots) {
503 const auto slotIsKnown = [this](int32_t slot) {
504 return mExecutorWithCache->isCacheEntryPresent(slot);
505 };
506
507 // find unique unknown slots
508 std::vector<int32_t> unknownSlots = slots;
509 auto unknownSlotsEnd = unknownSlots.end();
510 std::sort(unknownSlots.begin(), unknownSlotsEnd);
511 unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
512 unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
513 unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
514
515 // quick-exit if all slots are known
516 if (unknownSlots.empty()) {
517 return;
518 }
519
520 ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
521 std::vector<hidl_memory> returnedMemories;
522 auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
523 const hidl_vec<hidl_memory>& memories) {
524 errorStatus = status;
525 returnedMemories = memories;
526 };
527
528 const Return<void> ret = mCallback->getMemories(unknownSlots, cb);
529
530 if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
531 returnedMemories.size() != unknownSlots.size()) {
532 LOG(ERROR) << "Error retrieving memories";
533 return;
534 }
535
536 // add memories to unknown slots
537 for (size_t i = 0; i < unknownSlots.size(); ++i) {
538 mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]);
539 }
540}
541
542void ExecutionBurstServer::task() {
543 // loop until the burst object is being destroyed
544 while (!mTeardown) {
545 // receive request
546 auto arguments = mRequestChannelReceiver->getBlocking();
547
548 // if the request packet was not properly received, return a generic
549 // error and skip the execution
550 //
551 // if the burst is being torn down, skip the execution exection so the
552 // "task" function can end
553 if (!arguments) {
554 if (!mTeardown) {
555 mResultChannelSender->send(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
556 }
557 continue;
558 }
559
560 // otherwise begin tracing execution
561 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
562 "ExecutionBurstServer getting memory, executing, and returning results");
563
564 // unpack the arguments; types are Request, std::vector<int32_t>, and
Michael Butler238fe722019-03-21 12:17:27 -0700565 // MeasureTiming, respectively
Michael Butlerc932ebb2019-04-11 14:24:06 -0700566 const auto [requestWithoutPools, slotsOfPools, measure] = std::move(*arguments);
Michael Butler60296322019-01-17 17:54:51 -0800567
Michael Butler238fe722019-03-21 12:17:27 -0700568 // ensure executor with cache has required memory
569 std::lock_guard<std::mutex> hold(mMutex);
570 ensureCacheEntriesArePresentLocked(slotsOfPools);
571
572 // perform computation; types are ErrorStatus, hidl_vec<OutputShape>,
573 // and Timing, respectively
574 const auto [errorStatus, outputShapes, returnedTiming] =
575 mExecutorWithCache->execute(requestWithoutPools, slotsOfPools, measure);
Michael Butler60296322019-01-17 17:54:51 -0800576
577 // return result
Michael Butlerc932ebb2019-04-11 14:24:06 -0700578 mResultChannelSender->send(errorStatus, outputShapes, returnedTiming);
Michael Butler60296322019-01-17 17:54:51 -0800579 }
580}
581
Michael Butler3db6fe52019-01-29 11:20:30 -0800582} // namespace android::nn