Adds float16 support for RNN.
Bug: 118608628
Test: NeuralNetworksTest_static
Change-Id: I9e87c27169a046e10357bba27ee9999595dd7170
Merged-In: I9e87c27169a046e10357bba27ee9999595dd7170
(cherry picked from commit d49668d85f5776f19caa9461eee569c3e4a27c44)
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index 8d98417..f40a458 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -1696,19 +1696,38 @@
logInvalidInOutNumber(6, 2);
return ANEURALNETWORKS_BAD_DATA;
}
- std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_FLOAT32,
- OperandType::TENSOR_FLOAT32,
- OperandType::TENSOR_FLOAT32,
- OperandType::TENSOR_FLOAT32,
- OperandType::TENSOR_FLOAT32,
- OperandType::INT32};
- std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_FLOAT32,
- OperandType::TENSOR_FLOAT32};
- NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ OperandType inputType = operands[inputIndexes[0]].type;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT32) {
+ NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
+ OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
+ OperandType::TENSOR_FLOAT32, OperandType::INT32,
+ };
+ outExpectedTypes = {
+ OperandType::TENSOR_FLOAT32,
+ OperandType::TENSOR_FLOAT32,
+ };
+ } else if (inputType == OperandType::TENSOR_FLOAT16) {
+ NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
+ OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
+ OperandType::TENSOR_FLOAT16, OperandType::INT32,
+ };
+ outExpectedTypes = {
+ OperandType::TENSOR_FLOAT16,
+ OperandType::TENSOR_FLOAT16,
+ };
+ } else {
+ LOG(ERROR) << "Unsupported input tensor type for operation "
+ << getOperationName(opType);
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_SVDF: {