blob: ec935dad6d54855863100d709ef1850193a54186 [file] [log] [blame]
Michael Butler60296322019-01-17 17:54:51 -08001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Michael Butler89e99ba2019-01-24 02:36:37 -080017#define LOG_TAG "ExecutionBurstServer"
18
Michael Butler60296322019-01-17 17:54:51 -080019#include "ExecutionBurstServer.h"
20
21#include <android-base/logging.h>
Michael Butler3260db92019-04-26 17:51:23 -070022
Michael Butlerc82044a2019-06-24 10:36:20 -070023#include <algorithm>
Michael Butler4ef48f12019-05-02 14:09:17 -070024#include <cstring>
Michael Butlerc932ebb2019-04-11 14:24:06 -070025#include <limits>
Michael Butler3260db92019-04-26 17:51:23 -070026#include <map>
Michael Butlerc82044a2019-06-24 10:36:20 -070027#include <memory>
28#include <tuple>
29#include <utility>
30#include <vector>
Michael Butler3260db92019-04-26 17:51:23 -070031
Michael Butler3db6fe52019-01-29 11:20:30 -080032#include "Tracing.h"
Michael Butler60296322019-01-17 17:54:51 -080033
Michael Butler3db6fe52019-01-29 11:20:30 -080034namespace android::nn {
Michael Butler238fe722019-03-21 12:17:27 -070035namespace {
Michael Butler60296322019-01-17 17:54:51 -080036
Michael Butler19af9d22019-07-11 11:45:01 -070037using namespace hal;
38
Michael Butlerc82044a2019-06-24 10:36:20 -070039using hardware::MQDescriptorSync;
40
Michael Butlerc932ebb2019-04-11 14:24:06 -070041constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
42 std::numeric_limits<uint64_t>::max()};
43
Michael Butler238fe722019-03-21 12:17:27 -070044// DefaultBurstExecutorWithCache adapts an IPreparedModel so that it can be
45// used as an IBurstExecutorWithCache. Specifically, the cache simply stores the
46// hidl_memory object, and the execution forwards calls to the provided
47// IPreparedModel's "executeSynchronously" method. With this class, hidl_memory
48// must be mapped and unmapped for each execution.
49class DefaultBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
50 public:
51 DefaultBurstExecutorWithCache(IPreparedModel* preparedModel) : mpPreparedModel(preparedModel) {}
Michael Butler60296322019-01-17 17:54:51 -080052
Michael Butler238fe722019-03-21 12:17:27 -070053 bool isCacheEntryPresent(int32_t slot) const override {
Michael Butler3260db92019-04-26 17:51:23 -070054 const auto it = mMemoryCache.find(slot);
Michael Butler1ee58a52019-04-30 13:49:32 -070055 return (it != mMemoryCache.end()) && it->second.valid();
Michael Butler238fe722019-03-21 12:17:27 -070056 }
Michael Butler47c988f62019-03-14 17:34:48 -070057
Michael Butler238fe722019-03-21 12:17:27 -070058 void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
Michael Butler238fe722019-03-21 12:17:27 -070059 mMemoryCache[slot] = memory;
60 }
Michael Butler60296322019-01-17 17:54:51 -080061
Michael Butler3260db92019-04-26 17:51:23 -070062 void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
Michael Butler238fe722019-03-21 12:17:27 -070063
64 std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
65 const Request& request, const std::vector<int32_t>& slots,
66 MeasureTiming measure) override {
67 // convert slots to pools
68 hidl_vec<hidl_memory> pools(slots.size());
Michael Butler3260db92019-04-26 17:51:23 -070069 std::transform(slots.begin(), slots.end(), pools.begin(),
70 [this](int32_t slot) { return mMemoryCache[slot]; });
Michael Butler238fe722019-03-21 12:17:27 -070071
72 // create full request
73 Request fullRequest = request;
74 fullRequest.pools = std::move(pools);
75
76 // setup execution
77 ErrorStatus returnedStatus = ErrorStatus::GENERAL_FAILURE;
78 hidl_vec<OutputShape> returnedOutputShapes;
79 Timing returnedTiming;
80 auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](
81 ErrorStatus status, const hidl_vec<OutputShape>& outputShapes,
82 const Timing& timing) {
83 returnedStatus = status;
84 returnedOutputShapes = outputShapes;
85 returnedTiming = timing;
Michael Butler47c988f62019-03-14 17:34:48 -070086 };
Michael Butler60296322019-01-17 17:54:51 -080087
Michael Butler238fe722019-03-21 12:17:27 -070088 // execute
89 const Return<void> ret = mpPreparedModel->executeSynchronously(fullRequest, measure, cb);
90 if (!ret.isOk() || returnedStatus != ErrorStatus::NONE) {
91 LOG(ERROR) << "IPreparedModelAdapter::execute -- Error executing";
Raksit Ashokc1079232019-05-29 12:55:16 -070092 return {returnedStatus, {}, kNoTiming};
Michael Butler89e99ba2019-01-24 02:36:37 -080093 }
Michael Butler60296322019-01-17 17:54:51 -080094
Michael Butler238fe722019-03-21 12:17:27 -070095 return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming);
Michael Butler60296322019-01-17 17:54:51 -080096 }
97
Michael Butler238fe722019-03-21 12:17:27 -070098 private:
99 IPreparedModel* const mpPreparedModel;
Michael Butler3260db92019-04-26 17:51:23 -0700100 std::map<int32_t, hidl_memory> mMemoryCache;
Michael Butler238fe722019-03-21 12:17:27 -0700101};
Michael Butler47c988f62019-03-14 17:34:48 -0700102
Michael Butler238fe722019-03-21 12:17:27 -0700103} // anonymous namespace
Michael Butler60296322019-01-17 17:54:51 -0800104
Michael Butler60296322019-01-17 17:54:51 -0800105// serialize result
Michael Butlerc932ebb2019-04-11 14:24:06 -0700106std::vector<FmqResultDatum> serialize(ErrorStatus errorStatus,
107 const std::vector<OutputShape>& outputShapes, Timing timing) {
Michael Butler60296322019-01-17 17:54:51 -0800108 // count how many elements need to be sent for a request
109 size_t count = 2 + outputShapes.size();
110 for (const auto& outputShape : outputShapes) {
111 count += outputShape.dimensions.size();
112 }
113
114 // create buffer to temporarily store elements
115 std::vector<FmqResultDatum> data;
116 data.reserve(count);
117
118 // package packetInfo
119 {
120 FmqResultDatum datum;
121 datum.packetInformation({/*.packetSize=*/static_cast<uint32_t>(count),
122 /*.errorStatus=*/errorStatus,
123 /*.numberOfOperands=*/static_cast<uint32_t>(outputShapes.size())});
124 data.push_back(datum);
125 }
126
127 // package output shape data
128 for (const auto& operand : outputShapes) {
129 // package operand information
Steven Moreland393ac6d2019-04-25 15:33:25 -0700130 FmqResultDatum::OperandInformation info{};
131 info.isSufficient = operand.isSufficient;
132 info.numberOfDimensions = static_cast<uint32_t>(operand.dimensions.size());
133
Michael Butler60296322019-01-17 17:54:51 -0800134 FmqResultDatum datum;
Steven Moreland393ac6d2019-04-25 15:33:25 -0700135 datum.operandInformation(info);
Michael Butler60296322019-01-17 17:54:51 -0800136 data.push_back(datum);
137
138 // package operand dimensions
139 for (uint32_t dimension : operand.dimensions) {
140 FmqResultDatum datum;
141 datum.operandDimensionValue(dimension);
142 data.push_back(datum);
143 }
144 }
145
146 // package executionTiming
147 {
148 FmqResultDatum datum;
149 datum.executionTiming(timing);
150 data.push_back(datum);
151 }
152
153 // return result
154 return data;
155}
156
Michael Butlerc932ebb2019-04-11 14:24:06 -0700157// deserialize request
158std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> deserialize(
159 const std::vector<FmqRequestDatum>& data) {
160 using discriminator = FmqRequestDatum::hidl_discriminator;
Michael Butler60296322019-01-17 17:54:51 -0800161
Michael Butlerc932ebb2019-04-11 14:24:06 -0700162 size_t index = 0;
163
164 // validate packet information
Michael Butler3260db92019-04-26 17:51:23 -0700165 if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
Michael Butlerc932ebb2019-04-11 14:24:06 -0700166 LOG(ERROR) << "FMQ Request packet ill-formed";
167 return std::nullopt;
168 }
169
170 // unpackage packet information
171 const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
172 index++;
173 const uint32_t packetSize = packetInfo.packetSize;
174 const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
175 const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
176 const uint32_t numberOfPools = packetInfo.numberOfPools;
177
Michael Butler3260db92019-04-26 17:51:23 -0700178 // verify packet size
179 if (data.size() != packetSize) {
180 LOG(ERROR) << "FMQ Request packet ill-formed";
181 return std::nullopt;
182 }
183
Michael Butlerc932ebb2019-04-11 14:24:06 -0700184 // unpackage input operands
185 std::vector<RequestArgument> inputs;
186 inputs.reserve(numberOfInputOperands);
187 for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
188 // validate input operand information
189 if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
190 LOG(ERROR) << "FMQ Request packet ill-formed";
191 return std::nullopt;
Michael Butler60296322019-01-17 17:54:51 -0800192 }
193
Michael Butlerc932ebb2019-04-11 14:24:06 -0700194 // unpackage operand information
195 const FmqRequestDatum::OperandInformation& operandInfo =
196 data[index].inputOperandInformation();
197 index++;
198 const bool hasNoValue = operandInfo.hasNoValue;
199 const DataLocation location = operandInfo.location;
200 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
Michael Butler3db6fe52019-01-29 11:20:30 -0800201
Michael Butlerc932ebb2019-04-11 14:24:06 -0700202 // unpackage operand dimensions
203 std::vector<uint32_t> dimensions;
204 dimensions.reserve(numberOfDimensions);
205 for (size_t i = 0; i < numberOfDimensions; ++i) {
206 // validate dimension
207 if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
208 LOG(ERROR) << "FMQ Request packet ill-formed";
209 return std::nullopt;
210 }
211
212 // unpackage dimension
213 const uint32_t dimension = data[index].inputOperandDimensionValue();
214 index++;
215
216 // store result
217 dimensions.push_back(dimension);
218 }
219
220 // store result
221 inputs.push_back(
222 {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
223 }
224
225 // unpackage output operands
226 std::vector<RequestArgument> outputs;
227 outputs.reserve(numberOfOutputOperands);
228 for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
229 // validate output operand information
230 if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
231 LOG(ERROR) << "FMQ Request packet ill-formed";
232 return std::nullopt;
233 }
234
235 // unpackage operand information
236 const FmqRequestDatum::OperandInformation& operandInfo =
237 data[index].outputOperandInformation();
238 index++;
239 const bool hasNoValue = operandInfo.hasNoValue;
240 const DataLocation location = operandInfo.location;
241 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
242
243 // unpackage operand dimensions
244 std::vector<uint32_t> dimensions;
245 dimensions.reserve(numberOfDimensions);
246 for (size_t i = 0; i < numberOfDimensions; ++i) {
247 // validate dimension
248 if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
249 LOG(ERROR) << "FMQ Request packet ill-formed";
250 return std::nullopt;
251 }
252
253 // unpackage dimension
254 const uint32_t dimension = data[index].outputOperandDimensionValue();
255 index++;
256
257 // store result
258 dimensions.push_back(dimension);
259 }
260
261 // store result
262 outputs.push_back(
263 {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
264 }
265
266 // unpackage pools
267 std::vector<int32_t> slots;
268 slots.reserve(numberOfPools);
269 for (size_t pool = 0; pool < numberOfPools; ++pool) {
270 // validate input operand information
271 if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
272 LOG(ERROR) << "FMQ Request packet ill-formed";
273 return std::nullopt;
274 }
275
276 // unpackage operand information
277 const int32_t poolId = data[index].poolIdentifier();
278 index++;
279
280 // store result
281 slots.push_back(poolId);
282 }
283
284 // validate measureTiming
285 if (data[index].getDiscriminator() != discriminator::measureTiming) {
286 LOG(ERROR) << "FMQ Request packet ill-formed";
287 return std::nullopt;
288 }
289
290 // unpackage measureTiming
291 const MeasureTiming measure = data[index].measureTiming();
292 index++;
293
294 // validate packet information
295 if (index != packetSize) {
296 LOG(ERROR) << "FMQ Result packet ill-formed";
297 return std::nullopt;
298 }
299
300 // return request
301 Request request = {/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}};
302 return std::make_tuple(std::move(request), std::move(slots), measure);
303}
304
305// RequestChannelReceiver methods
306
307std::unique_ptr<RequestChannelReceiver> RequestChannelReceiver::create(
Michael Butlerc82044a2019-06-24 10:36:20 -0700308 const FmqRequestDescriptor& requestChannel, std::chrono::microseconds pollingTimeWindow) {
Michael Butlerc932ebb2019-04-11 14:24:06 -0700309 std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
310 std::make_unique<FmqRequestChannel>(requestChannel);
Michael Butlerc82044a2019-06-24 10:36:20 -0700311
Michael Butlerc932ebb2019-04-11 14:24:06 -0700312 if (!fmqRequestChannel->isValid()) {
313 LOG(ERROR) << "Unable to create RequestChannelReceiver";
314 return nullptr;
315 }
Michael Butlerc82044a2019-06-24 10:36:20 -0700316 if (fmqRequestChannel->getEventFlagWord() == nullptr) {
317 LOG(ERROR)
318 << "RequestChannelReceiver::create was passed an MQDescriptor without an EventFlag";
319 return nullptr;
320 }
321
322 return std::make_unique<RequestChannelReceiver>(std::move(fmqRequestChannel),
323 pollingTimeWindow);
Michael Butlerc932ebb2019-04-11 14:24:06 -0700324}
325
326RequestChannelReceiver::RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
Michael Butlerc82044a2019-06-24 10:36:20 -0700327 std::chrono::microseconds pollingTimeWindow)
328 : mFmqRequestChannel(std::move(fmqRequestChannel)), kPollingTimeWindow(pollingTimeWindow) {}
Michael Butlerc932ebb2019-04-11 14:24:06 -0700329
330std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>>
331RequestChannelReceiver::getBlocking() {
332 const auto packet = getPacketBlocking();
333 if (!packet) {
334 return std::nullopt;
335 }
336
337 return deserialize(*packet);
338}
339
340void RequestChannelReceiver::invalidate() {
341 mTeardown = true;
342
343 // force unblock
344 // ExecutionBurstServer is by default waiting on a request packet. If the
Michael Butlerc82044a2019-06-24 10:36:20 -0700345 // client process destroys its burst object, the server may still be waiting
346 // on the futex. This force unblock wakes up any thread waiting on the
347 // futex.
348 // TODO: look for a different/better way to signal/notify the futex to wake
349 // up any thread waiting on it
350 FmqRequestDatum datum;
351 datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
352 /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
353 mFmqRequestChannel->writeBlocking(&datum, 1);
Michael Butlerc932ebb2019-04-11 14:24:06 -0700354}
355
356std::optional<std::vector<FmqRequestDatum>> RequestChannelReceiver::getPacketBlocking() {
357 using discriminator = FmqRequestDatum::hidl_discriminator;
358
359 if (mTeardown) {
360 return std::nullopt;
361 }
362
Michael Butlerc82044a2019-06-24 10:36:20 -0700363 // First spend time polling if results are available in FMQ instead of
364 // waiting on the futex. Polling is more responsive (yielding lower
365 // latencies), but can take up more power, so only poll for a limited period
366 // of time.
367
368 auto& getCurrentTime = std::chrono::high_resolution_clock::now;
369 const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
370
371 while (getCurrentTime() < timeToStopPolling) {
372 // if class is being torn down, immediately return
373 if (mTeardown.load(std::memory_order_relaxed)) {
374 return std::nullopt;
375 }
376
377 // Check if data is available. If it is, immediately retrieve it and
378 // return.
379 const size_t available = mFmqRequestChannel->availableToRead();
380 if (available > 0) {
381 // This is the first point when we know an execution is occurring,
382 // so begin to collect systraces. Note that a similar systrace does
383 // not exist at the corresponding point in
384 // ResultChannelReceiver::getPacketBlocking because the execution is
385 // already in flight.
386 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
387 "ExecutionBurstServer getting packet");
388 std::vector<FmqRequestDatum> packet(available);
389 const bool success = mFmqRequestChannel->read(packet.data(), available);
390 if (!success) {
391 LOG(ERROR) << "Error receiving packet";
392 return std::nullopt;
393 }
394 return std::make_optional(std::move(packet));
Michael Butlerc932ebb2019-04-11 14:24:06 -0700395 }
396 }
397
Michael Butlerc82044a2019-06-24 10:36:20 -0700398 // If we get to this point, we either stopped polling because it was taking
399 // too long or polling was not allowed. Instead, perform a blocking call
400 // which uses a futex to save power.
401
402 // wait for request packet and read first element of request packet
403 FmqRequestDatum datum;
404 bool success = mFmqRequestChannel->readBlocking(&datum, 1);
405
406 // This is the first point when we know an execution is occurring, so begin
407 // to collect systraces. Note that a similar systrace does not exist at the
408 // corresponding point in ResultChannelReceiver::getPacketBlocking because
409 // the execution is already in flight.
Michael Butlerc932ebb2019-04-11 14:24:06 -0700410 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
411
Michael Butlerc932ebb2019-04-11 14:24:06 -0700412 // retrieve remaining elements
413 // NOTE: all of the data is already available at this point, so there's no
414 // need to do a blocking wait to wait for more data. This is known because
415 // in FMQ, all writes are published (made available) atomically. Currently,
416 // the producer always publishes the entire packet in one function call, so
417 // if the first element of the packet is available, the remaining elements
418 // are also available.
Michael Butler3260db92019-04-26 17:51:23 -0700419 const size_t count = mFmqRequestChannel->availableToRead();
420 std::vector<FmqRequestDatum> packet(count + 1);
Michael Butler4ef48f12019-05-02 14:09:17 -0700421 std::memcpy(&packet.front(), &datum, sizeof(datum));
Michael Butler3260db92019-04-26 17:51:23 -0700422 success &= mFmqRequestChannel->read(packet.data() + 1, count);
Michael Butlerc932ebb2019-04-11 14:24:06 -0700423
Michael Butler3260db92019-04-26 17:51:23 -0700424 // terminate loop
425 if (mTeardown) {
426 return std::nullopt;
427 }
428
429 // ensure packet was successfully received
Michael Butlerc932ebb2019-04-11 14:24:06 -0700430 if (!success) {
Michael Butler3260db92019-04-26 17:51:23 -0700431 LOG(ERROR) << "Error receiving packet";
432 return std::nullopt;
Michael Butlerc932ebb2019-04-11 14:24:06 -0700433 }
434
Michael Butler4ef48f12019-05-02 14:09:17 -0700435 return std::make_optional(std::move(packet));
Michael Butlerc932ebb2019-04-11 14:24:06 -0700436}
437
438// ResultChannelSender methods
439
440std::unique_ptr<ResultChannelSender> ResultChannelSender::create(
441 const FmqResultDescriptor& resultChannel) {
442 std::unique_ptr<FmqResultChannel> fmqResultChannel =
443 std::make_unique<FmqResultChannel>(resultChannel);
Michael Butlerc82044a2019-06-24 10:36:20 -0700444
Michael Butlerc932ebb2019-04-11 14:24:06 -0700445 if (!fmqResultChannel->isValid()) {
446 LOG(ERROR) << "Unable to create RequestChannelSender";
447 return nullptr;
448 }
Michael Butlerc82044a2019-06-24 10:36:20 -0700449 if (fmqResultChannel->getEventFlagWord() == nullptr) {
450 LOG(ERROR) << "ResultChannelSender::create was passed an MQDescriptor without an EventFlag";
451 return nullptr;
452 }
453
454 return std::make_unique<ResultChannelSender>(std::move(fmqResultChannel));
Michael Butlerc932ebb2019-04-11 14:24:06 -0700455}
456
Michael Butlerc82044a2019-06-24 10:36:20 -0700457ResultChannelSender::ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel)
458 : mFmqResultChannel(std::move(fmqResultChannel)) {}
Michael Butlerc932ebb2019-04-11 14:24:06 -0700459
460bool ResultChannelSender::send(ErrorStatus errorStatus,
461 const std::vector<OutputShape>& outputShapes, Timing timing) {
462 const std::vector<FmqResultDatum> serialized = serialize(errorStatus, outputShapes, timing);
463 return sendPacket(serialized);
464}
465
466bool ResultChannelSender::sendPacket(const std::vector<FmqResultDatum>& packet) {
Michael Butler3260db92019-04-26 17:51:23 -0700467 if (packet.size() > mFmqResultChannel->availableToWrite()) {
468 LOG(ERROR)
469 << "ResultChannelSender::sendPacket -- packet size exceeds size available in FMQ";
470 const std::vector<FmqResultDatum> errorPacket =
471 serialize(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
Michael Butlerc82044a2019-06-24 10:36:20 -0700472
473 // Always send the packet with "blocking" because this signals the futex
474 // and unblocks the consumer if it is waiting on the futex.
475 return mFmqResultChannel->writeBlocking(errorPacket.data(), errorPacket.size());
Michael Butler3260db92019-04-26 17:51:23 -0700476 }
477
Michael Butlerc82044a2019-06-24 10:36:20 -0700478 // Always send the packet with "blocking" because this signals the futex and
479 // unblocks the consumer if it is waiting on the futex.
480 return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
Michael Butlerc932ebb2019-04-11 14:24:06 -0700481}
482
483// ExecutionBurstServer methods
484
485sp<ExecutionBurstServer> ExecutionBurstServer::create(
486 const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
487 const MQDescriptorSync<FmqResultDatum>& resultChannel,
Michael Butlerc82044a2019-06-24 10:36:20 -0700488 std::shared_ptr<IBurstExecutorWithCache> executorWithCache,
489 std::chrono::microseconds pollingTimeWindow) {
Michael Butlerc932ebb2019-04-11 14:24:06 -0700490 // check inputs
491 if (callback == nullptr || executorWithCache == nullptr) {
492 LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
493 return nullptr;
494 }
495
496 // create FMQ objects
497 std::unique_ptr<RequestChannelReceiver> requestChannelReceiver =
Michael Butlerc82044a2019-06-24 10:36:20 -0700498 RequestChannelReceiver::create(requestChannel, pollingTimeWindow);
Michael Butlerc932ebb2019-04-11 14:24:06 -0700499 std::unique_ptr<ResultChannelSender> resultChannelSender =
500 ResultChannelSender::create(resultChannel);
501
502 // check FMQ objects
503 if (!requestChannelReceiver || !resultChannelSender) {
504 LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
505 return nullptr;
506 }
507
508 // make and return context
509 return new ExecutionBurstServer(callback, std::move(requestChannelReceiver),
510 std::move(resultChannelSender), std::move(executorWithCache));
511}
512
513sp<ExecutionBurstServer> ExecutionBurstServer::create(
514 const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
Michael Butlerc82044a2019-06-24 10:36:20 -0700515 const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel,
516 std::chrono::microseconds pollingTimeWindow) {
Michael Butlerc932ebb2019-04-11 14:24:06 -0700517 // check relevant input
518 if (preparedModel == nullptr) {
519 LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
520 return nullptr;
521 }
522
523 // adapt IPreparedModel to have caching
524 const std::shared_ptr<DefaultBurstExecutorWithCache> preparedModelAdapter =
525 std::make_shared<DefaultBurstExecutorWithCache>(preparedModel);
526
527 // make and return context
528 return ExecutionBurstServer::create(callback, requestChannel, resultChannel,
Michael Butlerc82044a2019-06-24 10:36:20 -0700529 preparedModelAdapter, pollingTimeWindow);
Michael Butlerc932ebb2019-04-11 14:24:06 -0700530}
531
532ExecutionBurstServer::ExecutionBurstServer(
533 const sp<IBurstCallback>& callback, std::unique_ptr<RequestChannelReceiver> requestChannel,
534 std::unique_ptr<ResultChannelSender> resultChannel,
535 std::shared_ptr<IBurstExecutorWithCache> executorWithCache)
536 : mCallback(callback),
537 mRequestChannelReceiver(std::move(requestChannel)),
538 mResultChannelSender(std::move(resultChannel)),
539 mExecutorWithCache(std::move(executorWithCache)) {
540 // TODO: highly document the threading behavior of this class
541 mWorker = std::thread([this] { task(); });
542}
543
544ExecutionBurstServer::~ExecutionBurstServer() {
545 // set teardown flag
546 mTeardown = true;
547 mRequestChannelReceiver->invalidate();
548
549 // wait for task thread to end
550 mWorker.join();
551}
552
553Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
Michael Butlerba59a542019-06-28 17:06:27 -0700554 std::lock_guard<std::mutex> hold(mMutex);
Michael Butlerc932ebb2019-04-11 14:24:06 -0700555 mExecutorWithCache->removeCacheEntry(slot);
556 return Void();
557}
558
559void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots) {
560 const auto slotIsKnown = [this](int32_t slot) {
561 return mExecutorWithCache->isCacheEntryPresent(slot);
562 };
563
564 // find unique unknown slots
565 std::vector<int32_t> unknownSlots = slots;
566 auto unknownSlotsEnd = unknownSlots.end();
567 std::sort(unknownSlots.begin(), unknownSlotsEnd);
568 unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
569 unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
570 unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
571
572 // quick-exit if all slots are known
573 if (unknownSlots.empty()) {
574 return;
575 }
576
577 ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
578 std::vector<hidl_memory> returnedMemories;
579 auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
580 const hidl_vec<hidl_memory>& memories) {
581 errorStatus = status;
582 returnedMemories = memories;
583 };
584
585 const Return<void> ret = mCallback->getMemories(unknownSlots, cb);
586
587 if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
588 returnedMemories.size() != unknownSlots.size()) {
589 LOG(ERROR) << "Error retrieving memories";
590 return;
591 }
592
593 // add memories to unknown slots
594 for (size_t i = 0; i < unknownSlots.size(); ++i) {
595 mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]);
596 }
597}
598
599void ExecutionBurstServer::task() {
600 // loop until the burst object is being destroyed
601 while (!mTeardown) {
602 // receive request
603 auto arguments = mRequestChannelReceiver->getBlocking();
604
605 // if the request packet was not properly received, return a generic
606 // error and skip the execution
607 //
608 // if the burst is being torn down, skip the execution exection so the
609 // "task" function can end
610 if (!arguments) {
611 if (!mTeardown) {
612 mResultChannelSender->send(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
613 }
614 continue;
615 }
616
617 // otherwise begin tracing execution
618 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
619 "ExecutionBurstServer getting memory, executing, and returning results");
620
621 // unpack the arguments; types are Request, std::vector<int32_t>, and
Michael Butler238fe722019-03-21 12:17:27 -0700622 // MeasureTiming, respectively
Michael Butlerc932ebb2019-04-11 14:24:06 -0700623 const auto [requestWithoutPools, slotsOfPools, measure] = std::move(*arguments);
Michael Butler60296322019-01-17 17:54:51 -0800624
Michael Butler238fe722019-03-21 12:17:27 -0700625 // ensure executor with cache has required memory
626 std::lock_guard<std::mutex> hold(mMutex);
627 ensureCacheEntriesArePresentLocked(slotsOfPools);
628
629 // perform computation; types are ErrorStatus, hidl_vec<OutputShape>,
630 // and Timing, respectively
631 const auto [errorStatus, outputShapes, returnedTiming] =
632 mExecutorWithCache->execute(requestWithoutPools, slotsOfPools, measure);
Michael Butler60296322019-01-17 17:54:51 -0800633
634 // return result
Michael Butlerc932ebb2019-04-11 14:24:06 -0700635 mResultChannelSender->send(errorStatus, outputShapes, returnedTiming);
Michael Butler60296322019-01-17 17:54:51 -0800636 }
637}
638
Michael Butler3db6fe52019-01-29 11:20:30 -0800639} // namespace android::nn