Cache ExecutionBurstServer memory resources
This change enables caching of memory resources. Prior to this CL, a
hidl_memory object in an execution burst had to be mapped and unmapped
on each execution, taking ~10us of time. With this change, a driver
could choose to map the hidl_memory when it is first retrieved, and
unmap when the burst is freed or when IContext::freeMemory is called.
Bug: 126244806
Test: mma
Test: atest NeuralNetworksTest_static
Test: atest VtsHalNeuralnetworksV1_0TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_1TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_2TargetTest (with sample-all)
Change-Id: I92d7632f2e1e2f5e76c63249d4c615c9af9d1acb
Merged-In: I92d7632f2e1e2f5e76c63249d4c615c9af9d1acb
(cherry picked from commit dfc94136aff39efdfb81436cee7e73c3f6834785)
diff --git a/nn/common/CpuExecutor.cpp b/nn/common/CpuExecutor.cpp
index 0150c72..e1e0178 100644
--- a/nn/common/CpuExecutor.cpp
+++ b/nn/common/CpuExecutor.cpp
@@ -244,26 +244,103 @@
} // namespace
+// Used to keep a pointer to a memory pool.
+//
+// In the case of an "mmap_fd" pool, owns the mmap region
+// returned by getBuffer() -- i.e., that region goes away
+// when the RunTimePoolInfo is destroyed or is assigned to.
+class RunTimePoolInfo::RunTimePoolInfoImpl {
+ public:
+ RunTimePoolInfoImpl(const hidl_memory& hidlMemory, uint8_t* buffer, const sp<IMemory>& memory,
+ const sp<GraphicBuffer>& graphicBuffer);
+
+ // rule of five...
+ ~RunTimePoolInfoImpl();
+ RunTimePoolInfoImpl(const RunTimePoolInfoImpl&) = delete;
+ RunTimePoolInfoImpl(RunTimePoolInfoImpl&&) noexcept = delete;
+ RunTimePoolInfoImpl& operator=(const RunTimePoolInfoImpl&) = delete;
+ RunTimePoolInfoImpl& operator=(RunTimePoolInfoImpl&&) noexcept = delete;
+
+ uint8_t* getBuffer() const { return mBuffer; }
+
+ bool update() const;
+
+ hidl_memory getHidlMemory() const { return mHidlMemory; }
+
+ private:
+ const hidl_memory mHidlMemory; // always used
+ uint8_t* const mBuffer = nullptr; // always used
+ const sp<IMemory> mMemory; // only used when hidlMemory.name() == "ashmem"
+ const sp<GraphicBuffer>
+ mGraphicBuffer; // only used when hidlMemory.name() == "hardware_buffer_blob"
+};
+
+RunTimePoolInfo::RunTimePoolInfoImpl::RunTimePoolInfoImpl(const hidl_memory& hidlMemory,
+ uint8_t* buffer,
+ const sp<IMemory>& memory,
+ const sp<GraphicBuffer>& graphicBuffer)
+ : mHidlMemory(hidlMemory), mBuffer(buffer), mMemory(memory), mGraphicBuffer(graphicBuffer) {}
+
+RunTimePoolInfo::RunTimePoolInfoImpl::~RunTimePoolInfoImpl() {
+ if (mBuffer == nullptr) {
+ return;
+ }
+
+ const std::string memType = mHidlMemory.name();
+ if (memType == "ashmem") {
+ // nothing to do
+ } else if (memType == "mmap_fd") {
+ const size_t size = mHidlMemory.size();
+ if (munmap(mBuffer, size)) {
+ LOG(ERROR) << "RunTimePoolInfoImpl::~RunTimePoolInfo(): Can't munmap";
+ }
+ } else if (memType == "hardware_buffer_blob") {
+ mGraphicBuffer->unlock();
+ } else if (memType == "") {
+ // Represents a POINTER argument; nothing to do
+ } else {
+ LOG(ERROR) << "RunTimePoolInfoImpl::~RunTimePoolInfoImpl(): unsupported hidl_memory type";
+ }
+}
+
+// Making sure the output data are correctly updated after execution.
+bool RunTimePoolInfo::RunTimePoolInfoImpl::update() const {
+ const std::string memType = mHidlMemory.name();
+ if (memType == "ashmem") {
+ mMemory->commit();
+ return true;
+ }
+ if (memType == "mmap_fd") {
+ int prot = mHidlMemory.handle()->data[1];
+ if (prot & PROT_WRITE) {
+ const size_t size = mHidlMemory.size();
+ return msync(mBuffer, size, MS_SYNC) == 0;
+ }
+ }
+ // No-op for other types of memory.
+ return true;
+}
+
// TODO: short term, make share memory mapping and updating a utility function.
// TODO: long term, implement mmap_fd as a hidl IMemory service.
-RunTimePoolInfo::RunTimePoolInfo(const hidl_memory& hidlMemory, bool* fail) {
- sp<IMemory> memory;
+std::optional<RunTimePoolInfo> RunTimePoolInfo::createFromHidlMemory(
+ const hidl_memory& hidlMemory) {
uint8_t* buffer = nullptr;
+ sp<IMemory> memory;
+ sp<GraphicBuffer> graphicBuffer;
const auto& memType = hidlMemory.name();
if (memType == "ashmem") {
memory = mapMemory(hidlMemory);
if (memory == nullptr) {
LOG(ERROR) << "Can't map shared memory.";
- if (fail) *fail = true;
- return;
+ return std::nullopt;
}
memory->update();
buffer = reinterpret_cast<uint8_t*>(static_cast<void*>(memory->getPointer()));
if (buffer == nullptr) {
LOG(ERROR) << "Can't access shared memory.";
- if (fail) *fail = true;
- return;
+ return std::nullopt;
}
} else if (memType == "mmap_fd") {
size_t size = hidlMemory.size();
@@ -273,8 +350,7 @@
buffer = static_cast<uint8_t*>(mmap(nullptr, size, prot, MAP_SHARED, fd, offset));
if (buffer == MAP_FAILED) {
LOG(ERROR) << "RunTimePoolInfo::set(): Can't mmap the file descriptor.";
- if (fail) *fail = true;
- return;
+ return std::nullopt;
}
} else if (memType == "hardware_buffer_blob") {
auto handle = hidlMemory.handle();
@@ -284,107 +360,59 @@
const uint32_t height = 1; // height is always 1 for BLOB mode AHardwareBuffer.
const uint32_t layers = 1; // layers is always 1 for BLOB mode AHardwareBuffer.
const uint32_t stride = hidlMemory.size();
- mGraphicBuffer = new GraphicBuffer(handle, GraphicBuffer::HandleWrapMethod::CLONE_HANDLE,
- width, height, format, layers, usage, stride);
+ graphicBuffer = new GraphicBuffer(handle, GraphicBuffer::HandleWrapMethod::CLONE_HANDLE,
+ width, height, format, layers, usage, stride);
void* gBuffer = nullptr;
- status_t status = mGraphicBuffer->lock(usage, &gBuffer);
+ status_t status = graphicBuffer->lock(usage, &gBuffer);
if (status != NO_ERROR) {
LOG(ERROR) << "RunTimePoolInfo Can't lock the AHardwareBuffer.";
- if (fail) *fail = true;
- return;
+ return std::nullopt;
}
buffer = static_cast<uint8_t*>(gBuffer);
} else {
LOG(ERROR) << "RunTimePoolInfo::set(): unsupported hidl_memory type";
- if (fail) *fail = true;
- return;
+ return std::nullopt;
}
- mHidlMemory = hidlMemory;
- mBuffer = buffer;
- mMemory = memory;
+ const auto impl =
+ std::make_shared<const RunTimePoolInfoImpl>(hidlMemory, buffer, memory, graphicBuffer);
+ return {RunTimePoolInfo(impl)};
}
-RunTimePoolInfo::RunTimePoolInfo(uint8_t* buffer) {
- mBuffer = buffer;
+RunTimePoolInfo RunTimePoolInfo::createFromExistingBuffer(uint8_t* buffer) {
+ const auto impl =
+ std::make_shared<const RunTimePoolInfoImpl>(hidl_memory{}, buffer, nullptr, nullptr);
+ return {impl};
}
-RunTimePoolInfo::RunTimePoolInfo(RunTimePoolInfo&& other) noexcept {
- moveFrom(std::move(other));
- other.mBuffer = nullptr;
+RunTimePoolInfo::RunTimePoolInfo(const std::shared_ptr<const RunTimePoolInfoImpl>& impl)
+ : mImpl(impl) {}
+
+uint8_t* RunTimePoolInfo::getBuffer() const {
+ return mImpl->getBuffer();
}
-RunTimePoolInfo& RunTimePoolInfo::operator=(RunTimePoolInfo&& other) noexcept {
- if (this != &other) {
- release();
- moveFrom(std::move(other));
- other.mBuffer = nullptr;
- }
- return *this;
-}
-
-void RunTimePoolInfo::moveFrom(RunTimePoolInfo&& other) {
- mHidlMemory = std::move(other.mHidlMemory);
- mBuffer = std::move(other.mBuffer);
- mMemory = std::move(other.mMemory);
-}
-
-void RunTimePoolInfo::release() {
- if (mBuffer == nullptr) {
- return;
- }
-
- auto memType = mHidlMemory.name();
- if (memType == "ashmem") {
- // nothing to do
- } else if (memType == "mmap_fd") {
- size_t size = mHidlMemory.size();
- if (munmap(mBuffer, size)) {
- LOG(ERROR) << "RunTimePoolInfo::release(): Can't munmap";
- }
- } else if (memType == "hardware_buffer_blob") {
- mGraphicBuffer->unlock();
- mGraphicBuffer = nullptr;
- } else if (memType == "") {
- // Represents a POINTER argument; nothing to do
- } else {
- LOG(ERROR) << "RunTimePoolInfo::release(): unsupported hidl_memory type";
- }
-
- mHidlMemory = hidl_memory();
- mMemory = nullptr;
- mBuffer = nullptr;
-}
-
-// Making sure the output data are correctly updated after execution.
bool RunTimePoolInfo::update() const {
- auto memType = mHidlMemory.name();
- if (memType == "ashmem") {
- mMemory->commit();
- return true;
- } else if (memType == "mmap_fd") {
- int prot = mHidlMemory.handle()->data[1];
- if (prot & PROT_WRITE) {
- size_t size = mHidlMemory.size();
- return msync(mBuffer, size, MS_SYNC) == 0;
- }
- }
- // No-op for other types of memory.
- return true;
+ return mImpl->update();
+}
+
+hidl_memory RunTimePoolInfo::getHidlMemory() const {
+ return mImpl->getHidlMemory();
}
bool setRunTimePoolInfosFromHidlMemories(std::vector<RunTimePoolInfo>* poolInfos,
const hidl_vec<hidl_memory>& pools) {
+ CHECK(poolInfos != nullptr);
poolInfos->clear();
poolInfos->reserve(pools.size());
- bool fail = false;
for (const auto& pool : pools) {
- poolInfos->emplace_back(pool, &fail);
- }
- if (fail) {
- LOG(ERROR) << "Could not map pools";
- poolInfos->clear();
- return false;
+ if (std::optional<RunTimePoolInfo> poolInfo = RunTimePoolInfo::createFromHidlMemory(pool)) {
+ poolInfos->push_back(*poolInfo);
+ } else {
+ LOG(ERROR) << "Could not map pools";
+ poolInfos->clear();
+ return false;
+ }
}
return true;
}
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp
index 84bb424..08d0e80 100644
--- a/nn/common/ExecutionBurstServer.cpp
+++ b/nn/common/ExecutionBurstServer.cpp
@@ -19,124 +19,134 @@
#include "ExecutionBurstServer.h"
#include <android-base/logging.h>
-#include <set>
#include <string>
#include "Tracing.h"
namespace android::nn {
-ExecutionBurstServer::BurstMemoryCache::BurstMemoryCache(const sp<IBurstCallback>& callback)
- : mCallback(callback) {}
+namespace {
-hidl_vec<hidl_memory> ExecutionBurstServer::BurstMemoryCache::getMemories(
- const std::vector<int32_t>& slots) {
- std::lock_guard<std::mutex> guard(mMutex);
+// DefaultBurstExecutorWithCache adapts an IPreparedModel so that it can be
+// used as an IBurstExecutorWithCache. Specifically, the cache simply stores the
+// hidl_memory object, and the execution forwards calls to the provided
+// IPreparedModel's "executeSynchronously" method. With this class, hidl_memory
+// must be mapped and unmapped for each execution.
+class DefaultBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
+ public:
+ DefaultBurstExecutorWithCache(IPreparedModel* preparedModel) : mpPreparedModel(preparedModel) {}
- const auto slotIsKnown = [this](int32_t slot) {
+ bool isCacheEntryPresent(int32_t slot) const override {
return slot < mMemoryCache.size() && mMemoryCache[slot].valid();
- };
+ }
- // find unique unknown slots
- std::vector<int32_t> unknownSlots = slots;
- auto unknownSlotsEnd = unknownSlots.end();
- std::sort(unknownSlots.begin(), unknownSlotsEnd);
- unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
- unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
- unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
+ void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
+ if (slot >= mMemoryCache.size()) {
+ mMemoryCache.resize(slot + 1);
+ }
+ mMemoryCache[slot] = memory;
+ }
- // retrieve unknown slots
- if (!unknownSlots.empty()) {
- ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
- std::vector<hidl_memory> returnedMemories;
- auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
- const hidl_vec<hidl_memory>& memories) {
- errorStatus = status;
- returnedMemories = memories;
+ void removeCacheEntry(int32_t slot) override {
+ if (slot < mMemoryCache.size()) {
+ mMemoryCache[slot] = {};
+ }
+ }
+
+ std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
+ const Request& request, const std::vector<int32_t>& slots,
+ MeasureTiming measure) override {
+ // convert slots to pools
+ hidl_vec<hidl_memory> pools(slots.size());
+ std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
+ return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
+ });
+
+ // create full request
+ Request fullRequest = request;
+ fullRequest.pools = std::move(pools);
+
+ // setup execution
+ ErrorStatus returnedStatus = ErrorStatus::GENERAL_FAILURE;
+ hidl_vec<OutputShape> returnedOutputShapes;
+ Timing returnedTiming;
+ auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](
+ ErrorStatus status, const hidl_vec<OutputShape>& outputShapes,
+ const Timing& timing) {
+ returnedStatus = status;
+ returnedOutputShapes = outputShapes;
+ returnedTiming = timing;
};
- Return<void> ret = mCallback->getMemories(unknownSlots, cb);
-
- // Ensure that the memories were successfully returned.
- // IBurstCallback.hal specifies the that the number of memories returned
- // must match the number of slots requested:
- // "slots.size() == buffers.size()"
- if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
- returnedMemories.size() != unknownSlots.size()) {
- LOG(ERROR) << "Error retrieving memories";
- return {};
+ // execute
+ const Return<void> ret = mpPreparedModel->executeSynchronously(fullRequest, measure, cb);
+ if (!ret.isOk() || returnedStatus != ErrorStatus::NONE) {
+ LOG(ERROR) << "IPreparedModelAdapter::execute -- Error executing";
+ return {ErrorStatus::GENERAL_FAILURE, {}, {}};
}
- // resize cache to fit new slots if necessary
- const int32_t maxUnknownSlot = unknownSlots.back();
- if (maxUnknownSlot >= mMemoryCache.size()) {
- mMemoryCache.resize(maxUnknownSlot + 1);
- }
-
- // add memories to unknown slots
- for (size_t i = 0; i < unknownSlots.size(); ++i) {
- mMemoryCache[unknownSlots[i]] = returnedMemories[i];
- }
+ return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming);
}
- // get all slots
- hidl_vec<hidl_memory> memories(slots.size());
- std::transform(slots.begin(), slots.end(), memories.begin(),
- [this](int32_t slot) { return mMemoryCache[slot]; });
+ private:
+ IPreparedModel* const mpPreparedModel;
+ std::vector<hidl_memory> mMemoryCache;
+};
- // Ensure all slots are valid. Although this case is never expected to
- // occur, theoretically IBurstCallback::getMemories could return invalid
- // hidl_memory objects that must be protected against.
- if (!std::all_of(memories.begin(), memories.end(),
- [](const hidl_memory& memory) { return memory.valid(); })) {
- LOG(ERROR) << "Error, not all slots are valid!";
- return {};
- }
-
- return memories;
-}
-
-void ExecutionBurstServer::BurstMemoryCache::freeMemory(int32_t slot) {
- std::lock_guard<std::mutex> guard(mMutex);
- if (slot < mMemoryCache.size()) {
- mMemoryCache[slot] = {};
- }
-}
+} // anonymous namespace
sp<ExecutionBurstServer> ExecutionBurstServer::create(
const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
- const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
+ const MQDescriptorSync<FmqResultDatum>& resultChannel,
+ std::shared_ptr<IBurstExecutorWithCache> executorWithCache) {
// check inputs
- if (callback == nullptr || preparedModel == nullptr) {
+ if (callback == nullptr || executorWithCache == nullptr) {
LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
return nullptr;
}
// create FMQ objects
- std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow)
- FmqRequestChannel(requestChannel)};
- std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow)
- FmqResultChannel(resultChannel)};
+ std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
+ std::make_unique<FmqRequestChannel>(requestChannel);
+ std::unique_ptr<FmqResultChannel> fmqResultChannel =
+ std::make_unique<FmqResultChannel>(resultChannel);
// check FMQ objects
- if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
- !fmqResultChannel->isValid()) {
+ if (!fmqRequestChannel->isValid() || !fmqResultChannel->isValid()) {
LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
return nullptr;
}
// make and return context
return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
- std::move(fmqResultChannel), preparedModel);
+ std::move(fmqResultChannel), std::move(executorWithCache));
}
-ExecutionBurstServer::ExecutionBurstServer(const sp<IBurstCallback>& callback,
- std::unique_ptr<FmqRequestChannel> requestChannel,
- std::unique_ptr<FmqResultChannel> resultChannel,
- IPreparedModel* preparedModel)
- : mMemoryCache(callback),
+sp<ExecutionBurstServer> ExecutionBurstServer::create(
+ const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+ const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
+ // check relevant input
+ if (preparedModel == nullptr) {
+ LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
+ return nullptr;
+ }
+
+ // adapt IPreparedModel to have caching
+ const std::shared_ptr<DefaultBurstExecutorWithCache> preparedModelAdapter =
+ std::make_shared<DefaultBurstExecutorWithCache>(preparedModel);
+
+ // make and return context
+ return ExecutionBurstServer::create(callback, requestChannel, resultChannel,
+ preparedModelAdapter);
+}
+
+ExecutionBurstServer::ExecutionBurstServer(
+ const sp<IBurstCallback>& callback, std::unique_ptr<FmqRequestChannel> requestChannel,
+ std::unique_ptr<FmqResultChannel> resultChannel,
+ std::shared_ptr<IBurstExecutorWithCache> executorWithCache)
+ : mCallback(callback),
mFmqRequestChannel(std::move(requestChannel)),
mFmqResultChannel(std::move(resultChannel)),
- mPreparedModel(preparedModel),
+ mExecutorWithCache(std::move(executorWithCache)),
mBlocking(mFmqRequestChannel->getEventFlagWord() != nullptr) {
// TODO: highly document the threading behavior of this class
mWorker = std::async(std::launch::async, [this] { task(); });
@@ -164,6 +174,52 @@
mWorker.wait();
}
+Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
+ std::lock_guard<std::mutex> hold(mMutex);
+ mExecutorWithCache->removeCacheEntry(slot);
+ return Void();
+}
+
+void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots) {
+ const auto slotIsKnown = [this](int32_t slot) {
+ return mExecutorWithCache->isCacheEntryPresent(slot);
+ };
+
+ // find unique unknown slots
+ std::vector<int32_t> unknownSlots = slots;
+ auto unknownSlotsEnd = unknownSlots.end();
+ std::sort(unknownSlots.begin(), unknownSlotsEnd);
+ unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
+ unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
+ unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
+
+ // quick-exit if all slots are known
+ if (unknownSlots.empty()) {
+ return;
+ }
+
+ ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
+ std::vector<hidl_memory> returnedMemories;
+ auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
+ const hidl_vec<hidl_memory>& memories) {
+ errorStatus = status;
+ returnedMemories = memories;
+ };
+
+ const Return<void> ret = mCallback->getMemories(unknownSlots, cb);
+
+ if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
+ returnedMemories.size() != unknownSlots.size()) {
+ LOG(ERROR) << "Error retrieving memories";
+ return;
+ }
+
+ // add memories to unknown slots
+ for (size_t i = 0; i < unknownSlots.size(); ++i) {
+ mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]);
+ }
+}
+
bool ExecutionBurstServer::sendPacket(const std::vector<FmqResultDatum>& packet) {
if (mTeardown) {
return false;
@@ -236,17 +292,16 @@
}
// deserialize request
-std::pair<Request, MeasureTiming> ExecutionBurstServer::deserialize(
+std::tuple<Request, std::vector<int32_t>, MeasureTiming> ExecutionBurstServer::deserialize(
const std::vector<FmqRequestDatum>& data) {
using discriminator = FmqRequestDatum::hidl_discriminator;
- Request request;
size_t index = 0;
// validate packet information
if (data[index].getDiscriminator() != discriminator::packetInformation) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage packet information
@@ -264,7 +319,7 @@
// validate input operand information
if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage operand information
@@ -282,7 +337,7 @@
// validate dimension
if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage dimension
@@ -305,7 +360,7 @@
// validate output operand information
if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage operand information
@@ -323,7 +378,7 @@
// validate dimension
if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage dimension
@@ -346,7 +401,7 @@
// validate input operand information
if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage operand information
@@ -356,12 +411,11 @@
// store result
slots.push_back(poolId);
}
- hidl_vec<hidl_memory> pools = mMemoryCache.getMemories(slots);
// validate measureTiming
if (data[index].getDiscriminator() != discriminator::measureTiming) {
LOG(ERROR) << "FMQ Request packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// unpackage measureTiming
@@ -371,11 +425,11 @@
// validate packet information
if (index != packetSize) {
LOG(ERROR) << "FMQ Result packet ill-formed";
- return {{}, MeasureTiming::NO};
+ return {{}, {}, MeasureTiming::NO};
}
// return request
- return {{/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/std::move(pools)}, measure};
+ return {{/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}}, std::move(slots), measure};
}
// serialize result
@@ -428,11 +482,6 @@
return data;
}
-Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
- mMemoryCache.freeMemory(slot);
- return Void();
-}
-
void ExecutionBurstServer::task() {
while (!mTeardown) {
// receive request
@@ -446,29 +495,18 @@
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
"ExecutionBurstServer processing packet and returning results");
- // continue processing
- Request request;
- MeasureTiming measure;
- std::tie(request, measure) = deserialize(requestData);
+ // continue processing; types are Request, std::vector<int32_t>, and
+ // MeasureTiming, respectively
+ const auto [requestWithoutPools, slotsOfPools, measure] = deserialize(requestData);
- // perform computation
- ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
- std::vector<OutputShape> outputShapes;
- Timing returnedTiming;
- // This call to IPreparedModel::executeSynchronously occurs entirely
- // within the same process, so ignore the Return<> errors via .isOk().
- // TODO: verify it is safe to always call isOk() here, or if there is
- // any benefit to checking any potential errors.
- mPreparedModel
- ->executeSynchronously(request, measure,
- [&errorStatus, &outputShapes, &returnedTiming](
- ErrorStatus status,
- const hidl_vec<OutputShape>& shapes, Timing timing) {
- errorStatus = status;
- outputShapes = shapes;
- returnedTiming = timing;
- })
- .isOk();
+ // ensure executor with cache has required memory
+ std::lock_guard<std::mutex> hold(mMutex);
+ ensureCacheEntriesArePresentLocked(slotsOfPools);
+
+ // perform computation; types are ErrorStatus, hidl_vec<OutputShape>,
+ // and Timing, respectively
+ const auto [errorStatus, outputShapes, returnedTiming] =
+ mExecutorWithCache->execute(requestWithoutPools, slotsOfPools, measure);
// return result
const std::vector<FmqResultDatum> result =
diff --git a/nn/common/include/CpuExecutor.h b/nn/common/include/CpuExecutor.h
index 7a36ebd..6a551d2 100644
--- a/nn/common/include/CpuExecutor.h
+++ b/nn/common/include/CpuExecutor.h
@@ -25,6 +25,7 @@
#include <android-base/macros.h>
#include <ui/GraphicBuffer.h>
#include <algorithm>
+#include <optional>
#include <vector>
namespace android {
@@ -91,40 +92,29 @@
// Used to keep a pointer to each of the memory pools.
//
-// In the case of an "mmap_fd" pool, owns the mmap region
-// returned by getBuffer() -- i.e., that region goes away
-// when the RunTimePoolInfo is destroyed or is assigned to.
+// RunTimePoolInfo references a region of memory. Other RunTimePoolInfo objects
+// may reference the same region of memory by either:
+// (1) copying an existing RunTimePoolInfo object, or
+// (2) creating multiple RunTimePoolInfo objects from the same memory resource
+// (e.g., "createFromHidlMemory" or "createFromExistingBuffer")
+//
+// If the underlying region of memory is mapped by "createFromHidlMemory", the
+// mapping will be sustained until it is no longer referenced by any
+// RunTimePoolInfo objects.
class RunTimePoolInfo {
-public:
- // If "fail" is not nullptr, and construction fails, then set *fail = true.
- // If construction succeeds, leave *fail unchanged.
- // getBuffer() == nullptr IFF construction fails.
- explicit RunTimePoolInfo(const hidl_memory& hidlMemory, bool* fail);
+ public:
+ static std::optional<RunTimePoolInfo> createFromHidlMemory(const hidl_memory& hidlMemory);
+ static RunTimePoolInfo createFromExistingBuffer(uint8_t* buffer);
- explicit RunTimePoolInfo(uint8_t* buffer);
-
- // Implement move
- RunTimePoolInfo(RunTimePoolInfo&& other) noexcept;
- RunTimePoolInfo& operator=(RunTimePoolInfo&& other) noexcept;
-
- // Forbid copy
- RunTimePoolInfo(const RunTimePoolInfo&) = delete;
- RunTimePoolInfo& operator=(const RunTimePoolInfo&) = delete;
-
- ~RunTimePoolInfo() { release(); }
-
- uint8_t* getBuffer() const { return mBuffer; }
-
+ uint8_t* getBuffer() const;
bool update() const;
+ hidl_memory getHidlMemory() const;
-private:
- void release();
- void moveFrom(RunTimePoolInfo&& other);
+ private:
+ class RunTimePoolInfoImpl;
+ RunTimePoolInfo(const std::shared_ptr<const RunTimePoolInfoImpl>& impl);
- hidl_memory mHidlMemory; // always used
- uint8_t* mBuffer = nullptr; // always used
- sp<IMemory> mMemory; // only used when hidlMemory.name() == "ashmem"
- sp<GraphicBuffer> mGraphicBuffer; // only used when hidlMemory.name() == "hardware_buffer_blob"
+ std::shared_ptr<const RunTimePoolInfoImpl> mImpl;
};
bool setRunTimePoolInfosFromHidlMemories(std::vector<RunTimePoolInfo>* poolInfos,
diff --git a/nn/common/include/ExecutionBurstServer.h b/nn/common/include/ExecutionBurstServer.h
index af9b076..4338f3e 100644
--- a/nn/common/include/ExecutionBurstServer.h
+++ b/nn/common/include/ExecutionBurstServer.h
@@ -22,6 +22,7 @@
#include <hidl/MQDescriptor.h>
#include <atomic>
#include <future>
+#include <memory>
#include <vector>
#include "HalInterfaces.h"
@@ -43,28 +44,96 @@
class ExecutionBurstServer : public IBurstContext {
DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstServer);
+ public:
/**
- * BurstMemoryCache is responsible for managing the local memory cache of
- * the burst object. If the ExecutionBurstServer requests a memory key that
- * is unrecognized, the BurstMemoryCache object will retrieve the memory
- * from the client, transparent from the ExecutionBurstServer object.
+ * IBurstExecutorWithCache is a callback object passed to
+ * ExecutionBurstServer's factory function that is used to perform an
+ * execution. Because some memory resources are needed across multiple
+ * executions, this object also contains a local cache that can directly be
+ * used in the execution.
+ *
+ * ExecutionBurstServer will never access its IBurstExecutorWithCache object
+ * with concurrent calls.
*/
- class BurstMemoryCache {
- DISALLOW_IMPLICIT_CONSTRUCTORS(BurstMemoryCache);
+ class IBurstExecutorWithCache {
+ DISALLOW_COPY_AND_ASSIGN(IBurstExecutorWithCache);
public:
- BurstMemoryCache(const sp<IBurstCallback>& callback);
+ IBurstExecutorWithCache() = default;
+ virtual ~IBurstExecutorWithCache() = default;
- hidl_vec<hidl_memory> getMemories(const std::vector<int32_t>& slots);
- void freeMemory(int32_t slot);
+ /**
+ * Checks if a cache entry specified by a slot is present in the cache.
+ *
+ * @param slot Identifier of the cache entry.
+ * @return 'true' if the cache entry is present in the cache, 'false'
+ * otherwise.
+ */
+ virtual bool isCacheEntryPresent(int32_t slot) const = 0;
- private:
- std::mutex mMutex;
- const sp<IBurstCallback> mCallback;
- std::vector<hidl_memory> mMemoryCache;
+ /**
+ * Adds an entry specified by a slot to the cache.
+ *
+ * The caller of this function must ensure that the cache entry that is
+ * being added is not already present in the cache. This can be checked
+ * via isCacheEntryPresent.
+ *
+ * @param memory Memory resource to be cached.
+ * @param slot Slot identifier corresponding to the memory resource.
+ */
+ virtual void addCacheEntry(const hidl_memory& memory, int32_t slot) = 0;
+
+ /**
+ * Removes an entry specified by a slot from the cache.
+ *
+ * If the cache entry corresponding to the slot number does not exist,
+ * the call does nothing.
+ *
+ * @param slot Slot identifier corresponding to the memory resource.
+ */
+ virtual void removeCacheEntry(int32_t slot) = 0;
+
+ /**
+ * Perform an execution.
+ *
+ * @param request Request object with inputs and outputs specified.
+ * Request::pools is empty, and DataLocation::poolIndex instead
+ * refers to the 'slots' argument as if it were Request::pools.
+ * @param slots Slots corresponding to the cached memory entries to be
+ * used.
+ * @param measure Whether timing information is requested for the
+ * execution.
+ * @return Result of the execution, including the status of the
+ * execution, dynamic output shapes, and any timing information.
+ */
+ virtual std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
+ const Request& request, const std::vector<int32_t>& slots,
+ MeasureTiming measure) = 0;
};
- public:
+ /**
+ * Create automated context to manage FMQ-based executions.
+ *
+ * This function is intended to be used by a service to automatically:
+ * 1) Receive data from a provided FMQ
+ * 2) Execute a model with the given information
+ * 3) Send the result to the created FMQ
+ *
+ * @param callback Callback used to retrieve memories corresponding to
+ * unrecognized slots.
+ * @param requestChannel Input FMQ channel through which the client passes the
+ * request to the service.
+ * @param resultChannel Output FMQ channel from which the client can retrieve
+ * the result of the execution.
+ * @param executorWithCache Object which maintains a local cache of the
+ * memory pools and executes using the cached memory pools.
+ * @result IBurstContext Handle to the burst context.
+ */
+ static sp<ExecutionBurstServer> create(
+ const sp<IBurstCallback>& callback, const FmqRequestDescriptor& requestChannel,
+ const FmqResultDescriptor& resultChannel,
+ std::shared_ptr<IBurstExecutorWithCache> executorWithCache);
+
/**
* Create automated context to manage FMQ-based executions.
*
@@ -80,7 +149,8 @@
* @param resultChannel Output FMQ channel from which the client can retrieve
* the result of the execution.
* @param preparedModel PreparedModel that the burst object was created from.
- * This will be used to synchronously perform the execution.
+ * IPreparedModel::executeSynchronously will be used to perform the
+ * execution.
* @result IBurstContext Handle to the burst context.
*/
static sp<ExecutionBurstServer> create(const sp<IBurstCallback>& callback,
@@ -91,7 +161,7 @@
ExecutionBurstServer(const sp<IBurstCallback>& callback,
std::unique_ptr<FmqRequestChannel> requestChannel,
std::unique_ptr<FmqResultChannel> resultChannel,
- IPreparedModel* preparedModel);
+ std::shared_ptr<IBurstExecutorWithCache> cachedExecutor);
~ExecutionBurstServer();
Return<void> freeMemory(int32_t slot) override;
@@ -102,15 +172,32 @@
std::vector<FmqResultDatum> serialize(ErrorStatus errorStatus,
const std::vector<OutputShape>& outputShapes,
Timing timing);
- std::pair<Request, MeasureTiming> deserialize(const std::vector<FmqRequestDatum>& data);
+
+ // Deserializes the request FMQ data. The three resulting fields are the
+ // Request object (where Request::pools is empty), slot identifiers (which
+ // are stand-ins for Request::pools), and whether timing information must
+ // be collected for the run.
+ std::tuple<Request, std::vector<int32_t>, MeasureTiming> deserialize(
+ const std::vector<FmqRequestDatum>& data);
+
+ // Ensures all cache entries contained in mExecutorWithCache are present in
+ // the cache. If they are not present, they are retrieved (via
+ // IBurstCallback::getMemories) and added to mExecutorWithCache.
+ //
+ // This method is locked via mMutex when it is called.
+ void ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots);
+
+ // Work loop that will continue processing execution requests until the
+ // ExecutionBurstServer object is freed.
void task();
- BurstMemoryCache mMemoryCache;
+ std::mutex mMutex;
std::atomic<bool> mTeardown{false};
std::future<void> mWorker;
+ const sp<IBurstCallback> mCallback;
const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
- IPreparedModel* mPreparedModel;
+ const std::shared_ptr<IBurstExecutorWithCache> mExecutorWithCache;
const bool mBlocking;
};