NNAPI Burst -- runtime and CTS
The NNAPI is introducing the notion of an "Execution Burst" object (or
more simply a "Burst" object), which is similar to an
ANeuralNetworksExecution, but is intended to be reused across multiple
executions and has lower IPC overheads. It achieves this low IPC
overhead by replacing HIDL HwBinder calls with FMQ messages.
This CL implements the NDK burst functions, implements the path through
the partitioner/scheduler, and creates CTS tests using the burst object.
Bug: 119570067
Test: mma
Test: NeuralNetworksTest_static
Change-Id: I1d2414f454910ad3ba4b2af728ab95ef8b609c9c
diff --git a/nn/common/ExecutionBurstController.cpp b/nn/common/ExecutionBurstController.cpp
index 32231d3..62fad73 100644
--- a/nn/common/ExecutionBurstController.cpp
+++ b/nn/common/ExecutionBurstController.cpp
@@ -14,14 +14,22 @@
* limitations under the License.
*/
+#define LOG_TAG "ExecutionBurstServer"
+
#include "ExecutionBurstController.h"
#include <android-base/logging.h>
+#include <string>
namespace android {
namespace nn {
namespace {
+
+using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
+using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
+
constexpr Timing invalidTiming = {UINT64_MAX, UINT64_MAX};
+
} // anonymous namespace
Return<void> ExecutionBurstCallback::getMemories(const hidl_vec<int32_t>& slots,
@@ -110,8 +118,14 @@
using discriminator = FmqResultDatum::hidl_discriminator;
// wait for result packet and read first element of result packet
+ // TODO: have a more elegant way to wait for data, and read it all at once.
+ // For example, EventFlag can be used to directly wait on the futex, and all
+ // the data can be read at once with a non-blocking call to
+ // MessageQueue::read. For further optimization, MessageQueue::beginRead and
+ // MessageQueue::commitRead can be used to avoid an extra copy of the
+ // metadata.
FmqResultDatum datum;
- bool success = false;
+ bool success = true;
if (mUsesFutex) {
success = mFmqResultChannel->readBlocking(&datum, 1);
} else {
@@ -133,7 +147,11 @@
// retrieve remaining elements
// NOTE: all of the data is already available at this point, so there's no
- // need to do a blocking wait to wait for more data
+ // need to do a blocking wait to wait for more data. This is known because
+ // in FMQ, all writes are published (made available) atomically. Currently,
+ // the producer always publishes the entire packet in one function call, so
+ // if the first element of the packet is available, the remaining elements
+ // are also available.
std::vector<FmqResultDatum> packet(count);
packet.front() = datum;
success = mFmqResultChannel->read(packet.data() + 1, packet.size() - 1);
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp
index 64a4ee2..a9af004 100644
--- a/nn/common/ExecutionBurstServer.cpp
+++ b/nn/common/ExecutionBurstServer.cpp
@@ -14,9 +14,13 @@
* limitations under the License.
*/
+#define LOG_TAG "ExecutionBurstServer"
+
#include "ExecutionBurstServer.h"
#include <android-base/logging.h>
+#include <set>
+#include <string>
namespace android {
namespace nn {
@@ -27,31 +31,37 @@
std::lock_guard<std::mutex> guard(mMutex);
// find unique unknown slots
- std::vector<int32_t> unknownSlots = slots;
- std::sort(unknownSlots.begin(), unknownSlots.end());
- auto last = std::unique(unknownSlots.begin(), unknownSlots.end());
- unknownSlots.erase(last, unknownSlots.end());
+ std::set<int32_t> setOfUnknownSlots;
+ for (int32_t slot : slots) {
+ if (mSlotToMemoryCache.find(slot) == mSlotToMemoryCache.end()) {
+ setOfUnknownSlots.insert(slot);
+ }
+ }
+ const std::vector<int32_t> unknownSlots(setOfUnknownSlots.begin(), setOfUnknownSlots.end());
// retrieve unknown slots
- ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
- std::vector<hidl_memory> returnedMemories;
- Return<void> ret = mCallback->getMemories(
- unknownSlots, [&errorStatus, &returnedMemories](ErrorStatus status,
- const hidl_vec<hidl_memory>& memories) {
- errorStatus = status;
- if (status == ErrorStatus::NONE) {
- returnedMemories = memories;
- }
- });
+ if (!unknownSlots.empty()) {
+ LOG(ERROR) << "server calling getMemories";
+ ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
+ std::vector<hidl_memory> returnedMemories;
+ Return<void> ret = mCallback->getMemories(
+ unknownSlots, [&errorStatus, &returnedMemories](
+ ErrorStatus status, const hidl_vec<hidl_memory>& memories) {
+ errorStatus = status;
+ if (status == ErrorStatus::NONE) {
+ returnedMemories = memories;
+ }
+ });
- if (!ret.isOk() || errorStatus != ErrorStatus::NONE) {
- LOG(ERROR) << "Error retrieving memories";
- return {};
- }
+ if (!ret.isOk() || errorStatus != ErrorStatus::NONE) {
+ LOG(ERROR) << "Error retrieving memories";
+ return {};
+ }
- // add memories to unknown slots
- for (size_t i = 0; i < unknownSlots.size(); ++i) {
- mSlotToMemoryCache[unknownSlots[i]] = returnedMemories[i];
+ // add memories to unknown slots
+ for (size_t i = 0; i < unknownSlots.size(); ++i) {
+ mSlotToMemoryCache[unknownSlots[i]] = returnedMemories[i];
+ }
}
// get all slots
@@ -59,6 +69,7 @@
for (size_t i = 0; i < slots.size(); ++i) {
memories[i] = mSlotToMemoryCache[slots[i]];
}
+
return memories;
}
@@ -85,9 +96,13 @@
mTeardown = true;
// force unblock
+ // ExecutionBurstServer is by default waiting on a request packet. If the
+ // client process destroys its burst object, the server will still be
+ // waiting on the futex (assuming mBlocking is true). This force unblock
+ // wakes up any thread waiting on the futex.
if (mBlocking) {
- // TODO: look for a different/better way to signal/notify the futex to wake
- // up any thread waiting on it
+ // TODO: look for a different/better way to signal/notify the futex to
+ // wake up any thread waiting on it
FmqRequestDatum datum;
datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
/*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
@@ -117,7 +132,13 @@
return {};
}
- // wait for request packet and read first element of result packet
+ // wait for request packet and read first element of request packet
+ // TODO: have a more elegant way to wait for data, and read it all at once.
+ // For example, EventFlag can be used to directly wait on the futex, and all
+ // the data can be read at once with a non-blocking call to
+ // MessageQueue::read. For further optimization, MessageQueue::beginRead and
+ // MessageQueue::commitRead can be used to avoid an extra copy of the
+ // metadata.
FmqRequestDatum datum;
bool success = false;
if (mBlocking) {
@@ -374,6 +395,10 @@
ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
std::vector<OutputShape> outputShapes;
Timing returnedTiming;
+ // This call to IPreparedModel::executeSynchronously occurs entirely
+ // within the same process, so ignore the Return<> errors via .isOk().
+ // TODO: verify it is safe to always call isOk() here, or if there is
+ // any benefit to checking any potential errors.
mPreparedModel
->executeSynchronously(request, measure,
[&errorStatus, &outputShapes, &returnedTiming](
diff --git a/nn/common/include/ExecutionBurstController.h b/nn/common/include/ExecutionBurstController.h
index bf36470..d40806d 100644
--- a/nn/common/include/ExecutionBurstController.h
+++ b/nn/common/include/ExecutionBurstController.h
@@ -35,8 +35,6 @@
using ::android::hardware::MQDescriptorSync;
using FmqRequestChannel = MessageQueue<FmqRequestDatum, kSynchronizedReadWrite>;
using FmqResultChannel = MessageQueue<FmqResultDatum, kSynchronizedReadWrite>;
-using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
-using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
/**
* Number of elements in the FMQ.
@@ -68,7 +66,20 @@
std::vector<int32_t> getSlots(const hidl_vec<hidl_memory>& memories,
const std::vector<intptr_t>& keys);
+
int32_t getSlot(const hidl_memory& memory, intptr_t key);
+
+ /*
+ * This function performs two different actions:
+ * 1) Removes an entry from the cache (if present), including the copied
+ * hidl_memory. This frees hidl_memory's underlying file descriptor which
+ * was duplicated upon creation.
+ * 2) Return whether a cache entry was removed and which slot was removed if
+ * found. If the key did not to correspond to any entry in the cache, a
+ * slot number of 0 is returned. The slot number and whether the entry
+ * existed is useful so the same slot can be freed in the service's
+ * cache.
+ */
std::pair<bool, int32_t> freeMemory(intptr_t key);
private:
@@ -99,7 +110,10 @@
* Execute a request on a model.
*
* @param request Arguments to be executed on a model.
- * @return status and output shape of the execution.
+ * @param memoryIds Identifiers corresponding to each memory object in the
+ * request's pools.
+ * @return status and output shape of the execution and whether to collect
+ * execution time measurements.
*/
std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> compute(
const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds);
diff --git a/nn/common/include/ExecutionBurstServer.h b/nn/common/include/ExecutionBurstServer.h
index 13dfaaf..8d34179 100644
--- a/nn/common/include/ExecutionBurstServer.h
+++ b/nn/common/include/ExecutionBurstServer.h
@@ -99,8 +99,8 @@
* unrecognized slots.
* @param requestChannel Input FMQ channel through which the client passes the
* request to the service.
- * @param requestChannel Output FMQ channel from which the client can retrieve
- * the result of the execution.
+ * @param resultChannel Output FMQ channel from which the client can retrieve
+ * the result of the execution.
* @param preparedModel PreparedModel that the burst object was created from.
* This will be used to synchronously perform the
* execution.
diff --git a/nn/runtime/Android.bp b/nn/runtime/Android.bp
index 4ab0bc3..58cb24f 100644
--- a/nn/runtime/Android.bp
+++ b/nn/runtime/Android.bp
@@ -36,6 +36,7 @@
// openmp: true,
srcs: [
+ "BurstBuilder.cpp",
"Callbacks.cpp",
"CompilationBuilder.cpp",
"ExecutionBuilder.cpp",
@@ -77,6 +78,7 @@
],
shared_libs: [
+ "libfmq",
"libtextclassifier_hash"
],
diff --git a/nn/runtime/BurstBuilder.cpp b/nn/runtime/BurstBuilder.cpp
new file mode 100644
index 0000000..ee4b371
--- /dev/null
+++ b/nn/runtime/BurstBuilder.cpp
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "BurstBuilder"
+
+#include "BurstBuilder.h"
+
+#include "CompilationBuilder.h"
+#include "ExecutionBurstController.h"
+
+namespace android {
+namespace nn {
+
+BurstBuilder::BurstBuilder(const CompilationBuilder* compilation,
+ std::vector<std::unique_ptr<ExecutionBurstController>> burstControllers)
+ : mCompilation(compilation), mBurstControllers(std::move(burstControllers)) {}
+
+bool BurstBuilder::tryLock() {
+ const bool alreadyRunning = mCurrentlyRunning.test_and_set();
+ return !alreadyRunning;
+}
+
+void BurstBuilder::unlock() {
+ mCurrentlyRunning.clear();
+}
+
+const CompilationBuilder* BurstBuilder::getCompilation() const {
+ return mCompilation;
+}
+
+ExecutionBurstController* BurstBuilder::getControllerAt(size_t index) const {
+ return index < mBurstControllers.size() ? mBurstControllers[index].get() : nullptr;
+}
+
+} // namespace nn
+} // namespace android
diff --git a/nn/runtime/BurstBuilder.h b/nn/runtime/BurstBuilder.h
new file mode 100644
index 0000000..288cf84
--- /dev/null
+++ b/nn/runtime/BurstBuilder.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_ML_NN_RUNTIME_BURST_BUILDER_H
+#define ANDROID_ML_NN_RUNTIME_BURST_BUILDER_H
+
+#include <atomic>
+#include <memory>
+#include <vector>
+#include "ExecutionBurstController.h"
+
+namespace android {
+namespace nn {
+
+class CompilationBuilder;
+
+/*
+ * TODO: Could we "hide" the per-step burst controller instance inside
+ * StepExecutor? Today it's exposed as a "sibling" to StepExecutor:
+ * ExecutionPlan::next both generates a StepExecutor instance and finds a
+ * pointer to a burst controller; and StepExecutor::startCompute is passed a
+ * pointer to a burst controller. Instead, could ExecutionPlan::next stash the
+ * burst controller in the StepExecutor, so that it doesn't have to be passed
+ * to any of the StepExecutor methods?
+ */
+
+class BurstBuilder {
+ public:
+ BurstBuilder(const CompilationBuilder* compilation,
+ std::vector<std::unique_ptr<ExecutionBurstController>> burstControllers);
+
+ bool tryLock();
+ void unlock();
+
+ const CompilationBuilder* getCompilation() const;
+ ExecutionBurstController* getControllerAt(size_t index) const;
+
+ private:
+ std::atomic_flag mCurrentlyRunning = ATOMIC_FLAG_INIT;
+ const CompilationBuilder* mCompilation;
+ std::vector<std::unique_ptr<ExecutionBurstController>> mBurstControllers;
+};
+
+} // namespace nn
+} // namespace android
+
+#endif // ANDROID_ML_NN_RUNTIME_BURST_BUILDER_H
diff --git a/nn/runtime/CompilationBuilder.cpp b/nn/runtime/CompilationBuilder.cpp
index ccd30d8..5dab730 100644
--- a/nn/runtime/CompilationBuilder.cpp
+++ b/nn/runtime/CompilationBuilder.cpp
@@ -18,7 +18,9 @@
#include "CompilationBuilder.h"
+#include "BurstBuilder.h"
#include "ExecutionBuilder.h"
+#include "ExecutionBurstController.h"
#include "ExecutionPlan.h"
#include "Manager.h"
#include "ModelBuilder.h"
@@ -132,5 +134,21 @@
return (*execution ? ANEURALNETWORKS_NO_ERROR : ANEURALNETWORKS_OUT_OF_MEMORY);
}
+int CompilationBuilder::createBurst(BurstBuilder** burst) {
+ if (!mFinished) {
+ LOG(ERROR) << "ANeuralNetworksBurst_create passed an unfinished compilation";
+ *burst = nullptr;
+ return ANEURALNETWORKS_BAD_STATE;
+ }
+ if (!mPlan.isValid()) {
+ LOG(ERROR) << "ANeuralNetworksBurst_create passed an invalid compilation";
+ *burst = nullptr;
+ return ANEURALNETWORKS_BAD_STATE;
+ }
+ std::vector<std::unique_ptr<ExecutionBurstController>> burstControllers = mPlan.makeBursts();
+ *burst = new (std::nothrow) BurstBuilder(this, std::move(burstControllers));
+ return (*burst ? ANEURALNETWORKS_NO_ERROR : ANEURALNETWORKS_OUT_OF_MEMORY);
+}
+
} // namespace nn
} // namespace android
diff --git a/nn/runtime/CompilationBuilder.h b/nn/runtime/CompilationBuilder.h
index 8f85ca0..5f163ab 100644
--- a/nn/runtime/CompilationBuilder.h
+++ b/nn/runtime/CompilationBuilder.h
@@ -26,6 +26,7 @@
namespace android {
namespace nn {
+class BurstBuilder;
class Device;
class ExecutionBuilder;
class ModelBuilder;
@@ -47,6 +48,8 @@
int createExecution(ExecutionBuilder** execution);
+ int createBurst(BurstBuilder** burst);
+
const ExecutionPlan& forTest_getExecutionPlan() const { return mPlan; }
private:
diff --git a/nn/runtime/ExecutionBuilder.cpp b/nn/runtime/ExecutionBuilder.cpp
index 456c0f6..55572f6 100644
--- a/nn/runtime/ExecutionBuilder.cpp
+++ b/nn/runtime/ExecutionBuilder.cpp
@@ -20,6 +20,7 @@
#include "CompilationBuilder.h"
#include "CpuExecutor.h"
+#include "ExecutionBurstController.h"
#include "HalInterfaces.h"
#include "Manager.h"
#include "ModelBuilder.h"
@@ -132,7 +133,8 @@
}
ExecutionBuilder::ExecutionBuilder(const CompilationBuilder* compilation)
- : mModel(compilation->mModel),
+ : mCompilation(compilation),
+ mModel(compilation->mModel),
mPlan(&compilation->mPlan),
mPartitioning(compilation->mPartitioning),
mInputs(mModel->inputCount()),
@@ -358,7 +360,8 @@
while (true) {
std::shared_ptr<StepExecutor> executor;
VLOG(EXECUTION) << "looking for next StepExecutor";
- int n = plan->next(controller, &executor);
+ ExecutionBurstController* burstController = nullptr;
+ int n = plan->next(controller, &executor, &burstController);
if (n != ANEURALNETWORKS_NO_ERROR) {
if (allowFallback) {
cpuFallbackFull(executionBuilder, executionCallback);
@@ -373,7 +376,7 @@
}
sp<ExecutionCallback> stepCallback;
- n = executor->startCompute(&stepCallback);
+ n = executor->startCompute(&stepCallback, burstController);
if (n != ANEURALNETWORKS_NO_ERROR) {
if (allowFallback) {
if (cpuFallbackPartial(executionBuilder, plan, controller, executionCallback)) {
@@ -409,7 +412,10 @@
}
}
-int ExecutionBuilder::compute(sp<ExecutionCallback>* synchronizationCallback) {
+int ExecutionBuilder::compute(sp<ExecutionCallback>* synchronizationCallback,
+ BurstBuilder* burstBuilder) {
+ assert(synchronizationCallback == nullptr || burstBuilder == nullptr);
+
const bool synchronous = (synchronizationCallback == nullptr);
if (!synchronous) {
@@ -439,7 +445,8 @@
// asynchronous thread -- take the asynchronous thread logic out of
// startComputeOnCpu() and use it to wrap the plan-based-path.
const bool allowFallback = DeviceManager::partitioningAllowsFallback(mPartitioning);
- std::shared_ptr<ExecutionPlan::Controller> controller = mPlan->makeController(this);
+ std::shared_ptr<ExecutionPlan::Controller> controller =
+ mPlan->makeController(this, burstBuilder);
if (synchronous) {
VLOG(EXECUTION) << "ExecutionBuilder::compute (synchronous API)";
sp<ExecutionCallback> localSynchronizationCallback = new ExecutionCallback();
@@ -603,7 +610,8 @@
return mDevice->getInterface() == nullptr;
}
-int StepExecutor::startCompute(sp<ExecutionCallback>* synchronizationCallback) {
+int StepExecutor::startCompute(sp<ExecutionCallback>* synchronizationCallback,
+ ExecutionBurstController* burstController) {
if (VLOG_IS_ON(EXECUTION)) {
logArguments("input", mInputs);
logArguments("output", mOutputs);
@@ -611,11 +619,12 @@
if (isCpu()) {
return startComputeOnCpu(synchronizationCallback);
} else {
- return startComputeOnDevice(synchronizationCallback);
+ return startComputeOnDevice(synchronizationCallback, burstController);
}
}
-int StepExecutor::startComputeOnDevice(sp<ExecutionCallback>* synchronizationCallback) {
+int StepExecutor::startComputeOnDevice(sp<ExecutionCallback>* synchronizationCallback,
+ ExecutionBurstController* burstController) {
CHECK(!isCpu());
*synchronizationCallback = nullptr;
@@ -711,7 +720,19 @@
// in the design document.
sp<ExecutionCallback> executionCallback = new ExecutionCallback();
- if (DeviceManager::get()->syncExecHal()) {
+ if (burstController != nullptr) {
+ std::vector<intptr_t> memoryIds(mMemories.size());
+ for (size_t i = 0; i < mMemories.size(); ++i) {
+ memoryIds[i] = reinterpret_cast<intptr_t>(mMemories[i]);
+ }
+
+ VLOG(EXECUTION) << "Before ExecutionBurstController->compute() "
+ << SHOW_IF_DEBUG(toString(request));
+ auto burstExecuteResult =
+ burstController->compute(request, measureTiming(mExecutionBuilder), memoryIds);
+ executionCallback->notify(std::get<0>(burstExecuteResult), std::get<1>(burstExecuteResult),
+ std::get<2>(burstExecuteResult));
+ } else if (DeviceManager::get()->syncExecHal()) {
VLOG(EXECUTION) << "Before mPreparedModel->executeSynchronously() "
<< SHOW_IF_DEBUG(toString(request));
auto syncExecuteResult =
diff --git a/nn/runtime/ExecutionBuilder.h b/nn/runtime/ExecutionBuilder.h
index 6cc80af..a7a6430 100644
--- a/nn/runtime/ExecutionBuilder.h
+++ b/nn/runtime/ExecutionBuilder.h
@@ -34,8 +34,10 @@
namespace android {
namespace nn {
+class BurstBuilder;
class CompilationBuilder;
class ExecutionPlan;
+class ExecutionBurstController;
class Memory;
class ModelBuilder;
class StepExecutor;
@@ -89,6 +91,7 @@
return compute(synchronizationCallback);
}
int computeSynchronously() { return compute(nullptr); }
+ int burstCompute(BurstBuilder* burst) { return compute(nullptr, burst); }
int getOutputOperandDimensions(uint32_t index, uint32_t* dimensions);
int getOutputOperandRank(uint32_t index, uint32_t* rank);
@@ -97,6 +100,7 @@
bool measureTiming() const { return mMeasureTiming; }
void reportTiming(Timing timing) { mTiming = timing; }
+ const CompilationBuilder* getCompilation() const { return mCompilation; }
const ModelBuilder* getModel() const { return mModel; }
ErrorStatus finish(ErrorStatus error);
@@ -104,8 +108,15 @@
private:
// If a callback is provided, then this is asynchronous. If a callback is
// not provided (i.e., is nullptr), then this is synchronous.
- int compute(sp<ExecutionCallback>* synchronizationCallback);
+ //
+ // If burst is provided, then the burst path will be used. If a burst is not
+ // provided (i.e., is nullptr), then a synchronous execution will occur.
+ //
+ // Providing both synchronizationCallbak and burstBuilder is an error.
+ int compute(sp<ExecutionCallback>* synchronizationCallback,
+ BurstBuilder* burstBuilder = nullptr);
+ const CompilationBuilder* mCompilation;
const ModelBuilder* mModel;
const ExecutionPlan* mPlan;
@@ -189,7 +200,8 @@
}
// Executes using the (driver, preparedModel) specified at construction time.
- int startCompute(sp<ExecutionCallback>* synchronizationCallback);
+ int startCompute(sp<ExecutionCallback>* synchronizationCallback,
+ ExecutionBurstController* burstController = nullptr);
// Executes using the CPU, regardless of the (driver,
// preparedModel) specified at construction time.
@@ -199,7 +211,8 @@
private:
int allocatePointerArgumentsToPool(std::vector<ModelArgumentInfo>* args, Memory* memory);
- int startComputeOnDevice(sp<ExecutionCallback>* synchronizationCallback);
+ int startComputeOnDevice(sp<ExecutionCallback>* synchronizationCallback,
+ ExecutionBurstController* burstController = nullptr);
void mapInputOrOutput(const ModelArgumentInfo& builderInputOrOutput,
ModelArgumentInfo* executorInputOrOutput);
diff --git a/nn/runtime/ExecutionPlan.cpp b/nn/runtime/ExecutionPlan.cpp
index 2780b48..16fac44 100644
--- a/nn/runtime/ExecutionPlan.cpp
+++ b/nn/runtime/ExecutionPlan.cpp
@@ -18,9 +18,11 @@
#include "ExecutionPlan.h"
+#include "BurstBuilder.h"
#include "Callbacks.h"
#include "CompilationBuilder.h"
#include "ExecutionBuilder.h"
+#include "ExecutionBurstController.h"
#include "Manager.h"
#include "ModelBuilder.h"
#include "Tracing.h"
@@ -504,10 +506,12 @@
ExecutionPlan::Controller::Controller(
const ExecutionPlan* plan, ExecutionBuilder* executionBuilder,
+ const BurstBuilder* burstBuilder,
std::shared_ptr<const SubModelInputsAndOutputsType> subModelInputsAndOutputs,
uint32_t totalSizeOfTemporaries)
: mPlan(plan),
mExecutionBuilder(executionBuilder),
+ mBurstBuilder(burstBuilder),
mSubModelInputsAndOutputs(subModelInputsAndOutputs),
mNextStepIndex(0) {
if (totalSizeOfTemporaries) {
@@ -518,8 +522,45 @@
}
}
+// Attempt to create a burst object for each PreparedModel/Partition. If the
+// burst controller object cannot be made, return a nullptr in its place to
+// indicate the regular execution path should be used. This can occur either
+// because PreparedModel was nullptr (cpu was best choice), or because the
+// IPreparedModel was of insufficient version or failed to configure the burst.
+std::vector<std::unique_ptr<ExecutionBurstController>> ExecutionPlan::makeBursts() const {
+ switch (mState) {
+ // burst object for each partition in the compound case
+ case COMPOUND: {
+ std::vector<std::unique_ptr<ExecutionBurstController>> bursts;
+ bursts.reserve(compound()->mSteps.size());
+ for (const auto& step : compound()->mSteps) {
+ if (const auto preparedModel = step->getPreparedSubModel()) {
+ bursts.push_back(preparedModel->configureExecutionBurst(/*blocking=*/true));
+ } else {
+ bursts.push_back(nullptr);
+ }
+ }
+ return bursts;
+ }
+ // single burst object for the simple case
+ case SIMPLE: {
+ std::vector<std::unique_ptr<ExecutionBurstController>> burst;
+ auto simpleBody = static_cast<const SimpleBody*>(mBody);
+ if (const auto preparedModel = simpleBody->mPreparedModel) {
+ burst.push_back(preparedModel->configureExecutionBurst(/*blocking=*/true));
+ } else {
+ burst.push_back(nullptr);
+ }
+ return burst;
+ }
+ // no burst objects made
+ default:
+ return {};
+ }
+}
+
std::shared_ptr<ExecutionPlan::Controller> ExecutionPlan::makeController(
- ExecutionBuilder* executionBuilder) const {
+ ExecutionBuilder* executionBuilder, const BurstBuilder* burstBuilder) const {
nnAssert(isValid());
// Create the layout for a Memory object big enough for to hold
@@ -569,7 +610,7 @@
}
}
- return std::shared_ptr<Controller>(new Controller(this, executionBuilder,
+ return std::shared_ptr<Controller>(new Controller(this, executionBuilder, burstBuilder,
subModelInputsAndOutputs,
totalSizeOfTemporaries));
}
@@ -598,8 +639,12 @@
}
int ExecutionPlan::next(std::shared_ptr<Controller> controller,
- std::shared_ptr<StepExecutor>* executor) const {
+ std::shared_ptr<StepExecutor>* executor,
+ ExecutionBurstController** burstController) const {
*executor = nullptr;
+ if (burstController != nullptr) {
+ *burstController = nullptr;
+ }
VLOG(EXECUTION) << "ExecutionPlan::next("
<< SHOW_IF_DEBUG(controller << ", " << executor)
@@ -623,6 +668,9 @@
simpleBody->mModel, simpleBody->mDevice,
simpleBody->mPreparedModel);
(*executor)->mapInputsAndOutputsTrivially();
+ if (burstController != nullptr && controller->mBurstBuilder != nullptr) {
+ *burstController = controller->mBurstBuilder->getControllerAt(0);
+ }
controller->mNextStepIndex = 1;
return ANEURALNETWORKS_NO_ERROR;
}
@@ -649,6 +697,9 @@
*executor = std::make_shared<StepExecutor>(controller->mExecutionBuilder, step->getSubModel(),
step->getDevice(), step->getPreparedSubModel());
step->mapInputsAndOutputs(*executor);
+ if (burstController != nullptr && controller->mBurstBuilder != nullptr) {
+ *burstController = controller->mBurstBuilder->getControllerAt(controller->mNextStepIndex);
+ }
if (controller->mSubModelInputsAndOutputs != nullptr) {
{
// Tell executor about temps as submodel outputs.
diff --git a/nn/runtime/ExecutionPlan.h b/nn/runtime/ExecutionPlan.h
index 5f3a250..5be5272 100644
--- a/nn/runtime/ExecutionPlan.h
+++ b/nn/runtime/ExecutionPlan.h
@@ -31,10 +31,12 @@
namespace android {
namespace nn {
+class BurstBuilder;
class CompilationBuilder;
class Device;
class ExecutionBuilder;
class ExecutionPlan;
+class ExecutionBurstController;
class Memory;
class StepExecutor;
@@ -183,19 +185,25 @@
static const size_t kBadStepIndex = ~size_t(0);
Controller(const ExecutionPlan* plan, ExecutionBuilder* executionBuilder,
+ const BurstBuilder* burstBuilder,
std::shared_ptr<const SubModelInputsAndOutputsType> subModelInputsAndOutputs,
uint32_t totalSizeOfTemporaries);
const ExecutionPlan* mPlan;
ExecutionBuilder* mExecutionBuilder;
+ const BurstBuilder* mBurstBuilder;
std::shared_ptr<const SubModelInputsAndOutputsType> mSubModelInputsAndOutputs; // may be nullptr
Memory mTemporaries;
size_t mNextStepIndex;
};
- std::shared_ptr<Controller> makeController(ExecutionBuilder* executionBuilder) const;
+ std::vector<std::unique_ptr<ExecutionBurstController>> makeBursts() const;
- int next(std::shared_ptr<Controller> controller, std::shared_ptr<StepExecutor>* executor) const;
+ std::shared_ptr<Controller> makeController(ExecutionBuilder* executionBuilder,
+ const BurstBuilder* burstBuilder) const;
+
+ int next(std::shared_ptr<Controller> controller, std::shared_ptr<StepExecutor>* executor,
+ ExecutionBurstController** burstController = nullptr) const;
// Create the same executor as the last one created by next().
int fallback(std::shared_ptr<Controller> controller, std::shared_ptr<StepExecutor>* executor) const;
diff --git a/nn/runtime/NeuralNetworks.cpp b/nn/runtime/NeuralNetworks.cpp
index 80ecb99..c02e780 100644
--- a/nn/runtime/NeuralNetworks.cpp
+++ b/nn/runtime/NeuralNetworks.cpp
@@ -22,6 +22,7 @@
#include "NeuralNetworks.h"
+#include "BurstBuilder.h"
#include "Callbacks.h"
#include "CompilationBuilder.h"
#include "ExecutionBuilder.h"
@@ -546,15 +547,18 @@
return ANEURALNETWORKS_UNEXPECTED_NULL;
}
- // TODO in subsequent CL
- return ANEURALNETWORKS_NO_ERROR;
+ CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
+ BurstBuilder* b = nullptr;
+ int result = c->createBurst(&b);
+ *burst = reinterpret_cast<ANeuralNetworksBurst*>(b);
+ return result;
}
void ANeuralNetworksBurst_free(ANeuralNetworksBurst* burst) {
NNTRACE_RT(NNTRACE_PHASE_TERMINATION, "ANeuralNetworksBurst_free");
// No validation. Free of nullptr is valid.
- (void)burst;
- // TODO in subsequent CL
+ BurstBuilder* b = reinterpret_cast<BurstBuilder*>(burst);
+ delete b;
}
int ANeuralNetworksExecution_burstCompute(ANeuralNetworksExecution* execution,
@@ -565,8 +569,27 @@
return ANEURALNETWORKS_UNEXPECTED_NULL;
}
- // TODO in subsequent CL
- return ANEURALNETWORKS_NO_ERROR;
+ ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
+ BurstBuilder* b = reinterpret_cast<BurstBuilder*>(burst);
+
+ if (r->getCompilation() != b->getCompilation()) {
+ LOG(ERROR) << "ANeuralNetworksBurst and ANeuralNetworksExecution "
+ "used in ANeuralNetworksExecution_burstCompute must "
+ "originate from the same ANeuralNetworksCompilation";
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+
+ const bool locked = b->tryLock();
+ if (!locked) {
+ LOG(ERROR) << "ANeuralNetworksBurst is already being used in another "
+ "call to ANeuralNetworksExecution_burstCompute";
+ return ANEURALNETWORKS_BAD_STATE;
+ }
+
+ const int n = r->burstCompute(b);
+ b->unlock();
+
+ return n;
}
int ANeuralNetworksMemory_createFromFd(size_t size, int prot, int fd, size_t offset,
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp
index 3dba329..62e5aab 100644
--- a/nn/runtime/VersionedInterfaces.cpp
+++ b/nn/runtime/VersionedInterfaces.cpp
@@ -17,6 +17,7 @@
#include "VersionedInterfaces.h"
#include "Callbacks.h"
+#include "ExecutionBurstController.h"
#include "Tracing.h"
#include "Utils.h"
@@ -80,6 +81,15 @@
}
}
+std::unique_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExecutionBurst(
+ bool blocking) const {
+ if (mPreparedModelV1_2 != nullptr) {
+ return createExecutionBurstController(mPreparedModelV1_2, blocking);
+ } else {
+ return nullptr;
+ }
+}
+
bool VersionedIPreparedModel::operator==(nullptr_t) const {
return mPreparedModelV1_0 == nullptr;
}
diff --git a/nn/runtime/VersionedInterfaces.h b/nn/runtime/VersionedInterfaces.h
index 395d448..ea37003 100644
--- a/nn/runtime/VersionedInterfaces.h
+++ b/nn/runtime/VersionedInterfaces.h
@@ -27,6 +27,8 @@
namespace android {
namespace nn {
+class ExecutionBurstController;
+
/**
* Each class (VersionedIDevice, VersionedIPreparedModel) wraps a HIDL interface
* of any version to abstract away version differences. It allows the remainder
@@ -374,6 +376,16 @@
const Request& request, MeasureTiming measure);
/**
+ * Creates a burst controller on a prepared model.
+ *
+ * @param blocking 'true' if the FMQ should block until data is available.
+ * @return ExecutionBurstController Execution burst controller object.
+ * nullptr is returned if the burst cannot
+ * be configured for any reason.
+ */
+ std::unique_ptr<ExecutionBurstController> configureExecutionBurst(bool blocking) const;
+
+ /**
* Returns whether this handle to an IPreparedModel object is valid or not.
*
* @return bool true if V1_0::IPreparedModel (which could be V1_2::IPreparedModel) is
diff --git a/nn/runtime/test/Android.bp b/nn/runtime/test/Android.bp
index 82527a3..ada6ed2 100644
--- a/nn/runtime/test/Android.bp
+++ b/nn/runtime/test/Android.bp
@@ -24,6 +24,7 @@
"libandroid",
"libbase",
"libcutils",
+ "libfmq",
"libhidlbase",
"libhidltransport",
"libhidlmemory",
@@ -107,7 +108,6 @@
"libSampleDriver",
],
shared_libs: [
- "libfmq",
"libcutils",
],
header_libs: [
diff --git a/nn/runtime/test/TestMain.cpp b/nn/runtime/test/TestMain.cpp
index 87528f1..dc32cec 100644
--- a/nn/runtime/test/TestMain.cpp
+++ b/nn/runtime/test/TestMain.cpp
@@ -40,7 +40,8 @@
// non-public DeviceManager::setSyncExecHal(); we assume the setting is always
// true, and if we are asked to set it to false, we return 0 ("success") without
// running tests.
-static int test(bool useCpuOnly, bool computeUsesSynchronousAPI, bool allowSyncExecHal = true) {
+static int test(bool useCpuOnly, bool computeUsesSynchronousAPI, bool allowSyncExecHal = true,
+ bool computeUsesBurstAPI = false) {
#ifdef NNTEST_ONLY_PUBLIC_API
if (useCpuOnly || !allowSyncExecHal) {
return 0;
@@ -51,6 +52,7 @@
#endif
Execution::setComputeUsesSynchronousAPI(computeUsesSynchronousAPI);
+ Execution::setComputeUsesBurstAPI(computeUsesBurstAPI);
LOG(INFO) << "test(useCpuOnly = " << useCpuOnly
<< ", computeUsesSynchronousAPI = " << computeUsesSynchronousAPI
@@ -77,5 +79,13 @@
// so there's no reason to run test(true, *, false) now.
n |= test(false, false, false) | test(false, true, false);
+ // Now try execution using a burst.
+ //
+ // The burst path is off by default in these tests. This is the first case
+ // where it is turned on. Both "computeUsesSynchronousAPI" and
+ // "allowSyncExecHal" are irrelevant here because the burst path is separate
+ // from both.
+ n |= test(false, false, false, true);
+
return n;
}
diff --git a/nn/runtime/test/TestNeuralNetworksWrapper.cpp b/nn/runtime/test/TestNeuralNetworksWrapper.cpp
index 056dcb2..9d61f49 100644
--- a/nn/runtime/test/TestNeuralNetworksWrapper.cpp
+++ b/nn/runtime/test/TestNeuralNetworksWrapper.cpp
@@ -20,6 +20,8 @@
namespace nn {
namespace test_wrapper {
+bool Execution::mComputeUsesBurstAPI = false;
+
bool Execution::mComputeUsesSychronousAPI = true;
} // namespace test_wrapper
diff --git a/nn/runtime/test/TestNeuralNetworksWrapper.h b/nn/runtime/test/TestNeuralNetworksWrapper.h
index 50292f9..5fdb7c7 100644
--- a/nn/runtime/test/TestNeuralNetworksWrapper.h
+++ b/nn/runtime/test/TestNeuralNetworksWrapper.h
@@ -353,7 +353,7 @@
class Execution {
public:
- Execution(const Compilation* compilation) {
+ Execution(const Compilation* compilation) : mCompilation(compilation->getHandle()) {
int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution);
if (result != 0) {
// TODO Handle the error
@@ -375,6 +375,8 @@
Execution& operator=(Execution&& other) {
if (this != &other) {
ANeuralNetworksExecution_free(mExecution);
+ mCompilation = other.mCompilation;
+ other.mCompilation = nullptr;
mExecution = other.mExecution;
other.mExecution = nullptr;
}
@@ -413,6 +415,18 @@
}
Result compute() {
+ if (mComputeUsesBurstAPI) {
+ ANeuralNetworksBurst* burst = nullptr;
+ Result result = static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst));
+ if (result != Result::NO_ERROR) {
+ ANeuralNetworksBurst_free(burst);
+ return result;
+ }
+ result = static_cast<Result>(ANeuralNetworksExecution_burstCompute(mExecution, burst));
+ ANeuralNetworksBurst_free(burst);
+ return result;
+ }
+
if (!mComputeUsesSychronousAPI) {
ANeuralNetworksEvent* event = nullptr;
Result result =
@@ -436,6 +450,8 @@
// computation to complete.
static void setComputeUsesSynchronousAPI(bool val) { mComputeUsesSychronousAPI = val; }
+ static void setComputeUsesBurstAPI(bool val) { mComputeUsesBurstAPI = val; }
+
Result getOutputOperandDimensions(uint32_t index, std::vector<uint32_t>* dimensions) {
uint32_t rank = 0;
Result result = static_cast<Result>(
@@ -451,8 +467,12 @@
}
private:
+ ANeuralNetworksCompilation* mCompilation = nullptr;
ANeuralNetworksExecution* mExecution = nullptr;
+ // Initialized to false in TestNeuralNetworksWrapper.cpp.
+ static bool mComputeUsesBurstAPI;
+
// Initialized to true in TestNeuralNetworksWrapper.cpp.
static bool mComputeUsesSychronousAPI;
};
diff --git a/nn/runtime/test/TestValidation.cpp b/nn/runtime/test/TestValidation.cpp
index 5a5e70f..be697bf 100644
--- a/nn/runtime/test/TestValidation.cpp
+++ b/nn/runtime/test/TestValidation.cpp
@@ -20,6 +20,7 @@
#include <android/sharedmem.h>
#include <gtest/gtest.h>
#include <sys/mman.h>
+#include <future>
#include <string>
// This file tests all the validations done by the Neural Networks API.
@@ -125,6 +126,20 @@
ANeuralNetworksExecution* mExecution = nullptr;
};
+class ValidationTestBurst : public ValidationTestExecution {
+ protected:
+ virtual void SetUp() {
+ ValidationTestExecution::SetUp();
+
+ ASSERT_EQ(ANeuralNetworksBurst_create(mCompilation, &mBurst), ANEURALNETWORKS_NO_ERROR);
+ }
+ virtual void TearDown() {
+ ANeuralNetworksBurst_free(mBurst);
+ ValidationTestExecution::TearDown();
+ }
+ ANeuralNetworksBurst* mBurst = nullptr;
+};
+
TEST_F(ValidationTest, CreateModel) {
EXPECT_EQ(ANeuralNetworksModel_create(nullptr), ANEURALNETWORKS_UNEXPECTED_NULL);
}
@@ -885,6 +900,72 @@
EXPECT_EQ(dims[0], expectedDims);
}
+TEST_F(ValidationTestBurst, BurstComputeNull) {
+ EXPECT_EQ(ANeuralNetworksExecution_burstCompute(mExecution, nullptr),
+ ANEURALNETWORKS_UNEXPECTED_NULL);
+ EXPECT_EQ(ANeuralNetworksExecution_burstCompute(nullptr, mBurst),
+ ANEURALNETWORKS_UNEXPECTED_NULL);
+}
+
+TEST_F(ValidationTestBurst, BurstComputeDifferentCompilations) {
+ ANeuralNetworksCompilation* secondCompilation;
+ ASSERT_EQ(ANeuralNetworksCompilation_create(mModel, &secondCompilation),
+ ANEURALNETWORKS_NO_ERROR);
+ ASSERT_EQ(ANeuralNetworksCompilation_finish(secondCompilation), ANEURALNETWORKS_NO_ERROR);
+
+ ANeuralNetworksBurst* burst;
+ EXPECT_EQ(ANeuralNetworksBurst_create(secondCompilation, &burst), ANEURALNETWORKS_NO_ERROR);
+
+ EXPECT_EQ(ANeuralNetworksExecution_burstCompute(mExecution, burst), ANEURALNETWORKS_BAD_DATA);
+
+ ANeuralNetworksBurst_free(burst);
+ ANeuralNetworksCompilation_free(secondCompilation);
+}
+
+TEST_F(ValidationTestBurst, BurstComputeConcurrent) {
+ ANeuralNetworksExecution* secondExecution;
+ EXPECT_EQ(ANeuralNetworksExecution_create(mCompilation, &secondExecution),
+ ANEURALNETWORKS_NO_ERROR);
+
+ // set inputs of first execution
+ float inputA0 = 1.0f, inputA1 = 2.0f, outputA0;
+ int32_t inputA2 = 0;
+ EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 0, nullptr, &inputA0, sizeof(float)),
+ ANEURALNETWORKS_NO_ERROR);
+ EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 1, nullptr, &inputA1, sizeof(float)),
+ ANEURALNETWORKS_NO_ERROR);
+ EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 2, nullptr, &inputA2, sizeof(int32_t)),
+ ANEURALNETWORKS_NO_ERROR);
+ EXPECT_EQ(ANeuralNetworksExecution_setOutput(mExecution, 0, nullptr, &outputA0, sizeof(float)),
+ ANEURALNETWORKS_NO_ERROR);
+
+ // set inputs of second execution
+ float inputB0 = 1.0f, inputB1 = 2.0f, outputB0;
+ int32_t inputB2 = 0;
+ EXPECT_EQ(
+ ANeuralNetworksExecution_setInput(secondExecution, 0, nullptr, &inputB0, sizeof(float)),
+ ANEURALNETWORKS_NO_ERROR);
+ EXPECT_EQ(
+ ANeuralNetworksExecution_setInput(secondExecution, 1, nullptr, &inputB1, sizeof(float)),
+ ANEURALNETWORKS_NO_ERROR);
+ EXPECT_EQ(ANeuralNetworksExecution_setInput(secondExecution, 2, nullptr, &inputB2,
+ sizeof(int32_t)),
+ ANEURALNETWORKS_NO_ERROR);
+ EXPECT_EQ(ANeuralNetworksExecution_setOutput(secondExecution, 0, nullptr, &outputB0,
+ sizeof(float)),
+ ANEURALNETWORKS_NO_ERROR);
+
+ // execute on the same burst concurrently
+ auto first = std::async(std::launch::async, [this] {
+ const int result = ANeuralNetworksExecution_burstCompute(mExecution, mBurst);
+ EXPECT_TRUE(result == ANEURALNETWORKS_BAD_STATE || result == ANEURALNETWORKS_NO_ERROR);
+ });
+ auto second = std::async(std::launch::async, [this, secondExecution] {
+ const int result = ANeuralNetworksExecution_burstCompute(secondExecution, mBurst);
+ EXPECT_TRUE(result == ANEURALNETWORKS_BAD_STATE || result == ANEURALNETWORKS_NO_ERROR);
+ });
+}
+
TEST(ValidationTestIntrospection, GetNumDevices) {
uint32_t numDevices = 0;
EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);