Adjust code in response to Burst validation tests
This CL includes three changes in response to the Burst validation
tests:
(1) it cleans up the validation code in both ExecutionBurstServer and
ExecutionBurstController to handle missed edge cases
(2) it alters IBurstExecutorWithCache's caches to use a map instead of
an indexed vector to handle cases where non-NN runtime clients use
sparse slots
(3) it combines the ValidationTest calls of "validateModel" and
"validateRequest" into a generic "validateEverything"
Bug: 129779280
Bug: 129157135
Test: mma
Test: VtsHalNeuralnetworksV1_2TargetTest (with sample-all)
Change-Id: Ie06235ee19ea7d8173f7ac02f620c83e15705370
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp
index 067c2f5..e9dc179 100644
--- a/nn/common/ExecutionBurstServer.cpp
+++ b/nn/common/ExecutionBurstServer.cpp
@@ -19,7 +19,10 @@
#include "ExecutionBurstServer.h"
#include <android-base/logging.h>
+
#include <limits>
+#include <map>
+
#include "Tracing.h"
namespace android::nn {
@@ -38,30 +41,26 @@
DefaultBurstExecutorWithCache(IPreparedModel* preparedModel) : mpPreparedModel(preparedModel) {}
bool isCacheEntryPresent(int32_t slot) const override {
- return slot < mMemoryCache.size() && mMemoryCache[slot].valid();
+ const auto it = mMemoryCache.find(slot);
+ if (it == mMemoryCache.end()) {
+ return false;
+ }
+ return it->second.valid();
}
void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
- if (slot >= mMemoryCache.size()) {
- mMemoryCache.resize(slot + 1);
- }
mMemoryCache[slot] = memory;
}
- void removeCacheEntry(int32_t slot) override {
- if (slot < mMemoryCache.size()) {
- mMemoryCache[slot] = {};
- }
- }
+ void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
const Request& request, const std::vector<int32_t>& slots,
MeasureTiming measure) override {
// convert slots to pools
hidl_vec<hidl_memory> pools(slots.size());
- std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
- return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
- });
+ std::transform(slots.begin(), slots.end(), pools.begin(),
+ [this](int32_t slot) { return mMemoryCache[slot]; });
// create full request
Request fullRequest = request;
@@ -91,7 +90,7 @@
private:
IPreparedModel* const mpPreparedModel;
- std::vector<hidl_memory> mMemoryCache;
+ std::map<int32_t, hidl_memory> mMemoryCache;
};
} // anonymous namespace
@@ -154,7 +153,7 @@
size_t index = 0;
// validate packet information
- if (data[index].getDiscriminator() != discriminator::packetInformation) {
+ if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
LOG(ERROR) << "FMQ Request packet ill-formed";
return std::nullopt;
}
@@ -167,6 +166,12 @@
const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
const uint32_t numberOfPools = packetInfo.numberOfPools;
+ // verify packet size
+ if (data.size() != packetSize) {
+ LOG(ERROR) << "FMQ Request packet ill-formed";
+ return std::nullopt;
+ }
+
// unpackage input operands
std::vector<RequestArgument> inputs;
inputs.reserve(numberOfInputOperands);
@@ -342,12 +347,6 @@
}
// wait for request packet and read first element of request packet
- // TODO: have a more elegant way to wait for data, and read it all at once.
- // For example, EventFlag can be used to directly wait on the futex, and all
- // the data can be read at once with a non-blocking call to
- // MessageQueue::read. For further optimization, MessageQueue::beginRead and
- // MessageQueue::commitRead can be used to avoid an extra copy of the
- // metadata.
FmqRequestDatum datum;
bool success = false;
if (mBlocking) {
@@ -358,23 +357,8 @@
}
}
- // terminate loop
- if (mTeardown) {
- return std::nullopt;
- }
-
- // validate packet information
- if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return std::make_optional<std::vector<FmqRequestDatum>>();
- }
-
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
- // unpack packet information
- const auto& packetInfo = datum.packetInformation();
- const size_t count = packetInfo.packetSize;
-
// retrieve remaining elements
// NOTE: all of the data is already available at this point, so there's no
// need to do a blocking wait to wait for more data. This is known because
@@ -382,12 +366,20 @@
// the producer always publishes the entire packet in one function call, so
// if the first element of the packet is available, the remaining elements
// are also available.
- std::vector<FmqRequestDatum> packet(count);
+ const size_t count = mFmqRequestChannel->availableToRead();
+ std::vector<FmqRequestDatum> packet(count + 1);
packet.front() = datum;
- success = mFmqRequestChannel->read(packet.data() + 1, packet.size() - 1);
+ success &= mFmqRequestChannel->read(packet.data() + 1, count);
+ // terminate loop
+ if (mTeardown) {
+ return std::nullopt;
+ }
+
+ // ensure packet was successfully received
if (!success) {
- return std::make_optional<std::vector<FmqRequestDatum>>();
+ LOG(ERROR) << "Error receiving packet";
+ return std::nullopt;
}
return packet;
@@ -418,6 +410,14 @@
}
bool ResultChannelSender::sendPacket(const std::vector<FmqResultDatum>& packet) {
+ if (packet.size() > mFmqResultChannel->availableToWrite()) {
+ LOG(ERROR)
+ << "ResultChannelSender::sendPacket -- packet size exceeds size available in FMQ";
+ const std::vector<FmqResultDatum> errorPacket =
+ serialize(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
+ return mFmqResultChannel->writeBlocking(errorPacket.data(), errorPacket.size());
+ }
+
if (mBlocking) {
return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
} else {