Adjust code in response to Burst validation tests

This CL includes three changes in response to the Burst validation
tests:
(1) it cleans up the validation code in both ExecutionBurstServer and
    ExecutionBurstController to handle missed edge cases
(2) it alters IBurstExecutorWithCache's caches to use a map instead of
    an indexed vector to handle cases where non-NN runtime clients use
    sparse slots
(3) it combines the ValidationTest calls of "validateModel" and
    "validateRequest" into a generic "validateEverything"

Bug: 129779280
Bug: 129157135
Test: mma
Test: VtsHalNeuralnetworksV1_2TargetTest (with sample-all)
Change-Id: Ie06235ee19ea7d8173f7ac02f620c83e15705370
Merged-In: Ie06235ee19ea7d8173f7ac02f620c83e15705370
(cherry picked from commit 3260db96845a6dd3becba100aa9f019e8fc504db)
diff --git a/nn/common/ExecutionBurstController.cpp b/nn/common/ExecutionBurstController.cpp
index 37969c3..2846c06 100644
--- a/nn/common/ExecutionBurstController.cpp
+++ b/nn/common/ExecutionBurstController.cpp
@@ -125,7 +125,7 @@
     size_t index = 0;
 
     // validate packet information
-    if (data[index].getDiscriminator() != discriminator::packetInformation) {
+    if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
         LOG(ERROR) << "FMQ Result packet ill-formed";
         return std::nullopt;
     }
@@ -137,6 +137,12 @@
     const ErrorStatus errorStatus = packetInfo.errorStatus;
     const uint32_t numberOfOperands = packetInfo.numberOfOperands;
 
+    // verify packet size
+    if (data.size() != packetSize) {
+        LOG(ERROR) << "FMQ Result packet ill-formed";
+        return std::nullopt;
+    }
+
     // unpackage operands
     for (size_t operand = 0; operand < numberOfOperands; ++operand) {
         // validate operand information
@@ -242,34 +248,21 @@
 std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
     using discriminator = FmqResultDatum::hidl_discriminator;
 
+    if (mTeardown) {
+        return std::nullopt;
+    }
+
     // wait for result packet and read first element of result packet
-    // TODO: have a more elegant way to wait for data, and read it all at once.
-    // For example, EventFlag can be used to directly wait on the futex, and all
-    // the data can be read at once with a non-blocking call to
-    // MessageQueue::read. For further optimization, MessageQueue::beginRead and
-    // MessageQueue::commitRead can be used to avoid an extra copy of the
-    // metadata.
     FmqResultDatum datum;
     bool success = true;
     if (mBlocking) {
         success = mFmqResultChannel->readBlocking(&datum, 1);
     } else {
-        // TODO: better handle the case where the service crashes after
-        // receiving the Request but before returning the result.
-        while (!mFmqResultChannel->read(&datum, 1)) {
+        while ((success = !mTeardown.load(std::memory_order_relaxed)) &&
+               !mFmqResultChannel->read(&datum, 1)) {
         }
     }
 
-    // validate packet information
-    if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
-        LOG(ERROR) << "FMQ Result packet ill-formed";
-        return std::nullopt;
-    }
-
-    // unpack packet information
-    const auto& packetInfo = datum.packetInformation();
-    const size_t count = packetInfo.packetSize;
-
     // retrieve remaining elements
     // NOTE: all of the data is already available at this point, so there's no
     // need to do a blocking wait to wait for more data. This is known because
@@ -277,11 +270,19 @@
     // the producer always publishes the entire packet in one function call, so
     // if the first element of the packet is available, the remaining elements
     // are also available.
-    std::vector<FmqResultDatum> packet(count);
+    const size_t count = mFmqResultChannel->availableToRead();
+    std::vector<FmqResultDatum> packet(count + 1);
     packet.front() = datum;
-    success = mFmqResultChannel->read(packet.data() + 1, packet.size() - 1);
+    success &= mFmqResultChannel->read(packet.data() + 1, count);
 
+    // terminate loop
+    if (mTeardown) {
+        return std::nullopt;
+    }
+
+    // ensure packet was successfully received
     if (!success) {
+        LOG(ERROR) << "Error receiving packet";
         return std::nullopt;
     }
 
@@ -313,7 +314,11 @@
 }
 
 bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
-    // TODO: handle the case where the serialziation exceeds FMQ channel length
+    if (packet.size() > mFmqRequestChannel->availableToWrite()) {
+        LOG(ERROR)
+                << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
+        return false;
+    }
 
     if (mBlocking) {
         return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp
index 067c2f5..e9dc179 100644
--- a/nn/common/ExecutionBurstServer.cpp
+++ b/nn/common/ExecutionBurstServer.cpp
@@ -19,7 +19,10 @@
 #include "ExecutionBurstServer.h"
 
 #include <android-base/logging.h>
+
 #include <limits>
+#include <map>
+
 #include "Tracing.h"
 
 namespace android::nn {
@@ -38,30 +41,26 @@
     DefaultBurstExecutorWithCache(IPreparedModel* preparedModel) : mpPreparedModel(preparedModel) {}
 
     bool isCacheEntryPresent(int32_t slot) const override {
-        return slot < mMemoryCache.size() && mMemoryCache[slot].valid();
+        const auto it = mMemoryCache.find(slot);
+        if (it == mMemoryCache.end()) {
+            return false;
+        }
+        return it->second.valid();
     }
 
     void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
-        if (slot >= mMemoryCache.size()) {
-            mMemoryCache.resize(slot + 1);
-        }
         mMemoryCache[slot] = memory;
     }
 
-    void removeCacheEntry(int32_t slot) override {
-        if (slot < mMemoryCache.size()) {
-            mMemoryCache[slot] = {};
-        }
-    }
+    void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
 
     std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
             const Request& request, const std::vector<int32_t>& slots,
             MeasureTiming measure) override {
         // convert slots to pools
         hidl_vec<hidl_memory> pools(slots.size());
-        std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
-            return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
-        });
+        std::transform(slots.begin(), slots.end(), pools.begin(),
+                       [this](int32_t slot) { return mMemoryCache[slot]; });
 
         // create full request
         Request fullRequest = request;
@@ -91,7 +90,7 @@
 
    private:
     IPreparedModel* const mpPreparedModel;
-    std::vector<hidl_memory> mMemoryCache;
+    std::map<int32_t, hidl_memory> mMemoryCache;
 };
 
 }  // anonymous namespace
@@ -154,7 +153,7 @@
     size_t index = 0;
 
     // validate packet information
-    if (data[index].getDiscriminator() != discriminator::packetInformation) {
+    if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
         LOG(ERROR) << "FMQ Request packet ill-formed";
         return std::nullopt;
     }
@@ -167,6 +166,12 @@
     const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
     const uint32_t numberOfPools = packetInfo.numberOfPools;
 
+    // verify packet size
+    if (data.size() != packetSize) {
+        LOG(ERROR) << "FMQ Request packet ill-formed";
+        return std::nullopt;
+    }
+
     // unpackage input operands
     std::vector<RequestArgument> inputs;
     inputs.reserve(numberOfInputOperands);
@@ -342,12 +347,6 @@
     }
 
     // wait for request packet and read first element of request packet
-    // TODO: have a more elegant way to wait for data, and read it all at once.
-    // For example, EventFlag can be used to directly wait on the futex, and all
-    // the data can be read at once with a non-blocking call to
-    // MessageQueue::read. For further optimization, MessageQueue::beginRead and
-    // MessageQueue::commitRead can be used to avoid an extra copy of the
-    // metadata.
     FmqRequestDatum datum;
     bool success = false;
     if (mBlocking) {
@@ -358,23 +357,8 @@
         }
     }
 
-    // terminate loop
-    if (mTeardown) {
-        return std::nullopt;
-    }
-
-    // validate packet information
-    if (!success || datum.getDiscriminator() != discriminator::packetInformation) {
-        LOG(ERROR) << "FMQ Request packet ill-formed";
-        return std::make_optional<std::vector<FmqRequestDatum>>();
-    }
-
     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
 
-    // unpack packet information
-    const auto& packetInfo = datum.packetInformation();
-    const size_t count = packetInfo.packetSize;
-
     // retrieve remaining elements
     // NOTE: all of the data is already available at this point, so there's no
     // need to do a blocking wait to wait for more data. This is known because
@@ -382,12 +366,20 @@
     // the producer always publishes the entire packet in one function call, so
     // if the first element of the packet is available, the remaining elements
     // are also available.
-    std::vector<FmqRequestDatum> packet(count);
+    const size_t count = mFmqRequestChannel->availableToRead();
+    std::vector<FmqRequestDatum> packet(count + 1);
     packet.front() = datum;
-    success = mFmqRequestChannel->read(packet.data() + 1, packet.size() - 1);
+    success &= mFmqRequestChannel->read(packet.data() + 1, count);
 
+    // terminate loop
+    if (mTeardown) {
+        return std::nullopt;
+    }
+
+    // ensure packet was successfully received
     if (!success) {
-        return std::make_optional<std::vector<FmqRequestDatum>>();
+        LOG(ERROR) << "Error receiving packet";
+        return std::nullopt;
     }
 
     return packet;
@@ -418,6 +410,14 @@
 }
 
 bool ResultChannelSender::sendPacket(const std::vector<FmqResultDatum>& packet) {
+    if (packet.size() > mFmqResultChannel->availableToWrite()) {
+        LOG(ERROR)
+                << "ResultChannelSender::sendPacket -- packet size exceeds size available in FMQ";
+        const std::vector<FmqResultDatum> errorPacket =
+                serialize(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
+        return mFmqResultChannel->writeBlocking(errorPacket.data(), errorPacket.size());
+    }
+
     if (mBlocking) {
         return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
     } else {
diff --git a/nn/common/include/ExecutionBurstController.h b/nn/common/include/ExecutionBurstController.h
index 6a72720..1154b5c 100644
--- a/nn/common/include/ExecutionBurstController.h
+++ b/nn/common/include/ExecutionBurstController.h
@@ -175,6 +175,7 @@
 class ExecutionBurstController {
     DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstController);
 
+   public:
     /**
      * NN runtime burst callback object and memory cache.
      *
@@ -191,18 +192,12 @@
      * buffer, they must use the same key.
      */
     class ExecutionBurstCallback : public IBurstCallback {
-        DISALLOW_COPY_AND_ASSIGN(ExecutionBurstCallback);
-
        public:
-        ExecutionBurstCallback() = default;
-
         Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override;
 
         std::vector<int32_t> getSlots(const hidl_vec<hidl_memory>& memories,
                                       const std::vector<intptr_t>& keys);
 
-        int32_t getSlot(const hidl_memory& memory, intptr_t key);
-
         /*
          * This function performs two different actions:
          * 1) Removes an entry from the cache (if present), including the local
@@ -227,7 +222,6 @@
         std::vector<hidl_memory> mMemoryCache;
     };
 
-   public:
     /**
      * Creates a burst controller on a prepared model.
      *