arm_compute v18.02
Change-Id: I7207aa488e5470f235f39b6c188b4678dc38d1a6
diff --git a/src/graph/nodes/DepthwiseConvolutionLayer.cpp b/src/graph/nodes/DepthwiseConvolutionLayer.cpp
index 1209d03..e5101cc 100644
--- a/src/graph/nodes/DepthwiseConvolutionLayer.cpp
+++ b/src/graph/nodes/DepthwiseConvolutionLayer.cpp
@@ -40,10 +40,8 @@
if(_weights.tensor() == nullptr)
{
- TensorShape shape = in->info()->tensor_shape();
- shape.set(Window::DimX, _conv_width);
- shape.set(Window::DimY, _conv_height);
- TensorInfo info = TensorInfo(TensorShape(shape), in->info()->num_channels(), in->info()->data_type(), in->info()->fixed_point_position());
+ TensorShape weights_shape(_conv_width, _conv_height, input->tensor()->info()->tensor_shape().z());
+ TensorInfo info = TensorInfo(TensorShape(weights_shape), in->info()->num_channels(), in->info()->data_type(), in->info()->fixed_point_position());
info.set_quantization_info(_quant_info);
_weights.set_info(std::move(info));
}