Cache ExecutionBurstServer memory resources
This change enables caching of memory resources. Prior to this CL, a
hidl_memory object in an execution burst had to be mapped and unmapped
on each execution, taking ~10us of time. With this change, a driver
could choose to map the hidl_memory when it is first retrieved, and
unmap when the burst is freed or when IContext::freeMemory is called.
Bug: 126244806
Test: mma
Test: atest NeuralNetworksTest_static
Test: atest VtsHalNeuralnetworksV1_0TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_1TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_2TargetTest (with sample-all)
Change-Id: I92d7632f2e1e2f5e76c63249d4c615c9af9d1acb
Merged-In: I92d7632f2e1e2f5e76c63249d4c615c9af9d1acb
(cherry picked from commit dfc94136aff39efdfb81436cee7e73c3f6834785)
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp
index 84bb424..08d0e80 100644
--- a/nn/common/ExecutionBurstServer.cpp
+++ b/nn/common/ExecutionBurstServer.cpp
@@ -19,124 +19,134 @@
#include "ExecutionBurstServer.h"
#include <android-base/logging.h>
-#include <set>
#include <string>
#include "Tracing.h"
namespace android::nn {
-ExecutionBurstServer::BurstMemoryCache::BurstMemoryCache(const sp<IBurstCallback>& callback)
- : mCallback(callback) {}
+namespace {
-hidl_vec<hidl_memory> ExecutionBurstServer::BurstMemoryCache::getMemories(
- const std::vector<int32_t>& slots) {
- std::lock_guard<std::mutex> guard(mMutex);
+// DefaultBurstExecutorWithCache adapts an IPreparedModel so that it can be
+// used as an IBurstExecutorWithCache. Specifically, the cache simply stores the
+// hidl_memory object, and the execution forwards calls to the provided
+// IPreparedModel's "executeSynchronously" method. With this class, hidl_memory
+// must be mapped and unmapped for each execution.
+class DefaultBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
+ public:
+ DefaultBurstExecutorWithCache(IPreparedModel* preparedModel) : mpPreparedModel(preparedModel) {}
- const auto slotIsKnown = [this](int32_t slot) {
+ bool isCacheEntryPresent(int32_t slot) const override {
return slot < mMemoryCache.size() && mMemoryCache[slot].valid();
- };
+ }
- // find unique unknown slots
- std::vector<int32_t> unknownSlots = slots;
- auto unknownSlotsEnd = unknownSlots.end();
- std::sort(unknownSlots.begin(), unknownSlotsEnd);
- unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
- unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
- unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
+ void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
+ if (slot >= mMemoryCache.size()) {
+ mMemoryCache.resize(slot + 1);
+ }
+ mMemoryCache[slot] = memory;
+ }
- // retrieve unknown slots
- if (!unknownSlots.empty()) {
- ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
- std::vector<hidl_memory> returnedMemories;
- auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
- const hidl_vec<hidl_memory>& memories) {
- errorStatus = status;
- returnedMemories = memories;
+ void removeCacheEntry(int32_t slot) override {
+ if (slot < mMemoryCache.size()) {
+ mMemoryCache[slot] = {};
+ }
+ }
+
+ std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
+ const Request& request, const std::vector<int32_t>& slots,
+ MeasureTiming measure) override {
+ // convert slots to pools
+ hidl_vec<hidl_memory> pools(slots.size());
+ std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
+ return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
+ });
+
+ // create full request
+ Request fullRequest = request;
+ fullRequest.pools = std::move(pools);
+
+ // setup execution
+ ErrorStatus returnedStatus = ErrorStatus::GENERAL_FAILURE;
+ hidl_vec<OutputShape> returnedOutputShapes;
+ Timing returnedTiming;
+ auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](
+ ErrorStatus status, const hidl_vec<OutputShape>& outputShapes,
+ const Timing& timing) {
+ returnedStatus = status;
+ returnedOutputShapes = outputShapes;
+ returnedTiming = timing;
};
- Return<void> ret = mCallback->getMemories(unknownSlots, cb);
-
- // Ensure that the memories were successfully returned.
- // IBurstCallback.hal specifies the that the number of memories returned
- // must match the number of slots requested:
- // "slots.size() == buffers.size()"
- if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
- returnedMemories.size() != unknownSlots.size()) {
- LOG(ERROR) << "Error retrieving memories";
- return {};
+ // execute
+ const Return<void> ret = mpPreparedModel->executeSynchronously(fullRequest, measure, cb);
+ if (!ret.isOk() || returnedStatus != ErrorStatus::NONE) {
+ LOG(ERROR) << "IPreparedModelAdapter::execute -- Error executing";
+ return {ErrorStatus::GENERAL_FAILURE, {}, {}};
}
- // resize cache to fit new slots if necessary
- const int32_t maxUnknownSlot = unknownSlots.back();
- if (maxUnknownSlot >= mMemoryCache.size()) {
- mMemoryCache.resize(maxUnknownSlot + 1);
- }
-
- // add memories to unknown slots
- for (size_t i = 0; i < unknownSlots.size(); ++i) {
- mMemoryCache[unknownSlots[i]] = returnedMemories[i];
- }
+ return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming);
}
- // get all slots
- hidl_vec<hidl_memory> memories(slots.size());
- std::transform(slots.begin(), slots.end(), memories.begin(),
- [this](int32_t slot) { return mMemoryCache[slot]; });
+ private:
+ IPreparedModel* const mpPreparedModel;
+ std::vector<hidl_memory> mMemoryCache;
+};
- // Ensure all slots are valid. Although this case is never expected to
- // occur, theoretically IBurstCallback::getMemories could return invalid
- // hidl_memory objects that must be protected against.
- if (!std::all_of(memories.begin(), memories.end(),
- [](const hidl_memory& memory) { return memory.valid(); })) {
- LOG(ERROR) << "Error, not all slots are valid!";
- return {};
- }
-
- return memories;
-}
-
-void ExecutionBurstServer::BurstMemoryCache::freeMemory(int32_t slot) {
- std::lock_guard<std::mutex> guard(mMutex);
- if (slot < mMemoryCache.size()) {
- mMemoryCache[slot] = {};
- }
-}
+} // anonymous namespace
sp<ExecutionBurstServer> ExecutionBurstServer::create(
const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
- const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
+ const MQDescriptorSync<FmqResultDatum>& resultChannel,
+ std::shared_ptr<IBurstExecutorWithCache> executorWithCache) {
// check inputs
- if (callback == nullptr || preparedModel == nullptr) {
+ if (callback == nullptr || executorWithCache == nullptr) {
LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
return nullptr;
}
// create FMQ objects
- std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow)
- FmqRequestChannel(requestChannel)};
- std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow)
- FmqResultChannel(resultChannel)};
+ std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
+ std::make_unique<FmqRequestChannel>(requestChannel);
+ std::unique_ptr<FmqResultChannel> fmqResultChannel =
+ std::make_unique<FmqResultChannel>(resultChannel);
// check FMQ objects
- if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
- !fmqResultChannel->isValid()) {
+ if (!fmqRequestChannel->isValid() || !fmqResultChannel->isValid()) {
LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
return nullptr;
}
// make and return context
return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
- std::move(fmqResultChannel), preparedModel);
+ std::move(fmqResultChannel), std::move(executorWithCache));
}
-ExecutionBurstServer::ExecutionBurstServer(const sp<IBurstCallback>& callback,
- std::unique_ptr<FmqRequestChannel> requestChannel,
- std::unique_ptr<FmqResultChannel> resultChannel,
- IPreparedModel* preparedModel)
- : mMemoryCache(callback),
+sp<ExecutionBurstServer> ExecutionBurstServer::create(
+ const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+ const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
+ // check relevant input
+ if (preparedModel == nullptr) {
+ LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
+ return nullptr;
+ }
+
+ // adapt IPreparedModel to have caching
+ const std::shared_ptr<DefaultBurstExecutorWithCache> preparedModelAdapter =
+ std::make_shared<DefaultBurstExecutorWithCache>(preparedModel);
+
+ // make and return context
+ return ExecutionBurstServer::create(callback, requestChannel, resultChannel,
+ preparedModelAdapter);
+}
+
+ExecutionBurstServer::ExecutionBurstServer(
+ const sp<IBurstCallback>& callback, std::unique_ptr<FmqRequestChannel> requestChannel,
+ std::unique_ptr<FmqResultChannel> resultChannel,
+ std::shared_ptr<IBurstExecutorWithCache> executorWithCache)
+ : mCallback(callback),
mFmqRequestChannel(std::move(requestChannel)),
mFmqResultChannel(std::move(resultChannel)),
- mPreparedModel(preparedModel),
+ mExecutorWithCache(std::move(executorWithCache)),
mBlocking(mFmqRequestChannel->getEventFlagWord() != nullptr) {
// TODO: highly document the threading behavior of this class
mWorker = std::async(std::launch::async, [this] { task(); });
@@ -164,6 +174,52 @@
mWorker.wait();
}
+Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
+ std::lock_guard<std::mutex> hold(mMutex);
+ mExecutorWithCache->removeCacheEntry(slot);
+ return Void();
+}
+
+void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots) {
+ const auto slotIsKnown = [this](int32_t slot) {
+ return mExecutorWithCache->isCacheEntryPresent(slot);
+ };
+
+ // find unique unknown slots
+ std::vector<int32_t> unknownSlots = slots;
+ auto unknownSlotsEnd = unknownSlots.end();
+ std::sort(unknownSlots.begin(), unknownSlotsEnd);
+ unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
+ unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
+ unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
+
+ // quick-exit if all slots are known
+ if (unknownSlots.empty()) {
+ return;
+ }
+
+ ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
+ std::vector<hidl_memory> returnedMemories;
+ auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
+ const hidl_vec<hidl_memory>& memories) {
+ errorStatus = status;
+ returnedMemories = memories;
+ };
+
+ const Return<void> ret = mCallback->getMemories(unknownSlots, cb);
+
+ if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
+ returnedMemories.size() != unknownSlots.size()) {
+ LOG(ERROR) << "Error retrieving memories";
+ return;
+ }
+
+ // add memories to unknown slots
+ for (size_t i = 0; i < unknownSlots.size(); ++i) {
+ mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]);
+ }
+}
+
bool ExecutionBurstServer::sendPacket(const std::vector<FmqResultDatum>& packet) {
if (mTeardown) {
return false;
@@ -236,17 +292,16 @@
}
// deserialize request
-std::pair<Request, MeasureTiming> ExecutionBurstServer::deserialize(
+std::tuple<Request, std::vector<int32_t>, MeasureTiming> ExecutionBurstServer::deserialize(
const std::vector<FmqRequestDatum>& data) {
using discriminator = FmqRequestDatum::hidl_discriminator;
- Request request;
size_t index = 0;
// validate packet information
if (data[index].getDiscriminator() != discriminator::packetInformation) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage packet information
@@ -264,7 +319,7 @@
// validate input operand information
if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage operand information
@@ -282,7 +337,7 @@
// validate dimension
if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage dimension
@@ -305,7 +360,7 @@
// validate output operand information
if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage operand information
@@ -323,7 +378,7 @@
// validate dimension
if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage dimension
@@ -346,7 +401,7 @@
// validate input operand information
if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage operand information
@@ -356,12 +411,11 @@
// store result
slots.push_back(poolId);
}
- hidl_vec<hidl_memory> pools = mMemoryCache.getMemories(slots);
// validate measureTiming
if (data[index].getDiscriminator() != discriminator::measureTiming) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage measureTiming
@@ -371,11 +425,11 @@
// validate packet information
if (index != packetSize) {
LOG(ERROR) << "FMQ Result packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// return request
- return {{/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/std::move(pools)}, measure};
+ return {{/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}}, std::move(slots), measure};
}
// serialize result
@@ -428,11 +482,6 @@
return data;
}
-Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
- mMemoryCache.freeMemory(slot);
- return Void();
-}
-
void ExecutionBurstServer::task() {
while (!mTeardown) {
// receive request
@@ -446,29 +495,18 @@
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
"ExecutionBurstServer processing packet and returning results");
- // continue processing
- Request request;
- MeasureTiming measure;
- std::tie(request, measure) = deserialize(requestData);
+ // continue processing; types are Request, std::vector<int32_t>, and
+ // MeasureTiming, respectively
+ const auto [requestWithoutPools, slotsOfPools, measure] = deserialize(requestData);
- // perform computation
- ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
- std::vector<OutputShape> outputShapes;
- Timing returnedTiming;
- // This call to IPreparedModel::executeSynchronously occurs entirely
- // within the same process, so ignore the Return<> errors via .isOk().
- // TODO: verify it is safe to always call isOk() here, or if there is
- // any benefit to checking any potential errors.
- mPreparedModel
- ->executeSynchronously(request, measure,
- [&errorStatus, &outputShapes, &returnedTiming](
- ErrorStatus status,
- const hidl_vec<OutputShape>& shapes, Timing timing) {
- errorStatus = status;
- outputShapes = shapes;
- returnedTiming = timing;
- })
- .isOk();
+ // ensure executor with cache has required memory
+ std::lock_guard<std::mutex> hold(mMutex);
+ ensureCacheEntriesArePresentLocked(slotsOfPools);
+
+ // perform computation; types are ErrorStatus, hidl_vec<OutputShape>,
+ // and Timing, respectively
+ const auto [errorStatus, outputShapes, returnedTiming] =
+ mExecutorWithCache->execute(requestWithoutPools, slotsOfPools, measure);
// return result
const std::vector<FmqResultDatum> result =