Avoid sending ahwb requests to 1.0 and 1.1 drivers.
AHWB buffer support is a feature introduced in 1.2. Prior to this
CL, requests with AHWB as memory pool will be sent to 1.0 and 1.1
drivers. This CL modifies the compliance check and avoids sending
ahwb to 1.0 and 1.1 drivers. Using AHWBs on compilations with 1.0
and 1.1 drivers will result in a CPU fallback.
Bug: 155686276
Test: NNTS
Change-Id: Ib00a2d0d24d4a8c385b5992a1168e20c4a6bb786
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index cd97ffa..81e5cf1 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -21,6 +21,8 @@
#include <android-base/logging.h>
#include <android-base/properties.h>
#include <android-base/strings.h>
+#include <errno.h>
+#include <poll.h>
#include <sys/system_properties.h>
#include <algorithm>
@@ -32,9 +34,6 @@
#include <utility>
#include <vector>
-#include <errno.h>
-#include <poll.h>
-
#include "ControlFlow.h"
#include "NeuralNetworks.h"
#include "NeuralNetworksOEM.h"
@@ -3100,7 +3099,22 @@
bool compliantWithV1_0(const V1_3::Request& request) {
return std::all_of(request.pools.begin(), request.pools.end(), [](const auto& pool) {
- return pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory;
+ if (pool.getDiscriminator() != V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory) {
+ return false;
+ }
+ const auto& name = pool.hidlMemory().name();
+ return name == "ashmem" || name == "mmap_fd";
+ });
+}
+
+bool compliantWithV1_2(const V1_3::Request& request) {
+ return std::all_of(request.pools.begin(), request.pools.end(), [](const auto& pool) {
+ if (pool.getDiscriminator() != V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory) {
+ return false;
+ }
+ const auto& name = pool.hidlMemory().name();
+ return name == "ashmem" || name == "mmap_fd" || name == "hardware_buffer_blob" ||
+ name == "hardware_buffer";
});
}
@@ -3123,17 +3137,29 @@
return request;
}
-V1_0::Request convertToV1_0(const V1_3::Request& request) {
- if (!compliantWithV1_0(request)) {
- LOG(ERROR) << "Upcasting non-compliant request " << SHOW_IF_DEBUG(toString(request))
- << " from V1_3::Request to V1_0::Request";
- }
+static V1_0::Request uncheckedConvertToV1_0(const V1_3::Request& request) {
hidl_vec<hidl_memory> pools(request.pools.size());
std::transform(request.pools.begin(), request.pools.end(), pools.begin(),
[](const auto& pool) { return convertToV1_0(pool); });
return {.inputs = request.inputs, .outputs = request.outputs, .pools = std::move(pools)};
}
+V1_0::Request convertToV1_0(const V1_3::Request& request) {
+ if (!compliantWithV1_0(request)) {
+ LOG(ERROR) << "Upcasting non-compliant request " << SHOW_IF_DEBUG(toString(request))
+ << " from V1_3::Request to V1_0::Request of version 1.0";
+ }
+ return uncheckedConvertToV1_0(request);
+}
+
+V1_0::Request convertToV1_2(const V1_3::Request& request) {
+ if (!compliantWithV1_2(request)) {
+ LOG(ERROR) << "Upcasting non-compliant request " << SHOW_IF_DEBUG(toString(request))
+ << " from V1_3::Request to V1_0::Request of version 1.2";
+ }
+ return uncheckedConvertToV1_0(request);
+}
+
V1_3::Request convertToV1_3(const V1_0::Request& request) {
hidl_vec<V1_3::Request::MemoryPool> pools(request.pools.size());
std::transform(request.pools.begin(), request.pools.end(), pools.begin(),
diff --git a/nn/common/include/Utils.h b/nn/common/include/Utils.h
index 24e6921..ca11c5e 100644
--- a/nn/common/include/Utils.h
+++ b/nn/common/include/Utils.h
@@ -530,9 +530,11 @@
bool compliantWithV1_0(const hal::V1_0::Request& request);
bool compliantWithV1_0(const hal::V1_3::Request& request);
+bool compliantWithV1_2(const hal::V1_3::Request& request);
hal::V1_0::Request convertToV1_0(const hal::V1_0::Request& request);
hal::V1_0::Request convertToV1_0(const hal::V1_3::Request& request);
+hal::V1_0::Request convertToV1_2(const hal::V1_3::Request& request);
hal::V1_3::Request convertToV1_3(const hal::V1_0::Request& request);
hal::V1_3::Request convertToV1_3(const hal::V1_3::Request& request);
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp
index 310710e..634cd2a 100644
--- a/nn/runtime/Manager.cpp
+++ b/nn/runtime/Manager.cpp
@@ -379,9 +379,9 @@
const bool burstCompute = (burstController != nullptr);
bool burstFallback = true;
if (burstCompute) {
- const bool compliant = compliantWithV1_0(request);
+ const bool compliant = compliantWithV1_2(request);
if (compliant) {
- V1_0::Request request10 = convertToV1_0(request);
+ V1_0::Request request12 = convertToV1_2(request);
std::vector<intptr_t> memoryIds;
memoryIds.reserve(localMemories.size());
for (const Memory* memory : localMemories) {
@@ -390,9 +390,9 @@
}
VLOG(EXECUTION) << "Before ExecutionBurstController->compute() "
- << SHOW_IF_DEBUG(toString(request10));
+ << SHOW_IF_DEBUG(toString(request12));
std::tie(n, outputShapes, timing, burstFallback) =
- burstController->compute(request10, measure, memoryIds);
+ burstController->compute(request12, measure, memoryIds);
}
}
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp
index 3ae950e..33d290c 100644
--- a/nn/runtime/VersionedInterfaces.cpp
+++ b/nn/runtime/VersionedInterfaces.cpp
@@ -241,17 +241,16 @@
return getResults(*callback);
}
- const bool compliant = compliantWithV1_0(request);
- if (!compliant) {
- LOG(ERROR) << "Could not handle execute or execute_1_2!";
- return failWithStatus(ErrorStatus::GENERAL_FAILURE);
- }
- const V1_0::Request request10 = convertToV1_0(request);
-
// version 1.2 HAL
if (mPreparedModelV1_2 != nullptr) {
+ const bool compliant = compliantWithV1_2(request);
+ if (!compliant) {
+ LOG(ERROR) << "Could not handle execute_1_2!";
+ return failWithStatus(ErrorStatus::GENERAL_FAILURE);
+ }
+ const V1_0::Request request12 = convertToV1_2(request);
Return<V1_0::ErrorStatus> ret =
- mPreparedModelV1_2->execute_1_2(request10, measure, callback);
+ mPreparedModelV1_2->execute_1_2(request12, measure, callback);
if (ret.isDeadObject()) {
LOG(ERROR) << "execute_1_2 failure: " << ret.description();
return failDeadObject();
@@ -271,6 +270,12 @@
// version 1.0 HAL
if (mPreparedModelV1_0 != nullptr) {
+ const bool compliant = compliantWithV1_0(request);
+ if (!compliant) {
+ LOG(ERROR) << "Could not handle execute!";
+ return failWithStatus(ErrorStatus::GENERAL_FAILURE);
+ }
+ const V1_0::Request request10 = convertToV1_0(request);
Return<V1_0::ErrorStatus> ret = mPreparedModelV1_0->execute(request10, callback);
if (ret.isDeadObject()) {
LOG(ERROR) << "execute failure: " << ret.description();
@@ -324,16 +329,16 @@
// version 1.2 HAL
if (mPreparedModelV1_2 != nullptr) {
- const bool compliant = compliantWithV1_0(request);
+ const bool compliant = compliantWithV1_2(request);
if (!compliant) {
LOG(ERROR) << "Could not handle executeSynchronously!";
return kFailure;
}
- const V1_0::Request request10 = convertToV1_0(request);
+ const V1_0::Request request12 = convertToV1_2(request);
std::tuple<int, std::vector<OutputShape>, Timing> result;
Return<void> ret = mPreparedModelV1_2->executeSynchronously(
- request10, measure,
+ request12, measure,
[&result](V1_0::ErrorStatus error, const hidl_vec<OutputShape>& outputShapes,
const Timing& timing) {
result = getExecutionResult(convertToV1_3(error), outputShapes, timing);
diff --git a/nn/runtime/test/TestCompliance.cpp b/nn/runtime/test/TestCompliance.cpp
index 53bff03..db5ab4d 100644
--- a/nn/runtime/test/TestCompliance.cpp
+++ b/nn/runtime/test/TestCompliance.cpp
@@ -18,6 +18,7 @@
#include "GeneratedTestUtils.h"
#include "HalInterfaces.h"
+#include "Memory.h"
#include "MemoryUtils.h"
#include "ModelBuilder.h"
#include "TestNeuralNetworksWrapper.h"
@@ -71,8 +72,14 @@
ASSERT_TRUE(compliantWithV1_0(hidlModel));
}
+static void testAvailableSinceV1_2(const Request& request) {
+ ASSERT_FALSE(compliantWithV1_0(request));
+ ASSERT_TRUE(compliantWithV1_2(request));
+}
+
static void testAvailableSinceV1_3(const Request& request) {
ASSERT_FALSE(compliantWithV1_0(request));
+ ASSERT_FALSE(compliantWithV1_2(request));
}
static const WrapperOperandType kTypeTensorFloat(WrapperType::TENSOR_FLOAT32, {1});
@@ -126,7 +133,7 @@
testAvailableSinceV1_2(model);
}
-TEST_F(ComplianceTest, HardwareBuffer) {
+TEST_F(ComplianceTest, HardwareBufferModel) {
const size_t memorySize = 20;
AHardwareBuffer_Desc desc{
.width = memorySize,
@@ -157,6 +164,29 @@
AHardwareBuffer_release(buffer);
}
+TEST_F(ComplianceTest, HardwareBufferRequest) {
+ const auto [n, ahwb] = MemoryRuntimeAHWB::create(1024);
+ ASSERT_EQ(n, ANEURALNETWORKS_NO_ERROR);
+ Request::MemoryPool sharedMemoryPool, ahwbMemoryPool = ahwb->getMemoryPool();
+ sharedMemoryPool.hidlMemory(allocateSharedMemory(1024));
+ ASSERT_TRUE(sharedMemoryPool.hidlMemory().valid());
+ ASSERT_TRUE(ahwbMemoryPool.hidlMemory().valid());
+
+ // AHardwareBuffer as input.
+ testAvailableSinceV1_2(Request{
+ .inputs = {{.hasNoValue = false, .location = {.poolIndex = 0}, .dimensions = {}}},
+ .outputs = {{.hasNoValue = false, .location = {.poolIndex = 1}, .dimensions = {}}},
+ .pools = {ahwbMemoryPool, sharedMemoryPool},
+ });
+
+ // AHardwareBuffer as output.
+ testAvailableSinceV1_2(Request{
+ .inputs = {{.hasNoValue = false, .location = {.poolIndex = 0}, .dimensions = {}}},
+ .outputs = {{.hasNoValue = false, .location = {.poolIndex = 1}, .dimensions = {}}},
+ .pools = {sharedMemoryPool, ahwbMemoryPool},
+ });
+}
+
TEST_F(ComplianceTest, DeviceMemory) {
Request::MemoryPool sharedMemoryPool, deviceMemoryPool;
sharedMemoryPool.hidlMemory(allocateSharedMemory(1024));