IVGCVSW-3568 Eliminate rank and axis restriction in Softmax.
* Restriction in axis will be now part of ACL. Currently, ACL only
supports axis = 0, which translates to axis = -1 in ArmNN and Android.
* Beta must be Float16 when input/output are Float16
!armnn:3690
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I2645a005840e17703367b3ec7e9ed91e83a2f6c7
diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp
index 0ad50f3..824a8f4 100644
--- a/ConversionUtils_1_2.hpp
+++ b/ConversionUtils_1_2.hpp
@@ -2023,28 +2023,38 @@
}
SoftmaxDescriptor desc;
- if (!GetInputFloat32<HalPolicy>(operation, 1, desc.m_Beta, model, data))
+ HalOperandType outputType = outputOperand->type;
+
+ // Read beta value
+ if (outputType == HalOperandType::TENSOR_FLOAT16)
{
- return Fail("%s: Operation has invalid inputs", __func__);
+ Half value;
+
+ if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, value, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
+ }
+
+ desc.m_Beta = static_cast<float>(value);
+ }
+ else
+ {
+ if (!GetInputFloat32<HalPolicy>(operation, 1, desc.m_Beta, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
+ }
}
if (operation.inputs.size() > 2 && !GetInputScalar<HalPolicy>(operation,
- 2,
- HalOperandType::INT32,
- desc.m_Axis,
- model,
- data))
+ 2,
+ HalOperandType::INT32,
+ desc.m_Axis,
+ model,
+ data))
{
return Fail("%s: Operation has invalid inputs", __func__);
}
- if (input.GetTensorInfo().GetNumDimensions() > 2 ||
- !(desc.m_Axis == 1 ||
- (desc.m_Axis < 0 && static_cast<int>(input.GetTensorInfo().GetNumDimensions()) + desc.m_Axis == 1)))
- {
- return Fail("%s: Unsupported input greater than 2D or axis != 1", __func__);
- }
-
bool isSupported = false;
FORWARD_LAYER_SUPPORT_FUNC(__func__,
IsSoftmaxSupported,