Adjust code in response to Burst validation tests
This CL includes three changes in response to the Burst validation
tests:
(1) it cleans up the validation code in both ExecutionBurstServer and
ExecutionBurstController to handle missed edge cases
(2) it alters IBurstExecutorWithCache's caches to use a map instead of
an indexed vector to handle cases where non-NN runtime clients use
sparse slots
(3) it combines the ValidationTest calls of "validateModel" and
"validateRequest" into a generic "validateEverything"
Bug: 129779280
Bug: 129157135
Test: mma
Test: VtsHalNeuralnetworksV1_2TargetTest (with sample-all)
Change-Id: Ie06235ee19ea7d8173f7ac02f620c83e15705370
Merged-In: Ie06235ee19ea7d8173f7ac02f620c83e15705370
(cherry picked from commit 3260db96845a6dd3becba100aa9f019e8fc504db)
diff --git a/nn/common/ExecutionBurstController.cpp b/nn/common/ExecutionBurstController.cpp
index 37969c3..2846c06 100644
--- a/nn/common/ExecutionBurstController.cpp
+++ b/nn/common/ExecutionBurstController.cpp
@@ -125,7 +125,7 @@
size_t index = 0;
// validate packet information
- if (data[index].getDiscriminator() != discriminator::packetInformation) {
+ if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
LOG(ERROR) << "FMQ Result packet ill-formed";
return std::nullopt;
}
@@ -137,6 +137,12 @@
const ErrorStatus errorStatus = packetInfo.errorStatus;
const uint32_t numberOfOperands = packetInfo.numberOfOperands;
+ // verify packet size
+ if (data.size() != packetSize) {
+ LOG(ERROR) << "FMQ Result packet ill-formed";
+ return std::nullopt;
+ }
+
// unpackage operands
for (size_t operand = 0; operand < numberOfOperands; ++operand) {
// validate operand information
@@ -242,34 +248,21 @@
std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
using discriminator = FmqResultDatum::hidl_discriminator;
+ if (mTeardown) {
+ return std::nullopt;
+ }
+
// wait for result packet and read first element of result packet
- // TODO: have a more elegant way to wait for data, and read it all at once.
- // For example, EventFlag can be used to directly wait on the futex, and all
- // the data can be read at once with a non-blocking call to
- // MessageQueue::read. For further optimization, MessageQueue::beginRead and
- // MessageQueue::commitRead can be used to avoid an extra copy of the
- // metadata.
FmqResultDatum datum;
bool success = true;
if (mBlocking) {
success = mFmqResultChannel->readBlocking(&datum, 1);
} else {
- // TODO: better handle the case where the service crashes after
- // receiving the Request but before returning the result.
- while (!mFmqResultChannel->read(&datum, 1)) {
+ while ((success = !mTeardown.load(std::memory_order_relaxed)) &&
+ !mFmqResultChannel->read(&datum, 1)) {
}
}
- // validate packet information
- if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return std::nullopt;
- }
-
- // unpack packet information
- const auto& packetInfo = datum.packetInformation();
- const size_t count = packetInfo.packetSize;
-
// retrieve remaining elements
// NOTE: all of the data is already available at this point, so there's no
// need to do a blocking wait to wait for more data. This is known because
@@ -277,11 +270,19 @@
// the producer always publishes the entire packet in one function call, so
// if the first element of the packet is available, the remaining elements
// are also available.
- std::vector<FmqResultDatum> packet(count);
+ const size_t count = mFmqResultChannel->availableToRead();
+ std::vector<FmqResultDatum> packet(count + 1);
packet.front() = datum;
- success = mFmqResultChannel->read(packet.data() + 1, packet.size() - 1);
+ success &= mFmqResultChannel->read(packet.data() + 1, count);
+ // terminate loop
+ if (mTeardown) {
+ return std::nullopt;
+ }
+
+ // ensure packet was successfully received
if (!success) {
+ LOG(ERROR) << "Error receiving packet";
return std::nullopt;
}
@@ -313,7 +314,11 @@
}
bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
- // TODO: handle the case where the serialziation exceeds FMQ channel length
+ if (packet.size() > mFmqRequestChannel->availableToWrite()) {
+ LOG(ERROR)
+ << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
+ return false;
+ }
if (mBlocking) {
return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp
index 067c2f5..e9dc179 100644
--- a/nn/common/ExecutionBurstServer.cpp
+++ b/nn/common/ExecutionBurstServer.cpp
@@ -19,7 +19,10 @@
#include "ExecutionBurstServer.h"
#include <android-base/logging.h>
+
#include <limits>
+#include <map>
+
#include "Tracing.h"
namespace android::nn {
@@ -38,30 +41,26 @@
DefaultBurstExecutorWithCache(IPreparedModel* preparedModel) : mpPreparedModel(preparedModel) {}
bool isCacheEntryPresent(int32_t slot) const override {
- return slot < mMemoryCache.size() && mMemoryCache[slot].valid();
+ const auto it = mMemoryCache.find(slot);
+ if (it == mMemoryCache.end()) {
+ return false;
+ }
+ return it->second.valid();
}
void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
- if (slot >= mMemoryCache.size()) {
- mMemoryCache.resize(slot + 1);
- }
mMemoryCache[slot] = memory;
}
- void removeCacheEntry(int32_t slot) override {
- if (slot < mMemoryCache.size()) {
- mMemoryCache[slot] = {};
- }
- }
+ void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
const Request& request, const std::vector<int32_t>& slots,
MeasureTiming measure) override {
// convert slots to pools
hidl_vec<hidl_memory> pools(slots.size());
- std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
- return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
- });
+ std::transform(slots.begin(), slots.end(), pools.begin(),
+ [this](int32_t slot) { return mMemoryCache[slot]; });
// create full request
Request fullRequest = request;
@@ -91,7 +90,7 @@
private:
IPreparedModel* const mpPreparedModel;
- std::vector<hidl_memory> mMemoryCache;
+ std::map<int32_t, hidl_memory> mMemoryCache;
};
} // anonymous namespace
@@ -154,7 +153,7 @@
size_t index = 0;
// validate packet information
- if (data[index].getDiscriminator() != discriminator::packetInformation) {
+ if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
LOG(ERROR) << "FMQ Request packet ill-formed";
return std::nullopt;
}
@@ -167,6 +166,12 @@
const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
const uint32_t numberOfPools = packetInfo.numberOfPools;
+ // verify packet size
+ if (data.size() != packetSize) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
// unpackage input operands
std::vector<RequestArgument> inputs;
inputs.reserve(numberOfInputOperands);
@@ -342,12 +347,6 @@
}
// wait for request packet and read first element of request packet
- // TODO: have a more elegant way to wait for data, and read it all at once.
- // For example, EventFlag can be used to directly wait on the futex, and all
- // the data can be read at once with a non-blocking call to
- // MessageQueue::read. For further optimization, MessageQueue::beginRead and
- // MessageQueue::commitRead can be used to avoid an extra copy of the
- // metadata.
FmqRequestDatum datum;
bool success = false;
if (mBlocking) {
@@ -358,23 +357,8 @@
}
}
- // terminate loop
- if (mTeardown) {
- return std::nullopt;
- }
-
- // validate packet information
- if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return std::make_optional<std::vector<FmqRequestDatum>>();
- }
-
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
- // unpack packet information
- const auto& packetInfo = datum.packetInformation();
- const size_t count = packetInfo.packetSize;
-
// retrieve remaining elements
// NOTE: all of the data is already available at this point, so there's no
// need to do a blocking wait to wait for more data. This is known because
@@ -382,12 +366,20 @@
// the producer always publishes the entire packet in one function call, so
// if the first element of the packet is available, the remaining elements
// are also available.
- std::vector<FmqRequestDatum> packet(count);
+ const size_t count = mFmqRequestChannel->availableToRead();
+ std::vector<FmqRequestDatum> packet(count + 1);
packet.front() = datum;
- success = mFmqRequestChannel->read(packet.data() + 1, packet.size() - 1);
+ success &= mFmqRequestChannel->read(packet.data() + 1, count);
+ // terminate loop
+ if (mTeardown) {
+ return std::nullopt;
+ }
+
+ // ensure packet was successfully received
if (!success) {
- return std::make_optional<std::vector<FmqRequestDatum>>();
+ LOG(ERROR) << "Error receiving packet";
+ return std::nullopt;
}
return packet;
@@ -418,6 +410,14 @@
}
bool ResultChannelSender::sendPacket(const std::vector<FmqResultDatum>& packet) {
+ if (packet.size() > mFmqResultChannel->availableToWrite()) {
+ LOG(ERROR)
+ << "ResultChannelSender::sendPacket -- packet size exceeds size available in FMQ";
+ const std::vector<FmqResultDatum> errorPacket =
+ serialize(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
+ return mFmqResultChannel->writeBlocking(errorPacket.data(), errorPacket.size());
+ }
+
if (mBlocking) {
return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
} else {
diff --git a/nn/common/include/ExecutionBurstController.h b/nn/common/include/ExecutionBurstController.h
index 6a72720..1154b5c 100644
--- a/nn/common/include/ExecutionBurstController.h
+++ b/nn/common/include/ExecutionBurstController.h
@@ -175,6 +175,7 @@
class ExecutionBurstController {
DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstController);
+ public:
/**
* NN runtime burst callback object and memory cache.
*
@@ -191,18 +192,12 @@
* buffer, they must use the same key.
*/
class ExecutionBurstCallback : public IBurstCallback {
- DISALLOW_COPY_AND_ASSIGN(ExecutionBurstCallback);
-
public:
- ExecutionBurstCallback() = default;
-
Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override;
std::vector<int32_t> getSlots(const hidl_vec<hidl_memory>& memories,
const std::vector<intptr_t>& keys);
- int32_t getSlot(const hidl_memory& memory, intptr_t key);
-
/*
* This function performs two different actions:
* 1) Removes an entry from the cache (if present), including the local
@@ -227,7 +222,6 @@
std::vector<hidl_memory> mMemoryCache;
};
- public:
/**
* Creates a burst controller on a prepared model.
*