arm_compute v18.08
diff --git a/src/graph/nodes/FullyConnectedLayer.cpp b/src/graph/nodes/FullyConnectedLayer.cpp
index d94a785..6ea0292 100644
--- a/src/graph/nodes/FullyConnectedLayer.cpp
+++ b/src/graph/nodes/FullyConnectedLayer.cpp
@@ -31,15 +31,17 @@
{
namespace graph
{
-FullyConnectedLayerNode::FullyConnectedLayerNode(unsigned int num_outputs)
- : _num_outputs(num_outputs)
+FullyConnectedLayerNode::FullyConnectedLayerNode(unsigned int num_outputs, QuantizationInfo out_quant_info, FullyConnectedLayerInfo fc_info)
+ : _num_outputs(num_outputs), _out_quant_info(out_quant_info), _info(fc_info)
{
_input_edges.resize(3, EmptyEdgeID);
_outputs.resize(1, NullTensorID);
}
TensorDescriptor FullyConnectedLayerNode::compute_weights_descriptor(const TensorDescriptor &input_descriptor,
- unsigned int num_outputs)
+ unsigned int num_outputs,
+ FullyConnectedLayerInfo fc_info,
+ QuantizationInfo weights_quant_info)
{
unsigned int num_weights = 1;
unsigned int num_dimensions = input_descriptor.shape.num_dimensions();
@@ -56,11 +58,24 @@
TensorDescriptor weights_descriptor = input_descriptor;
weights_descriptor.shape = TensorShape(num_weights, num_outputs);
+ // If weights are tranposed, use tranposed shape
+ if(!fc_info.transpose_weights)
+ {
+ weights_descriptor.shape = TensorShape(num_outputs, num_weights);
+ }
+
+ // Set quantization info if present
+ if(!weights_quant_info.empty())
+ {
+ weights_descriptor.quant_info = weights_quant_info;
+ }
+
return weights_descriptor;
}
TensorDescriptor FullyConnectedLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor,
- unsigned int num_outputs)
+ unsigned int num_outputs,
+ QuantizationInfo out_quant_info)
{
// Note: Only 1D batch space is supported at the moment
unsigned int batches = input_descriptor.shape[1];
@@ -69,12 +84,24 @@
batches = input_descriptor.shape[3];
}
+ // Set descriptor shape
TensorDescriptor output_descriptor = input_descriptor;
output_descriptor.shape = TensorShape(num_outputs, batches);
+ // Set quantization info if present
+ if(!out_quant_info.empty())
+ {
+ output_descriptor.quant_info = out_quant_info;
+ }
+
return output_descriptor;
}
+FullyConnectedLayerInfo FullyConnectedLayerNode::info() const
+{
+ return _info;
+}
+
bool FullyConnectedLayerNode::forward_descriptors()
{
if((input_id(0) != NullTensorID) && (output_id(0) != NullTensorID))
@@ -93,7 +120,7 @@
const Tensor *src = input(0);
ARM_COMPUTE_ERROR_ON(src == nullptr);
- return compute_output_descriptor(src->desc(), _num_outputs);
+ return compute_output_descriptor(src->desc(), _num_outputs, _out_quant_info);
}
NodeType FullyConnectedLayerNode::type() const