NNAPI Burst object cleanup

This CL addresses some follow up comments from ag/6154003 and
ag/6575732.

Bug: 119570067
Test: mma
Test: atest NeuralNetworksTest_static
Change-Id: I1a2bd4c9d97296f50d6ef9bb86515ea8e9a54515
Merged-In: I1a2bd4c9d97296f50d6ef9bb86515ea8e9a54515
(cherry picked from commit 3db6fe510dcc3c6076e3894814f954f7b8e2008e)
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp
index 64a4ee2..6ede34d 100644
--- a/nn/common/ExecutionBurstServer.cpp
+++ b/nn/common/ExecutionBurstServer.cpp
@@ -14,44 +14,57 @@
  * limitations under the License.
  */
 
+#define LOG_TAG "ExecutionBurstServer"
+
 #include "ExecutionBurstServer.h"
 
 #include <android-base/logging.h>
+#include <set>
+#include <string>
+#include "Tracing.h"
 
-namespace android {
-namespace nn {
+namespace android::nn {
 
-BurstMemoryCache::BurstMemoryCache(const sp<IBurstCallback>& callback) : mCallback(callback) {}
+ExecutionBurstServer::BurstMemoryCache::BurstMemoryCache(const sp<IBurstCallback>& callback)
+    : mCallback(callback) {}
 
-hidl_vec<hidl_memory> BurstMemoryCache::getMemories(const std::vector<int32_t>& slots) {
+hidl_vec<hidl_memory> ExecutionBurstServer::BurstMemoryCache::getMemories(
+        const std::vector<int32_t>& slots) {
     std::lock_guard<std::mutex> guard(mMutex);
 
     // find unique unknown slots
-    std::vector<int32_t> unknownSlots = slots;
-    std::sort(unknownSlots.begin(), unknownSlots.end());
-    auto last = std::unique(unknownSlots.begin(), unknownSlots.end());
-    unknownSlots.erase(last, unknownSlots.end());
+    std::set<int32_t> setOfUnknownSlots;
+    for (int32_t slot : slots) {
+        if (mSlotToMemoryCache.find(slot) == mSlotToMemoryCache.end()) {
+            setOfUnknownSlots.insert(slot);
+        }
+    }
+    const std::vector<int32_t> vecOfUnknownSlots(setOfUnknownSlots.begin(),
+                                                 setOfUnknownSlots.end());
 
     // retrieve unknown slots
-    ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
-    std::vector<hidl_memory> returnedMemories;
-    Return<void> ret = mCallback->getMemories(
-            unknownSlots, [&errorStatus, &returnedMemories](ErrorStatus status,
-                                                            const hidl_vec<hidl_memory>& memories) {
-                errorStatus = status;
-                if (status == ErrorStatus::NONE) {
-                    returnedMemories = memories;
-                }
-            });
+    if (!vecOfUnknownSlots.empty()) {
+        ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
+        std::vector<hidl_memory> returnedMemories;
+        Return<void> ret = mCallback->getMemories(
+                vecOfUnknownSlots,
+                [&errorStatus, &returnedMemories](ErrorStatus status,
+                                                  const hidl_vec<hidl_memory>& memories) {
+                    errorStatus = status;
+                    if (status == ErrorStatus::NONE) {
+                        returnedMemories = memories;
+                    }
+                });
 
-    if (!ret.isOk() || errorStatus != ErrorStatus::NONE) {
-        LOG(ERROR) << "Error retrieving memories";
-        return {};
-    }
+        if (!ret.isOk() || errorStatus != ErrorStatus::NONE) {
+            LOG(ERROR) << "Error retrieving memories";
+            return {};
+        }
 
-    // add memories to unknown slots
-    for (size_t i = 0; i < unknownSlots.size(); ++i) {
-        mSlotToMemoryCache[unknownSlots[i]] = returnedMemories[i];
+        // add memories to unknown slots
+        for (size_t i = 0; i < vecOfUnknownSlots.size(); ++i) {
+            mSlotToMemoryCache[vecOfUnknownSlots[i]] = returnedMemories[i];
+        }
     }
 
     // get all slots
@@ -59,14 +72,42 @@
     for (size_t i = 0; i < slots.size(); ++i) {
         memories[i] = mSlotToMemoryCache[slots[i]];
     }
+
     return memories;
 }
 
-void BurstMemoryCache::freeMemory(int32_t slot) {
+void ExecutionBurstServer::BurstMemoryCache::freeMemory(int32_t slot) {
     std::lock_guard<std::mutex> guard(mMutex);
     mSlotToMemoryCache.erase(slot);
 }
 
+sp<ExecutionBurstServer> ExecutionBurstServer::create(
+        const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+        const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
+    // check inputs
+    if (callback == nullptr || preparedModel == nullptr) {
+        LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
+        return nullptr;
+    }
+
+    // create FMQ objects
+    std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow)
+                                                                 FmqRequestChannel(requestChannel)};
+    std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow)
+                                                               FmqResultChannel(resultChannel)};
+
+    // check FMQ objects
+    if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
+        !fmqResultChannel->isValid()) {
+        LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
+        return nullptr;
+    }
+
+    // make and return context
+    return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
+                                    std::move(fmqResultChannel), preparedModel);
+}
+
 ExecutionBurstServer::ExecutionBurstServer(const sp<IBurstCallback>& callback,
                                            std::unique_ptr<FmqRequestChannel> requestChannel,
                                            std::unique_ptr<FmqResultChannel> resultChannel,
@@ -85,9 +126,13 @@
     mTeardown = true;
 
     // force unblock
+    // ExecutionBurstServer is by default waiting on a request packet. If the
+    // client process destroys its burst object, the server will still be
+    // waiting on the futex (assuming mBlocking is true). This force unblock
+    // wakes up any thread waiting on the futex.
     if (mBlocking) {
-        // TODO: look for a different/better way to signal/notify the futex to wake
-        // up any thread waiting on it
+        // TODO: look for a different/better way to signal/notify the futex to
+        // wake up any thread waiting on it
         FmqRequestDatum datum;
         datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
                                  /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
@@ -117,7 +162,13 @@
         return {};
     }
 
-    // wait for request packet and read first element of result packet
+    // wait for request packet and read first element of request packet
+    // TODO: have a more elegant way to wait for data, and read it all at once.
+    // For example, EventFlag can be used to directly wait on the futex, and all
+    // the data can be read at once with a non-blocking call to
+    // MessageQueue::read. For further optimization, MessageQueue::beginRead and
+    // MessageQueue::commitRead can be used to avoid an extra copy of the
+    // metadata.
     FmqRequestDatum datum;
     bool success = false;
     if (mBlocking) {
@@ -139,13 +190,19 @@
         return {};
     }
 
+    NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
+
     // unpack packet information
     const auto& packetInfo = datum.packetInformation();
     const size_t count = packetInfo.packetSize;
 
     // retrieve remaining elements
     // NOTE: all of the data is already available at this point, so there's no
-    // need to do a blocking wait to wait for more data
+    // need to do a blocking wait to wait for more data. This is known because
+    // in FMQ, all writes are published (made available) atomically. Currently,
+    // the producer always publishes the entire packet in one function call, so
+    // if the first element of the packet is available, the remaining elements
+    // are also available.
     std::vector<FmqRequestDatum> packet(count);
     packet.front() = datum;
     success = mFmqRequestChannel->read(packet.data() + 1, packet.size() - 1);
@@ -365,6 +422,9 @@
             return;
         }
 
+        NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
+                     "ExecutionBurstServer processing packet and returning results");
+
         // continue processing
         Request request;
         MeasureTiming measure;
@@ -374,6 +434,10 @@
         ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
         std::vector<OutputShape> outputShapes;
         Timing returnedTiming;
+        // This call to IPreparedModel::executeSynchronously occurs entirely
+        // within the same process, so ignore the Return<> errors via .isOk().
+        // TODO: verify it is safe to always call isOk() here, or if there is
+        // any benefit to checking any potential errors.
         mPreparedModel
                 ->executeSynchronously(request, measure,
                                        [&errorStatus, &outputShapes, &returnedTiming](
@@ -392,33 +456,4 @@
     }
 }
 
-sp<IBurstContext> createBurstContext(const sp<IBurstCallback>& callback,
-                                     const MQDescriptorSync<FmqRequestDatum>& requestChannel,
-                                     const MQDescriptorSync<FmqResultDatum>& resultChannel,
-                                     IPreparedModel* preparedModel) {
-    // check inputs
-    if (callback == nullptr || preparedModel == nullptr) {
-        LOG(ERROR) << "createExecutionBurstServer passed a nullptr";
-        return nullptr;
-    }
-
-    // create FMQ objects
-    std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow)
-                                                                 FmqRequestChannel(requestChannel)};
-    std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow)
-                                                               FmqResultChannel(resultChannel)};
-
-    // check FMQ objects
-    if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
-        !fmqResultChannel->isValid()) {
-        LOG(ERROR) << "createExecutionBurstServer failed to create FastMessageQueue";
-        return nullptr;
-    }
-
-    // make and return context
-    return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
-                                    std::move(fmqResultChannel), preparedModel);
-}
-
-}  // namespace nn
-}  // namespace android
+}  // namespace android::nn