Add TENSOR_QUANT8_ASYMM_SIGNED support for TOPK_V2 op
Bug: 143935115
Test: quantization coupling tests in CTS and VTS
Change-Id: Ibf028b3dba23d0ff381639c5f2830c154b1566d2
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index 169845a..706c01e 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -1741,7 +1741,8 @@
if (inputType == OperandType::TENSOR_FLOAT16 ||
inputType == OperandType::TENSOR_FLOAT32 ||
inputType == OperandType::TENSOR_INT32 ||
- inputType == OperandType::TENSOR_QUANT8_ASYMM) {
+ inputType == OperandType::TENSOR_QUANT8_ASYMM ||
+ inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
inExpectedTypes = {inputType, OperandType::INT32};
outExpectedTypes = {inputType, OperandType::TENSOR_INT32};
} else {
@@ -1749,7 +1750,11 @@
<< getOperationName(opType);
return ANEURALNETWORKS_BAD_DATA;
}
- NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
+ if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
+ NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
+ } else {
+ NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
+ }
return validateOperationOperandTypes(operands, inputCount, inputIndexes,
inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
diff --git a/nn/common/operations/TopK_V2.cpp b/nn/common/operations/TopK_V2.cpp
index 010d380..00c88a1 100644
--- a/nn/common/operations/TopK_V2.cpp
+++ b/nn/common/operations/TopK_V2.cpp
@@ -22,6 +22,8 @@
#include "OperationsUtils.h"
#include <algorithm>
+#include <utility>
+#include <vector>
namespace android {
namespace nn {
@@ -93,6 +95,11 @@
reinterpret_cast<uint8_t*>(valuesData), valuesShape,
reinterpret_cast<int32_t*>(indicesData), indicesShape);
} break;
+ case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: {
+ return evalGeneric(reinterpret_cast<const int8_t*>(inputData), inputShape, k,
+ reinterpret_cast<int8_t*>(valuesData), valuesShape,
+ reinterpret_cast<int32_t*>(indicesData), indicesShape);
+ } break;
default: {
LOG(ERROR) << "Unsupported data type: " << toString(inputShape.type);
return false;