Cache ExecutionBurstServer memory resources

This change enables caching of memory resources. Prior to this CL, a
hidl_memory object in an execution burst had to be mapped and unmapped
on each execution, taking ~10us of time. With this change, a driver
could choose to map the hidl_memory when it is first retrieved, and
unmap when the burst is freed or when IContext::freeMemory is called.

Bug: 126244806
Test: mma
Test: atest NeuralNetworksTest_static
Test: atest VtsHalNeuralnetworksV1_0TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_1TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_2TargetTest (with sample-all)
Change-Id: I92d7632f2e1e2f5e76c63249d4c615c9af9d1acb
Merged-In: I92d7632f2e1e2f5e76c63249d4c615c9af9d1acb
(cherry picked from commit dfc94136aff39efdfb81436cee7e73c3f6834785)
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp
index 84bb424..08d0e80 100644
--- a/nn/common/ExecutionBurstServer.cpp
+++ b/nn/common/ExecutionBurstServer.cpp
@@ -19,124 +19,134 @@
 #include "ExecutionBurstServer.h"
 
 #include <android-base/logging.h>
-#include <set>
 #include <string>
 #include "Tracing.h"
 
 namespace android::nn {
 
-ExecutionBurstServer::BurstMemoryCache::BurstMemoryCache(const sp<IBurstCallback>& callback)
-    : mCallback(callback) {}
+namespace {
 
-hidl_vec<hidl_memory> ExecutionBurstServer::BurstMemoryCache::getMemories(
-        const std::vector<int32_t>& slots) {
-    std::lock_guard<std::mutex> guard(mMutex);
+// DefaultBurstExecutorWithCache adapts an IPreparedModel so that it can be
+// used as an IBurstExecutorWithCache. Specifically, the cache simply stores the
+// hidl_memory object, and the execution forwards calls to the provided
+// IPreparedModel's "executeSynchronously" method. With this class, hidl_memory
+// must be mapped and unmapped for each execution.
+class DefaultBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
+   public:
+    DefaultBurstExecutorWithCache(IPreparedModel* preparedModel) : mpPreparedModel(preparedModel) {}
 
-    const auto slotIsKnown = [this](int32_t slot) {
+    bool isCacheEntryPresent(int32_t slot) const override {
         return slot < mMemoryCache.size() && mMemoryCache[slot].valid();
-    };
+    }
 
-    // find unique unknown slots
-    std::vector<int32_t> unknownSlots = slots;
-    auto unknownSlotsEnd = unknownSlots.end();
-    std::sort(unknownSlots.begin(), unknownSlotsEnd);
-    unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
-    unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
-    unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
+    void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
+        if (slot >= mMemoryCache.size()) {
+            mMemoryCache.resize(slot + 1);
+        }
+        mMemoryCache[slot] = memory;
+    }
 
-    // retrieve unknown slots
-    if (!unknownSlots.empty()) {
-        ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
-        std::vector<hidl_memory> returnedMemories;
-        auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
-                                                    const hidl_vec<hidl_memory>& memories) {
-            errorStatus = status;
-            returnedMemories = memories;
+    void removeCacheEntry(int32_t slot) override {
+        if (slot < mMemoryCache.size()) {
+            mMemoryCache[slot] = {};
+        }
+    }
+
+    std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
+            const Request& request, const std::vector<int32_t>& slots,
+            MeasureTiming measure) override {
+        // convert slots to pools
+        hidl_vec<hidl_memory> pools(slots.size());
+        std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
+            return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
+        });
+
+        // create full request
+        Request fullRequest = request;
+        fullRequest.pools = std::move(pools);
+
+        // setup execution
+        ErrorStatus returnedStatus = ErrorStatus::GENERAL_FAILURE;
+        hidl_vec<OutputShape> returnedOutputShapes;
+        Timing returnedTiming;
+        auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](
+                          ErrorStatus status, const hidl_vec<OutputShape>& outputShapes,
+                          const Timing& timing) {
+            returnedStatus = status;
+            returnedOutputShapes = outputShapes;
+            returnedTiming = timing;
         };
 
-        Return<void> ret = mCallback->getMemories(unknownSlots, cb);
-
-        // Ensure that the memories were successfully returned.
-        // IBurstCallback.hal specifies the that the number of memories returned
-        // must match the number of slots requested:
-        //     "slots.size() == buffers.size()"
-        if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
-            returnedMemories.size() != unknownSlots.size()) {
-            LOG(ERROR) << "Error retrieving memories";
-            return {};
+        // execute
+        const Return<void> ret = mpPreparedModel->executeSynchronously(fullRequest, measure, cb);
+        if (!ret.isOk() || returnedStatus != ErrorStatus::NONE) {
+            LOG(ERROR) << "IPreparedModelAdapter::execute -- Error executing";
+            return {ErrorStatus::GENERAL_FAILURE, {}, {}};
         }
 
-        // resize cache to fit new slots if necessary
-        const int32_t maxUnknownSlot = unknownSlots.back();
-        if (maxUnknownSlot >= mMemoryCache.size()) {
-            mMemoryCache.resize(maxUnknownSlot + 1);
-        }
-
-        // add memories to unknown slots
-        for (size_t i = 0; i < unknownSlots.size(); ++i) {
-            mMemoryCache[unknownSlots[i]] = returnedMemories[i];
-        }
+        return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming);
     }
 
-    // get all slots
-    hidl_vec<hidl_memory> memories(slots.size());
-    std::transform(slots.begin(), slots.end(), memories.begin(),
-                   [this](int32_t slot) { return mMemoryCache[slot]; });
+   private:
+    IPreparedModel* const mpPreparedModel;
+    std::vector<hidl_memory> mMemoryCache;
+};
 
-    // Ensure all slots are valid. Although this case is never expected to
-    // occur, theoretically IBurstCallback::getMemories could return invalid
-    // hidl_memory objects that must be protected against.
-    if (!std::all_of(memories.begin(), memories.end(),
-                     [](const hidl_memory& memory) { return memory.valid(); })) {
-        LOG(ERROR) << "Error, not all slots are valid!";
-        return {};
-    }
-
-    return memories;
-}
-
-void ExecutionBurstServer::BurstMemoryCache::freeMemory(int32_t slot) {
-    std::lock_guard<std::mutex> guard(mMutex);
-    if (slot < mMemoryCache.size()) {
-        mMemoryCache[slot] = {};
-    }
-}
+}  // anonymous namespace
 
 sp<ExecutionBurstServer> ExecutionBurstServer::create(
         const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
-        const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
+        const MQDescriptorSync<FmqResultDatum>& resultChannel,
+        std::shared_ptr<IBurstExecutorWithCache> executorWithCache) {
     // check inputs
-    if (callback == nullptr || preparedModel == nullptr) {
+    if (callback == nullptr || executorWithCache == nullptr) {
         LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
         return nullptr;
     }
 
     // create FMQ objects
-    std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow)
-                                                                 FmqRequestChannel(requestChannel)};
-    std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow)
-                                                               FmqResultChannel(resultChannel)};
+    std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
+            std::make_unique<FmqRequestChannel>(requestChannel);
+    std::unique_ptr<FmqResultChannel> fmqResultChannel =
+            std::make_unique<FmqResultChannel>(resultChannel);
 
     // check FMQ objects
-    if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
-        !fmqResultChannel->isValid()) {
+    if (!fmqRequestChannel->isValid() || !fmqResultChannel->isValid()) {
         LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
         return nullptr;
     }
 
     // make and return context
     return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
-                                    std::move(fmqResultChannel), preparedModel);
+                                    std::move(fmqResultChannel), std::move(executorWithCache));
 }
 
-ExecutionBurstServer::ExecutionBurstServer(const sp<IBurstCallback>& callback,
-                                           std::unique_ptr<FmqRequestChannel> requestChannel,
-                                           std::unique_ptr<FmqResultChannel> resultChannel,
-                                           IPreparedModel* preparedModel)
-    : mMemoryCache(callback),
+sp<ExecutionBurstServer> ExecutionBurstServer::create(
+        const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+        const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
+    // check relevant input
+    if (preparedModel == nullptr) {
+        LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
+        return nullptr;
+    }
+
+    // adapt IPreparedModel to have caching
+    const std::shared_ptr<DefaultBurstExecutorWithCache> preparedModelAdapter =
+            std::make_shared<DefaultBurstExecutorWithCache>(preparedModel);
+
+    // make and return context
+    return ExecutionBurstServer::create(callback, requestChannel, resultChannel,
+                                        preparedModelAdapter);
+}
+
+ExecutionBurstServer::ExecutionBurstServer(
+        const sp<IBurstCallback>& callback, std::unique_ptr<FmqRequestChannel> requestChannel,
+        std::unique_ptr<FmqResultChannel> resultChannel,
+        std::shared_ptr<IBurstExecutorWithCache> executorWithCache)
+    : mCallback(callback),
       mFmqRequestChannel(std::move(requestChannel)),
       mFmqResultChannel(std::move(resultChannel)),
-      mPreparedModel(preparedModel),
+      mExecutorWithCache(std::move(executorWithCache)),
       mBlocking(mFmqRequestChannel->getEventFlagWord() != nullptr) {
     // TODO: highly document the threading behavior of this class
     mWorker = std::async(std::launch::async, [this] { task(); });
@@ -164,6 +174,52 @@
     mWorker.wait();
 }
 
+Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
+    std::lock_guard<std::mutex> hold(mMutex);
+    mExecutorWithCache->removeCacheEntry(slot);
+    return Void();
+}
+
+void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots) {
+    const auto slotIsKnown = [this](int32_t slot) {
+        return mExecutorWithCache->isCacheEntryPresent(slot);
+    };
+
+    // find unique unknown slots
+    std::vector<int32_t> unknownSlots = slots;
+    auto unknownSlotsEnd = unknownSlots.end();
+    std::sort(unknownSlots.begin(), unknownSlotsEnd);
+    unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
+    unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
+    unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
+
+    // quick-exit if all slots are known
+    if (unknownSlots.empty()) {
+        return;
+    }
+
+    ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
+    std::vector<hidl_memory> returnedMemories;
+    auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
+                                                const hidl_vec<hidl_memory>& memories) {
+        errorStatus = status;
+        returnedMemories = memories;
+    };
+
+    const Return<void> ret = mCallback->getMemories(unknownSlots, cb);
+
+    if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
+        returnedMemories.size() != unknownSlots.size()) {
+        LOG(ERROR) << "Error retrieving memories";
+        return;
+    }
+
+    // add memories to unknown slots
+    for (size_t i = 0; i < unknownSlots.size(); ++i) {
+        mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]);
+    }
+}
+
 bool ExecutionBurstServer::sendPacket(const std::vector<FmqResultDatum>& packet) {
     if (mTeardown) {
         return false;
@@ -236,17 +292,16 @@
 }
 
 // deserialize request
-std::pair<Request, MeasureTiming> ExecutionBurstServer::deserialize(
+std::tuple<Request, std::vector<int32_t>, MeasureTiming> ExecutionBurstServer::deserialize(
         const std::vector<FmqRequestDatum>& data) {
     using discriminator = FmqRequestDatum::hidl_discriminator;
 
-    Request request;
     size_t index = 0;
 
     // validate packet information
     if (data[index].getDiscriminator() != discriminator::packetInformation) {
         LOG(ERROR) << "FMQ Request packet ill-formed";
-        return {{}, MeasureTiming::NO};
+        return {{}, {}, MeasureTiming::NO};
     }
 
     // unpackage packet information
@@ -264,7 +319,7 @@
         // validate input operand information
         if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
             LOG(ERROR) << "FMQ Request packet ill-formed";
-            return {{}, MeasureTiming::NO};
+            return {{}, {}, MeasureTiming::NO};
         }
 
         // unpackage operand information
@@ -282,7 +337,7 @@
             // validate dimension
             if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
                 LOG(ERROR) << "FMQ Request packet ill-formed";
-                return {{}, MeasureTiming::NO};
+                return {{}, {}, MeasureTiming::NO};
             }
 
             // unpackage dimension
@@ -305,7 +360,7 @@
         // validate output operand information
         if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
             LOG(ERROR) << "FMQ Request packet ill-formed";
-            return {{}, MeasureTiming::NO};
+            return {{}, {}, MeasureTiming::NO};
         }
 
         // unpackage operand information
@@ -323,7 +378,7 @@
             // validate dimension
             if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
                 LOG(ERROR) << "FMQ Request packet ill-formed";
-                return {{}, MeasureTiming::NO};
+                return {{}, {}, MeasureTiming::NO};
             }
 
             // unpackage dimension
@@ -346,7 +401,7 @@
         // validate input operand information
         if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
             LOG(ERROR) << "FMQ Request packet ill-formed";
-            return {{}, MeasureTiming::NO};
+            return {{}, {}, MeasureTiming::NO};
         }
 
         // unpackage operand information
@@ -356,12 +411,11 @@
         // store result
         slots.push_back(poolId);
     }
-    hidl_vec<hidl_memory> pools = mMemoryCache.getMemories(slots);
 
     // validate measureTiming
     if (data[index].getDiscriminator() != discriminator::measureTiming) {
         LOG(ERROR) << "FMQ Request packet ill-formed";
-        return {{}, MeasureTiming::NO};
+        return {{}, {}, MeasureTiming::NO};
     }
 
     // unpackage measureTiming
@@ -371,11 +425,11 @@
     // validate packet information
     if (index != packetSize) {
         LOG(ERROR) << "FMQ Result packet ill-formed";
-        return {{}, MeasureTiming::NO};
+        return {{}, {}, MeasureTiming::NO};
     }
 
     // return request
-    return {{/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/std::move(pools)}, measure};
+    return {{/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}}, std::move(slots), measure};
 }
 
 // serialize result
@@ -428,11 +482,6 @@
     return data;
 }
 
-Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
-    mMemoryCache.freeMemory(slot);
-    return Void();
-}
-
 void ExecutionBurstServer::task() {
     while (!mTeardown) {
         // receive request
@@ -446,29 +495,18 @@
         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
                      "ExecutionBurstServer processing packet and returning results");
 
-        // continue processing
-        Request request;
-        MeasureTiming measure;
-        std::tie(request, measure) = deserialize(requestData);
+        // continue processing; types are Request, std::vector<int32_t>, and
+        // MeasureTiming, respectively
+        const auto [requestWithoutPools, slotsOfPools, measure] = deserialize(requestData);
 
-        // perform computation
-        ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
-        std::vector<OutputShape> outputShapes;
-        Timing returnedTiming;
-        // This call to IPreparedModel::executeSynchronously occurs entirely
-        // within the same process, so ignore the Return<> errors via .isOk().
-        // TODO: verify it is safe to always call isOk() here, or if there is
-        // any benefit to checking any potential errors.
-        mPreparedModel
-                ->executeSynchronously(request, measure,
-                                       [&errorStatus, &outputShapes, &returnedTiming](
-                                               ErrorStatus status,
-                                               const hidl_vec<OutputShape>& shapes, Timing timing) {
-                                           errorStatus = status;
-                                           outputShapes = shapes;
-                                           returnedTiming = timing;
-                                       })
-                .isOk();
+        // ensure executor with cache has required memory
+        std::lock_guard<std::mutex> hold(mMutex);
+        ensureCacheEntriesArePresentLocked(slotsOfPools);
+
+        // perform computation; types are ErrorStatus, hidl_vec<OutputShape>,
+        // and Timing, respectively
+        const auto [errorStatus, outputShapes, returnedTiming] =
+                mExecutorWithCache->execute(requestWithoutPools, slotsOfPools, measure);
 
         // return result
         const std::vector<FmqResultDatum> result =