blob: 0bcb57dbeab5b28b14e87d6715ac663e2dfab2fc [file] [log] [blame]
Michael Butler60296322019-01-17 17:54:51 -08001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Michael Butler89e99ba2019-01-24 02:36:37 -080017#define LOG_TAG "ExecutionBurstServer"
18
Michael Butler60296322019-01-17 17:54:51 -080019#include "ExecutionBurstServer.h"
20
21#include <android-base/logging.h>
Michael Butler3260db92019-04-26 17:51:23 -070022
Michael Butlerc82044a2019-06-24 10:36:20 -070023#include <algorithm>
Michael Butler4ef48f12019-05-02 14:09:17 -070024#include <cstring>
Michael Butlerc932ebb2019-04-11 14:24:06 -070025#include <limits>
Michael Butler3260db92019-04-26 17:51:23 -070026#include <map>
Michael Butlerc82044a2019-06-24 10:36:20 -070027#include <memory>
28#include <tuple>
29#include <utility>
30#include <vector>
Michael Butler3260db92019-04-26 17:51:23 -070031
Michael Butler3db6fe52019-01-29 11:20:30 -080032#include "Tracing.h"
Michael Butler60296322019-01-17 17:54:51 -080033
Michael Butler3db6fe52019-01-29 11:20:30 -080034namespace android::nn {
Michael Butler238fe722019-03-21 12:17:27 -070035namespace {
Michael Butler60296322019-01-17 17:54:51 -080036
Michael Butler19af9d22019-07-11 11:45:01 -070037using namespace hal;
38
Michael Butlerc82044a2019-06-24 10:36:20 -070039using hardware::MQDescriptorSync;
40
Michael Butlerc932ebb2019-04-11 14:24:06 -070041constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
42 std::numeric_limits<uint64_t>::max()};
43
Michael Butler238fe722019-03-21 12:17:27 -070044// DefaultBurstExecutorWithCache adapts an IPreparedModel so that it can be
45// used as an IBurstExecutorWithCache. Specifically, the cache simply stores the
46// hidl_memory object, and the execution forwards calls to the provided
47// IPreparedModel's "executeSynchronously" method. With this class, hidl_memory
48// must be mapped and unmapped for each execution.
49class DefaultBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
50 public:
Xusong Wang196a1872019-10-25 12:06:20 -070051 DefaultBurstExecutorWithCache(V1_2::IPreparedModel* preparedModel)
52 : mpPreparedModel(preparedModel) {}
Michael Butler60296322019-01-17 17:54:51 -080053
Michael Butler238fe722019-03-21 12:17:27 -070054 bool isCacheEntryPresent(int32_t slot) const override {
Michael Butler3260db92019-04-26 17:51:23 -070055 const auto it = mMemoryCache.find(slot);
Michael Butler1ee58a52019-04-30 13:49:32 -070056 return (it != mMemoryCache.end()) && it->second.valid();
Michael Butler238fe722019-03-21 12:17:27 -070057 }
Michael Butler47c988f62019-03-14 17:34:48 -070058
Michael Butler238fe722019-03-21 12:17:27 -070059 void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
Michael Butler238fe722019-03-21 12:17:27 -070060 mMemoryCache[slot] = memory;
61 }
Michael Butler60296322019-01-17 17:54:51 -080062
Michael Butler3260db92019-04-26 17:51:23 -070063 void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
Michael Butler238fe722019-03-21 12:17:27 -070064
65 std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
66 const Request& request, const std::vector<int32_t>& slots,
67 MeasureTiming measure) override {
68 // convert slots to pools
69 hidl_vec<hidl_memory> pools(slots.size());
Michael Butler3260db92019-04-26 17:51:23 -070070 std::transform(slots.begin(), slots.end(), pools.begin(),
71 [this](int32_t slot) { return mMemoryCache[slot]; });
Michael Butler238fe722019-03-21 12:17:27 -070072
73 // create full request
74 Request fullRequest = request;
75 fullRequest.pools = std::move(pools);
76
77 // setup execution
78 ErrorStatus returnedStatus = ErrorStatus::GENERAL_FAILURE;
79 hidl_vec<OutputShape> returnedOutputShapes;
80 Timing returnedTiming;
81 auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](
82 ErrorStatus status, const hidl_vec<OutputShape>& outputShapes,
83 const Timing& timing) {
84 returnedStatus = status;
85 returnedOutputShapes = outputShapes;
86 returnedTiming = timing;
Michael Butler47c988f62019-03-14 17:34:48 -070087 };
Michael Butler60296322019-01-17 17:54:51 -080088
Michael Butler238fe722019-03-21 12:17:27 -070089 // execute
90 const Return<void> ret = mpPreparedModel->executeSynchronously(fullRequest, measure, cb);
91 if (!ret.isOk() || returnedStatus != ErrorStatus::NONE) {
92 LOG(ERROR) << "IPreparedModelAdapter::execute -- Error executing";
Raksit Ashokc1079232019-05-29 12:55:16 -070093 return {returnedStatus, {}, kNoTiming};
Michael Butler89e99ba2019-01-24 02:36:37 -080094 }
Michael Butler60296322019-01-17 17:54:51 -080095
Michael Butler238fe722019-03-21 12:17:27 -070096 return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming);
Michael Butler60296322019-01-17 17:54:51 -080097 }
98
Michael Butler238fe722019-03-21 12:17:27 -070099 private:
Xusong Wang196a1872019-10-25 12:06:20 -0700100 V1_2::IPreparedModel* const mpPreparedModel;
Michael Butler3260db92019-04-26 17:51:23 -0700101 std::map<int32_t, hidl_memory> mMemoryCache;
Michael Butler238fe722019-03-21 12:17:27 -0700102};
Michael Butler47c988f62019-03-14 17:34:48 -0700103
Michael Butler238fe722019-03-21 12:17:27 -0700104} // anonymous namespace
Michael Butler60296322019-01-17 17:54:51 -0800105
Michael Butler60296322019-01-17 17:54:51 -0800106// serialize result
Michael Butlerc932ebb2019-04-11 14:24:06 -0700107std::vector<FmqResultDatum> serialize(ErrorStatus errorStatus,
108 const std::vector<OutputShape>& outputShapes, Timing timing) {
Michael Butler60296322019-01-17 17:54:51 -0800109 // count how many elements need to be sent for a request
110 size_t count = 2 + outputShapes.size();
111 for (const auto& outputShape : outputShapes) {
112 count += outputShape.dimensions.size();
113 }
114
115 // create buffer to temporarily store elements
116 std::vector<FmqResultDatum> data;
117 data.reserve(count);
118
119 // package packetInfo
120 {
121 FmqResultDatum datum;
122 datum.packetInformation({/*.packetSize=*/static_cast<uint32_t>(count),
123 /*.errorStatus=*/errorStatus,
124 /*.numberOfOperands=*/static_cast<uint32_t>(outputShapes.size())});
125 data.push_back(datum);
126 }
127
128 // package output shape data
129 for (const auto& operand : outputShapes) {
130 // package operand information
Steven Moreland393ac6d2019-04-25 15:33:25 -0700131 FmqResultDatum::OperandInformation info{};
132 info.isSufficient = operand.isSufficient;
133 info.numberOfDimensions = static_cast<uint32_t>(operand.dimensions.size());
134
Michael Butler60296322019-01-17 17:54:51 -0800135 FmqResultDatum datum;
Steven Moreland393ac6d2019-04-25 15:33:25 -0700136 datum.operandInformation(info);
Michael Butler60296322019-01-17 17:54:51 -0800137 data.push_back(datum);
138
139 // package operand dimensions
140 for (uint32_t dimension : operand.dimensions) {
141 FmqResultDatum datum;
142 datum.operandDimensionValue(dimension);
143 data.push_back(datum);
144 }
145 }
146
147 // package executionTiming
148 {
149 FmqResultDatum datum;
150 datum.executionTiming(timing);
151 data.push_back(datum);
152 }
153
154 // return result
155 return data;
156}
157
Michael Butlerc932ebb2019-04-11 14:24:06 -0700158// deserialize request
159std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> deserialize(
160 const std::vector<FmqRequestDatum>& data) {
161 using discriminator = FmqRequestDatum::hidl_discriminator;
Michael Butler60296322019-01-17 17:54:51 -0800162
Michael Butlerc932ebb2019-04-11 14:24:06 -0700163 size_t index = 0;
164
165 // validate packet information
Michael Butler3260db92019-04-26 17:51:23 -0700166 if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
Michael Butlerc932ebb2019-04-11 14:24:06 -0700167 LOG(ERROR) << "FMQ Request packet ill-formed";
168 return std::nullopt;
169 }
170
171 // unpackage packet information
172 const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
173 index++;
174 const uint32_t packetSize = packetInfo.packetSize;
175 const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
176 const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
177 const uint32_t numberOfPools = packetInfo.numberOfPools;
178
Michael Butler3260db92019-04-26 17:51:23 -0700179 // verify packet size
180 if (data.size() != packetSize) {
181 LOG(ERROR) << "FMQ Request packet ill-formed";
182 return std::nullopt;
183 }
184
Michael Butlerc932ebb2019-04-11 14:24:06 -0700185 // unpackage input operands
186 std::vector<RequestArgument> inputs;
187 inputs.reserve(numberOfInputOperands);
188 for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
189 // validate input operand information
190 if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
191 LOG(ERROR) << "FMQ Request packet ill-formed";
192 return std::nullopt;
Michael Butler60296322019-01-17 17:54:51 -0800193 }
194
Michael Butlerc932ebb2019-04-11 14:24:06 -0700195 // unpackage operand information
196 const FmqRequestDatum::OperandInformation& operandInfo =
197 data[index].inputOperandInformation();
198 index++;
199 const bool hasNoValue = operandInfo.hasNoValue;
200 const DataLocation location = operandInfo.location;
201 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
Michael Butler3db6fe52019-01-29 11:20:30 -0800202
Michael Butlerc932ebb2019-04-11 14:24:06 -0700203 // unpackage operand dimensions
204 std::vector<uint32_t> dimensions;
205 dimensions.reserve(numberOfDimensions);
206 for (size_t i = 0; i < numberOfDimensions; ++i) {
207 // validate dimension
208 if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
209 LOG(ERROR) << "FMQ Request packet ill-formed";
210 return std::nullopt;
211 }
212
213 // unpackage dimension
214 const uint32_t dimension = data[index].inputOperandDimensionValue();
215 index++;
216
217 // store result
218 dimensions.push_back(dimension);
219 }
220
221 // store result
222 inputs.push_back(
223 {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
224 }
225
226 // unpackage output operands
227 std::vector<RequestArgument> outputs;
228 outputs.reserve(numberOfOutputOperands);
229 for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
230 // validate output operand information
231 if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
232 LOG(ERROR) << "FMQ Request packet ill-formed";
233 return std::nullopt;
234 }
235
236 // unpackage operand information
237 const FmqRequestDatum::OperandInformation& operandInfo =
238 data[index].outputOperandInformation();
239 index++;
240 const bool hasNoValue = operandInfo.hasNoValue;
241 const DataLocation location = operandInfo.location;
242 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
243
244 // unpackage operand dimensions
245 std::vector<uint32_t> dimensions;
246 dimensions.reserve(numberOfDimensions);
247 for (size_t i = 0; i < numberOfDimensions; ++i) {
248 // validate dimension
249 if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
250 LOG(ERROR) << "FMQ Request packet ill-formed";
251 return std::nullopt;
252 }
253
254 // unpackage dimension
255 const uint32_t dimension = data[index].outputOperandDimensionValue();
256 index++;
257
258 // store result
259 dimensions.push_back(dimension);
260 }
261
262 // store result
263 outputs.push_back(
264 {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
265 }
266
267 // unpackage pools
268 std::vector<int32_t> slots;
269 slots.reserve(numberOfPools);
270 for (size_t pool = 0; pool < numberOfPools; ++pool) {
271 // validate input operand information
272 if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
273 LOG(ERROR) << "FMQ Request packet ill-formed";
274 return std::nullopt;
275 }
276
277 // unpackage operand information
278 const int32_t poolId = data[index].poolIdentifier();
279 index++;
280
281 // store result
282 slots.push_back(poolId);
283 }
284
285 // validate measureTiming
286 if (data[index].getDiscriminator() != discriminator::measureTiming) {
287 LOG(ERROR) << "FMQ Request packet ill-formed";
288 return std::nullopt;
289 }
290
291 // unpackage measureTiming
292 const MeasureTiming measure = data[index].measureTiming();
293 index++;
294
295 // validate packet information
296 if (index != packetSize) {
297 LOG(ERROR) << "FMQ Result packet ill-formed";
298 return std::nullopt;
299 }
300
301 // return request
302 Request request = {/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}};
303 return std::make_tuple(std::move(request), std::move(slots), measure);
304}
305
306// RequestChannelReceiver methods
307
308std::unique_ptr<RequestChannelReceiver> RequestChannelReceiver::create(
Michael Butlerc82044a2019-06-24 10:36:20 -0700309 const FmqRequestDescriptor& requestChannel, std::chrono::microseconds pollingTimeWindow) {
Michael Butlerc932ebb2019-04-11 14:24:06 -0700310 std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
311 std::make_unique<FmqRequestChannel>(requestChannel);
Michael Butlerc82044a2019-06-24 10:36:20 -0700312
Michael Butlerc932ebb2019-04-11 14:24:06 -0700313 if (!fmqRequestChannel->isValid()) {
314 LOG(ERROR) << "Unable to create RequestChannelReceiver";
315 return nullptr;
316 }
Michael Butlerc82044a2019-06-24 10:36:20 -0700317 if (fmqRequestChannel->getEventFlagWord() == nullptr) {
318 LOG(ERROR)
319 << "RequestChannelReceiver::create was passed an MQDescriptor without an EventFlag";
320 return nullptr;
321 }
322
323 return std::make_unique<RequestChannelReceiver>(std::move(fmqRequestChannel),
324 pollingTimeWindow);
Michael Butlerc932ebb2019-04-11 14:24:06 -0700325}
326
327RequestChannelReceiver::RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
Michael Butlerc82044a2019-06-24 10:36:20 -0700328 std::chrono::microseconds pollingTimeWindow)
329 : mFmqRequestChannel(std::move(fmqRequestChannel)), kPollingTimeWindow(pollingTimeWindow) {}
Michael Butlerc932ebb2019-04-11 14:24:06 -0700330
331std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>>
332RequestChannelReceiver::getBlocking() {
333 const auto packet = getPacketBlocking();
334 if (!packet) {
335 return std::nullopt;
336 }
337
338 return deserialize(*packet);
339}
340
341void RequestChannelReceiver::invalidate() {
342 mTeardown = true;
343
344 // force unblock
345 // ExecutionBurstServer is by default waiting on a request packet. If the
Michael Butlerc82044a2019-06-24 10:36:20 -0700346 // client process destroys its burst object, the server may still be waiting
347 // on the futex. This force unblock wakes up any thread waiting on the
348 // futex.
349 // TODO: look for a different/better way to signal/notify the futex to wake
350 // up any thread waiting on it
351 FmqRequestDatum datum;
352 datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
353 /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
354 mFmqRequestChannel->writeBlocking(&datum, 1);
Michael Butlerc932ebb2019-04-11 14:24:06 -0700355}
356
357std::optional<std::vector<FmqRequestDatum>> RequestChannelReceiver::getPacketBlocking() {
358 using discriminator = FmqRequestDatum::hidl_discriminator;
359
360 if (mTeardown) {
361 return std::nullopt;
362 }
363
Michael Butlerc82044a2019-06-24 10:36:20 -0700364 // First spend time polling if results are available in FMQ instead of
365 // waiting on the futex. Polling is more responsive (yielding lower
366 // latencies), but can take up more power, so only poll for a limited period
367 // of time.
368
369 auto& getCurrentTime = std::chrono::high_resolution_clock::now;
370 const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
371
372 while (getCurrentTime() < timeToStopPolling) {
373 // if class is being torn down, immediately return
374 if (mTeardown.load(std::memory_order_relaxed)) {
375 return std::nullopt;
376 }
377
378 // Check if data is available. If it is, immediately retrieve it and
379 // return.
380 const size_t available = mFmqRequestChannel->availableToRead();
381 if (available > 0) {
382 // This is the first point when we know an execution is occurring,
383 // so begin to collect systraces. Note that a similar systrace does
384 // not exist at the corresponding point in
385 // ResultChannelReceiver::getPacketBlocking because the execution is
386 // already in flight.
387 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
388 "ExecutionBurstServer getting packet");
389 std::vector<FmqRequestDatum> packet(available);
390 const bool success = mFmqRequestChannel->read(packet.data(), available);
391 if (!success) {
392 LOG(ERROR) << "Error receiving packet";
393 return std::nullopt;
394 }
395 return std::make_optional(std::move(packet));
Michael Butlerc932ebb2019-04-11 14:24:06 -0700396 }
397 }
398
Michael Butlerc82044a2019-06-24 10:36:20 -0700399 // If we get to this point, we either stopped polling because it was taking
400 // too long or polling was not allowed. Instead, perform a blocking call
401 // which uses a futex to save power.
402
403 // wait for request packet and read first element of request packet
404 FmqRequestDatum datum;
405 bool success = mFmqRequestChannel->readBlocking(&datum, 1);
406
407 // This is the first point when we know an execution is occurring, so begin
408 // to collect systraces. Note that a similar systrace does not exist at the
409 // corresponding point in ResultChannelReceiver::getPacketBlocking because
410 // the execution is already in flight.
Michael Butlerc932ebb2019-04-11 14:24:06 -0700411 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
412
Michael Butlerc932ebb2019-04-11 14:24:06 -0700413 // retrieve remaining elements
414 // NOTE: all of the data is already available at this point, so there's no
415 // need to do a blocking wait to wait for more data. This is known because
416 // in FMQ, all writes are published (made available) atomically. Currently,
417 // the producer always publishes the entire packet in one function call, so
418 // if the first element of the packet is available, the remaining elements
419 // are also available.
Michael Butler3260db92019-04-26 17:51:23 -0700420 const size_t count = mFmqRequestChannel->availableToRead();
421 std::vector<FmqRequestDatum> packet(count + 1);
Michael Butler4ef48f12019-05-02 14:09:17 -0700422 std::memcpy(&packet.front(), &datum, sizeof(datum));
Michael Butler3260db92019-04-26 17:51:23 -0700423 success &= mFmqRequestChannel->read(packet.data() + 1, count);
Michael Butlerc932ebb2019-04-11 14:24:06 -0700424
Michael Butler3260db92019-04-26 17:51:23 -0700425 // terminate loop
426 if (mTeardown) {
427 return std::nullopt;
428 }
429
430 // ensure packet was successfully received
Michael Butlerc932ebb2019-04-11 14:24:06 -0700431 if (!success) {
Michael Butler3260db92019-04-26 17:51:23 -0700432 LOG(ERROR) << "Error receiving packet";
433 return std::nullopt;
Michael Butlerc932ebb2019-04-11 14:24:06 -0700434 }
435
Michael Butler4ef48f12019-05-02 14:09:17 -0700436 return std::make_optional(std::move(packet));
Michael Butlerc932ebb2019-04-11 14:24:06 -0700437}
438
439// ResultChannelSender methods
440
441std::unique_ptr<ResultChannelSender> ResultChannelSender::create(
442 const FmqResultDescriptor& resultChannel) {
443 std::unique_ptr<FmqResultChannel> fmqResultChannel =
444 std::make_unique<FmqResultChannel>(resultChannel);
Michael Butlerc82044a2019-06-24 10:36:20 -0700445
Michael Butlerc932ebb2019-04-11 14:24:06 -0700446 if (!fmqResultChannel->isValid()) {
447 LOG(ERROR) << "Unable to create RequestChannelSender";
448 return nullptr;
449 }
Michael Butlerc82044a2019-06-24 10:36:20 -0700450 if (fmqResultChannel->getEventFlagWord() == nullptr) {
451 LOG(ERROR) << "ResultChannelSender::create was passed an MQDescriptor without an EventFlag";
452 return nullptr;
453 }
454
455 return std::make_unique<ResultChannelSender>(std::move(fmqResultChannel));
Michael Butlerc932ebb2019-04-11 14:24:06 -0700456}
457
Michael Butlerc82044a2019-06-24 10:36:20 -0700458ResultChannelSender::ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel)
459 : mFmqResultChannel(std::move(fmqResultChannel)) {}
Michael Butlerc932ebb2019-04-11 14:24:06 -0700460
461bool ResultChannelSender::send(ErrorStatus errorStatus,
462 const std::vector<OutputShape>& outputShapes, Timing timing) {
463 const std::vector<FmqResultDatum> serialized = serialize(errorStatus, outputShapes, timing);
464 return sendPacket(serialized);
465}
466
467bool ResultChannelSender::sendPacket(const std::vector<FmqResultDatum>& packet) {
Michael Butler3260db92019-04-26 17:51:23 -0700468 if (packet.size() > mFmqResultChannel->availableToWrite()) {
469 LOG(ERROR)
470 << "ResultChannelSender::sendPacket -- packet size exceeds size available in FMQ";
471 const std::vector<FmqResultDatum> errorPacket =
472 serialize(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
Michael Butlerc82044a2019-06-24 10:36:20 -0700473
474 // Always send the packet with "blocking" because this signals the futex
475 // and unblocks the consumer if it is waiting on the futex.
476 return mFmqResultChannel->writeBlocking(errorPacket.data(), errorPacket.size());
Michael Butler3260db92019-04-26 17:51:23 -0700477 }
478
Michael Butlerc82044a2019-06-24 10:36:20 -0700479 // Always send the packet with "blocking" because this signals the futex and
480 // unblocks the consumer if it is waiting on the futex.
481 return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
Michael Butlerc932ebb2019-04-11 14:24:06 -0700482}
483
484// ExecutionBurstServer methods
485
486sp<ExecutionBurstServer> ExecutionBurstServer::create(
487 const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
488 const MQDescriptorSync<FmqResultDatum>& resultChannel,
Michael Butlerc82044a2019-06-24 10:36:20 -0700489 std::shared_ptr<IBurstExecutorWithCache> executorWithCache,
490 std::chrono::microseconds pollingTimeWindow) {
Michael Butlerc932ebb2019-04-11 14:24:06 -0700491 // check inputs
492 if (callback == nullptr || executorWithCache == nullptr) {
493 LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
494 return nullptr;
495 }
496
497 // create FMQ objects
498 std::unique_ptr<RequestChannelReceiver> requestChannelReceiver =
Michael Butlerc82044a2019-06-24 10:36:20 -0700499 RequestChannelReceiver::create(requestChannel, pollingTimeWindow);
Michael Butlerc932ebb2019-04-11 14:24:06 -0700500 std::unique_ptr<ResultChannelSender> resultChannelSender =
501 ResultChannelSender::create(resultChannel);
502
503 // check FMQ objects
504 if (!requestChannelReceiver || !resultChannelSender) {
505 LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
506 return nullptr;
507 }
508
509 // make and return context
510 return new ExecutionBurstServer(callback, std::move(requestChannelReceiver),
511 std::move(resultChannelSender), std::move(executorWithCache));
512}
513
514sp<ExecutionBurstServer> ExecutionBurstServer::create(
515 const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
Xusong Wang196a1872019-10-25 12:06:20 -0700516 const MQDescriptorSync<FmqResultDatum>& resultChannel, V1_2::IPreparedModel* preparedModel,
Michael Butlerc82044a2019-06-24 10:36:20 -0700517 std::chrono::microseconds pollingTimeWindow) {
Michael Butlerc932ebb2019-04-11 14:24:06 -0700518 // check relevant input
519 if (preparedModel == nullptr) {
520 LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
521 return nullptr;
522 }
523
524 // adapt IPreparedModel to have caching
525 const std::shared_ptr<DefaultBurstExecutorWithCache> preparedModelAdapter =
526 std::make_shared<DefaultBurstExecutorWithCache>(preparedModel);
527
528 // make and return context
529 return ExecutionBurstServer::create(callback, requestChannel, resultChannel,
Michael Butlerc82044a2019-06-24 10:36:20 -0700530 preparedModelAdapter, pollingTimeWindow);
Michael Butlerc932ebb2019-04-11 14:24:06 -0700531}
532
533ExecutionBurstServer::ExecutionBurstServer(
534 const sp<IBurstCallback>& callback, std::unique_ptr<RequestChannelReceiver> requestChannel,
535 std::unique_ptr<ResultChannelSender> resultChannel,
536 std::shared_ptr<IBurstExecutorWithCache> executorWithCache)
537 : mCallback(callback),
538 mRequestChannelReceiver(std::move(requestChannel)),
539 mResultChannelSender(std::move(resultChannel)),
540 mExecutorWithCache(std::move(executorWithCache)) {
541 // TODO: highly document the threading behavior of this class
542 mWorker = std::thread([this] { task(); });
543}
544
545ExecutionBurstServer::~ExecutionBurstServer() {
546 // set teardown flag
547 mTeardown = true;
548 mRequestChannelReceiver->invalidate();
549
550 // wait for task thread to end
551 mWorker.join();
552}
553
554Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
Michael Butlerba59a542019-06-28 17:06:27 -0700555 std::lock_guard<std::mutex> hold(mMutex);
Michael Butlerc932ebb2019-04-11 14:24:06 -0700556 mExecutorWithCache->removeCacheEntry(slot);
557 return Void();
558}
559
560void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots) {
561 const auto slotIsKnown = [this](int32_t slot) {
562 return mExecutorWithCache->isCacheEntryPresent(slot);
563 };
564
565 // find unique unknown slots
566 std::vector<int32_t> unknownSlots = slots;
567 auto unknownSlotsEnd = unknownSlots.end();
568 std::sort(unknownSlots.begin(), unknownSlotsEnd);
569 unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
570 unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
571 unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
572
573 // quick-exit if all slots are known
574 if (unknownSlots.empty()) {
575 return;
576 }
577
578 ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
579 std::vector<hidl_memory> returnedMemories;
580 auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
581 const hidl_vec<hidl_memory>& memories) {
582 errorStatus = status;
583 returnedMemories = memories;
584 };
585
586 const Return<void> ret = mCallback->getMemories(unknownSlots, cb);
587
588 if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
589 returnedMemories.size() != unknownSlots.size()) {
590 LOG(ERROR) << "Error retrieving memories";
591 return;
592 }
593
594 // add memories to unknown slots
595 for (size_t i = 0; i < unknownSlots.size(); ++i) {
596 mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]);
597 }
598}
599
600void ExecutionBurstServer::task() {
601 // loop until the burst object is being destroyed
602 while (!mTeardown) {
603 // receive request
604 auto arguments = mRequestChannelReceiver->getBlocking();
605
606 // if the request packet was not properly received, return a generic
607 // error and skip the execution
608 //
609 // if the burst is being torn down, skip the execution exection so the
610 // "task" function can end
611 if (!arguments) {
612 if (!mTeardown) {
613 mResultChannelSender->send(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
614 }
615 continue;
616 }
617
618 // otherwise begin tracing execution
619 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
620 "ExecutionBurstServer getting memory, executing, and returning results");
621
622 // unpack the arguments; types are Request, std::vector<int32_t>, and
Michael Butler238fe722019-03-21 12:17:27 -0700623 // MeasureTiming, respectively
Michael Butlerc932ebb2019-04-11 14:24:06 -0700624 const auto [requestWithoutPools, slotsOfPools, measure] = std::move(*arguments);
Michael Butler60296322019-01-17 17:54:51 -0800625
Michael Butler238fe722019-03-21 12:17:27 -0700626 // ensure executor with cache has required memory
627 std::lock_guard<std::mutex> hold(mMutex);
628 ensureCacheEntriesArePresentLocked(slotsOfPools);
629
630 // perform computation; types are ErrorStatus, hidl_vec<OutputShape>,
631 // and Timing, respectively
632 const auto [errorStatus, outputShapes, returnedTiming] =
633 mExecutorWithCache->execute(requestWithoutPools, slotsOfPools, measure);
Michael Butler60296322019-01-17 17:54:51 -0800634
635 // return result
Michael Butlerc932ebb2019-04-11 14:24:06 -0700636 mResultChannelSender->send(errorStatus, outputShapes, returnedTiming);
Michael Butler60296322019-01-17 17:54:51 -0800637 }
638}
639
Michael Butler3db6fe52019-01-29 11:20:30 -0800640} // namespace android::nn