Cache ExecutionBurstServer memory resources

This change enables caching of memory resources. Prior to this CL, a
hidl_memory object in an execution burst had to be mapped and unmapped
on each execution, taking ~10us of time. With this change, a driver
could choose to map the hidl_memory when it is first retrieved, and
unmap when the burst is freed or when IContext::freeMemory is called.

Bug: 126244806
Test: mma
Test: atest NeuralNetworksTest_static
Test: atest VtsHalNeuralnetworksV1_0TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_1TargetTest (with sample-all)
Test: atest VtsHalNeuralnetworksV1_2TargetTest (with sample-all)
Change-Id: I92d7632f2e1e2f5e76c63249d4c615c9af9d1acb
Merged-In: I92d7632f2e1e2f5e76c63249d4c615c9af9d1acb
(cherry picked from commit dfc94136aff39efdfb81436cee7e73c3f6834785)
diff --git a/nn/common/CpuExecutor.cpp b/nn/common/CpuExecutor.cpp
index 0150c72..e1e0178 100644
--- a/nn/common/CpuExecutor.cpp
+++ b/nn/common/CpuExecutor.cpp
@@ -244,26 +244,103 @@
 
 }  // namespace
 
+// Used to keep a pointer to a memory pool.
+//
+// In the case of an "mmap_fd" pool, owns the mmap region
+// returned by getBuffer() -- i.e., that region goes away
+// when the RunTimePoolInfo is destroyed or is assigned to.
+class RunTimePoolInfo::RunTimePoolInfoImpl {
+   public:
+    RunTimePoolInfoImpl(const hidl_memory& hidlMemory, uint8_t* buffer, const sp<IMemory>& memory,
+                        const sp<GraphicBuffer>& graphicBuffer);
+
+    // rule of five...
+    ~RunTimePoolInfoImpl();
+    RunTimePoolInfoImpl(const RunTimePoolInfoImpl&) = delete;
+    RunTimePoolInfoImpl(RunTimePoolInfoImpl&&) noexcept = delete;
+    RunTimePoolInfoImpl& operator=(const RunTimePoolInfoImpl&) = delete;
+    RunTimePoolInfoImpl& operator=(RunTimePoolInfoImpl&&) noexcept = delete;
+
+    uint8_t* getBuffer() const { return mBuffer; }
+
+    bool update() const;
+
+    hidl_memory getHidlMemory() const { return mHidlMemory; }
+
+   private:
+    const hidl_memory mHidlMemory;     // always used
+    uint8_t* const mBuffer = nullptr;  // always used
+    const sp<IMemory> mMemory;         // only used when hidlMemory.name() == "ashmem"
+    const sp<GraphicBuffer>
+            mGraphicBuffer;  // only used when hidlMemory.name() == "hardware_buffer_blob"
+};
+
+RunTimePoolInfo::RunTimePoolInfoImpl::RunTimePoolInfoImpl(const hidl_memory& hidlMemory,
+                                                          uint8_t* buffer,
+                                                          const sp<IMemory>& memory,
+                                                          const sp<GraphicBuffer>& graphicBuffer)
+    : mHidlMemory(hidlMemory), mBuffer(buffer), mMemory(memory), mGraphicBuffer(graphicBuffer) {}
+
+RunTimePoolInfo::RunTimePoolInfoImpl::~RunTimePoolInfoImpl() {
+    if (mBuffer == nullptr) {
+        return;
+    }
+
+    const std::string memType = mHidlMemory.name();
+    if (memType == "ashmem") {
+        // nothing to do
+    } else if (memType == "mmap_fd") {
+        const size_t size = mHidlMemory.size();
+        if (munmap(mBuffer, size)) {
+            LOG(ERROR) << "RunTimePoolInfoImpl::~RunTimePoolInfo(): Can't munmap";
+        }
+    } else if (memType == "hardware_buffer_blob") {
+        mGraphicBuffer->unlock();
+    } else if (memType == "") {
+        // Represents a POINTER argument; nothing to do
+    } else {
+        LOG(ERROR) << "RunTimePoolInfoImpl::~RunTimePoolInfoImpl(): unsupported hidl_memory type";
+    }
+}
+
+// Making sure the output data are correctly updated after execution.
+bool RunTimePoolInfo::RunTimePoolInfoImpl::update() const {
+    const std::string memType = mHidlMemory.name();
+    if (memType == "ashmem") {
+        mMemory->commit();
+        return true;
+    }
+    if (memType == "mmap_fd") {
+        int prot = mHidlMemory.handle()->data[1];
+        if (prot & PROT_WRITE) {
+            const size_t size = mHidlMemory.size();
+            return msync(mBuffer, size, MS_SYNC) == 0;
+        }
+    }
+    // No-op for other types of memory.
+    return true;
+}
+
 // TODO: short term, make share memory mapping and updating a utility function.
 // TODO: long term, implement mmap_fd as a hidl IMemory service.
-RunTimePoolInfo::RunTimePoolInfo(const hidl_memory& hidlMemory, bool* fail) {
-    sp<IMemory> memory;
+std::optional<RunTimePoolInfo> RunTimePoolInfo::createFromHidlMemory(
+        const hidl_memory& hidlMemory) {
     uint8_t* buffer = nullptr;
+    sp<IMemory> memory;
+    sp<GraphicBuffer> graphicBuffer;
 
     const auto& memType = hidlMemory.name();
     if (memType == "ashmem") {
         memory = mapMemory(hidlMemory);
         if (memory == nullptr) {
             LOG(ERROR) << "Can't map shared memory.";
-            if (fail) *fail = true;
-            return;
+            return std::nullopt;
         }
         memory->update();
         buffer = reinterpret_cast<uint8_t*>(static_cast<void*>(memory->getPointer()));
         if (buffer == nullptr) {
             LOG(ERROR) << "Can't access shared memory.";
-            if (fail) *fail = true;
-            return;
+            return std::nullopt;
         }
     } else if (memType == "mmap_fd") {
         size_t size = hidlMemory.size();
@@ -273,8 +350,7 @@
         buffer = static_cast<uint8_t*>(mmap(nullptr, size, prot, MAP_SHARED, fd, offset));
         if (buffer == MAP_FAILED) {
             LOG(ERROR) << "RunTimePoolInfo::set(): Can't mmap the file descriptor.";
-            if (fail) *fail = true;
-            return;
+            return std::nullopt;
         }
     } else if (memType == "hardware_buffer_blob") {
         auto handle = hidlMemory.handle();
@@ -284,107 +360,59 @@
         const uint32_t height = 1;  // height is always 1 for BLOB mode AHardwareBuffer.
         const uint32_t layers = 1;  // layers is always 1 for BLOB mode AHardwareBuffer.
         const uint32_t stride = hidlMemory.size();
-        mGraphicBuffer = new GraphicBuffer(handle, GraphicBuffer::HandleWrapMethod::CLONE_HANDLE,
-                                           width, height, format, layers, usage, stride);
+        graphicBuffer = new GraphicBuffer(handle, GraphicBuffer::HandleWrapMethod::CLONE_HANDLE,
+                                          width, height, format, layers, usage, stride);
         void* gBuffer = nullptr;
-        status_t status = mGraphicBuffer->lock(usage, &gBuffer);
+        status_t status = graphicBuffer->lock(usage, &gBuffer);
         if (status != NO_ERROR) {
             LOG(ERROR) << "RunTimePoolInfo Can't lock the AHardwareBuffer.";
-            if (fail) *fail = true;
-            return;
+            return std::nullopt;
         }
         buffer = static_cast<uint8_t*>(gBuffer);
     } else {
         LOG(ERROR) << "RunTimePoolInfo::set(): unsupported hidl_memory type";
-        if (fail) *fail = true;
-        return;
+        return std::nullopt;
     }
 
-    mHidlMemory = hidlMemory;
-    mBuffer = buffer;
-    mMemory = memory;
+    const auto impl =
+            std::make_shared<const RunTimePoolInfoImpl>(hidlMemory, buffer, memory, graphicBuffer);
+    return {RunTimePoolInfo(impl)};
 }
 
-RunTimePoolInfo::RunTimePoolInfo(uint8_t* buffer) {
-    mBuffer = buffer;
+RunTimePoolInfo RunTimePoolInfo::createFromExistingBuffer(uint8_t* buffer) {
+    const auto impl =
+            std::make_shared<const RunTimePoolInfoImpl>(hidl_memory{}, buffer, nullptr, nullptr);
+    return {impl};
 }
 
-RunTimePoolInfo::RunTimePoolInfo(RunTimePoolInfo&& other) noexcept {
-    moveFrom(std::move(other));
-    other.mBuffer = nullptr;
+RunTimePoolInfo::RunTimePoolInfo(const std::shared_ptr<const RunTimePoolInfoImpl>& impl)
+    : mImpl(impl) {}
+
+uint8_t* RunTimePoolInfo::getBuffer() const {
+    return mImpl->getBuffer();
 }
 
-RunTimePoolInfo& RunTimePoolInfo::operator=(RunTimePoolInfo&& other) noexcept {
-    if (this != &other) {
-        release();
-        moveFrom(std::move(other));
-        other.mBuffer = nullptr;
-    }
-    return *this;
-}
-
-void RunTimePoolInfo::moveFrom(RunTimePoolInfo&& other) {
-    mHidlMemory = std::move(other.mHidlMemory);
-    mBuffer = std::move(other.mBuffer);
-    mMemory = std::move(other.mMemory);
-}
-
-void RunTimePoolInfo::release() {
-    if (mBuffer == nullptr) {
-        return;
-    }
-
-    auto memType = mHidlMemory.name();
-    if (memType == "ashmem") {
-        // nothing to do
-    } else if (memType == "mmap_fd") {
-        size_t size = mHidlMemory.size();
-        if (munmap(mBuffer, size)) {
-            LOG(ERROR) << "RunTimePoolInfo::release(): Can't munmap";
-        }
-    } else if (memType == "hardware_buffer_blob") {
-        mGraphicBuffer->unlock();
-        mGraphicBuffer = nullptr;
-    } else if (memType == "") {
-        // Represents a POINTER argument; nothing to do
-    } else {
-        LOG(ERROR) << "RunTimePoolInfo::release(): unsupported hidl_memory type";
-    }
-
-    mHidlMemory = hidl_memory();
-    mMemory = nullptr;
-    mBuffer = nullptr;
-}
-
-// Making sure the output data are correctly updated after execution.
 bool RunTimePoolInfo::update() const {
-    auto memType = mHidlMemory.name();
-    if (memType == "ashmem") {
-        mMemory->commit();
-        return true;
-    } else if (memType == "mmap_fd") {
-        int prot = mHidlMemory.handle()->data[1];
-        if (prot & PROT_WRITE) {
-            size_t size = mHidlMemory.size();
-            return msync(mBuffer, size, MS_SYNC) == 0;
-        }
-    }
-    // No-op for other types of memory.
-    return true;
+    return mImpl->update();
+}
+
+hidl_memory RunTimePoolInfo::getHidlMemory() const {
+    return mImpl->getHidlMemory();
 }
 
 bool setRunTimePoolInfosFromHidlMemories(std::vector<RunTimePoolInfo>* poolInfos,
                                          const hidl_vec<hidl_memory>& pools) {
+    CHECK(poolInfos != nullptr);
     poolInfos->clear();
     poolInfos->reserve(pools.size());
-    bool fail = false;
     for (const auto& pool : pools) {
-        poolInfos->emplace_back(pool, &fail);
-    }
-    if (fail) {
-        LOG(ERROR) << "Could not map pools";
-        poolInfos->clear();
-        return false;
+        if (std::optional<RunTimePoolInfo> poolInfo = RunTimePoolInfo::createFromHidlMemory(pool)) {
+            poolInfos->push_back(*poolInfo);
+        } else {
+            LOG(ERROR) << "Could not map pools";
+            poolInfos->clear();
+            return false;
+        }
     }
     return true;
 }
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp
index 84bb424..08d0e80 100644
--- a/nn/common/ExecutionBurstServer.cpp
+++ b/nn/common/ExecutionBurstServer.cpp
@@ -19,124 +19,134 @@
 #include "ExecutionBurstServer.h"
 
 #include <android-base/logging.h>
-#include <set>
 #include <string>
 #include "Tracing.h"
 
 namespace android::nn {
 
-ExecutionBurstServer::BurstMemoryCache::BurstMemoryCache(const sp<IBurstCallback>& callback)
-    : mCallback(callback) {}
+namespace {
 
-hidl_vec<hidl_memory> ExecutionBurstServer::BurstMemoryCache::getMemories(
-        const std::vector<int32_t>& slots) {
-    std::lock_guard<std::mutex> guard(mMutex);
+// DefaultBurstExecutorWithCache adapts an IPreparedModel so that it can be
+// used as an IBurstExecutorWithCache. Specifically, the cache simply stores the
+// hidl_memory object, and the execution forwards calls to the provided
+// IPreparedModel's "executeSynchronously" method. With this class, hidl_memory
+// must be mapped and unmapped for each execution.
+class DefaultBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
+   public:
+    DefaultBurstExecutorWithCache(IPreparedModel* preparedModel) : mpPreparedModel(preparedModel) {}
 
-    const auto slotIsKnown = [this](int32_t slot) {
+    bool isCacheEntryPresent(int32_t slot) const override {
         return slot < mMemoryCache.size() && mMemoryCache[slot].valid();
-    };
+    }
 
-    // find unique unknown slots
-    std::vector<int32_t> unknownSlots = slots;
-    auto unknownSlotsEnd = unknownSlots.end();
-    std::sort(unknownSlots.begin(), unknownSlotsEnd);
-    unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
-    unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
-    unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
+    void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
+        if (slot >= mMemoryCache.size()) {
+            mMemoryCache.resize(slot + 1);
+        }
+        mMemoryCache[slot] = memory;
+    }
 
-    // retrieve unknown slots
-    if (!unknownSlots.empty()) {
-        ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
-        std::vector<hidl_memory> returnedMemories;
-        auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
-                                                    const hidl_vec<hidl_memory>& memories) {
-            errorStatus = status;
-            returnedMemories = memories;
+    void removeCacheEntry(int32_t slot) override {
+        if (slot < mMemoryCache.size()) {
+            mMemoryCache[slot] = {};
+        }
+    }
+
+    std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
+            const Request& request, const std::vector<int32_t>& slots,
+            MeasureTiming measure) override {
+        // convert slots to pools
+        hidl_vec<hidl_memory> pools(slots.size());
+        std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
+            return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
+        });
+
+        // create full request
+        Request fullRequest = request;
+        fullRequest.pools = std::move(pools);
+
+        // setup execution
+        ErrorStatus returnedStatus = ErrorStatus::GENERAL_FAILURE;
+        hidl_vec<OutputShape> returnedOutputShapes;
+        Timing returnedTiming;
+        auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](
+                          ErrorStatus status, const hidl_vec<OutputShape>& outputShapes,
+                          const Timing& timing) {
+            returnedStatus = status;
+            returnedOutputShapes = outputShapes;
+            returnedTiming = timing;
         };
 
-        Return<void> ret = mCallback->getMemories(unknownSlots, cb);
-
-        // Ensure that the memories were successfully returned.
-        // IBurstCallback.hal specifies the that the number of memories returned
-        // must match the number of slots requested:
-        //     "slots.size() == buffers.size()"
-        if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
-            returnedMemories.size() != unknownSlots.size()) {
-            LOG(ERROR) << "Error retrieving memories";
-            return {};
+        // execute
+        const Return<void> ret = mpPreparedModel->executeSynchronously(fullRequest, measure, cb);
+        if (!ret.isOk() || returnedStatus != ErrorStatus::NONE) {
+            LOG(ERROR) << "IPreparedModelAdapter::execute -- Error executing";
+            return {ErrorStatus::GENERAL_FAILURE, {}, {}};
         }
 
-        // resize cache to fit new slots if necessary
-        const int32_t maxUnknownSlot = unknownSlots.back();
-        if (maxUnknownSlot >= mMemoryCache.size()) {
-            mMemoryCache.resize(maxUnknownSlot + 1);
-        }
-
-        // add memories to unknown slots
-        for (size_t i = 0; i < unknownSlots.size(); ++i) {
-            mMemoryCache[unknownSlots[i]] = returnedMemories[i];
-        }
+        return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming);
     }
 
-    // get all slots
-    hidl_vec<hidl_memory> memories(slots.size());
-    std::transform(slots.begin(), slots.end(), memories.begin(),
-                   [this](int32_t slot) { return mMemoryCache[slot]; });
+   private:
+    IPreparedModel* const mpPreparedModel;
+    std::vector<hidl_memory> mMemoryCache;
+};
 
-    // Ensure all slots are valid. Although this case is never expected to
-    // occur, theoretically IBurstCallback::getMemories could return invalid
-    // hidl_memory objects that must be protected against.
-    if (!std::all_of(memories.begin(), memories.end(),
-                     [](const hidl_memory& memory) { return memory.valid(); })) {
-        LOG(ERROR) << "Error, not all slots are valid!";
-        return {};
-    }
-
-    return memories;
-}
-
-void ExecutionBurstServer::BurstMemoryCache::freeMemory(int32_t slot) {
-    std::lock_guard<std::mutex> guard(mMutex);
-    if (slot < mMemoryCache.size()) {
-        mMemoryCache[slot] = {};
-    }
-}
+}  // anonymous namespace
 
 sp<ExecutionBurstServer> ExecutionBurstServer::create(
         const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
-        const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
+        const MQDescriptorSync<FmqResultDatum>& resultChannel,
+        std::shared_ptr<IBurstExecutorWithCache> executorWithCache) {
     // check inputs
-    if (callback == nullptr || preparedModel == nullptr) {
+    if (callback == nullptr || executorWithCache == nullptr) {
         LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
         return nullptr;
     }
 
     // create FMQ objects
-    std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow)
-                                                                 FmqRequestChannel(requestChannel)};
-    std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow)
-                                                               FmqResultChannel(resultChannel)};
+    std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
+            std::make_unique<FmqRequestChannel>(requestChannel);
+    std::unique_ptr<FmqResultChannel> fmqResultChannel =
+            std::make_unique<FmqResultChannel>(resultChannel);
 
     // check FMQ objects
-    if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
-        !fmqResultChannel->isValid()) {
+    if (!fmqRequestChannel->isValid() || !fmqResultChannel->isValid()) {
         LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
         return nullptr;
     }
 
     // make and return context
     return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
-                                    std::move(fmqResultChannel), preparedModel);
+                                    std::move(fmqResultChannel), std::move(executorWithCache));
 }
 
-ExecutionBurstServer::ExecutionBurstServer(const sp<IBurstCallback>& callback,
-                                           std::unique_ptr<FmqRequestChannel> requestChannel,
-                                           std::unique_ptr<FmqResultChannel> resultChannel,
-                                           IPreparedModel* preparedModel)
-    : mMemoryCache(callback),
+sp<ExecutionBurstServer> ExecutionBurstServer::create(
+        const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+        const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
+    // check relevant input
+    if (preparedModel == nullptr) {
+        LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
+        return nullptr;
+    }
+
+    // adapt IPreparedModel to have caching
+    const std::shared_ptr<DefaultBurstExecutorWithCache> preparedModelAdapter =
+            std::make_shared<DefaultBurstExecutorWithCache>(preparedModel);
+
+    // make and return context
+    return ExecutionBurstServer::create(callback, requestChannel, resultChannel,
+                                        preparedModelAdapter);
+}
+
+ExecutionBurstServer::ExecutionBurstServer(
+        const sp<IBurstCallback>& callback, std::unique_ptr<FmqRequestChannel> requestChannel,
+        std::unique_ptr<FmqResultChannel> resultChannel,
+        std::shared_ptr<IBurstExecutorWithCache> executorWithCache)
+    : mCallback(callback),
       mFmqRequestChannel(std::move(requestChannel)),
       mFmqResultChannel(std::move(resultChannel)),
-      mPreparedModel(preparedModel),
+      mExecutorWithCache(std::move(executorWithCache)),
       mBlocking(mFmqRequestChannel->getEventFlagWord() != nullptr) {
     // TODO: highly document the threading behavior of this class
     mWorker = std::async(std::launch::async, [this] { task(); });
@@ -164,6 +174,52 @@
     mWorker.wait();
 }
 
+Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
+    std::lock_guard<std::mutex> hold(mMutex);
+    mExecutorWithCache->removeCacheEntry(slot);
+    return Void();
+}
+
+void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots) {
+    const auto slotIsKnown = [this](int32_t slot) {
+        return mExecutorWithCache->isCacheEntryPresent(slot);
+    };
+
+    // find unique unknown slots
+    std::vector<int32_t> unknownSlots = slots;
+    auto unknownSlotsEnd = unknownSlots.end();
+    std::sort(unknownSlots.begin(), unknownSlotsEnd);
+    unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
+    unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
+    unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
+
+    // quick-exit if all slots are known
+    if (unknownSlots.empty()) {
+        return;
+    }
+
+    ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
+    std::vector<hidl_memory> returnedMemories;
+    auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
+                                                const hidl_vec<hidl_memory>& memories) {
+        errorStatus = status;
+        returnedMemories = memories;
+    };
+
+    const Return<void> ret = mCallback->getMemories(unknownSlots, cb);
+
+    if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
+        returnedMemories.size() != unknownSlots.size()) {
+        LOG(ERROR) << "Error retrieving memories";
+        return;
+    }
+
+    // add memories to unknown slots
+    for (size_t i = 0; i < unknownSlots.size(); ++i) {
+        mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]);
+    }
+}
+
 bool ExecutionBurstServer::sendPacket(const std::vector<FmqResultDatum>& packet) {
     if (mTeardown) {
         return false;
@@ -236,17 +292,16 @@
 }
 
 // deserialize request
-std::pair<Request, MeasureTiming> ExecutionBurstServer::deserialize(
+std::tuple<Request, std::vector<int32_t>, MeasureTiming> ExecutionBurstServer::deserialize(
         const std::vector<FmqRequestDatum>& data) {
     using discriminator = FmqRequestDatum::hidl_discriminator;
 
-    Request request;
     size_t index = 0;
 
     // validate packet information
     if (data[index].getDiscriminator() != discriminator::packetInformation) {
         LOG(ERROR) << "FMQ Request packet ill-formed";
-        return {{}, MeasureTiming::NO};
+        return {{}, {}, MeasureTiming::NO};
     }
 
     // unpackage packet information
@@ -264,7 +319,7 @@
         // validate input operand information
         if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
             LOG(ERROR) << "FMQ Request packet ill-formed";
-            return {{}, MeasureTiming::NO};
+            return {{}, {}, MeasureTiming::NO};
         }
 
         // unpackage operand information
@@ -282,7 +337,7 @@
             // validate dimension
             if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
                 LOG(ERROR) << "FMQ Request packet ill-formed";
-                return {{}, MeasureTiming::NO};
+                return {{}, {}, MeasureTiming::NO};
             }
 
             // unpackage dimension
@@ -305,7 +360,7 @@
         // validate output operand information
         if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
             LOG(ERROR) << "FMQ Request packet ill-formed";
-            return {{}, MeasureTiming::NO};
+            return {{}, {}, MeasureTiming::NO};
         }
 
         // unpackage operand information
@@ -323,7 +378,7 @@
             // validate dimension
             if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
                 LOG(ERROR) << "FMQ Request packet ill-formed";
-                return {{}, MeasureTiming::NO};
+                return {{}, {}, MeasureTiming::NO};
             }
 
             // unpackage dimension
@@ -346,7 +401,7 @@
         // validate input operand information
         if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
             LOG(ERROR) << "FMQ Request packet ill-formed";
-            return {{}, MeasureTiming::NO};
+            return {{}, {}, MeasureTiming::NO};
         }
 
         // unpackage operand information
@@ -356,12 +411,11 @@
         // store result
         slots.push_back(poolId);
     }
-    hidl_vec<hidl_memory> pools = mMemoryCache.getMemories(slots);
 
     // validate measureTiming
     if (data[index].getDiscriminator() != discriminator::measureTiming) {
         LOG(ERROR) << "FMQ Request packet ill-formed";
-        return {{}, MeasureTiming::NO};
+        return {{}, {}, MeasureTiming::NO};
     }
 
     // unpackage measureTiming
@@ -371,11 +425,11 @@
     // validate packet information
     if (index != packetSize) {
         LOG(ERROR) << "FMQ Result packet ill-formed";
-        return {{}, MeasureTiming::NO};
+        return {{}, {}, MeasureTiming::NO};
     }
 
     // return request
-    return {{/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/std::move(pools)}, measure};
+    return {{/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}}, std::move(slots), measure};
 }
 
 // serialize result
@@ -428,11 +482,6 @@
     return data;
 }
 
-Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
-    mMemoryCache.freeMemory(slot);
-    return Void();
-}
-
 void ExecutionBurstServer::task() {
     while (!mTeardown) {
         // receive request
@@ -446,29 +495,18 @@
         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
                      "ExecutionBurstServer processing packet and returning results");
 
-        // continue processing
-        Request request;
-        MeasureTiming measure;
-        std::tie(request, measure) = deserialize(requestData);
+        // continue processing; types are Request, std::vector<int32_t>, and
+        // MeasureTiming, respectively
+        const auto [requestWithoutPools, slotsOfPools, measure] = deserialize(requestData);
 
-        // perform computation
-        ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
-        std::vector<OutputShape> outputShapes;
-        Timing returnedTiming;
-        // This call to IPreparedModel::executeSynchronously occurs entirely
-        // within the same process, so ignore the Return<> errors via .isOk().
-        // TODO: verify it is safe to always call isOk() here, or if there is
-        // any benefit to checking any potential errors.
-        mPreparedModel
-                ->executeSynchronously(request, measure,
-                                       [&errorStatus, &outputShapes, &returnedTiming](
-                                               ErrorStatus status,
-                                               const hidl_vec<OutputShape>& shapes, Timing timing) {
-                                           errorStatus = status;
-                                           outputShapes = shapes;
-                                           returnedTiming = timing;
-                                       })
-                .isOk();
+        // ensure executor with cache has required memory
+        std::lock_guard<std::mutex> hold(mMutex);
+        ensureCacheEntriesArePresentLocked(slotsOfPools);
+
+        // perform computation; types are ErrorStatus, hidl_vec<OutputShape>,
+        // and Timing, respectively
+        const auto [errorStatus, outputShapes, returnedTiming] =
+                mExecutorWithCache->execute(requestWithoutPools, slotsOfPools, measure);
 
         // return result
         const std::vector<FmqResultDatum> result =
diff --git a/nn/common/include/CpuExecutor.h b/nn/common/include/CpuExecutor.h
index 7a36ebd..6a551d2 100644
--- a/nn/common/include/CpuExecutor.h
+++ b/nn/common/include/CpuExecutor.h
@@ -25,6 +25,7 @@
 #include <android-base/macros.h>
 #include <ui/GraphicBuffer.h>
 #include <algorithm>
+#include <optional>
 #include <vector>
 
 namespace android {
@@ -91,40 +92,29 @@
 
 // Used to keep a pointer to each of the memory pools.
 //
-// In the case of an "mmap_fd" pool, owns the mmap region
-// returned by getBuffer() -- i.e., that region goes away
-// when the RunTimePoolInfo is destroyed or is assigned to.
+// RunTimePoolInfo references a region of memory. Other RunTimePoolInfo objects
+// may reference the same region of memory by either:
+// (1) copying an existing RunTimePoolInfo object, or
+// (2) creating multiple RunTimePoolInfo objects from the same memory resource
+//     (e.g., "createFromHidlMemory" or "createFromExistingBuffer")
+//
+// If the underlying region of memory is mapped by "createFromHidlMemory", the
+// mapping will be sustained until it is no longer referenced by any
+// RunTimePoolInfo objects.
 class RunTimePoolInfo {
-public:
-    // If "fail" is not nullptr, and construction fails, then set *fail = true.
-    // If construction succeeds, leave *fail unchanged.
-    // getBuffer() == nullptr IFF construction fails.
-    explicit RunTimePoolInfo(const hidl_memory& hidlMemory, bool* fail);
+   public:
+    static std::optional<RunTimePoolInfo> createFromHidlMemory(const hidl_memory& hidlMemory);
+    static RunTimePoolInfo createFromExistingBuffer(uint8_t* buffer);
 
-    explicit RunTimePoolInfo(uint8_t* buffer);
-
-    // Implement move
-    RunTimePoolInfo(RunTimePoolInfo&& other) noexcept;
-    RunTimePoolInfo& operator=(RunTimePoolInfo&& other) noexcept;
-
-    // Forbid copy
-    RunTimePoolInfo(const RunTimePoolInfo&) = delete;
-    RunTimePoolInfo& operator=(const RunTimePoolInfo&) = delete;
-
-    ~RunTimePoolInfo() { release(); }
-
-    uint8_t* getBuffer() const { return mBuffer; }
-
+    uint8_t* getBuffer() const;
     bool update() const;
+    hidl_memory getHidlMemory() const;
 
-private:
-    void release();
-    void moveFrom(RunTimePoolInfo&& other);
+   private:
+    class RunTimePoolInfoImpl;
+    RunTimePoolInfo(const std::shared_ptr<const RunTimePoolInfoImpl>& impl);
 
-    hidl_memory mHidlMemory;           // always used
-    uint8_t* mBuffer = nullptr;        // always used
-    sp<IMemory> mMemory;               // only used when hidlMemory.name() == "ashmem"
-    sp<GraphicBuffer> mGraphicBuffer;  // only used when hidlMemory.name() == "hardware_buffer_blob"
+    std::shared_ptr<const RunTimePoolInfoImpl> mImpl;
 };
 
 bool setRunTimePoolInfosFromHidlMemories(std::vector<RunTimePoolInfo>* poolInfos,
diff --git a/nn/common/include/ExecutionBurstServer.h b/nn/common/include/ExecutionBurstServer.h
index af9b076..4338f3e 100644
--- a/nn/common/include/ExecutionBurstServer.h
+++ b/nn/common/include/ExecutionBurstServer.h
@@ -22,6 +22,7 @@
 #include <hidl/MQDescriptor.h>
 #include <atomic>
 #include <future>
+#include <memory>
 #include <vector>
 #include "HalInterfaces.h"
 
@@ -43,28 +44,96 @@
 class ExecutionBurstServer : public IBurstContext {
     DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstServer);
 
+   public:
     /**
-     * BurstMemoryCache is responsible for managing the local memory cache of
-     * the burst object. If the ExecutionBurstServer requests a memory key that
-     * is unrecognized, the BurstMemoryCache object will retrieve the memory
-     * from the client, transparent from the ExecutionBurstServer object.
+     * IBurstExecutorWithCache is a callback object passed to
+     * ExecutionBurstServer's factory function that is used to perform an
+     * execution. Because some memory resources are needed across multiple
+     * executions, this object also contains a local cache that can directly be
+     * used in the execution.
+     *
+     * ExecutionBurstServer will never access its IBurstExecutorWithCache object
+     * with concurrent calls.
      */
-    class BurstMemoryCache {
-        DISALLOW_IMPLICIT_CONSTRUCTORS(BurstMemoryCache);
+    class IBurstExecutorWithCache {
+        DISALLOW_COPY_AND_ASSIGN(IBurstExecutorWithCache);
 
        public:
-        BurstMemoryCache(const sp<IBurstCallback>& callback);
+        IBurstExecutorWithCache() = default;
+        virtual ~IBurstExecutorWithCache() = default;
 
-        hidl_vec<hidl_memory> getMemories(const std::vector<int32_t>& slots);
-        void freeMemory(int32_t slot);
+        /**
+         * Checks if a cache entry specified by a slot is present in the cache.
+         *
+         * @param slot Identifier of the cache entry.
+         * @return 'true' if the cache entry is present in the cache, 'false'
+         *     otherwise.
+         */
+        virtual bool isCacheEntryPresent(int32_t slot) const = 0;
 
-       private:
-        std::mutex mMutex;
-        const sp<IBurstCallback> mCallback;
-        std::vector<hidl_memory> mMemoryCache;
+        /**
+         * Adds an entry specified by a slot to the cache.
+         *
+         * The caller of this function must ensure that the cache entry that is
+         * being added is not already present in the cache. This can be checked
+         * via isCacheEntryPresent.
+         *
+         * @param memory Memory resource to be cached.
+         * @param slot Slot identifier corresponding to the memory resource.
+         */
+        virtual void addCacheEntry(const hidl_memory& memory, int32_t slot) = 0;
+
+        /**
+         * Removes an entry specified by a slot from the cache.
+         *
+         * If the cache entry corresponding to the slot number does not exist,
+         * the call does nothing.
+         *
+         * @param slot Slot identifier corresponding to the memory resource.
+         */
+        virtual void removeCacheEntry(int32_t slot) = 0;
+
+        /**
+         * Perform an execution.
+         *
+         * @param request Request object with inputs and outputs specified.
+         *     Request::pools is empty, and DataLocation::poolIndex instead
+         *     refers to the 'slots' argument as if it were Request::pools.
+         * @param slots Slots corresponding to the cached memory entries to be
+         *     used.
+         * @param measure Whether timing information is requested for the
+         *     execution.
+         * @return Result of the execution, including the status of the
+         *     execution, dynamic output shapes, and any timing information.
+         */
+        virtual std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
+                const Request& request, const std::vector<int32_t>& slots,
+                MeasureTiming measure) = 0;
     };
 
-   public:
+    /**
+     * Create automated context to manage FMQ-based executions.
+     *
+     * This function is intended to be used by a service to automatically:
+     * 1) Receive data from a provided FMQ
+     * 2) Execute a model with the given information
+     * 3) Send the result to the created FMQ
+     *
+     * @param callback Callback used to retrieve memories corresponding to
+     *     unrecognized slots.
+     * @param requestChannel Input FMQ channel through which the client passes the
+     *     request to the service.
+     * @param resultChannel Output FMQ channel from which the client can retrieve
+     *     the result of the execution.
+     * @param executorWithCache Object which maintains a local cache of the
+     *     memory pools and executes using the cached memory pools.
+     * @result IBurstContext Handle to the burst context.
+     */
+    static sp<ExecutionBurstServer> create(
+            const sp<IBurstCallback>& callback, const FmqRequestDescriptor& requestChannel,
+            const FmqResultDescriptor& resultChannel,
+            std::shared_ptr<IBurstExecutorWithCache> executorWithCache);
+
     /**
      * Create automated context to manage FMQ-based executions.
      *
@@ -80,7 +149,8 @@
      * @param resultChannel Output FMQ channel from which the client can retrieve
      *     the result of the execution.
      * @param preparedModel PreparedModel that the burst object was created from.
-     *     This will be used to synchronously perform the execution.
+     *     IPreparedModel::executeSynchronously will be used to perform the
+     *     execution.
      * @result IBurstContext Handle to the burst context.
      */
     static sp<ExecutionBurstServer> create(const sp<IBurstCallback>& callback,
@@ -91,7 +161,7 @@
     ExecutionBurstServer(const sp<IBurstCallback>& callback,
                          std::unique_ptr<FmqRequestChannel> requestChannel,
                          std::unique_ptr<FmqResultChannel> resultChannel,
-                         IPreparedModel* preparedModel);
+                         std::shared_ptr<IBurstExecutorWithCache> cachedExecutor);
     ~ExecutionBurstServer();
 
     Return<void> freeMemory(int32_t slot) override;
@@ -102,15 +172,32 @@
     std::vector<FmqResultDatum> serialize(ErrorStatus errorStatus,
                                           const std::vector<OutputShape>& outputShapes,
                                           Timing timing);
-    std::pair<Request, MeasureTiming> deserialize(const std::vector<FmqRequestDatum>& data);
+
+    // Deserializes the request FMQ data. The three resulting fields are the
+    // Request object (where Request::pools is empty), slot identifiers (which
+    // are stand-ins for Request::pools), and whether timing information must
+    // be collected for the run.
+    std::tuple<Request, std::vector<int32_t>, MeasureTiming> deserialize(
+            const std::vector<FmqRequestDatum>& data);
+
+    // Ensures all cache entries contained in mExecutorWithCache are present in
+    // the cache. If they are not present, they are retrieved (via
+    // IBurstCallback::getMemories) and added to mExecutorWithCache.
+    //
+    // This method is locked via mMutex when it is called.
+    void ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots);
+
+    // Work loop that will continue processing execution requests until the
+    // ExecutionBurstServer object is freed.
     void task();
 
-    BurstMemoryCache mMemoryCache;
+    std::mutex mMutex;
     std::atomic<bool> mTeardown{false};
     std::future<void> mWorker;
+    const sp<IBurstCallback> mCallback;
     const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
     const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
-    IPreparedModel* mPreparedModel;
+    const std::shared_ptr<IBurstExecutorWithCache> mExecutorWithCache;
     const bool mBlocking;
 };
 
diff --git a/nn/driver/sample/SampleDriver.cpp b/nn/driver/sample/SampleDriver.cpp
index 8372f4b..2d68b90 100644
--- a/nn/driver/sample/SampleDriver.cpp
+++ b/nn/driver/sample/SampleDriver.cpp
@@ -27,6 +27,7 @@
 #include <android-base/logging.h>
 #include <hidl/LegacySupport.h>
 #include <chrono>
+#include <optional>
 #include <thread>
 
 namespace android {
@@ -344,6 +345,94 @@
     return Void();
 }
 
+// BurstExecutorWithCache maps hidl_memory when it is first seen, and preserves
+// the mapping until either (1) the memory is freed in the runtime, or (2) the
+// burst object is destroyed. This allows for subsequent executions operating on
+// pools that have been used before to reuse the mapping instead of mapping and
+// unmapping the memory on each execution.
+class BurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
+   public:
+    BurstExecutorWithCache(const Model& model, const SampleDriver* driver,
+                           const std::vector<RunTimePoolInfo>& poolInfos)
+        : mModel(model), mDriver(driver), mModelPoolInfos(poolInfos) {}
+
+    bool isCacheEntryPresent(int32_t slot) const override {
+        return static_cast<size_t>(slot) < mMemoryCache.size() && mMemoryCache[slot].has_value();
+    }
+
+    void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
+        if (static_cast<size_t>(slot) >= mMemoryCache.size()) {
+            mMemoryCache.resize(slot + 1);
+        }
+        mMemoryCache[slot] = RunTimePoolInfo::createFromHidlMemory(memory);
+    }
+
+    void removeCacheEntry(int32_t slot) override {
+        if (static_cast<size_t>(slot) < mMemoryCache.size()) {
+            mMemoryCache[slot] = std::nullopt;
+        }
+    }
+
+    std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
+            const Request& request, const std::vector<int32_t>& slots,
+            MeasureTiming measure) override {
+        NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
+                     "BurstExecutorWithCache::execute");
+
+        time_point driverStart, driverEnd, deviceStart, deviceEnd;
+        if (measure == MeasureTiming::YES) driverStart = now();
+
+        // ensure all relevant pools are valid
+        if (!std::all_of(slots.begin(), slots.end(),
+                         [this](int32_t slot) { return isCacheEntryPresent(slot); })) {
+            return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
+        }
+
+        // finish the request object (for validation)
+        hidl_vec<hidl_memory> pools(slots.size());
+        std::transform(slots.begin(), slots.end(), pools.begin(),
+                       [this](int32_t slot) { return mMemoryCache[slot]->getHidlMemory(); });
+        Request fullRequest = request;
+        fullRequest.pools = std::move(pools);
+
+        // validate request object against the model
+        if (!validateRequest(fullRequest, mModel)) {
+            return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
+        }
+
+        // select relevant entries from cache
+        std::vector<RunTimePoolInfo> requestPoolInfos;
+        requestPoolInfos.reserve(slots.size());
+        std::transform(slots.begin(), slots.end(), std::back_inserter(requestPoolInfos),
+                       [this](int32_t slot) { return *mMemoryCache[slot]; });
+
+        // execution
+        CpuExecutor executor = mDriver->getExecutor();
+        if (measure == MeasureTiming::YES) deviceStart = now();
+        int n = executor.run(mModel, request, mModelPoolInfos, requestPoolInfos);
+        if (measure == MeasureTiming::YES) deviceEnd = now();
+        VLOG(DRIVER) << "executor.run returned " << n;
+        ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
+        hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
+        if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
+            driverEnd = now();
+            Timing timing = {
+                    .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
+                    .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
+            VLOG(DRIVER) << "BurstExecutorWithCache::execute timing = " << toString(timing);
+            return std::make_tuple(executionStatus, outputShapes, timing);
+        } else {
+            return std::make_tuple(executionStatus, outputShapes, kNoTiming);
+        }
+    }
+
+   private:
+    const Model mModel;
+    const SampleDriver* const mDriver;
+    const std::vector<RunTimePoolInfo> mModelPoolInfos;
+    std::vector<std::optional<RunTimePoolInfo>> mMemoryCache;  // cached requestPoolInfos
+};
+
 Return<void> SamplePreparedModel::configureExecutionBurst(
         const sp<V1_2::IBurstCallback>& callback,
         const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
@@ -352,8 +441,17 @@
     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
                  "SampleDriver::configureExecutionBurst");
 
-    const sp<V1_2::IBurstContext> burst =
-            ExecutionBurstServer::create(callback, requestChannel, resultChannel, this);
+    // Alternatively, the burst could be configured via:
+    // const sp<V1_2::IBurstContext> burst =
+    //         ExecutionBurstServer::create(callback, requestChannel,
+    //                                      resultChannel, this);
+    //
+    // However, this alternative representation does not include a memory map
+    // caching optimization, and adds overhead.
+    const std::shared_ptr<BurstExecutorWithCache> executorWithCache =
+            std::make_shared<BurstExecutorWithCache>(mModel, mDriver, mPoolInfos);
+    const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
+            callback, requestChannel, resultChannel, executorWithCache);
 
     if (burst == nullptr) {
         cb(ErrorStatus::GENERAL_FAILURE, {});
diff --git a/nn/runtime/ExecutionBuilder.cpp b/nn/runtime/ExecutionBuilder.cpp
index 68d019d..8be9b2f 100644
--- a/nn/runtime/ExecutionBuilder.cpp
+++ b/nn/runtime/ExecutionBuilder.cpp
@@ -29,6 +29,7 @@
 #include "Utils.h"
 
 #include <mutex>
+#include <optional>
 #include <thread>
 #include <vector>
 
@@ -1021,12 +1022,13 @@
 
     std::vector<RunTimePoolInfo> requestPoolInfos;
     requestPoolInfos.reserve(mMemories.size());
-    bool fail = false;
     for (const Memory* mem : mMemories) {
-        requestPoolInfos.emplace_back(mem->getHidlMemory(), &fail);
-    }
-    if (fail) {
-        return ANEURALNETWORKS_UNMAPPABLE;
+        if (std::optional<RunTimePoolInfo> poolInfo =
+                    RunTimePoolInfo::createFromHidlMemory(mem->getHidlMemory())) {
+            requestPoolInfos.emplace_back(*poolInfo);
+        } else {
+            return ANEURALNETWORKS_UNMAPPABLE;
+        }
     }
     // Create as many pools as there are input / output.
     auto fixPointerArguments = [&requestPoolInfos](std::vector<ModelArgumentInfo>& argumentInfos) {
@@ -1035,7 +1037,8 @@
                 argumentInfo.locationAndLength.poolIndex =
                         static_cast<uint32_t>(requestPoolInfos.size());
                 argumentInfo.locationAndLength.offset = 0;
-                requestPoolInfos.emplace_back(static_cast<uint8_t*>(argumentInfo.buffer));
+                requestPoolInfos.emplace_back(RunTimePoolInfo::createFromExistingBuffer(
+                        static_cast<uint8_t*>(argumentInfo.buffer)));
             }
         }
     };
diff --git a/nn/runtime/test/TestIntrospectionControl.cpp b/nn/runtime/test/TestIntrospectionControl.cpp
index acd6e7f..0b7ad48 100644
--- a/nn/runtime/test/TestIntrospectionControl.cpp
+++ b/nn/runtime/test/TestIntrospectionControl.cpp
@@ -15,6 +15,7 @@
  */
 
 #include "CompilationBuilder.h"
+#include "ExecutionBurstServer.h"
 #include "HalInterfaces.h"
 #include "Manager.h"
 #include "NeuralNetworks.h"
@@ -26,6 +27,7 @@
 
 #include <gtest/gtest.h>
 
+#include <iterator>
 #include <map>
 #include <queue>
 #include <set>
@@ -38,6 +40,7 @@
 using Device = nn::Device;
 using DeviceManager = nn::DeviceManager;
 using ExecutePreference = nn::test_wrapper::ExecutePreference;
+using ExecutionBurstServer = nn::ExecutionBurstServer;
 using HidlModel = hardware::neuralnetworks::V1_2::Model;
 using HidlToken = hardware::hidl_array<uint8_t, ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN>;
 using PreparedModelCallback = hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback;
@@ -264,6 +267,13 @@
     NEW   // new enough to support timing (1.2 or later)
 };
 
+std::ostream& operator<<(std::ostream& os, DriverKind kind) {
+    const char* names[] = {"CPU", "OLD", "NEW"};
+    const uint32_t index = static_cast<uint32_t>(kind);
+    CHECK(index < std::size(names));
+    return os << names[index];
+}
+
 enum class Success {
     // ASYNC: Return ErrorStatus::NONE; notify ErrorStatus::NONE and timing
     // SYNC, BURST: Return ErrorStatus::NONE and timing
@@ -282,6 +292,14 @@
     FAIL_WAIT
 };
 
+std::ostream& operator<<(std::ostream& os, Success success) {
+    const char* names[] = {"PASS_NEITHER", "PASS_DEVICE", "PASS_DRIVER", "PASS_BOTH",
+                           "PASS_CPU",     "FAIL_LAUNCH", "FAIL_WAIT"};
+    const uint32_t index = static_cast<uint32_t>(success);
+    CHECK(index < std::size(names));
+    return os << names[index];
+}
+
 std::map<Success, Timing> expectedTimingMap = {
         {Success::PASS_NEITHER, kBadTiming},
         {Success::PASS_DEVICE,
@@ -297,6 +315,13 @@
 
 enum class Compute { ASYNC, SYNC, BURST };
 
+std::ostream& operator<<(std::ostream& os, Compute compute) {
+    const char* names[] = {"ASYNC", "SYNC", "BURST"};
+    const uint32_t index = static_cast<uint32_t>(compute);
+    CHECK(index < std::size(names));
+    return os << names[index];
+}
+
 // For these tests we don't care about actually running an inference -- we
 // just want to dummy up execution status and timing results.
 class TestPreparedModel12 : public SamplePreparedModel {
@@ -369,8 +394,20 @@
         }
     }
 
-    // SampleDriver's burst execution uses executeSynchronously(), so we can
-    // rely on that, rather than having to implement configureExecutionBurst().
+    // ExecutionBurstServer::create has an overload that will use
+    // IPreparedModel::executeSynchronously(), so we can rely on that, rather
+    // than having to implement ExecutionBurstServer::IExecutorWithCache.
+    Return<void> configureExecutionBurst(
+            const sp<V1_2::IBurstCallback>& callback,
+            const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
+            const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
+            configureExecutionBurst_cb cb) override {
+        const sp<V1_2::IBurstContext> burst =
+                ExecutionBurstServer::create(callback, requestChannel, resultChannel, this);
+
+        cb(burst == nullptr ? ErrorStatus::GENERAL_FAILURE : ErrorStatus::NONE, burst);
+        return Void();
+    }
 
    private:
     Success mSuccess;