BurstMemoryCache serialization cleanup
This CL offers some prerequisite cleanup for the following two
subsequent changes:
(1) Handle case where Burst FMQ packet exceeds FMQ buffer size
(2) Create validation tests for FMQ serialized message format
It does this by making the serialization and deserialization functions
more modular and their usage clearer.
This CL additionally changes the worker from std::future to std::thread.
This is because some implementations of std::future use a
ThreadPool/WorkPool, which can cause unintended deadlocks if not handled
carefully. std::thread does not have those same concerns.
Bug: 129779280
Bug: 129157135
Test: mma
Test: atest NeuralNetworksTest_static
Test: atest VtsHalNeuralnetworksV1_0TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_1TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_2TargetTest (with sample-all)
Change-Id: Ia0842bc62c37c676c4faf83fc177198adac1f832
Merged-In: Ia0842bc62c37c676c4faf83fc177198adac1f832
(cherry picked from commit 1a57745f547572ed4fcfd77570e109b1c2288d5a)
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp
index 08d0e80..067c2f5 100644
--- a/nn/common/ExecutionBurstServer.cpp
+++ b/nn/common/ExecutionBurstServer.cpp
@@ -19,13 +19,15 @@
#include "ExecutionBurstServer.h"
#include <android-base/logging.h>
-#include <string>
+#include <limits>
#include "Tracing.h"
namespace android::nn {
-
namespace {
+constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
+ std::numeric_limits<uint64_t>::max()};
+
// DefaultBurstExecutorWithCache adapts an IPreparedModel so that it can be
// used as an IBurstExecutorWithCache. Specifically, the cache simply stores the
// hidl_memory object, and the execution forwards calls to the provided
@@ -94,347 +96,9 @@
} // anonymous namespace
-sp<ExecutionBurstServer> ExecutionBurstServer::create(
- const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
- const MQDescriptorSync<FmqResultDatum>& resultChannel,
- std::shared_ptr<IBurstExecutorWithCache> executorWithCache) {
- // check inputs
- if (callback == nullptr || executorWithCache == nullptr) {
- LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
- return nullptr;
- }
-
- // create FMQ objects
- std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
- std::make_unique<FmqRequestChannel>(requestChannel);
- std::unique_ptr<FmqResultChannel> fmqResultChannel =
- std::make_unique<FmqResultChannel>(resultChannel);
-
- // check FMQ objects
- if (!fmqRequestChannel->isValid() || !fmqResultChannel->isValid()) {
- LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
- return nullptr;
- }
-
- // make and return context
- return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
- std::move(fmqResultChannel), std::move(executorWithCache));
-}
-
-sp<ExecutionBurstServer> ExecutionBurstServer::create(
- const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
- const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
- // check relevant input
- if (preparedModel == nullptr) {
- LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
- return nullptr;
- }
-
- // adapt IPreparedModel to have caching
- const std::shared_ptr<DefaultBurstExecutorWithCache> preparedModelAdapter =
- std::make_shared<DefaultBurstExecutorWithCache>(preparedModel);
-
- // make and return context
- return ExecutionBurstServer::create(callback, requestChannel, resultChannel,
- preparedModelAdapter);
-}
-
-ExecutionBurstServer::ExecutionBurstServer(
- const sp<IBurstCallback>& callback, std::unique_ptr<FmqRequestChannel> requestChannel,
- std::unique_ptr<FmqResultChannel> resultChannel,
- std::shared_ptr<IBurstExecutorWithCache> executorWithCache)
- : mCallback(callback),
- mFmqRequestChannel(std::move(requestChannel)),
- mFmqResultChannel(std::move(resultChannel)),
- mExecutorWithCache(std::move(executorWithCache)),
- mBlocking(mFmqRequestChannel->getEventFlagWord() != nullptr) {
- // TODO: highly document the threading behavior of this class
- mWorker = std::async(std::launch::async, [this] { task(); });
-}
-
-ExecutionBurstServer::~ExecutionBurstServer() {
- // set teardown flag
- mTeardown = true;
-
- // force unblock
- // ExecutionBurstServer is by default waiting on a request packet. If the
- // client process destroys its burst object, the server will still be
- // waiting on the futex (assuming mBlocking is true). This force unblock
- // wakes up any thread waiting on the futex.
- if (mBlocking) {
- // TODO: look for a different/better way to signal/notify the futex to
- // wake up any thread waiting on it
- FmqRequestDatum datum;
- datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
- /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
- mFmqRequestChannel->writeBlocking(&datum, 1);
- }
-
- // wait for task thread to end
- mWorker.wait();
-}
-
-Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
- std::lock_guard<std::mutex> hold(mMutex);
- mExecutorWithCache->removeCacheEntry(slot);
- return Void();
-}
-
-void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots) {
- const auto slotIsKnown = [this](int32_t slot) {
- return mExecutorWithCache->isCacheEntryPresent(slot);
- };
-
- // find unique unknown slots
- std::vector<int32_t> unknownSlots = slots;
- auto unknownSlotsEnd = unknownSlots.end();
- std::sort(unknownSlots.begin(), unknownSlotsEnd);
- unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
- unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
- unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
-
- // quick-exit if all slots are known
- if (unknownSlots.empty()) {
- return;
- }
-
- ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
- std::vector<hidl_memory> returnedMemories;
- auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
- const hidl_vec<hidl_memory>& memories) {
- errorStatus = status;
- returnedMemories = memories;
- };
-
- const Return<void> ret = mCallback->getMemories(unknownSlots, cb);
-
- if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
- returnedMemories.size() != unknownSlots.size()) {
- LOG(ERROR) << "Error retrieving memories";
- return;
- }
-
- // add memories to unknown slots
- for (size_t i = 0; i < unknownSlots.size(); ++i) {
- mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]);
- }
-}
-
-bool ExecutionBurstServer::sendPacket(const std::vector<FmqResultDatum>& packet) {
- if (mTeardown) {
- return false;
- }
-
- if (mBlocking) {
- return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
- } else {
- return mFmqResultChannel->write(packet.data(), packet.size());
- }
-}
-
-std::vector<FmqRequestDatum> ExecutionBurstServer::getPacketBlocking() {
- using discriminator = FmqRequestDatum::hidl_discriminator;
-
- if (mTeardown) {
- return {};
- }
-
- // wait for request packet and read first element of request packet
- // TODO: have a more elegant way to wait for data, and read it all at once.
- // For example, EventFlag can be used to directly wait on the futex, and all
- // the data can be read at once with a non-blocking call to
- // MessageQueue::read. For further optimization, MessageQueue::beginRead and
- // MessageQueue::commitRead can be used to avoid an extra copy of the
- // metadata.
- FmqRequestDatum datum;
- bool success = false;
- if (mBlocking) {
- success = mFmqRequestChannel->readBlocking(&datum, 1);
- } else {
- while ((success = !mTeardown.load(std::memory_order_relaxed)) &&
- !mFmqRequestChannel->read(&datum, 1)) {
- }
- }
-
- // terminate loop
- if (mTeardown) {
- return {};
- }
-
- // validate packet information
- if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {};
- }
-
- NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
-
- // unpack packet information
- const auto& packetInfo = datum.packetInformation();
- const size_t count = packetInfo.packetSize;
-
- // retrieve remaining elements
- // NOTE: all of the data is already available at this point, so there's no
- // need to do a blocking wait to wait for more data. This is known because
- // in FMQ, all writes are published (made available) atomically. Currently,
- // the producer always publishes the entire packet in one function call, so
- // if the first element of the packet is available, the remaining elements
- // are also available.
- std::vector<FmqRequestDatum> packet(count);
- packet.front() = datum;
- success = mFmqRequestChannel->read(packet.data() + 1, packet.size() - 1);
-
- if (!success) {
- return {};
- }
-
- return packet;
-}
-
-// deserialize request
-std::tuple<Request, std::vector<int32_t>, MeasureTiming> ExecutionBurstServer::deserialize(
- const std::vector<FmqRequestDatum>& data) {
- using discriminator = FmqRequestDatum::hidl_discriminator;
-
- size_t index = 0;
-
- // validate packet information
- if (data[index].getDiscriminator() != discriminator::packetInformation) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage packet information
- const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
- index++;
- const uint32_t packetSize = packetInfo.packetSize;
- const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
- const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
- const uint32_t numberOfPools = packetInfo.numberOfPools;
-
- // unpackage input operands
- std::vector<RequestArgument> inputs;
- inputs.reserve(numberOfInputOperands);
- for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
- // validate input operand information
- if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage operand information
- const FmqRequestDatum::OperandInformation& operandInfo =
- data[index].inputOperandInformation();
- index++;
- const bool hasNoValue = operandInfo.hasNoValue;
- const DataLocation location = operandInfo.location;
- const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
-
- // unpackage operand dimensions
- std::vector<uint32_t> dimensions;
- dimensions.reserve(numberOfDimensions);
- for (size_t i = 0; i < numberOfDimensions; ++i) {
- // validate dimension
- if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage dimension
- const uint32_t dimension = data[index].inputOperandDimensionValue();
- index++;
-
- // store result
- dimensions.push_back(dimension);
- }
-
- // store result
- inputs.push_back(
- {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
- }
-
- // unpackage output operands
- std::vector<RequestArgument> outputs;
- outputs.reserve(numberOfOutputOperands);
- for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
- // validate output operand information
- if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage operand information
- const FmqRequestDatum::OperandInformation& operandInfo =
- data[index].outputOperandInformation();
- index++;
- const bool hasNoValue = operandInfo.hasNoValue;
- const DataLocation location = operandInfo.location;
- const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
-
- // unpackage operand dimensions
- std::vector<uint32_t> dimensions;
- dimensions.reserve(numberOfDimensions);
- for (size_t i = 0; i < numberOfDimensions; ++i) {
- // validate dimension
- if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage dimension
- const uint32_t dimension = data[index].outputOperandDimensionValue();
- index++;
-
- // store result
- dimensions.push_back(dimension);
- }
-
- // store result
- outputs.push_back(
- {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
- }
-
- // unpackage pools
- std::vector<int32_t> slots;
- slots.reserve(numberOfPools);
- for (size_t pool = 0; pool < numberOfPools; ++pool) {
- // validate input operand information
- if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage operand information
- const int32_t poolId = data[index].poolIdentifier();
- index++;
-
- // store result
- slots.push_back(poolId);
- }
-
- // validate measureTiming
- if (data[index].getDiscriminator() != discriminator::measureTiming) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage measureTiming
- const MeasureTiming measure = data[index].measureTiming();
- index++;
-
- // validate packet information
- if (index != packetSize) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // return request
- return {{/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}}, std::move(slots), measure};
-}
-
// serialize result
-std::vector<FmqResultDatum> ExecutionBurstServer::serialize(
- ErrorStatus errorStatus, const std::vector<OutputShape>& outputShapes, Timing timing) {
+std::vector<FmqResultDatum> serialize(ErrorStatus errorStatus,
+ const std::vector<OutputShape>& outputShapes, Timing timing) {
// count how many elements need to be sent for a request
size_t count = 2 + outputShapes.size();
for (const auto& outputShape : outputShapes) {
@@ -482,22 +146,423 @@
return data;
}
-void ExecutionBurstServer::task() {
- while (!mTeardown) {
- // receive request
- const std::vector<FmqRequestDatum> requestData = getPacketBlocking();
+// deserialize request
+std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> deserialize(
+ const std::vector<FmqRequestDatum>& data) {
+ using discriminator = FmqRequestDatum::hidl_discriminator;
- // terminate loop
- if (mTeardown) {
- return;
+ size_t index = 0;
+
+ // validate packet information
+ if (data[index].getDiscriminator() != discriminator::packetInformation) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage packet information
+ const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
+ index++;
+ const uint32_t packetSize = packetInfo.packetSize;
+ const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
+ const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
+ const uint32_t numberOfPools = packetInfo.numberOfPools;
+
+ // unpackage input operands
+ std::vector<RequestArgument> inputs;
+ inputs.reserve(numberOfInputOperands);
+ for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
+ // validate input operand information
+ if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
}
- NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
- "ExecutionBurstServer processing packet and returning results");
+ // unpackage operand information
+ const FmqRequestDatum::OperandInformation& operandInfo =
+ data[index].inputOperandInformation();
+ index++;
+ const bool hasNoValue = operandInfo.hasNoValue;
+ const DataLocation location = operandInfo.location;
+ const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
- // continue processing; types are Request, std::vector<int32_t>, and
+ // unpackage operand dimensions
+ std::vector<uint32_t> dimensions;
+ dimensions.reserve(numberOfDimensions);
+ for (size_t i = 0; i < numberOfDimensions; ++i) {
+ // validate dimension
+ if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage dimension
+ const uint32_t dimension = data[index].inputOperandDimensionValue();
+ index++;
+
+ // store result
+ dimensions.push_back(dimension);
+ }
+
+ // store result
+ inputs.push_back(
+ {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
+ }
+
+ // unpackage output operands
+ std::vector<RequestArgument> outputs;
+ outputs.reserve(numberOfOutputOperands);
+ for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
+ // validate output operand information
+ if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage operand information
+ const FmqRequestDatum::OperandInformation& operandInfo =
+ data[index].outputOperandInformation();
+ index++;
+ const bool hasNoValue = operandInfo.hasNoValue;
+ const DataLocation location = operandInfo.location;
+ const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
+
+ // unpackage operand dimensions
+ std::vector<uint32_t> dimensions;
+ dimensions.reserve(numberOfDimensions);
+ for (size_t i = 0; i < numberOfDimensions; ++i) {
+ // validate dimension
+ if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage dimension
+ const uint32_t dimension = data[index].outputOperandDimensionValue();
+ index++;
+
+ // store result
+ dimensions.push_back(dimension);
+ }
+
+ // store result
+ outputs.push_back(
+ {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
+ }
+
+ // unpackage pools
+ std::vector<int32_t> slots;
+ slots.reserve(numberOfPools);
+ for (size_t pool = 0; pool < numberOfPools; ++pool) {
+ // validate input operand information
+ if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage operand information
+ const int32_t poolId = data[index].poolIdentifier();
+ index++;
+
+ // store result
+ slots.push_back(poolId);
+ }
+
+ // validate measureTiming
+ if (data[index].getDiscriminator() != discriminator::measureTiming) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage measureTiming
+ const MeasureTiming measure = data[index].measureTiming();
+ index++;
+
+ // validate packet information
+ if (index != packetSize) {
+ LOG(ERROR) << "FMQ Result packet ill-formed";
+ return std::nullopt;
+ }
+
+ // return request
+ Request request = {/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}};
+ return std::make_tuple(std::move(request), std::move(slots), measure);
+}
+
+// RequestChannelReceiver methods
+
+std::unique_ptr<RequestChannelReceiver> RequestChannelReceiver::create(
+ const FmqRequestDescriptor& requestChannel) {
+ std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
+ std::make_unique<FmqRequestChannel>(requestChannel);
+ if (!fmqRequestChannel->isValid()) {
+ LOG(ERROR) << "Unable to create RequestChannelReceiver";
+ return nullptr;
+ }
+ const bool blocking = fmqRequestChannel->getEventFlagWord() != nullptr;
+ return std::make_unique<RequestChannelReceiver>(std::move(fmqRequestChannel), blocking);
+}
+
+RequestChannelReceiver::RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
+ bool blocking)
+ : mFmqRequestChannel(std::move(fmqRequestChannel)), mBlocking(blocking) {}
+
+std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>>
+RequestChannelReceiver::getBlocking() {
+ const auto packet = getPacketBlocking();
+ if (!packet) {
+ return std::nullopt;
+ }
+
+ return deserialize(*packet);
+}
+
+void RequestChannelReceiver::invalidate() {
+ mTeardown = true;
+
+ // force unblock
+ // ExecutionBurstServer is by default waiting on a request packet. If the
+ // client process destroys its burst object, the server will still be
+ // waiting on the futex (assuming mBlocking is true). This force unblock
+ // wakes up any thread waiting on the futex.
+ if (mBlocking) {
+ // TODO: look for a different/better way to signal/notify the futex to
+ // wake up any thread waiting on it
+ FmqRequestDatum datum;
+ datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
+ /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
+ mFmqRequestChannel->writeBlocking(&datum, 1);
+ }
+}
+
+std::optional<std::vector<FmqRequestDatum>> RequestChannelReceiver::getPacketBlocking() {
+ using discriminator = FmqRequestDatum::hidl_discriminator;
+
+ if (mTeardown) {
+ return std::nullopt;
+ }
+
+ // wait for request packet and read first element of request packet
+ // TODO: have a more elegant way to wait for data, and read it all at once.
+ // For example, EventFlag can be used to directly wait on the futex, and all
+ // the data can be read at once with a non-blocking call to
+ // MessageQueue::read. For further optimization, MessageQueue::beginRead and
+ // MessageQueue::commitRead can be used to avoid an extra copy of the
+ // metadata.
+ FmqRequestDatum datum;
+ bool success = false;
+ if (mBlocking) {
+ success = mFmqRequestChannel->readBlocking(&datum, 1);
+ } else {
+ while ((success = !mTeardown.load(std::memory_order_relaxed)) &&
+ !mFmqRequestChannel->read(&datum, 1)) {
+ }
+ }
+
+ // terminate loop
+ if (mTeardown) {
+ return std::nullopt;
+ }
+
+ // validate packet information
+ if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::make_optional<std::vector<FmqRequestDatum>>();
+ }
+
+ NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
+
+ // unpack packet information
+ const auto& packetInfo = datum.packetInformation();
+ const size_t count = packetInfo.packetSize;
+
+ // retrieve remaining elements
+ // NOTE: all of the data is already available at this point, so there's no
+ // need to do a blocking wait to wait for more data. This is known because
+ // in FMQ, all writes are published (made available) atomically. Currently,
+ // the producer always publishes the entire packet in one function call, so
+ // if the first element of the packet is available, the remaining elements
+ // are also available.
+ std::vector<FmqRequestDatum> packet(count);
+ packet.front() = datum;
+ success = mFmqRequestChannel->read(packet.data() + 1, packet.size() - 1);
+
+ if (!success) {
+ return std::make_optional<std::vector<FmqRequestDatum>>();
+ }
+
+ return packet;
+}
+
+// ResultChannelSender methods
+
+std::unique_ptr<ResultChannelSender> ResultChannelSender::create(
+ const FmqResultDescriptor& resultChannel) {
+ std::unique_ptr<FmqResultChannel> fmqResultChannel =
+ std::make_unique<FmqResultChannel>(resultChannel);
+ if (!fmqResultChannel->isValid()) {
+ LOG(ERROR) << "Unable to create RequestChannelSender";
+ return nullptr;
+ }
+ const bool blocking = fmqResultChannel->getEventFlagWord() != nullptr;
+ return std::make_unique<ResultChannelSender>(std::move(fmqResultChannel), blocking);
+}
+
+ResultChannelSender::ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel,
+ bool blocking)
+ : mFmqResultChannel(std::move(fmqResultChannel)), mBlocking(blocking) {}
+
+bool ResultChannelSender::send(ErrorStatus errorStatus,
+ const std::vector<OutputShape>& outputShapes, Timing timing) {
+ const std::vector<FmqResultDatum> serialized = serialize(errorStatus, outputShapes, timing);
+ return sendPacket(serialized);
+}
+
+bool ResultChannelSender::sendPacket(const std::vector<FmqResultDatum>& packet) {
+ if (mBlocking) {
+ return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
+ } else {
+ return mFmqResultChannel->write(packet.data(), packet.size());
+ }
+}
+
+// ExecutionBurstServer methods
+
+sp<ExecutionBurstServer> ExecutionBurstServer::create(
+ const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+ const MQDescriptorSync<FmqResultDatum>& resultChannel,
+ std::shared_ptr<IBurstExecutorWithCache> executorWithCache) {
+ // check inputs
+ if (callback == nullptr || executorWithCache == nullptr) {
+ LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
+ return nullptr;
+ }
+
+ // create FMQ objects
+ std::unique_ptr<RequestChannelReceiver> requestChannelReceiver =
+ RequestChannelReceiver::create(requestChannel);
+ std::unique_ptr<ResultChannelSender> resultChannelSender =
+ ResultChannelSender::create(resultChannel);
+
+ // check FMQ objects
+ if (!requestChannelReceiver || !resultChannelSender) {
+ LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
+ return nullptr;
+ }
+
+ // make and return context
+ return new ExecutionBurstServer(callback, std::move(requestChannelReceiver),
+ std::move(resultChannelSender), std::move(executorWithCache));
+}
+
+sp<ExecutionBurstServer> ExecutionBurstServer::create(
+ const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+ const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
+ // check relevant input
+ if (preparedModel == nullptr) {
+ LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
+ return nullptr;
+ }
+
+ // adapt IPreparedModel to have caching
+ const std::shared_ptr<DefaultBurstExecutorWithCache> preparedModelAdapter =
+ std::make_shared<DefaultBurstExecutorWithCache>(preparedModel);
+
+ // make and return context
+ return ExecutionBurstServer::create(callback, requestChannel, resultChannel,
+ preparedModelAdapter);
+}
+
+ExecutionBurstServer::ExecutionBurstServer(
+ const sp<IBurstCallback>& callback, std::unique_ptr<RequestChannelReceiver> requestChannel,
+ std::unique_ptr<ResultChannelSender> resultChannel,
+ std::shared_ptr<IBurstExecutorWithCache> executorWithCache)
+ : mCallback(callback),
+ mRequestChannelReceiver(std::move(requestChannel)),
+ mResultChannelSender(std::move(resultChannel)),
+ mExecutorWithCache(std::move(executorWithCache)) {
+ // TODO: highly document the threading behavior of this class
+ mWorker = std::thread([this] { task(); });
+}
+
+ExecutionBurstServer::~ExecutionBurstServer() {
+ // set teardown flag
+ mTeardown = true;
+ mRequestChannelReceiver->invalidate();
+
+ // wait for task thread to end
+ mWorker.join();
+}
+
+Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
+ mExecutorWithCache->removeCacheEntry(slot);
+ return Void();
+}
+
+void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots) {
+ const auto slotIsKnown = [this](int32_t slot) {
+ return mExecutorWithCache->isCacheEntryPresent(slot);
+ };
+
+ // find unique unknown slots
+ std::vector<int32_t> unknownSlots = slots;
+ auto unknownSlotsEnd = unknownSlots.end();
+ std::sort(unknownSlots.begin(), unknownSlotsEnd);
+ unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
+ unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
+ unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
+
+ // quick-exit if all slots are known
+ if (unknownSlots.empty()) {
+ return;
+ }
+
+ ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
+ std::vector<hidl_memory> returnedMemories;
+ auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
+ const hidl_vec<hidl_memory>& memories) {
+ errorStatus = status;
+ returnedMemories = memories;
+ };
+
+ const Return<void> ret = mCallback->getMemories(unknownSlots, cb);
+
+ if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
+ returnedMemories.size() != unknownSlots.size()) {
+ LOG(ERROR) << "Error retrieving memories";
+ return;
+ }
+
+ // add memories to unknown slots
+ for (size_t i = 0; i < unknownSlots.size(); ++i) {
+ mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]);
+ }
+}
+
+void ExecutionBurstServer::task() {
+ // loop until the burst object is being destroyed
+ while (!mTeardown) {
+ // receive request
+ auto arguments = mRequestChannelReceiver->getBlocking();
+
+ // if the request packet was not properly received, return a generic
+ // error and skip the execution
+ //
+ // if the burst is being torn down, skip the execution exection so the
+ // "task" function can end
+ if (!arguments) {
+ if (!mTeardown) {
+ mResultChannelSender->send(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
+ }
+ continue;
+ }
+
+ // otherwise begin tracing execution
+ NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
+ "ExecutionBurstServer getting memory, executing, and returning results");
+
+ // unpack the arguments; types are Request, std::vector<int32_t>, and
// MeasureTiming, respectively
- const auto [requestWithoutPools, slotsOfPools, measure] = deserialize(requestData);
+ const auto [requestWithoutPools, slotsOfPools, measure] = std::move(*arguments);
// ensure executor with cache has required memory
std::lock_guard<std::mutex> hold(mMutex);
@@ -509,9 +574,7 @@
mExecutorWithCache->execute(requestWithoutPools, slotsOfPools, measure);
// return result
- const std::vector<FmqResultDatum> result =
- serialize(errorStatus, outputShapes, returnedTiming);
- sendPacket(result);
+ mResultChannelSender->send(errorStatus, outputShapes, returnedTiming);
}
}