Memory Domain Runtime: HAL APIs.
Modify NNAPI runtime in accordance with the HAL API changes:
- Upgrade Request to V1_3.
- Add compliance, conversion, and validation utilities.
- Add compliance test.
- Add dummy allocate method to SampleDriver.
Bug: 141353602
Bug: 141363565
Test: m
Test: NNT_static
Change-Id: I863c674f960d465168cb10cef19ee1c3c6e46e92
diff --git a/nn/common/CpuExecutor.cpp b/nn/common/CpuExecutor.cpp
index ca3ba70..d147deb 100644
--- a/nn/common/CpuExecutor.cpp
+++ b/nn/common/CpuExecutor.cpp
@@ -24,6 +24,7 @@
#include <Eigen/Core>
#include <memory>
+#include <utility>
#include <vector>
// b/109953668, disable OpenMP
@@ -437,6 +438,29 @@
return true;
}
+bool setRunTimePoolInfosFromMemoryPools(std::vector<RunTimePoolInfo>* poolInfos,
+ const hidl_vec<Request::MemoryPool>& pools) {
+ CHECK(poolInfos != nullptr);
+ poolInfos->clear();
+ poolInfos->reserve(pools.size());
+ for (const auto& pool : pools) {
+ if (pool.getDiscriminator() != Request::MemoryPool::hidl_discriminator::hidlMemory) {
+ LOG(ERROR) << "Unknown memory token";
+ poolInfos->clear();
+ return false;
+ }
+ if (std::optional<RunTimePoolInfo> poolInfo =
+ RunTimePoolInfo::createFromHidlMemory(pool.hidlMemory())) {
+ poolInfos->push_back(*poolInfo);
+ } else {
+ LOG(ERROR) << "Could not map pools";
+ poolInfos->clear();
+ return false;
+ }
+ }
+ return true;
+}
+
template <typename T>
inline bool convertToNhwcImpl(T* to, const T* from, const std::vector<uint32_t>& fromDim) {
uint32_t spatialSize = fromDim[2] * fromDim[3];
diff --git a/nn/common/ExecutionBurstController.cpp b/nn/common/ExecutionBurstController.cpp
index c71ff7f..1136024 100644
--- a/nn/common/ExecutionBurstController.cpp
+++ b/nn/common/ExecutionBurstController.cpp
@@ -63,7 +63,7 @@
} // anonymous namespace
// serialize a request into a packet
-std::vector<FmqRequestDatum> serialize(const Request& request, MeasureTiming measure,
+std::vector<FmqRequestDatum> serialize(const V1_0::Request& request, MeasureTiming measure,
const std::vector<int32_t>& slots) {
// count how many elements need to be sent for a request
size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
@@ -356,7 +356,7 @@
RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)
: mFmqRequestChannel(std::move(fmqRequestChannel)) {}
-bool RequestChannelSender::send(const Request& request, MeasureTiming measure,
+bool RequestChannelSender::send(const V1_0::Request& request, MeasureTiming measure,
const std::vector<int32_t>& slots) {
const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
return sendPacket(serialized);
@@ -572,7 +572,8 @@
}
std::tuple<int, std::vector<OutputShape>, Timing, bool> ExecutionBurstController::compute(
- const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) {
+ const V1_0::Request& request, MeasureTiming measure,
+ const std::vector<intptr_t>& memoryIds) {
// This is the first point when we know an execution is occurring, so begin
// to collect systraces. Note that the first point we can begin collecting
// systraces in ExecutionBurstServer is when the RequestChannelReceiver
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp
index 0bcb57d..890b653 100644
--- a/nn/common/ExecutionBurstServer.cpp
+++ b/nn/common/ExecutionBurstServer.cpp
@@ -63,7 +63,7 @@
void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
- const Request& request, const std::vector<int32_t>& slots,
+ const V1_0::Request& request, const std::vector<int32_t>& slots,
MeasureTiming measure) override {
// convert slots to pools
hidl_vec<hidl_memory> pools(slots.size());
@@ -71,7 +71,7 @@
[this](int32_t slot) { return mMemoryCache[slot]; });
// create full request
- Request fullRequest = request;
+ V1_0::Request fullRequest = request;
fullRequest.pools = std::move(pools);
// setup execution
@@ -156,7 +156,7 @@
}
// deserialize request
-std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> deserialize(
+std::optional<std::tuple<V1_0::Request, std::vector<int32_t>, MeasureTiming>> deserialize(
const std::vector<FmqRequestDatum>& data) {
using discriminator = FmqRequestDatum::hidl_discriminator;
@@ -299,7 +299,7 @@
}
// return request
- Request request = {/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}};
+ V1_0::Request request = {/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}};
return std::make_tuple(std::move(request), std::move(slots), measure);
}
@@ -328,7 +328,7 @@
std::chrono::microseconds pollingTimeWindow)
: mFmqRequestChannel(std::move(fmqRequestChannel)), kPollingTimeWindow(pollingTimeWindow) {}
-std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>>
+std::optional<std::tuple<V1_0::Request, std::vector<int32_t>, MeasureTiming>>
RequestChannelReceiver::getBlocking() {
const auto packet = getPacketBlocking();
if (!packet) {
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index 5c4b333..4fe58be 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -2816,6 +2816,57 @@
return model;
}
+bool compliantWithV1_0(const V1_0::Request& request) {
+ return true;
+}
+
+bool compliantWithV1_0(const V1_3::Request& request) {
+ return std::all_of(request.pools.begin(), request.pools.end(), [](const auto& pool) {
+ return pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory;
+ });
+}
+
+static hidl_memory convertToV1_0(const V1_3::Request::MemoryPool& pool) {
+ switch (pool.getDiscriminator()) {
+ case V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory:
+ return pool.hidlMemory();
+ case V1_3::Request::MemoryPool::hidl_discriminator::token:
+ return hidl_memory{};
+ }
+}
+
+static V1_3::Request::MemoryPool convertToV1_3(const hidl_memory& pool) {
+ V1_3::Request::MemoryPool ret;
+ ret.hidlMemory(pool);
+ return ret;
+}
+
+V1_0::Request convertToV1_0(const V1_0::Request& request) {
+ return request;
+}
+
+V1_0::Request convertToV1_0(const V1_3::Request& request) {
+ if (!compliantWithV1_0(request)) {
+ LOG(ERROR) << "Upcasting non-compliant request " << SHOW_IF_DEBUG(toString(request))
+ << " from V1_3::Request to V1_0::Request";
+ }
+ hidl_vec<hidl_memory> pools(request.pools.size());
+ std::transform(request.pools.begin(), request.pools.end(), pools.begin(),
+ [](const auto& pool) { return convertToV1_0(pool); });
+ return {.inputs = request.inputs, .outputs = request.outputs, .pools = std::move(pools)};
+}
+
+V1_3::Request convertToV1_3(const V1_0::Request& request) {
+ hidl_vec<V1_3::Request::MemoryPool> pools(request.pools.size());
+ std::transform(request.pools.begin(), request.pools.end(), pools.begin(),
+ [](const auto& pool) { return convertToV1_3(pool); });
+ return {.inputs = request.inputs, .outputs = request.outputs, .pools = std::move(pools)};
+}
+
+V1_3::Request convertToV1_3(const V1_3::Request& request) {
+ return request;
+}
+
#ifdef NN_DEBUGGABLE
uint32_t getProp(const char* str, uint32_t defaultValue) {
const std::string propStr = android::base::GetProperty(str, "");
diff --git a/nn/common/ValidateHal.cpp b/nn/common/ValidateHal.cpp
index 8ec6320..b418640 100644
--- a/nn/common/ValidateHal.cpp
+++ b/nn/common/ValidateHal.cpp
@@ -60,7 +60,21 @@
mPoolSizes[i] = pools[i].size();
}
}
- bool validate(const DataLocation& location) {
+ MemoryAccessVerifier(const hidl_vec<V1_3::Request::MemoryPool>& pools)
+ : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
+ for (size_t i = 0; i < mPoolCount; i++) {
+ switch (pools[i].getDiscriminator()) {
+ case Request::MemoryPool::hidl_discriminator::hidlMemory:
+ mPoolSizes[i] = pools[i].hidlMemory().size();
+ break;
+ case Request::MemoryPool::hidl_discriminator::token:
+ // Set size to 0 to enforce length == 0 && offset == 0.
+ mPoolSizes[i] = 0;
+ break;
+ }
+ }
+ }
+ bool validate(const DataLocation& location) const {
if (location.poolIndex >= mPoolCount) {
LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount;
return false;
@@ -473,9 +487,21 @@
return true;
}
-static bool validatePools(const hidl_vec<hidl_memory>& pools, HalVersion ver) {
+bool validatePool(const V1_3::Request::MemoryPool& pool, HalVersion ver) {
+ switch (pool.getDiscriminator()) {
+ case Request::MemoryPool::hidl_discriminator::hidlMemory:
+ return validatePool(pool.hidlMemory(), ver);
+ case Request::MemoryPool::hidl_discriminator::token:
+ return pool.token() > 0;
+ }
+ LOG(FATAL) << "unknown MemoryPool discriminator";
+ return false;
+}
+
+template <class T_MemoryPool>
+static bool validatePools(const hidl_vec<T_MemoryPool>& pools, HalVersion ver) {
return std::all_of(pools.begin(), pools.end(),
- [ver](const hidl_memory& pool) { return validatePool(pool, ver); });
+ [ver](const auto& pool) { return validatePool(pool, ver); });
}
static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes,
@@ -536,9 +562,8 @@
static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArguments,
const hidl_vec<uint32_t>& operandIndexes,
const hidl_vec<Operand>& operands,
- const hidl_vec<hidl_memory>& pools, bool allowUnspecified,
- const char* type) {
- MemoryAccessVerifier poolVerifier(pools);
+ const MemoryAccessVerifier& poolVerifier,
+ bool allowUnspecified, const char* type) {
// The request should specify as many arguments as were described in the model.
const size_t requestArgumentCount = requestArguments.size();
if (requestArgumentCount != operandIndexes.size()) {
@@ -609,22 +634,27 @@
return true;
}
-template <class T_Model>
-bool validateRequest(const Request& request, const T_Model& model) {
+template <class T_Request, class T_Model>
+bool validateRequest(const T_Request& request, const T_Model& model) {
HalVersion version = ModelToHalVersion<T_Model>::version;
+ MemoryAccessVerifier poolVerifier(request.pools);
return (validateRequestArguments(request.inputs, model.inputIndexes,
- convertToV1_3(model.operands), request.pools,
+ convertToV1_3(model.operands), poolVerifier,
/*allowUnspecified=*/false, "input") &&
validateRequestArguments(request.outputs, model.outputIndexes,
- convertToV1_3(model.operands), request.pools,
+ convertToV1_3(model.operands), poolVerifier,
/*allowUnspecified=*/version >= HalVersion::V1_2, "output") &&
validatePools(request.pools, version));
}
-template bool validateRequest<V1_0::Model>(const Request& request, const V1_0::Model& model);
-template bool validateRequest<V1_1::Model>(const Request& request, const V1_1::Model& model);
-template bool validateRequest<V1_2::Model>(const Request& request, const V1_2::Model& model);
-template bool validateRequest<V1_3::Model>(const Request& request, const V1_3::Model& model);
+template bool validateRequest<V1_0::Request, V1_0::Model>(const V1_0::Request& request,
+ const V1_0::Model& model);
+template bool validateRequest<V1_0::Request, V1_1::Model>(const V1_0::Request& request,
+ const V1_1::Model& model);
+template bool validateRequest<V1_0::Request, V1_2::Model>(const V1_0::Request& request,
+ const V1_2::Model& model);
+template bool validateRequest<V1_3::Request, V1_3::Model>(const V1_3::Request& request,
+ const V1_3::Model& model);
bool validateExecutionPreference(ExecutionPreference preference) {
return preference == ExecutionPreference::LOW_POWER ||
diff --git a/nn/common/include/CpuExecutor.h b/nn/common/include/CpuExecutor.h
index 02eca1d..93241ad 100644
--- a/nn/common/include/CpuExecutor.h
+++ b/nn/common/include/CpuExecutor.h
@@ -121,6 +121,9 @@
bool setRunTimePoolInfosFromHidlMemories(std::vector<RunTimePoolInfo>* poolInfos,
const hal::hidl_vec<hal::hidl_memory>& pools);
+bool setRunTimePoolInfosFromMemoryPools(std::vector<RunTimePoolInfo>* poolInfos,
+ const hal::hidl_vec<hal::Request::MemoryPool>& pools);
+
// This class is used to execute a model on the CPU.
class CpuExecutor {
public:
diff --git a/nn/common/include/ExecutionBurstController.h b/nn/common/include/ExecutionBurstController.h
index d4bc449..15db0fc 100644
--- a/nn/common/include/ExecutionBurstController.h
+++ b/nn/common/include/ExecutionBurstController.h
@@ -51,7 +51,8 @@
* request.
* @return Serialized FMQ request data.
*/
-std::vector<hal::FmqRequestDatum> serialize(const hal::Request& request, hal::MeasureTiming measure,
+std::vector<hal::FmqRequestDatum> serialize(const hal::V1_0::Request& request,
+ hal::MeasureTiming measure,
const std::vector<int32_t>& slots);
/**
@@ -160,7 +161,7 @@
* the request.
* @return 'true' on successful send, 'false' otherwise.
*/
- bool send(const hal::Request& request, hal::MeasureTiming measure,
+ bool send(const hal::V1_0::Request& request, hal::MeasureTiming measure,
const std::vector<int32_t>& slots);
/**
@@ -298,7 +299,7 @@
* different path (e.g., IPreparedModel::executeSynchronously)
*/
std::tuple<int, std::vector<hal::OutputShape>, hal::Timing, bool> compute(
- const hal::Request& request, hal::MeasureTiming measure,
+ const hal::V1_0::Request& request, hal::MeasureTiming measure,
const std::vector<intptr_t>& memoryIds);
/**
diff --git a/nn/common/include/ExecutionBurstServer.h b/nn/common/include/ExecutionBurstServer.h
index 7f631cf..5bac095 100644
--- a/nn/common/include/ExecutionBurstServer.h
+++ b/nn/common/include/ExecutionBurstServer.h
@@ -60,7 +60,7 @@
* @param data Serialized FMQ request data.
* @return Request object if successfully deserialized, std::nullopt otherwise.
*/
-std::optional<std::tuple<hal::Request, std::vector<int32_t>, hal::MeasureTiming>> deserialize(
+std::optional<std::tuple<hal::V1_0::Request, std::vector<int32_t>, hal::MeasureTiming>> deserialize(
const std::vector<hal::FmqRequestDatum>& data);
/**
@@ -103,7 +103,8 @@
* @return Request object if successfully received, std::nullopt if error or
* if the receiver object was invalidated.
*/
- std::optional<std::tuple<hal::Request, std::vector<int32_t>, hal::MeasureTiming>> getBlocking();
+ std::optional<std::tuple<hal::V1_0::Request, std::vector<int32_t>, hal::MeasureTiming>>
+ getBlocking();
/**
* Method to mark the channel as invalid, unblocking any current or future
@@ -233,7 +234,7 @@
* execution, dynamic output shapes, and any timing information.
*/
virtual std::tuple<hal::ErrorStatus, hal::hidl_vec<hal::OutputShape>, hal::Timing> execute(
- const hal::Request& request, const std::vector<int32_t>& slots,
+ const hal::V1_0::Request& request, const std::vector<int32_t>& slots,
hal::MeasureTiming measure) = 0;
};
diff --git a/nn/common/include/HalInterfaces.h b/nn/common/include/HalInterfaces.h
index 3f4490f..23ffaab 100644
--- a/nn/common/include/HalInterfaces.h
+++ b/nn/common/include/HalInterfaces.h
@@ -64,7 +64,6 @@
using V1_0::FusedActivationFunc;
using V1_0::OperandLifeTime;
using V1_0::PerformanceInfo;
-using V1_0::Request;
using V1_0::RequestArgument;
using V1_1::ExecutionPreference;
using V1_2::Constant;
@@ -80,7 +79,10 @@
using V1_2::OutputShape;
using V1_2::SymmPerChannelQuantParams;
using V1_2::Timing;
+using V1_3::BufferDesc;
+using V1_3::BufferRole;
using V1_3::Capabilities;
+using V1_3::IBuffer;
using V1_3::IDevice;
using V1_3::IPreparedModel;
using V1_3::IPreparedModelCallback;
@@ -90,6 +92,7 @@
using V1_3::OperandTypeRange;
using V1_3::Operation;
using V1_3::OperationType;
+using V1_3::Request;
using CacheToken =
hardware::hidl_array<uint8_t, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
diff --git a/nn/common/include/Utils.h b/nn/common/include/Utils.h
index 146ae21..bd4401e 100644
--- a/nn/common/include/Utils.h
+++ b/nn/common/include/Utils.h
@@ -437,6 +437,14 @@
hal::hidl_vec<hal::V1_3::Operand> convertToV1_3(const hal::hidl_vec<hal::V1_2::Operand>& operands);
hal::hidl_vec<hal::V1_3::Operand> convertToV1_3(const hal::hidl_vec<hal::V1_3::Operand>& operands);
+bool compliantWithV1_0(const hal::V1_0::Request& request);
+bool compliantWithV1_0(const hal::V1_3::Request& request);
+
+hal::V1_0::Request convertToV1_0(const hal::V1_0::Request& request);
+hal::V1_0::Request convertToV1_0(const hal::V1_3::Request& request);
+hal::V1_3::Request convertToV1_3(const hal::V1_0::Request& request);
+hal::V1_3::Request convertToV1_3(const hal::V1_3::Request& request);
+
#ifdef NN_DEBUGGABLE
uint32_t getProp(const char* str, uint32_t defaultValue = 0);
#endif // NN_DEBUGGABLE
diff --git a/nn/common/include/ValidateHal.h b/nn/common/include/ValidateHal.h
index bfe1483..733c8b9 100644
--- a/nn/common/include/ValidateHal.h
+++ b/nn/common/include/ValidateHal.h
@@ -44,8 +44,11 @@
// IMPORTANT: This function cannot validate that OEM operation and operands
// are correctly defined, as these are specific to each implementation.
// Each driver should do their own validation of OEM types.
-template <class T_Model>
-bool validateRequest(const hal::Request& request, const T_Model& model);
+// For HAL version 1.3 or higher, this function cannot validate that the
+// buffer tokens are valid. Each driver should do their own validation of
+// buffer tokens.
+template <class T_Request, class T_Model>
+bool validateRequest(const T_Request& request, const T_Model& model);
// Verfies that the execution preference is valid.
bool validateExecutionPreference(hal::ExecutionPreference preference);
@@ -60,6 +63,7 @@
// Verfies that the memory pool is valid in the specified HAL version.
bool validatePool(const hal::hidl_memory& pool, HalVersion ver = HalVersion::LATEST);
+bool validatePool(const hal::V1_3::Request::MemoryPool& pool, HalVersion ver = HalVersion::LATEST);
} // namespace nn
} // namespace android