soc: qcom: fix to avoid memory allocation failures

Currenlty wdsp_glink_write() API allocates 256KB buffer and free
after it is transferred to glink. But this may result in memory
allocation failures in low memory situations. Fix this issue by
allocating the required size and validate it to avoid overflow.

Change-Id: I891e4361aa5f24c5a518b3acb1d7cf6a9baa5bdf
Signed-off-by: Vidyakumar Athota <vathota@codeaurora.org>
diff --git a/drivers/soc/qcom/wcd-dsp-glink.c b/drivers/soc/qcom/wcd-dsp-glink.c
index 3313122..c8bb13d 100644
--- a/drivers/soc/qcom/wcd-dsp-glink.c
+++ b/drivers/soc/qcom/wcd-dsp-glink.c
@@ -21,6 +21,7 @@
 #include <linux/list.h>
 #include <linux/cdev.h>
 #include <linux/platform_device.h>
+#include <linux/vmalloc.h>
 #include <soc/qcom/glink.h>
 #include "sound/wcd-dsp-glink.h"
 
@@ -29,6 +30,10 @@
 #define WDSP_MAX_READ_SIZE (4 * 1024)
 #define WDSP_MAX_NO_OF_INTENTS (20)
 #define WDSP_MAX_NO_OF_CHANNELS (10)
+#define WDSP_WRITE_PKT_SIZE (sizeof(struct wdsp_write_pkt))
+#define WDSP_REG_PKT_SIZE (sizeof(struct wdsp_reg_pkt))
+#define WDSP_CMD_PKT_SIZE (sizeof(struct wdsp_cmd_pkt))
+#define WDSP_CH_CFG_SIZE (sizeof(struct wdsp_glink_ch_cfg))
 
 #define MINOR_NUMBER_COUNT 1
 #define WDSP_EDGE "wdsp"
@@ -183,7 +188,7 @@
 		return;
 	}
 	/* Free tx pkt */
-	kfree(pkt_priv);
+	vfree(pkt_priv);
 }
 
 /*
@@ -201,7 +206,7 @@
 		return;
 	}
 	/* Free tx pkt */
-	kfree(pkt_priv);
+	vfree(pkt_priv);
 }
 
 /*
@@ -519,9 +524,10 @@
  * and register with glink
  * wpriv:     Wdsp_glink private structure.
  * pkt:       Glink registration packet contains glink channel information.
+ * pkt_size:  Size of the pkt.
  */
 static int wdsp_glink_ch_info_init(struct wdsp_glink_priv *wpriv,
-				   struct wdsp_reg_pkt *pkt)
+				   struct wdsp_reg_pkt *pkt, size_t pkt_size)
 {
 	int ret = 0, i, j;
 	struct glink_link_info link_info;
@@ -530,6 +536,7 @@
 	u8 no_of_channels;
 	u8 *payload;
 	u32 ch_size, ch_cfg_size;
+	size_t size = WDSP_WRITE_PKT_SIZE + WDSP_REG_PKT_SIZE;
 
 	mutex_lock(&wpriv->glink_mutex);
 	if (wpriv->ch) {
@@ -542,9 +549,10 @@
 	no_of_channels = pkt->no_of_channels;
 
 	if (no_of_channels > WDSP_MAX_NO_OF_CHANNELS) {
-		dev_info(wpriv->dev, "%s: no_of_channels = %d are limited to %d\n",
-			 __func__, no_of_channels, WDSP_MAX_NO_OF_CHANNELS);
-		no_of_channels = WDSP_MAX_NO_OF_CHANNELS;
+		dev_err(wpriv->dev, "%s: no_of_channels: %d but max allowed are %d\n",
+			__func__, no_of_channels, WDSP_MAX_NO_OF_CHANNELS);
+		ret = -EINVAL;
+		goto done;
 	}
 	ch = kcalloc(no_of_channels, sizeof(struct wdsp_glink_ch *),
 		     GFP_KERNEL);
@@ -558,20 +566,34 @@
 	for (i = 0; i < no_of_channels; i++) {
 		ch_cfg = (struct wdsp_glink_ch_cfg *)payload;
 
+		size += WDSP_CH_CFG_SIZE;
+		if (size > pkt_size) {
+			dev_err(wpriv->dev, "%s: Invalid size = %zd, pkt_size = %zd\n",
+				__func__, size, pkt_size);
+			ret = -EINVAL;
+			goto err_ch_mem;
+		}
 		if (ch_cfg->no_of_intents > WDSP_MAX_NO_OF_INTENTS) {
 			dev_err(wpriv->dev, "%s: Invalid no_of_intents = %d\n",
 				__func__, ch_cfg->no_of_intents);
 			ret = -EINVAL;
 			goto err_ch_mem;
 		}
+		size += (sizeof(u32) * ch_cfg->no_of_intents);
+		if (size > pkt_size) {
+			dev_err(wpriv->dev, "%s: Invalid size = %zd, pkt_size = %zd\n",
+				__func__, size, pkt_size);
+			ret = -EINVAL;
+			goto err_ch_mem;
+		}
 
 		ch_cfg_size = sizeof(struct wdsp_glink_ch_cfg) +
 					(sizeof(u32) * ch_cfg->no_of_intents);
 		ch_size = sizeof(struct wdsp_glink_ch) +
 					(sizeof(u32) * ch_cfg->no_of_intents);
 
-		dev_dbg(wpriv->dev, "%s: channels = %d, ch_cfg_size %d",
-			 __func__, no_of_channels, ch_cfg_size);
+		dev_dbg(wpriv->dev, "%s: channels: %d ch_cfg_size: %d, size: %zd, pkt_size: %zd",
+			 __func__, no_of_channels, ch_cfg_size, size, pkt_size);
 
 		ch[i] = kzalloc(ch_size, GFP_KERNEL);
 		if (!ch[i]) {
@@ -658,7 +680,7 @@
 			 * there won't be any tx_done notification to
 			 * free the buffer.
 			 */
-			kfree(tx_buf);
+			vfree(tx_buf);
 		}
 	} else {
 		mutex_unlock(&tx_buf->ch->mutex);
@@ -668,7 +690,7 @@
 		 * Free tx_buf here as there won't be any tx_done
 		 * notification in this case also.
 		 */
-		kfree(tx_buf);
+		vfree(tx_buf);
 	}
 }
 
@@ -761,6 +783,7 @@
 	struct wdsp_cmd_pkt *cpkt;
 	struct wdsp_glink_tx_buf *tx_buf;
 	struct wdsp_glink_priv *wpriv;
+	size_t pkt_max_size;
 
 	wpriv = (struct wdsp_glink_priv *)file->private_data;
 	if (!wpriv) {
@@ -769,7 +792,7 @@
 		goto done;
 	}
 
-	if ((count < sizeof(struct wdsp_write_pkt)) ||
+	if ((count < WDSP_WRITE_PKT_SIZE) ||
 	    (count > WDSP_MAX_WRITE_SIZE)) {
 		dev_err(wpriv->dev, "%s: Invalid count = %zd\n",
 			__func__, count);
@@ -779,8 +802,8 @@
 
 	dev_dbg(wpriv->dev, "%s: count = %zd\n", __func__, count);
 
-	tx_buf_size = WDSP_MAX_WRITE_SIZE + sizeof(struct wdsp_glink_tx_buf);
-	tx_buf = kzalloc(tx_buf_size, GFP_KERNEL);
+	tx_buf_size = count + sizeof(struct wdsp_glink_tx_buf);
+	tx_buf = vzalloc(tx_buf_size);
 	if (!tx_buf) {
 		ret = -ENOMEM;
 		goto done;
@@ -797,19 +820,20 @@
 	wpkt = (struct wdsp_write_pkt *)tx_buf->buf;
 	switch (wpkt->pkt_type) {
 	case WDSP_REG_PKT:
-		if (count <= (sizeof(struct wdsp_write_pkt) +
-			      sizeof(struct wdsp_reg_pkt))) {
+		if (count < (WDSP_WRITE_PKT_SIZE + WDSP_REG_PKT_SIZE +
+			     WDSP_CH_CFG_SIZE)) {
 			dev_err(wpriv->dev, "%s: Invalid reg pkt size = %zd\n",
 				__func__, count);
 			ret = -EINVAL;
 			goto free_buf;
 		}
 		ret = wdsp_glink_ch_info_init(wpriv,
-					(struct wdsp_reg_pkt *)wpkt->payload);
+					(struct wdsp_reg_pkt *)wpkt->payload,
+					count);
 		if (ret < 0)
 			dev_err(wpriv->dev, "%s: glink register failed, ret = %d\n",
 				__func__, ret);
-		kfree(tx_buf);
+		vfree(tx_buf);
 		break;
 	case WDSP_READY_PKT:
 		ret = wait_event_timeout(wpriv->link_state_wait,
@@ -823,11 +847,10 @@
 			goto free_buf;
 		}
 		ret = 0;
-		kfree(tx_buf);
+		vfree(tx_buf);
 		break;
 	case WDSP_CMD_PKT:
-		if (count <= (sizeof(struct wdsp_write_pkt) +
-			      sizeof(struct wdsp_cmd_pkt))) {
+		if (count <= (WDSP_WRITE_PKT_SIZE + WDSP_CMD_PKT_SIZE)) {
 			dev_err(wpriv->dev, "%s: Invalid cmd pkt size = %zd\n",
 				__func__, count);
 			ret = -EINVAL;
@@ -843,10 +866,18 @@
 			goto free_buf;
 		}
 		mutex_unlock(&wpriv->glink_mutex);
-
 		cpkt = (struct wdsp_cmd_pkt *)wpkt->payload;
-		dev_dbg(wpriv->dev, "%s: requested ch_name: %s\n", __func__,
-			 cpkt->ch_name);
+		pkt_max_size =  sizeof(struct wdsp_write_pkt) +
+					sizeof(struct wdsp_cmd_pkt) +
+					cpkt->payload_size;
+		if (count < pkt_max_size) {
+			dev_err(wpriv->dev, "%s: Invalid cmd pkt count = %zd, pkt_size = %zd\n",
+				__func__, count, pkt_max_size);
+			ret = -EINVAL;
+			goto free_buf;
+		}
+		dev_dbg(wpriv->dev, "%s: requested ch_name: %s, pkt_size: %zd\n",
+			__func__, cpkt->ch_name, pkt_max_size);
 		for (i = 0; i < wpriv->no_of_channels; i++) {
 			if (wpriv->ch && wpriv->ch[i] &&
 				(!strcmp(cpkt->ch_name,
@@ -881,13 +912,13 @@
 	default:
 		dev_err(wpriv->dev, "%s: Invalid packet type\n", __func__);
 		ret = -EINVAL;
-		kfree(tx_buf);
+		vfree(tx_buf);
 		break;
 	}
 	goto done;
 
 free_buf:
-	kfree(tx_buf);
+	vfree(tx_buf);
 
 done:
 	return ret;