Propagate ANNMemory_free to IBurstContext::freeMemory
This CL extends nn::Memory to include a reference to all burst objects
which use its memory. When the nn:Memory object is destroyed (via
ANNMemory_free), it signals all these burst objects so that they can
properly clean their memory caches (via IBurstContext::freeMemory).
This CL also provides a more intelligent memory slot allocator within
ExecutionBurstController to reuse slots after they are freed.
This CL includes some additional miscellaneous code cleanup of the
neighboring test cases, e.g., closing file descriptors when they are no
longer needed.
Bug: 128319484
Test: mma
Test: atest NeuralNetworksTest_static
Change-Id: Ibc19059af5194cd3dd58c9a9d8baa54fa6b26de5
diff --git a/nn/common/ExecutionBurstController.cpp b/nn/common/ExecutionBurstController.cpp
index 87f5c1e..286c62a 100644
--- a/nn/common/ExecutionBurstController.cpp
+++ b/nn/common/ExecutionBurstController.cpp
@@ -38,14 +38,15 @@
// get all memories
hidl_vec<hidl_memory> memories(slots.size());
- for (size_t i = 0; i < slots.size(); ++i) {
- // if memory is available, return it; otherwise return error
- auto iter = mSlotToMemoryCache.find(slots[i]);
- if (iter == mSlotToMemoryCache.end()) {
- cb(ErrorStatus::INVALID_ARGUMENT, {});
- return Void();
- }
- memories[i] = iter->second;
+ std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
+ return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
+ });
+
+ // ensure all memories are valid
+ if (!std::all_of(memories.begin(), memories.end(),
+ [](const hidl_memory& memory) { return memory.valid(); })) {
+ cb(ErrorStatus::INVALID_ARGUMENT, {});
+ return Void();
}
// return successful
@@ -70,26 +71,24 @@
intptr_t key) {
std::lock_guard<std::mutex> guard(mMutex);
- auto iter = mMemoryIdToSlotCache.find(key);
- if (iter != mMemoryIdToSlotCache.end()) {
- const int32_t slot = iter->second;
- mMemoryIdToSlotCache.erase(key);
- mSlotToMemoryCache.erase(slot);
- return {true, slot};
- } else {
+ auto iter = mMemoryIdToSlot.find(key);
+ if (iter == mMemoryIdToSlot.end()) {
return {false, 0};
}
+ const int32_t slot = iter->second;
+ mMemoryIdToSlot.erase(key);
+ mMemoryCache[slot] = {};
+ mFreeSlots.push(slot);
+ return {true, slot};
}
int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory,
intptr_t key) {
- auto iter = mMemoryIdToSlotCache.find(key);
- if (iter == mMemoryIdToSlotCache.end()) {
- const int32_t slot = mNextSlot;
- // TODO: change mNextSlot to uint64_t or maintain a free list of IDs
- mNextSlot = (mNextSlot + 1) % (1 << 30);
- mMemoryIdToSlotCache[key] = slot;
- mSlotToMemoryCache[slot] = memory;
+ auto iter = mMemoryIdToSlot.find(key);
+ if (iter == mMemoryIdToSlot.end()) {
+ const int32_t slot = allocateSlotLocked();
+ mMemoryIdToSlot[key] = slot;
+ mMemoryCache[slot] = memory;
return slot;
} else {
const int32_t slot = iter->second;
@@ -97,6 +96,24 @@
}
}
+int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
+ constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
+
+ // if there is a free slot, use it
+ if (mFreeSlots.size() > 0) {
+ const int32_t slot = mFreeSlots.top();
+ mFreeSlots.pop();
+ return slot;
+ }
+
+ // otherwise use a slot for the first time
+ CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
+ const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
+ mMemoryCache.emplace_back();
+
+ return slot;
+}
+
std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
const sp<IPreparedModel>& preparedModel, bool blocking) {
// check inputs
@@ -386,6 +403,8 @@
const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) {
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
+ std::lock_guard<std::mutex> guard(mMutex);
+
// serialize request
std::vector<FmqRequestDatum> requestData = serialize(request, measure, memoryIds);
@@ -411,6 +430,8 @@
}
void ExecutionBurstController::freeMemory(intptr_t key) {
+ std::lock_guard<std::mutex> guard(mMutex);
+
bool valid;
int32_t slot;
std::tie(valid, slot) = mMemoryCache->freeMemory(key);
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp
index 6ede34d..84bb424 100644
--- a/nn/common/ExecutionBurstServer.cpp
+++ b/nn/common/ExecutionBurstServer.cpp
@@ -32,45 +32,64 @@
const std::vector<int32_t>& slots) {
std::lock_guard<std::mutex> guard(mMutex);
+ const auto slotIsKnown = [this](int32_t slot) {
+ return slot < mMemoryCache.size() && mMemoryCache[slot].valid();
+ };
+
// find unique unknown slots
- std::set<int32_t> setOfUnknownSlots;
- for (int32_t slot : slots) {
- if (mSlotToMemoryCache.find(slot) == mSlotToMemoryCache.end()) {
- setOfUnknownSlots.insert(slot);
- }
- }
- const std::vector<int32_t> vecOfUnknownSlots(setOfUnknownSlots.begin(),
- setOfUnknownSlots.end());
+ std::vector<int32_t> unknownSlots = slots;
+ auto unknownSlotsEnd = unknownSlots.end();
+ std::sort(unknownSlots.begin(), unknownSlotsEnd);
+ unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
+ unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
+ unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
// retrieve unknown slots
- if (!vecOfUnknownSlots.empty()) {
+ if (!unknownSlots.empty()) {
ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
std::vector<hidl_memory> returnedMemories;
- Return<void> ret = mCallback->getMemories(
- vecOfUnknownSlots,
- [&errorStatus, &returnedMemories](ErrorStatus status,
- const hidl_vec<hidl_memory>& memories) {
- errorStatus = status;
- if (status == ErrorStatus::NONE) {
- returnedMemories = memories;
- }
- });
+ auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
+ const hidl_vec<hidl_memory>& memories) {
+ errorStatus = status;
+ returnedMemories = memories;
+ };
- if (!ret.isOk() || errorStatus != ErrorStatus::NONE) {
+ Return<void> ret = mCallback->getMemories(unknownSlots, cb);
+
+ // Ensure that the memories were successfully returned.
+ // IBurstCallback.hal specifies the that the number of memories returned
+ // must match the number of slots requested:
+ // "slots.size() == buffers.size()"
+ if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
+ returnedMemories.size() != unknownSlots.size()) {
LOG(ERROR) << "Error retrieving memories";
return {};
}
+ // resize cache to fit new slots if necessary
+ const int32_t maxUnknownSlot = unknownSlots.back();
+ if (maxUnknownSlot >= mMemoryCache.size()) {
+ mMemoryCache.resize(maxUnknownSlot + 1);
+ }
+
// add memories to unknown slots
- for (size_t i = 0; i < vecOfUnknownSlots.size(); ++i) {
- mSlotToMemoryCache[vecOfUnknownSlots[i]] = returnedMemories[i];
+ for (size_t i = 0; i < unknownSlots.size(); ++i) {
+ mMemoryCache[unknownSlots[i]] = returnedMemories[i];
}
}
// get all slots
hidl_vec<hidl_memory> memories(slots.size());
- for (size_t i = 0; i < slots.size(); ++i) {
- memories[i] = mSlotToMemoryCache[slots[i]];
+ std::transform(slots.begin(), slots.end(), memories.begin(),
+ [this](int32_t slot) { return mMemoryCache[slot]; });
+
+ // Ensure all slots are valid. Although this case is never expected to
+ // occur, theoretically IBurstCallback::getMemories could return invalid
+ // hidl_memory objects that must be protected against.
+ if (!std::all_of(memories.begin(), memories.end(),
+ [](const hidl_memory& memory) { return memory.valid(); })) {
+ LOG(ERROR) << "Error, not all slots are valid!";
+ return {};
}
return memories;
@@ -78,7 +97,9 @@
void ExecutionBurstServer::BurstMemoryCache::freeMemory(int32_t slot) {
std::lock_guard<std::mutex> guard(mMutex);
- mSlotToMemoryCache.erase(slot);
+ if (slot < mMemoryCache.size()) {
+ mMemoryCache[slot] = {};
+ }
}
sp<ExecutionBurstServer> ExecutionBurstServer::create(
diff --git a/nn/common/include/ExecutionBurstController.h b/nn/common/include/ExecutionBurstController.h
index 7152325..33564f4 100644
--- a/nn/common/include/ExecutionBurstController.h
+++ b/nn/common/include/ExecutionBurstController.h
@@ -24,6 +24,7 @@
#include <map>
#include <memory>
#include <mutex>
+#include <stack>
#include <tuple>
#include "HalInterfaces.h"
@@ -93,11 +94,12 @@
private:
int32_t getSlotLocked(const hidl_memory& memory, intptr_t key);
+ int32_t allocateSlotLocked();
std::mutex mMutex;
- int32_t mNextSlot = 0;
- std::map<intptr_t, int32_t> mMemoryIdToSlotCache;
- std::map<int32_t, hidl_memory> mSlotToMemoryCache;
+ std::stack<int32_t, std::vector<int32_t>> mFreeSlots;
+ std::map<intptr_t, int32_t> mMemoryIdToSlot;
+ std::vector<hidl_memory> mMemoryCache;
};
public:
@@ -150,6 +152,7 @@
std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> deserialize(
const std::vector<FmqResultDatum>& data);
+ std::mutex mMutex;
const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
const sp<IBurstContext> mBurstContext;
diff --git a/nn/common/include/ExecutionBurstServer.h b/nn/common/include/ExecutionBurstServer.h
index 0a3222f..af9b076 100644
--- a/nn/common/include/ExecutionBurstServer.h
+++ b/nn/common/include/ExecutionBurstServer.h
@@ -22,8 +22,7 @@
#include <hidl/MQDescriptor.h>
#include <atomic>
#include <future>
-#include <map>
-#include <set>
+#include <vector>
#include "HalInterfaces.h"
namespace android::nn {
@@ -62,7 +61,7 @@
private:
std::mutex mMutex;
const sp<IBurstCallback> mCallback;
- std::map<int32_t, hidl_memory> mSlotToMemoryCache;
+ std::vector<hidl_memory> mMemoryCache;
};
public: