IB/mlx5: Fix congestion counters in LAG mode

Congestion counters are counted and queried per physical function.
When working in LAG mode, CNP packets can be sent or received on both
of the functions, thus congestion counters should be aggregated from
the two physical functions.

Fixes: e1f24a79f424 ("IB/mlx5: Support congestion related counters")
Signed-off-by: Majd Dibbiny <majd@mellanox.com>
Reviewed-by: Aviv Heller <avivh@mellanox.com>
Signed-off-by: Leon Romanovsky <leon@kernel.org>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
diff --git a/drivers/infiniband/hw/mlx5/main.c b/drivers/infiniband/hw/mlx5/main.c
index 543d0a4..b4ef4d9 100644
--- a/drivers/infiniband/hw/mlx5/main.c
+++ b/drivers/infiniband/hw/mlx5/main.c
@@ -3737,34 +3737,6 @@ static int mlx5_ib_query_q_counters(struct mlx5_ib_dev *dev,
 	return ret;
 }
 
-static int mlx5_ib_query_cong_counters(struct mlx5_ib_dev *dev,
-				       struct mlx5_ib_port *port,
-				       struct rdma_hw_stats *stats)
-{
-	int outlen = MLX5_ST_SZ_BYTES(query_cong_statistics_out);
-	void *out;
-	int ret, i;
-	int offset = port->cnts.num_q_counters;
-
-	out = kvzalloc(outlen, GFP_KERNEL);
-	if (!out)
-		return -ENOMEM;
-
-	ret = mlx5_cmd_query_cong_counter(dev->mdev, false, out, outlen);
-	if (ret)
-		goto free;
-
-	for (i = 0; i < port->cnts.num_cong_counters; i++) {
-		stats->value[i + offset] =
-			be64_to_cpup((__be64 *)(out +
-				     port->cnts.offsets[i + offset]));
-	}
-
-free:
-	kvfree(out);
-	return ret;
-}
-
 static int mlx5_ib_get_hw_stats(struct ib_device *ibdev,
 				struct rdma_hw_stats *stats,
 				u8 port_num, int index)
@@ -3782,7 +3754,12 @@ static int mlx5_ib_get_hw_stats(struct ib_device *ibdev,
 	num_counters = port->cnts.num_q_counters;
 
 	if (MLX5_CAP_GEN(dev->mdev, cc_query_allowed)) {
-		ret = mlx5_ib_query_cong_counters(dev, port, stats);
+		ret = mlx5_lag_query_cong_counters(dev->mdev,
+						   stats->value +
+						   port->cnts.num_q_counters,
+						   port->cnts.num_cong_counters,
+						   port->cnts.offsets +
+						   port->cnts.num_q_counters);
 		if (ret)
 			return ret;
 		num_counters += port->cnts.num_cong_counters;