BurstMemoryCache serialization cleanup
This CL offers some prerequisite cleanup for the following two
subsequent changes:
(1) Handle case where Burst FMQ packet exceeds FMQ buffer size
(2) Create validation tests for FMQ serialized message format
It does this by making the serialization and deserialization functions
more modular and their usage clearer.
This CL additionally changes the worker from std::future to std::thread.
This is because some implementations of std::future use a
ThreadPool/WorkPool, which can cause unintended deadlocks if not handled
carefully. std::thread does not have those same concerns.
Bug: 129779280
Bug: 129157135
Test: mma
Test: atest NeuralNetworksTest_static
Test: atest VtsHalNeuralnetworksV1_0TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_1TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_2TargetTest (with sample-all)
Change-Id: Ia0842bc62c37c676c4faf83fc177198adac1f832
Merged-In: Ia0842bc62c37c676c4faf83fc177198adac1f832
(cherry picked from commit 1a57745f547572ed4fcfd77570e109b1c2288d5a)
diff --git a/nn/common/ExecutionBurstController.cpp b/nn/common/ExecutionBurstController.cpp
index 286c62a..37969c3 100644
--- a/nn/common/ExecutionBurstController.cpp
+++ b/nn/common/ExecutionBurstController.cpp
@@ -19,19 +19,309 @@
#include "ExecutionBurstController.h"
#include <android-base/logging.h>
+#include <limits>
#include <string>
#include "Tracing.h"
namespace android::nn {
namespace {
+using ::android::hardware::MQDescriptorSync;
using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
-constexpr Timing kInvalidTiming = {UINT64_MAX, UINT64_MAX};
+constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
+ std::numeric_limits<uint64_t>::max()};
} // anonymous namespace
+// serialize a request into a packet
+std::vector<FmqRequestDatum> serialize(const Request& request, MeasureTiming measure,
+ const std::vector<int32_t>& slots) {
+ // count how many elements need to be sent for a request
+ size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
+ for (const auto& input : request.inputs) {
+ count += input.dimensions.size();
+ }
+ for (const auto& output : request.outputs) {
+ count += output.dimensions.size();
+ }
+
+ // create buffer to temporarily store elements
+ std::vector<FmqRequestDatum> data;
+ data.reserve(count);
+
+ // package packetInfo
+ {
+ FmqRequestDatum datum;
+ datum.packetInformation(
+ {/*.packetSize=*/static_cast<uint32_t>(count),
+ /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
+ /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
+ /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
+ data.push_back(datum);
+ }
+
+ // package input data
+ for (const auto& input : request.inputs) {
+ // package operand information
+ FmqRequestDatum datum;
+ datum.inputOperandInformation(
+ {/*.hasNoValue=*/input.hasNoValue,
+ /*.location=*/input.location,
+ /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
+ data.push_back(datum);
+
+ // package operand dimensions
+ for (uint32_t dimension : input.dimensions) {
+ FmqRequestDatum datum;
+ datum.inputOperandDimensionValue(dimension);
+ data.push_back(datum);
+ }
+ }
+
+ // package output data
+ for (const auto& output : request.outputs) {
+ // package operand information
+ FmqRequestDatum datum;
+ datum.outputOperandInformation(
+ {/*.hasNoValue=*/output.hasNoValue,
+ /*.location=*/output.location,
+ /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
+ data.push_back(datum);
+
+ // package operand dimensions
+ for (uint32_t dimension : output.dimensions) {
+ FmqRequestDatum datum;
+ datum.outputOperandDimensionValue(dimension);
+ data.push_back(datum);
+ }
+ }
+
+ // package pool identifier
+ for (int32_t slot : slots) {
+ FmqRequestDatum datum;
+ datum.poolIdentifier(slot);
+ data.push_back(datum);
+ }
+
+ // package measureTiming
+ {
+ FmqRequestDatum datum;
+ datum.measureTiming(measure);
+ data.push_back(datum);
+ }
+
+ // return packet
+ return data;
+}
+
+// deserialize a packet into the result
+std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> deserialize(
+ const std::vector<FmqResultDatum>& data) {
+ using discriminator = FmqResultDatum::hidl_discriminator;
+
+ std::vector<OutputShape> outputShapes;
+ size_t index = 0;
+
+ // validate packet information
+ if (data[index].getDiscriminator() != discriminator::packetInformation) {
+ LOG(ERROR) << "FMQ Result packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage packet information
+ const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
+ index++;
+ const uint32_t packetSize = packetInfo.packetSize;
+ const ErrorStatus errorStatus = packetInfo.errorStatus;
+ const uint32_t numberOfOperands = packetInfo.numberOfOperands;
+
+ // unpackage operands
+ for (size_t operand = 0; operand < numberOfOperands; ++operand) {
+ // validate operand information
+ if (data[index].getDiscriminator() != discriminator::operandInformation) {
+ LOG(ERROR) << "FMQ Result packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage operand information
+ const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
+ index++;
+ const bool isSufficient = operandInfo.isSufficient;
+ const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
+
+ // unpackage operand dimensions
+ std::vector<uint32_t> dimensions;
+ dimensions.reserve(numberOfDimensions);
+ for (size_t i = 0; i < numberOfDimensions; ++i) {
+ // validate dimension
+ if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
+ LOG(ERROR) << "FMQ Result packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage dimension
+ const uint32_t dimension = data[index].operandDimensionValue();
+ index++;
+
+ // store result
+ dimensions.push_back(dimension);
+ }
+
+ // store result
+ outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
+ }
+
+ // validate execution timing
+ if (data[index].getDiscriminator() != discriminator::executionTiming) {
+ LOG(ERROR) << "FMQ Result packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage execution timing
+ const Timing timing = data[index].executionTiming();
+ index++;
+
+ // validate packet information
+ if (index != packetSize) {
+ LOG(ERROR) << "FMQ Result packet ill-formed";
+ return std::nullopt;
+ }
+
+ // return result
+ return std::make_tuple(errorStatus, std::move(outputShapes), timing);
+}
+
+std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*>
+ResultChannelReceiver::create(size_t channelLength, bool blocking) {
+ std::unique_ptr<FmqResultChannel> fmqResultChannel =
+ std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/blocking);
+ if (!fmqResultChannel->isValid()) {
+ LOG(ERROR) << "Unable to create ResultChannelReceiver";
+ return {nullptr, nullptr};
+ }
+ const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc();
+ return std::make_pair(
+ std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), blocking),
+ descriptor);
+}
+
+ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
+ bool blocking)
+ : mFmqResultChannel(std::move(fmqResultChannel)), mBlocking(blocking) {}
+
+std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>>
+ResultChannelReceiver::getBlocking() {
+ const auto packet = getPacketBlocking();
+ if (!packet) {
+ return std::nullopt;
+ }
+
+ return deserialize(*packet);
+}
+
+void ResultChannelReceiver::invalidate() {
+ mTeardown = true;
+
+ // force unblock
+ // ExecutionBurstServer is by default waiting on a request packet. If the
+ // client process destroys its burst object, the server will still be
+ // waiting on the futex (assuming mBlocking is true). This force unblock
+ // wakes up any thread waiting on the futex.
+ if (mBlocking) {
+ // TODO: look for a different/better way to signal/notify the futex to
+ // wake up any thread waiting on it
+ FmqResultDatum datum;
+ datum.packetInformation({/*.packetSize=*/0, /*.errorStatus=*/ErrorStatus::GENERAL_FAILURE,
+ /*.numberOfOperands=*/0});
+ mFmqResultChannel->writeBlocking(&datum, 1);
+ }
+}
+
+std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
+ using discriminator = FmqResultDatum::hidl_discriminator;
+
+ // wait for result packet and read first element of result packet
+ // TODO: have a more elegant way to wait for data, and read it all at once.
+ // For example, EventFlag can be used to directly wait on the futex, and all
+ // the data can be read at once with a non-blocking call to
+ // MessageQueue::read. For further optimization, MessageQueue::beginRead and
+ // MessageQueue::commitRead can be used to avoid an extra copy of the
+ // metadata.
+ FmqResultDatum datum;
+ bool success = true;
+ if (mBlocking) {
+ success = mFmqResultChannel->readBlocking(&datum, 1);
+ } else {
+ // TODO: better handle the case where the service crashes after
+ // receiving the Request but before returning the result.
+ while (!mFmqResultChannel->read(&datum, 1)) {
+ }
+ }
+
+ // validate packet information
+ if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
+ LOG(ERROR) << "FMQ Result packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpack packet information
+ const auto& packetInfo = datum.packetInformation();
+ const size_t count = packetInfo.packetSize;
+
+ // retrieve remaining elements
+ // NOTE: all of the data is already available at this point, so there's no
+ // need to do a blocking wait to wait for more data. This is known because
+ // in FMQ, all writes are published (made available) atomically. Currently,
+ // the producer always publishes the entire packet in one function call, so
+ // if the first element of the packet is available, the remaining elements
+ // are also available.
+ std::vector<FmqResultDatum> packet(count);
+ packet.front() = datum;
+ success = mFmqResultChannel->read(packet.data() + 1, packet.size() - 1);
+
+ if (!success) {
+ return std::nullopt;
+ }
+
+ return packet;
+}
+
+std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*>
+RequestChannelSender::create(size_t channelLength, bool blocking) {
+ std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
+ std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/blocking);
+ if (!fmqRequestChannel->isValid()) {
+ LOG(ERROR) << "Unable to create RequestChannelSender";
+ return {nullptr, nullptr};
+ }
+ const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc();
+ return std::make_pair(
+ std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel), blocking),
+ descriptor);
+}
+
+RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
+ bool blocking)
+ : mFmqRequestChannel(std::move(fmqRequestChannel)), mBlocking(blocking) {}
+
+bool RequestChannelSender::send(const Request& request, MeasureTiming measure,
+ const std::vector<int32_t>& slots) {
+ const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
+ return sendPacket(serialized);
+}
+
+bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
+ // TODO: handle the case where the serialziation exceeds FMQ channel length
+
+ if (mBlocking) {
+ return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
+ } else {
+ return mFmqRequestChannel->write(packet.data(), packet.size());
+ }
+}
+
Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
const hidl_vec<int32_t>& slots, getMemories_cb cb) {
std::lock_guard<std::mutex> guard(mMutex);
@@ -130,35 +420,35 @@
}
// create FMQ objects
- std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow) FmqRequestChannel(
- kExecutionBurstChannelLength, /*confEventFlag=*/blocking)};
- std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow) FmqResultChannel(
- kExecutionBurstChannelLength, /*confEventFlag=*/blocking)};
+ auto [fmqRequestChannel, fmqRequestDescriptor] =
+ RequestChannelSender::create(kExecutionBurstChannelLength, blocking);
+ auto [fmqResultChannel, fmqResultDescriptor] =
+ ResultChannelReceiver::create(kExecutionBurstChannelLength, blocking);
// check FMQ objects
- if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
- !fmqResultChannel->isValid()) {
+ if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestDescriptor || !fmqResultDescriptor) {
LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
return nullptr;
}
- // descriptors
- const FmqRequestDescriptor& fmqRequestDescriptor = *fmqRequestChannel->getDesc();
- const FmqResultDescriptor& fmqResultDescriptor = *fmqResultChannel->getDesc();
-
// configure burst
ErrorStatus errorStatus;
sp<IBurstContext> burstContext;
- Return<void> ret = preparedModel->configureExecutionBurst(
- callback, fmqRequestDescriptor, fmqResultDescriptor,
+ const Return<void> ret = preparedModel->configureExecutionBurst(
+ callback, *fmqRequestDescriptor, *fmqResultDescriptor,
[&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
errorStatus = status;
burstContext = context;
});
// check burst
+ if (!ret.isOk()) {
+ LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
+ << ret.description();
+ return nullptr;
+ }
if (errorStatus != ErrorStatus::NONE) {
- LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with "
+ LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
<< toString(errorStatus);
return nullptr;
}
@@ -168,236 +458,18 @@
}
// make and return controller
- return std::make_unique<ExecutionBurstController>(std::move(fmqRequestChannel),
- std::move(fmqResultChannel), burstContext,
- callback, blocking);
+ return std::make_unique<ExecutionBurstController>(
+ std::move(fmqRequestChannel), std::move(fmqResultChannel), burstContext, callback);
}
ExecutionBurstController::ExecutionBurstController(
- std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
- std::unique_ptr<FmqResultChannel> fmqResultChannel, const sp<IBurstContext>& burstContext,
- const sp<ExecutionBurstCallback>& callback, bool blocking)
- : mFmqRequestChannel(std::move(fmqRequestChannel)),
- mFmqResultChannel(std::move(fmqResultChannel)),
+ std::unique_ptr<RequestChannelSender> requestChannelSender,
+ std::unique_ptr<ResultChannelReceiver> resultChannelReceiver,
+ const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback)
+ : mRequestChannelSender(std::move(requestChannelSender)),
+ mResultChannelReceiver(std::move(resultChannelReceiver)),
mBurstContext(burstContext),
- mMemoryCache(callback),
- mUsesFutex(blocking) {}
-
-bool ExecutionBurstController::sendPacket(const std::vector<FmqRequestDatum>& packet) {
- if (mUsesFutex) {
- return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
- } else {
- return mFmqRequestChannel->write(packet.data(), packet.size());
- }
-}
-
-std::vector<FmqResultDatum> ExecutionBurstController::getPacketBlocking() {
- using discriminator = FmqResultDatum::hidl_discriminator;
-
- // wait for result packet and read first element of result packet
- // TODO: have a more elegant way to wait for data, and read it all at once.
- // For example, EventFlag can be used to directly wait on the futex, and all
- // the data can be read at once with a non-blocking call to
- // MessageQueue::read. For further optimization, MessageQueue::beginRead and
- // MessageQueue::commitRead can be used to avoid an extra copy of the
- // metadata.
- FmqResultDatum datum;
- bool success = true;
- if (mUsesFutex) {
- success = mFmqResultChannel->readBlocking(&datum, 1);
- } else {
- // TODO: better handle the case where the service crashes after
- // receiving the Request but before returning the result.
- while (!mFmqResultChannel->read(&datum, 1)) {
- }
- }
-
- // validate packet information
- if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return {};
- }
-
- // unpack packet information
- const auto& packetInfo = datum.packetInformation();
- const size_t count = packetInfo.packetSize;
-
- // retrieve remaining elements
- // NOTE: all of the data is already available at this point, so there's no
- // need to do a blocking wait to wait for more data. This is known because
- // in FMQ, all writes are published (made available) atomically. Currently,
- // the producer always publishes the entire packet in one function call, so
- // if the first element of the packet is available, the remaining elements
- // are also available.
- std::vector<FmqResultDatum> packet(count);
- packet.front() = datum;
- success = mFmqResultChannel->read(packet.data() + 1, packet.size() - 1);
-
- if (!success) {
- return {};
- }
-
- return packet;
-}
-
-// serialize a request into a packet
-std::vector<FmqRequestDatum> ExecutionBurstController::serialize(
- const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) {
- // count how many elements need to be sent for a request
- size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
- for (const auto& input : request.inputs) {
- count += input.dimensions.size();
- }
- for (const auto& output : request.outputs) {
- count += output.dimensions.size();
- }
-
- // create buffer to temporarily store elements
- std::vector<FmqRequestDatum> data;
- data.reserve(count);
-
- // package packetInfo
- {
- FmqRequestDatum datum;
- datum.packetInformation(
- {/*.packetSize=*/static_cast<uint32_t>(count),
- /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
- /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
- /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
- data.push_back(datum);
- }
-
- // package input data
- for (const auto& input : request.inputs) {
- // package operand information
- FmqRequestDatum datum;
- datum.inputOperandInformation(
- {/*.hasNoValue=*/input.hasNoValue,
- /*.location=*/input.location,
- /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
- data.push_back(datum);
-
- // package operand dimensions
- for (uint32_t dimension : input.dimensions) {
- FmqRequestDatum datum;
- datum.inputOperandDimensionValue(dimension);
- data.push_back(datum);
- }
- }
-
- // package output data
- for (const auto& output : request.outputs) {
- // package operand information
- FmqRequestDatum datum;
- datum.outputOperandInformation(
- {/*.hasNoValue=*/output.hasNoValue,
- /*.location=*/output.location,
- /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
- data.push_back(datum);
-
- // package operand dimensions
- for (uint32_t dimension : output.dimensions) {
- FmqRequestDatum datum;
- datum.outputOperandDimensionValue(dimension);
- data.push_back(datum);
- }
- }
-
- // package pool identifier
- const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
- for (int32_t slot : slots) {
- FmqRequestDatum datum;
- datum.poolIdentifier(slot);
- data.push_back(datum);
- }
-
- // package measureTiming
- {
- FmqRequestDatum datum;
- datum.measureTiming(measure);
- data.push_back(datum);
- }
-
- // return packet
- return data;
-}
-
-// deserialize a packet into the result
-std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstController::deserialize(
- const std::vector<FmqResultDatum>& data) {
- using discriminator = FmqResultDatum::hidl_discriminator;
-
- std::vector<OutputShape> outputShapes;
- size_t index = 0;
-
- // validate packet information
- if (data[index].getDiscriminator() != discriminator::packetInformation) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
- }
-
- // unpackage packet information
- const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
- index++;
- const uint32_t packetSize = packetInfo.packetSize;
- const ErrorStatus errorStatus = packetInfo.errorStatus;
- const uint32_t numberOfOperands = packetInfo.numberOfOperands;
-
- // unpackage operands
- for (size_t operand = 0; operand < numberOfOperands; ++operand) {
- // validate operand information
- if (data[index].getDiscriminator() != discriminator::operandInformation) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
- }
-
- // unpackage operand information
- const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
- index++;
- const bool isSufficient = operandInfo.isSufficient;
- const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
-
- // unpackage operand dimensions
- std::vector<uint32_t> dimensions;
- dimensions.reserve(numberOfDimensions);
- for (size_t i = 0; i < numberOfDimensions; ++i) {
- // validate dimension
- if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
- }
-
- // unpackage dimension
- const uint32_t dimension = data[index].operandDimensionValue();
- index++;
-
- // store result
- dimensions.push_back(dimension);
- }
-
- // store result
- outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
- }
-
- // validate execution timing
- if (data[index].getDiscriminator() != discriminator::executionTiming) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
- }
-
- // unpackage execution timing
- const Timing timing = data[index].executionTiming();
- index++;
-
- // validate packet information
- if (index != packetSize) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
- }
-
- // return result
- return std::make_tuple(errorStatus, std::move(outputShapes), timing);
-}
+ mMemoryCache(callback) {}
std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstController::compute(
const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) {
@@ -405,28 +477,21 @@
std::lock_guard<std::mutex> guard(mMutex);
- // serialize request
- std::vector<FmqRequestDatum> requestData = serialize(request, measure, memoryIds);
-
- // TODO: handle the case where the serialziation exceeds
- // kExecutionBurstChannelLength
-
// send request packet
- bool success = sendPacket(requestData);
+ const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
+ const bool success = mRequestChannelSender->send(request, measure, slots);
if (!success) {
LOG(ERROR) << "Error sending FMQ packet";
- return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
+ return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming};
}
// get result packet
- const std::vector<FmqResultDatum> resultData = getPacketBlocking();
- if (resultData.empty()) {
+ const auto result = mResultChannelReceiver->getBlocking();
+ if (!result) {
LOG(ERROR) << "Error retrieving FMQ packet";
- return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
+ return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming};
}
-
- // deserialize result
- return deserialize(resultData);
+ return *result;
}
void ExecutionBurstController::freeMemory(intptr_t key) {
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp
index 08d0e80..067c2f5 100644
--- a/nn/common/ExecutionBurstServer.cpp
+++ b/nn/common/ExecutionBurstServer.cpp
@@ -19,13 +19,15 @@
#include "ExecutionBurstServer.h"
#include <android-base/logging.h>
-#include <string>
+#include <limits>
#include "Tracing.h"
namespace android::nn {
-
namespace {
+constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
+ std::numeric_limits<uint64_t>::max()};
+
// DefaultBurstExecutorWithCache adapts an IPreparedModel so that it can be
// used as an IBurstExecutorWithCache. Specifically, the cache simply stores the
// hidl_memory object, and the execution forwards calls to the provided
@@ -94,347 +96,9 @@
} // anonymous namespace
-sp<ExecutionBurstServer> ExecutionBurstServer::create(
- const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
- const MQDescriptorSync<FmqResultDatum>& resultChannel,
- std::shared_ptr<IBurstExecutorWithCache> executorWithCache) {
- // check inputs
- if (callback == nullptr || executorWithCache == nullptr) {
- LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
- return nullptr;
- }
-
- // create FMQ objects
- std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
- std::make_unique<FmqRequestChannel>(requestChannel);
- std::unique_ptr<FmqResultChannel> fmqResultChannel =
- std::make_unique<FmqResultChannel>(resultChannel);
-
- // check FMQ objects
- if (!fmqRequestChannel->isValid() || !fmqResultChannel->isValid()) {
- LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
- return nullptr;
- }
-
- // make and return context
- return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
- std::move(fmqResultChannel), std::move(executorWithCache));
-}
-
-sp<ExecutionBurstServer> ExecutionBurstServer::create(
- const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
- const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
- // check relevant input
- if (preparedModel == nullptr) {
- LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
- return nullptr;
- }
-
- // adapt IPreparedModel to have caching
- const std::shared_ptr<DefaultBurstExecutorWithCache> preparedModelAdapter =
- std::make_shared<DefaultBurstExecutorWithCache>(preparedModel);
-
- // make and return context
- return ExecutionBurstServer::create(callback, requestChannel, resultChannel,
- preparedModelAdapter);
-}
-
-ExecutionBurstServer::ExecutionBurstServer(
- const sp<IBurstCallback>& callback, std::unique_ptr<FmqRequestChannel> requestChannel,
- std::unique_ptr<FmqResultChannel> resultChannel,
- std::shared_ptr<IBurstExecutorWithCache> executorWithCache)
- : mCallback(callback),
- mFmqRequestChannel(std::move(requestChannel)),
- mFmqResultChannel(std::move(resultChannel)),
- mExecutorWithCache(std::move(executorWithCache)),
- mBlocking(mFmqRequestChannel->getEventFlagWord() != nullptr) {
- // TODO: highly document the threading behavior of this class
- mWorker = std::async(std::launch::async, [this] { task(); });
-}
-
-ExecutionBurstServer::~ExecutionBurstServer() {
- // set teardown flag
- mTeardown = true;
-
- // force unblock
- // ExecutionBurstServer is by default waiting on a request packet. If the
- // client process destroys its burst object, the server will still be
- // waiting on the futex (assuming mBlocking is true). This force unblock
- // wakes up any thread waiting on the futex.
- if (mBlocking) {
- // TODO: look for a different/better way to signal/notify the futex to
- // wake up any thread waiting on it
- FmqRequestDatum datum;
- datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
- /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
- mFmqRequestChannel->writeBlocking(&datum, 1);
- }
-
- // wait for task thread to end
- mWorker.wait();
-}
-
-Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
- std::lock_guard<std::mutex> hold(mMutex);
- mExecutorWithCache->removeCacheEntry(slot);
- return Void();
-}
-
-void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots) {
- const auto slotIsKnown = [this](int32_t slot) {
- return mExecutorWithCache->isCacheEntryPresent(slot);
- };
-
- // find unique unknown slots
- std::vector<int32_t> unknownSlots = slots;
- auto unknownSlotsEnd = unknownSlots.end();
- std::sort(unknownSlots.begin(), unknownSlotsEnd);
- unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
- unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
- unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
-
- // quick-exit if all slots are known
- if (unknownSlots.empty()) {
- return;
- }
-
- ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
- std::vector<hidl_memory> returnedMemories;
- auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
- const hidl_vec<hidl_memory>& memories) {
- errorStatus = status;
- returnedMemories = memories;
- };
-
- const Return<void> ret = mCallback->getMemories(unknownSlots, cb);
-
- if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
- returnedMemories.size() != unknownSlots.size()) {
- LOG(ERROR) << "Error retrieving memories";
- return;
- }
-
- // add memories to unknown slots
- for (size_t i = 0; i < unknownSlots.size(); ++i) {
- mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]);
- }
-}
-
-bool ExecutionBurstServer::sendPacket(const std::vector<FmqResultDatum>& packet) {
- if (mTeardown) {
- return false;
- }
-
- if (mBlocking) {
- return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
- } else {
- return mFmqResultChannel->write(packet.data(), packet.size());
- }
-}
-
-std::vector<FmqRequestDatum> ExecutionBurstServer::getPacketBlocking() {
- using discriminator = FmqRequestDatum::hidl_discriminator;
-
- if (mTeardown) {
- return {};
- }
-
- // wait for request packet and read first element of request packet
- // TODO: have a more elegant way to wait for data, and read it all at once.
- // For example, EventFlag can be used to directly wait on the futex, and all
- // the data can be read at once with a non-blocking call to
- // MessageQueue::read. For further optimization, MessageQueue::beginRead and
- // MessageQueue::commitRead can be used to avoid an extra copy of the
- // metadata.
- FmqRequestDatum datum;
- bool success = false;
- if (mBlocking) {
- success = mFmqRequestChannel->readBlocking(&datum, 1);
- } else {
- while ((success = !mTeardown.load(std::memory_order_relaxed)) &&
- !mFmqRequestChannel->read(&datum, 1)) {
- }
- }
-
- // terminate loop
- if (mTeardown) {
- return {};
- }
-
- // validate packet information
- if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {};
- }
-
- NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
-
- // unpack packet information
- const auto& packetInfo = datum.packetInformation();
- const size_t count = packetInfo.packetSize;
-
- // retrieve remaining elements
- // NOTE: all of the data is already available at this point, so there's no
- // need to do a blocking wait to wait for more data. This is known because
- // in FMQ, all writes are published (made available) atomically. Currently,
- // the producer always publishes the entire packet in one function call, so
- // if the first element of the packet is available, the remaining elements
- // are also available.
- std::vector<FmqRequestDatum> packet(count);
- packet.front() = datum;
- success = mFmqRequestChannel->read(packet.data() + 1, packet.size() - 1);
-
- if (!success) {
- return {};
- }
-
- return packet;
-}
-
-// deserialize request
-std::tuple<Request, std::vector<int32_t>, MeasureTiming> ExecutionBurstServer::deserialize(
- const std::vector<FmqRequestDatum>& data) {
- using discriminator = FmqRequestDatum::hidl_discriminator;
-
- size_t index = 0;
-
- // validate packet information
- if (data[index].getDiscriminator() != discriminator::packetInformation) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage packet information
- const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
- index++;
- const uint32_t packetSize = packetInfo.packetSize;
- const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
- const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
- const uint32_t numberOfPools = packetInfo.numberOfPools;
-
- // unpackage input operands
- std::vector<RequestArgument> inputs;
- inputs.reserve(numberOfInputOperands);
- for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
- // validate input operand information
- if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage operand information
- const FmqRequestDatum::OperandInformation& operandInfo =
- data[index].inputOperandInformation();
- index++;
- const bool hasNoValue = operandInfo.hasNoValue;
- const DataLocation location = operandInfo.location;
- const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
-
- // unpackage operand dimensions
- std::vector<uint32_t> dimensions;
- dimensions.reserve(numberOfDimensions);
- for (size_t i = 0; i < numberOfDimensions; ++i) {
- // validate dimension
- if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage dimension
- const uint32_t dimension = data[index].inputOperandDimensionValue();
- index++;
-
- // store result
- dimensions.push_back(dimension);
- }
-
- // store result
- inputs.push_back(
- {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
- }
-
- // unpackage output operands
- std::vector<RequestArgument> outputs;
- outputs.reserve(numberOfOutputOperands);
- for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
- // validate output operand information
- if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage operand information
- const FmqRequestDatum::OperandInformation& operandInfo =
- data[index].outputOperandInformation();
- index++;
- const bool hasNoValue = operandInfo.hasNoValue;
- const DataLocation location = operandInfo.location;
- const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
-
- // unpackage operand dimensions
- std::vector<uint32_t> dimensions;
- dimensions.reserve(numberOfDimensions);
- for (size_t i = 0; i < numberOfDimensions; ++i) {
- // validate dimension
- if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage dimension
- const uint32_t dimension = data[index].outputOperandDimensionValue();
- index++;
-
- // store result
- dimensions.push_back(dimension);
- }
-
- // store result
- outputs.push_back(
- {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
- }
-
- // unpackage pools
- std::vector<int32_t> slots;
- slots.reserve(numberOfPools);
- for (size_t pool = 0; pool < numberOfPools; ++pool) {
- // validate input operand information
- if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage operand information
- const int32_t poolId = data[index].poolIdentifier();
- index++;
-
- // store result
- slots.push_back(poolId);
- }
-
- // validate measureTiming
- if (data[index].getDiscriminator() != discriminator::measureTiming) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // unpackage measureTiming
- const MeasureTiming measure = data[index].measureTiming();
- index++;
-
- // validate packet information
- if (index != packetSize) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return {{}, {}, MeasureTiming::NO};
- }
-
- // return request
- return {{/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}}, std::move(slots), measure};
-}
-
// serialize result
-std::vector<FmqResultDatum> ExecutionBurstServer::serialize(
- ErrorStatus errorStatus, const std::vector<OutputShape>& outputShapes, Timing timing) {
+std::vector<FmqResultDatum> serialize(ErrorStatus errorStatus,
+ const std::vector<OutputShape>& outputShapes, Timing timing) {
// count how many elements need to be sent for a request
size_t count = 2 + outputShapes.size();
for (const auto& outputShape : outputShapes) {
@@ -482,22 +146,423 @@
return data;
}
-void ExecutionBurstServer::task() {
- while (!mTeardown) {
- // receive request
- const std::vector<FmqRequestDatum> requestData = getPacketBlocking();
+// deserialize request
+std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> deserialize(
+ const std::vector<FmqRequestDatum>& data) {
+ using discriminator = FmqRequestDatum::hidl_discriminator;
- // terminate loop
- if (mTeardown) {
- return;
+ size_t index = 0;
+
+ // validate packet information
+ if (data[index].getDiscriminator() != discriminator::packetInformation) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage packet information
+ const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
+ index++;
+ const uint32_t packetSize = packetInfo.packetSize;
+ const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
+ const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
+ const uint32_t numberOfPools = packetInfo.numberOfPools;
+
+ // unpackage input operands
+ std::vector<RequestArgument> inputs;
+ inputs.reserve(numberOfInputOperands);
+ for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
+ // validate input operand information
+ if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
}
- NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
- "ExecutionBurstServer processing packet and returning results");
+ // unpackage operand information
+ const FmqRequestDatum::OperandInformation& operandInfo =
+ data[index].inputOperandInformation();
+ index++;
+ const bool hasNoValue = operandInfo.hasNoValue;
+ const DataLocation location = operandInfo.location;
+ const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
- // continue processing; types are Request, std::vector<int32_t>, and
+ // unpackage operand dimensions
+ std::vector<uint32_t> dimensions;
+ dimensions.reserve(numberOfDimensions);
+ for (size_t i = 0; i < numberOfDimensions; ++i) {
+ // validate dimension
+ if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage dimension
+ const uint32_t dimension = data[index].inputOperandDimensionValue();
+ index++;
+
+ // store result
+ dimensions.push_back(dimension);
+ }
+
+ // store result
+ inputs.push_back(
+ {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
+ }
+
+ // unpackage output operands
+ std::vector<RequestArgument> outputs;
+ outputs.reserve(numberOfOutputOperands);
+ for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
+ // validate output operand information
+ if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage operand information
+ const FmqRequestDatum::OperandInformation& operandInfo =
+ data[index].outputOperandInformation();
+ index++;
+ const bool hasNoValue = operandInfo.hasNoValue;
+ const DataLocation location = operandInfo.location;
+ const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
+
+ // unpackage operand dimensions
+ std::vector<uint32_t> dimensions;
+ dimensions.reserve(numberOfDimensions);
+ for (size_t i = 0; i < numberOfDimensions; ++i) {
+ // validate dimension
+ if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage dimension
+ const uint32_t dimension = data[index].outputOperandDimensionValue();
+ index++;
+
+ // store result
+ dimensions.push_back(dimension);
+ }
+
+ // store result
+ outputs.push_back(
+ {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
+ }
+
+ // unpackage pools
+ std::vector<int32_t> slots;
+ slots.reserve(numberOfPools);
+ for (size_t pool = 0; pool < numberOfPools; ++pool) {
+ // validate input operand information
+ if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage operand information
+ const int32_t poolId = data[index].poolIdentifier();
+ index++;
+
+ // store result
+ slots.push_back(poolId);
+ }
+
+ // validate measureTiming
+ if (data[index].getDiscriminator() != discriminator::measureTiming) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
+ // unpackage measureTiming
+ const MeasureTiming measure = data[index].measureTiming();
+ index++;
+
+ // validate packet information
+ if (index != packetSize) {
+ LOG(ERROR) << "FMQ Result packet ill-formed";
+ return std::nullopt;
+ }
+
+ // return request
+ Request request = {/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}};
+ return std::make_tuple(std::move(request), std::move(slots), measure);
+}
+
+// RequestChannelReceiver methods
+
+std::unique_ptr<RequestChannelReceiver> RequestChannelReceiver::create(
+ const FmqRequestDescriptor& requestChannel) {
+ std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
+ std::make_unique<FmqRequestChannel>(requestChannel);
+ if (!fmqRequestChannel->isValid()) {
+ LOG(ERROR) << "Unable to create RequestChannelReceiver";
+ return nullptr;
+ }
+ const bool blocking = fmqRequestChannel->getEventFlagWord() != nullptr;
+ return std::make_unique<RequestChannelReceiver>(std::move(fmqRequestChannel), blocking);
+}
+
+RequestChannelReceiver::RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
+ bool blocking)
+ : mFmqRequestChannel(std::move(fmqRequestChannel)), mBlocking(blocking) {}
+
+std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>>
+RequestChannelReceiver::getBlocking() {
+ const auto packet = getPacketBlocking();
+ if (!packet) {
+ return std::nullopt;
+ }
+
+ return deserialize(*packet);
+}
+
+void RequestChannelReceiver::invalidate() {
+ mTeardown = true;
+
+ // force unblock
+ // ExecutionBurstServer is by default waiting on a request packet. If the
+ // client process destroys its burst object, the server will still be
+ // waiting on the futex (assuming mBlocking is true). This force unblock
+ // wakes up any thread waiting on the futex.
+ if (mBlocking) {
+ // TODO: look for a different/better way to signal/notify the futex to
+ // wake up any thread waiting on it
+ FmqRequestDatum datum;
+ datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
+ /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
+ mFmqRequestChannel->writeBlocking(&datum, 1);
+ }
+}
+
+std::optional<std::vector<FmqRequestDatum>> RequestChannelReceiver::getPacketBlocking() {
+ using discriminator = FmqRequestDatum::hidl_discriminator;
+
+ if (mTeardown) {
+ return std::nullopt;
+ }
+
+ // wait for request packet and read first element of request packet
+ // TODO: have a more elegant way to wait for data, and read it all at once.
+ // For example, EventFlag can be used to directly wait on the futex, and all
+ // the data can be read at once with a non-blocking call to
+ // MessageQueue::read. For further optimization, MessageQueue::beginRead and
+ // MessageQueue::commitRead can be used to avoid an extra copy of the
+ // metadata.
+ FmqRequestDatum datum;
+ bool success = false;
+ if (mBlocking) {
+ success = mFmqRequestChannel->readBlocking(&datum, 1);
+ } else {
+ while ((success = !mTeardown.load(std::memory_order_relaxed)) &&
+ !mFmqRequestChannel->read(&datum, 1)) {
+ }
+ }
+
+ // terminate loop
+ if (mTeardown) {
+ return std::nullopt;
+ }
+
+ // validate packet information
+ if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::make_optional<std::vector<FmqRequestDatum>>();
+ }
+
+ NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
+
+ // unpack packet information
+ const auto& packetInfo = datum.packetInformation();
+ const size_t count = packetInfo.packetSize;
+
+ // retrieve remaining elements
+ // NOTE: all of the data is already available at this point, so there's no
+ // need to do a blocking wait to wait for more data. This is known because
+ // in FMQ, all writes are published (made available) atomically. Currently,
+ // the producer always publishes the entire packet in one function call, so
+ // if the first element of the packet is available, the remaining elements
+ // are also available.
+ std::vector<FmqRequestDatum> packet(count);
+ packet.front() = datum;
+ success = mFmqRequestChannel->read(packet.data() + 1, packet.size() - 1);
+
+ if (!success) {
+ return std::make_optional<std::vector<FmqRequestDatum>>();
+ }
+
+ return packet;
+}
+
+// ResultChannelSender methods
+
+std::unique_ptr<ResultChannelSender> ResultChannelSender::create(
+ const FmqResultDescriptor& resultChannel) {
+ std::unique_ptr<FmqResultChannel> fmqResultChannel =
+ std::make_unique<FmqResultChannel>(resultChannel);
+ if (!fmqResultChannel->isValid()) {
+ LOG(ERROR) << "Unable to create RequestChannelSender";
+ return nullptr;
+ }
+ const bool blocking = fmqResultChannel->getEventFlagWord() != nullptr;
+ return std::make_unique<ResultChannelSender>(std::move(fmqResultChannel), blocking);
+}
+
+ResultChannelSender::ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel,
+ bool blocking)
+ : mFmqResultChannel(std::move(fmqResultChannel)), mBlocking(blocking) {}
+
+bool ResultChannelSender::send(ErrorStatus errorStatus,
+ const std::vector<OutputShape>& outputShapes, Timing timing) {
+ const std::vector<FmqResultDatum> serialized = serialize(errorStatus, outputShapes, timing);
+ return sendPacket(serialized);
+}
+
+bool ResultChannelSender::sendPacket(const std::vector<FmqResultDatum>& packet) {
+ if (mBlocking) {
+ return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
+ } else {
+ return mFmqResultChannel->write(packet.data(), packet.size());
+ }
+}
+
+// ExecutionBurstServer methods
+
+sp<ExecutionBurstServer> ExecutionBurstServer::create(
+ const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+ const MQDescriptorSync<FmqResultDatum>& resultChannel,
+ std::shared_ptr<IBurstExecutorWithCache> executorWithCache) {
+ // check inputs
+ if (callback == nullptr || executorWithCache == nullptr) {
+ LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
+ return nullptr;
+ }
+
+ // create FMQ objects
+ std::unique_ptr<RequestChannelReceiver> requestChannelReceiver =
+ RequestChannelReceiver::create(requestChannel);
+ std::unique_ptr<ResultChannelSender> resultChannelSender =
+ ResultChannelSender::create(resultChannel);
+
+ // check FMQ objects
+ if (!requestChannelReceiver || !resultChannelSender) {
+ LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
+ return nullptr;
+ }
+
+ // make and return context
+ return new ExecutionBurstServer(callback, std::move(requestChannelReceiver),
+ std::move(resultChannelSender), std::move(executorWithCache));
+}
+
+sp<ExecutionBurstServer> ExecutionBurstServer::create(
+ const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+ const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
+ // check relevant input
+ if (preparedModel == nullptr) {
+ LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
+ return nullptr;
+ }
+
+ // adapt IPreparedModel to have caching
+ const std::shared_ptr<DefaultBurstExecutorWithCache> preparedModelAdapter =
+ std::make_shared<DefaultBurstExecutorWithCache>(preparedModel);
+
+ // make and return context
+ return ExecutionBurstServer::create(callback, requestChannel, resultChannel,
+ preparedModelAdapter);
+}
+
+ExecutionBurstServer::ExecutionBurstServer(
+ const sp<IBurstCallback>& callback, std::unique_ptr<RequestChannelReceiver> requestChannel,
+ std::unique_ptr<ResultChannelSender> resultChannel,
+ std::shared_ptr<IBurstExecutorWithCache> executorWithCache)
+ : mCallback(callback),
+ mRequestChannelReceiver(std::move(requestChannel)),
+ mResultChannelSender(std::move(resultChannel)),
+ mExecutorWithCache(std::move(executorWithCache)) {
+ // TODO: highly document the threading behavior of this class
+ mWorker = std::thread([this] { task(); });
+}
+
+ExecutionBurstServer::~ExecutionBurstServer() {
+ // set teardown flag
+ mTeardown = true;
+ mRequestChannelReceiver->invalidate();
+
+ // wait for task thread to end
+ mWorker.join();
+}
+
+Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
+ mExecutorWithCache->removeCacheEntry(slot);
+ return Void();
+}
+
+void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots) {
+ const auto slotIsKnown = [this](int32_t slot) {
+ return mExecutorWithCache->isCacheEntryPresent(slot);
+ };
+
+ // find unique unknown slots
+ std::vector<int32_t> unknownSlots = slots;
+ auto unknownSlotsEnd = unknownSlots.end();
+ std::sort(unknownSlots.begin(), unknownSlotsEnd);
+ unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
+ unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
+ unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
+
+ // quick-exit if all slots are known
+ if (unknownSlots.empty()) {
+ return;
+ }
+
+ ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
+ std::vector<hidl_memory> returnedMemories;
+ auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
+ const hidl_vec<hidl_memory>& memories) {
+ errorStatus = status;
+ returnedMemories = memories;
+ };
+
+ const Return<void> ret = mCallback->getMemories(unknownSlots, cb);
+
+ if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
+ returnedMemories.size() != unknownSlots.size()) {
+ LOG(ERROR) << "Error retrieving memories";
+ return;
+ }
+
+ // add memories to unknown slots
+ for (size_t i = 0; i < unknownSlots.size(); ++i) {
+ mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]);
+ }
+}
+
+void ExecutionBurstServer::task() {
+ // loop until the burst object is being destroyed
+ while (!mTeardown) {
+ // receive request
+ auto arguments = mRequestChannelReceiver->getBlocking();
+
+ // if the request packet was not properly received, return a generic
+ // error and skip the execution
+ //
+ // if the burst is being torn down, skip the execution exection so the
+ // "task" function can end
+ if (!arguments) {
+ if (!mTeardown) {
+ mResultChannelSender->send(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
+ }
+ continue;
+ }
+
+ // otherwise begin tracing execution
+ NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
+ "ExecutionBurstServer getting memory, executing, and returning results");
+
+ // unpack the arguments; types are Request, std::vector<int32_t>, and
// MeasureTiming, respectively
- const auto [requestWithoutPools, slotsOfPools, measure] = deserialize(requestData);
+ const auto [requestWithoutPools, slotsOfPools, measure] = std::move(*arguments);
// ensure executor with cache has required memory
std::lock_guard<std::mutex> hold(mMutex);
@@ -509,9 +574,7 @@
mExecutorWithCache->execute(requestWithoutPools, slotsOfPools, measure);
// return result
- const std::vector<FmqResultDatum> result =
- serialize(errorStatus, outputShapes, returnedTiming);
- sendPacket(result);
+ mResultChannelSender->send(errorStatus, outputShapes, returnedTiming);
}
}
diff --git a/nn/common/include/ExecutionBurstController.h b/nn/common/include/ExecutionBurstController.h
index 33564f4..6a72720 100644
--- a/nn/common/include/ExecutionBurstController.h
+++ b/nn/common/include/ExecutionBurstController.h
@@ -17,31 +17,156 @@
#ifndef ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H
#define ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H
+#include "HalInterfaces.h"
+
#include <android-base/macros.h>
#include <fmq/MessageQueue.h>
#include <hidl/MQDescriptor.h>
+
#include <atomic>
#include <map>
#include <memory>
#include <mutex>
#include <stack>
#include <tuple>
-#include "HalInterfaces.h"
namespace android::nn {
-using ::android::hardware::kSynchronizedReadWrite;
-using ::android::hardware::MessageQueue;
-using ::android::hardware::MQDescriptorSync;
-using FmqRequestChannel = MessageQueue<FmqRequestDatum, kSynchronizedReadWrite>;
-using FmqResultChannel = MessageQueue<FmqResultDatum, kSynchronizedReadWrite>;
-
/**
* Number of elements in the FMQ.
*/
constexpr const size_t kExecutionBurstChannelLength = 1024;
/**
+ * Function to serialize a request.
+ *
+ * Prefer calling RequestChannelSender::send.
+ *
+ * @param request Request object without the pool information.
+ * @param measure Whether to collect timing information for the execution.
+ * @param memoryIds Slot identifiers corresponding to memory resources for the
+ * request.
+ * @return Serialized FMQ request data.
+ */
+std::vector<FmqRequestDatum> serialize(const Request& request, MeasureTiming measure,
+ const std::vector<int32_t>& slots);
+
+/**
+ * Deserialize the FMQ result data.
+ *
+ * The three resulting fields are the status of the execution, the dynamic
+ * shapes of the output tensors, and the timing information of the execution.
+ *
+ * @param data Serialized FMQ result data.
+ * @return Result object if successfully deserialized, std::nullopt otherwise.
+ */
+std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> deserialize(
+ const std::vector<FmqResultDatum>& data);
+
+/**
+ * ResultChannelReceiver is responsible for waiting on the channel until the
+ * packet is available, extracting the packet from the channel, and
+ * deserializing the packet.
+ *
+ * Because the receiver can wait on a packet that may never come (e.g., because
+ * the sending side of the packet has been closed), this object can be
+ * invalidating, unblocking the receiver.
+ */
+class ResultChannelReceiver {
+ using FmqResultDescriptor = ::android::hardware::MQDescriptorSync<FmqResultDatum>;
+ using FmqResultChannel =
+ hardware::MessageQueue<FmqResultDatum, hardware::kSynchronizedReadWrite>;
+
+ public:
+ /**
+ * Create the receiving end of a result channel.
+ *
+ * Prefer this call over the constructor.
+ *
+ * @param channelLength Number of elements in the FMQ.
+ * @param blocking 'true' if FMQ should use futex, 'false' if it should
+ * spin-wait.
+ * @return A pair of ResultChannelReceiver and the FMQ descriptor on
+ * successful creation, both nullptr otherwise.
+ */
+ static std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*> create(
+ size_t channelLength, bool blocking);
+
+ /**
+ * Get the result from the channel.
+ *
+ * This method will block until either:
+ * 1) The packet has been retrieved, or
+ * 2) The receiver has been invalidated
+ *
+ * @return Result object if successfully received, std::nullopt if error or
+ * if the receiver object was invalidated.
+ */
+ std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> getBlocking();
+
+ /**
+ * Method to mark the channel as invalid, unblocking any current or future
+ * calls to ResultChannelReceiver::getBlocking.
+ */
+ void invalidate();
+
+ ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel, bool blocking);
+
+ private:
+ std::optional<std::vector<FmqResultDatum>> getPacketBlocking();
+
+ const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
+ std::atomic<bool> mTeardown{false};
+ const bool mBlocking;
+};
+
+/**
+ * RequestChannelSender is responsible for serializing the result packet of
+ * information, sending it on the result channel, and signaling that the data is
+ * available.
+ */
+class RequestChannelSender {
+ using FmqRequestDescriptor = ::android::hardware::MQDescriptorSync<FmqRequestDatum>;
+ using FmqRequestChannel =
+ hardware::MessageQueue<FmqRequestDatum, hardware::kSynchronizedReadWrite>;
+
+ public:
+ /**
+ * Create the sending end of a request channel.
+ *
+ * Prefer this call over the constructor.
+ *
+ * @param channelLength Number of elements in the FMQ.
+ * @param blocking 'true' if FMQ should use futex, 'false' if it should
+ * spin-wait.
+ * @return A pair of ResultChannelReceiver and the FMQ descriptor on
+ * successful creation, both nullptr otherwise.
+ */
+ static std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*> create(
+ size_t channelLength, bool blocking);
+
+ /**
+ * Send the request to the channel.
+ *
+ * @param request Request object without the pool information.
+ * @param measure Whether to collect timing information for the execution.
+ * @param memoryIds Slot identifiers corresponding to memory resources for
+ * the request.
+ * @return 'true' on successful send, 'false' otherwise.
+ */
+ bool send(const Request& request, MeasureTiming measure, const std::vector<int32_t>& slots);
+
+ // prefer calling RequestChannelSender::send
+ bool sendPacket(const std::vector<FmqRequestDatum>& packet);
+
+ RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel, bool blocking);
+
+ private:
+ const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
+ const bool mBlocking;
+};
+
+/**
* The ExecutionBurstController class manages both the serialization and
* deserialization of data across FMQ, making it appear to the runtime as a
* regular synchronous inference. Additionally, this class manages the burst's
@@ -119,10 +244,10 @@
static std::unique_ptr<ExecutionBurstController> create(const sp<IPreparedModel>& preparedModel,
bool blocking);
- ExecutionBurstController(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
- std::unique_ptr<FmqResultChannel> fmqResultChannel,
+ ExecutionBurstController(std::unique_ptr<RequestChannelSender> requestChannelSender,
+ std::unique_ptr<ResultChannelReceiver> resultChannelReceiver,
const sp<IBurstContext>& burstContext,
- const sp<ExecutionBurstCallback>& callback, bool blocking);
+ const sp<ExecutionBurstCallback>& callback);
/**
* Execute a request on a model.
@@ -145,19 +270,11 @@
void freeMemory(intptr_t key);
private:
- std::vector<FmqResultDatum> getPacketBlocking();
- bool sendPacket(const std::vector<FmqRequestDatum>& packet);
- std::vector<FmqRequestDatum> serialize(const Request& request, MeasureTiming measure,
- const std::vector<intptr_t>& memoryIds);
- std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> deserialize(
- const std::vector<FmqResultDatum>& data);
-
std::mutex mMutex;
- const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
- const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
+ const std::unique_ptr<RequestChannelSender> mRequestChannelSender;
+ const std::unique_ptr<ResultChannelReceiver> mResultChannelReceiver;
const sp<IBurstContext> mBurstContext;
const sp<ExecutionBurstCallback> mMemoryCache;
- const bool mUsesFutex;
};
} // namespace android::nn
diff --git a/nn/common/include/ExecutionBurstServer.h b/nn/common/include/ExecutionBurstServer.h
index 4338f3e..f2a9843 100644
--- a/nn/common/include/ExecutionBurstServer.h
+++ b/nn/common/include/ExecutionBurstServer.h
@@ -17,26 +17,144 @@
#ifndef ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
#define ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
+#include "HalInterfaces.h"
+
#include <android-base/macros.h>
#include <fmq/MessageQueue.h>
#include <hidl/MQDescriptor.h>
+
#include <atomic>
-#include <future>
#include <memory>
+#include <optional>
+#include <thread>
#include <vector>
-#include "HalInterfaces.h"
namespace android::nn {
-using ::android::hardware::kSynchronizedReadWrite;
-using ::android::hardware::MessageQueue;
using ::android::hardware::MQDescriptorSync;
-using FmqRequestChannel = MessageQueue<FmqRequestDatum, kSynchronizedReadWrite>;
-using FmqResultChannel = MessageQueue<FmqResultDatum, kSynchronizedReadWrite>;
using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
/**
+ * Function to serialize results.
+ *
+ * Prefer calling ResultChannelSender::send.
+ *
+ * @param errorStatus Status of the execution.
+ * @param outputShapes Dynamic shapes of the output tensors.
+ * @param timing Timing information of the execution.
+ * @return Serialized FMQ result data.
+ */
+std::vector<FmqResultDatum> serialize(ErrorStatus errorStatus,
+ const std::vector<OutputShape>& outputShapes, Timing timing);
+
+/**
+ * Deserialize the FMQ request data.
+ *
+ * The three resulting fields are the Request object (where Request::pools is
+ * empty), slot identifiers (which are stand-ins for Request::pools), and
+ * whether timing information must be collected for the run.
+ *
+ * @param data Serialized FMQ request data.
+ * @return Request object if successfully deserialized, std::nullopt otherwise.
+ */
+std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> deserialize(
+ const std::vector<FmqRequestDatum>& data);
+
+/**
+ * RequestChannelReceiver is responsible for waiting on the channel until the
+ * packet is available, extracting the packet from the channel, and
+ * deserializing the packet.
+ *
+ * Because the receiver can wait on a packet that may never come (e.g., because
+ * the sending side of the packet has been closed), this object can be
+ * invalidating, unblocking the receiver.
+ */
+class RequestChannelReceiver {
+ using FmqRequestChannel =
+ hardware::MessageQueue<FmqRequestDatum, hardware::kSynchronizedReadWrite>;
+
+ public:
+ /**
+ * Create the receiving end of a request channel.
+ *
+ * Prefer this call over the constructor.
+ *
+ * @param requestChannel Descriptor for the request channel.
+ * @return RequestChannelReceiver on successful creation, nullptr otherwise.
+ */
+ static std::unique_ptr<RequestChannelReceiver> create(
+ const FmqRequestDescriptor& requestChannel);
+
+ /**
+ * Get the request from the channel.
+ *
+ * This method will block until either:
+ * 1) The packet has been retrieved, or
+ * 2) The receiver has been invalidated
+ *
+ * @return Request object if successfully received, std::nullopt if error or
+ * if the receiver object was invalidated.
+ */
+ std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> getBlocking();
+
+ /**
+ * Method to mark the channel as invalid, unblocking any current or future
+ * calls to RequestChannelReceiver::getBlocking.
+ */
+ void invalidate();
+
+ RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel, bool blocking);
+
+ private:
+ std::optional<std::vector<FmqRequestDatum>> getPacketBlocking();
+
+ const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
+ std::atomic<bool> mTeardown{false};
+ const bool mBlocking;
+};
+
+/**
+ * ResultChannelSender is responsible for serializing the result packet of
+ * information, sending it on the result channel, and signaling that the data is
+ * available.
+ */
+class ResultChannelSender {
+ using FmqResultChannel =
+ hardware::MessageQueue<FmqResultDatum, hardware::kSynchronizedReadWrite>;
+
+ public:
+ /**
+ * Create the sending end of a result channel.
+ *
+ * Prefer this call over the constructor.
+ *
+ * @param resultChannel Descriptor for the result channel.
+ * @return ResultChannelSender on successful creation, nullptr otherwise.
+ */
+ static std::unique_ptr<ResultChannelSender> create(const FmqResultDescriptor& resultChannel);
+
+ /**
+ * Send the result to the channel.
+ *
+ * @param errorStatus Status of the execution.
+ * @param outputShapes Dynamic shapes of the output tensors.
+ * @param timing Timing information of the execution.
+ * @return 'true' on successful send, 'false' otherwise.
+ */
+ bool send(ErrorStatus errorStatus, const std::vector<OutputShape>& outputShapes, Timing timing);
+
+ // prefer calling ResultChannelSender::send
+ bool sendPacket(const std::vector<FmqResultDatum>& packet);
+
+ ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel, bool blocking);
+
+ private:
+ const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
+ const bool mBlocking;
+};
+
+/**
* The ExecutionBurstServer class is responsible for waiting for and
* deserializing a request object from a FMQ, performing the inference, and
* serializing the result back across another FMQ.
@@ -159,27 +277,15 @@
IPreparedModel* preparedModel);
ExecutionBurstServer(const sp<IBurstCallback>& callback,
- std::unique_ptr<FmqRequestChannel> requestChannel,
- std::unique_ptr<FmqResultChannel> resultChannel,
+ std::unique_ptr<RequestChannelReceiver> requestChannel,
+ std::unique_ptr<ResultChannelSender> resultChannel,
std::shared_ptr<IBurstExecutorWithCache> cachedExecutor);
~ExecutionBurstServer();
+ // Used by the NN runtime to preemptively remove any stored memory.
Return<void> freeMemory(int32_t slot) override;
private:
- bool sendPacket(const std::vector<FmqResultDatum>& packet);
- std::vector<FmqRequestDatum> getPacketBlocking();
- std::vector<FmqResultDatum> serialize(ErrorStatus errorStatus,
- const std::vector<OutputShape>& outputShapes,
- Timing timing);
-
- // Deserializes the request FMQ data. The three resulting fields are the
- // Request object (where Request::pools is empty), slot identifiers (which
- // are stand-ins for Request::pools), and whether timing information must
- // be collected for the run.
- std::tuple<Request, std::vector<int32_t>, MeasureTiming> deserialize(
- const std::vector<FmqRequestDatum>& data);
-
// Ensures all cache entries contained in mExecutorWithCache are present in
// the cache. If they are not present, they are retrieved (via
// IBurstCallback::getMemories) and added to mExecutorWithCache.
@@ -191,14 +297,13 @@
// ExecutionBurstServer object is freed.
void task();
+ std::thread mWorker;
std::mutex mMutex;
std::atomic<bool> mTeardown{false};
- std::future<void> mWorker;
const sp<IBurstCallback> mCallback;
- const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
- const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
+ const std::unique_ptr<RequestChannelReceiver> mRequestChannelReceiver;
+ const std::unique_ptr<ResultChannelSender> mResultChannelSender;
const std::shared_ptr<IBurstExecutorWithCache> mExecutorWithCache;
- const bool mBlocking;
};
} // namespace android::nn