Merge "iommu: msm: Program all M2V tables at once"
diff --git a/drivers/iommu/msm_iommu-v1.c b/drivers/iommu/msm_iommu-v1.c
index d0c2bbf..d047f9e 100644
--- a/drivers/iommu/msm_iommu-v1.c
+++ b/drivers/iommu/msm_iommu-v1.c
@@ -476,15 +476,14 @@
 	mb();
 }
 
-static void __release_smg(void __iomem *base, int ctx)
+static void __release_smg(void __iomem *base)
 {
 	int i, smt_size;
 	smt_size = GET_IDR0_NUMSMRG(base);
 
-	/* Invalidate any SMGs associated with this context */
+	/* Invalidate all SMGs */
 	for (i = 0; i < smt_size; i++)
-		if (GET_SMR_VALID(base, i) &&
-		    GET_S2CR_CBNDX(base, i) == ctx)
+		if (GET_SMR_VALID(base, i))
 			SET_SMR_VALID(base, i, 0);
 }
 
@@ -527,17 +526,52 @@
 	}
 }
 
+
+static int program_m2v_table(struct device *dev, void __iomem *base)
+{
+	struct msm_iommu_ctx_drvdata *ctx_drvdata = dev_get_drvdata(dev);
+	u32 *sids = ctx_drvdata->sids;
+	unsigned int ctx = ctx_drvdata->num;
+	int num = 0, i, smt_size;
+	int len = ctx_drvdata->nsid;
+
+	smt_size = GET_IDR0_NUMSMRG(base);
+	/* Program the M2V tables for this context */
+	for (i = 0; i < len / sizeof(*sids); i++) {
+		for (; num < smt_size; num++)
+			if (GET_SMR_VALID(base, num) == 0)
+				break;
+		BUG_ON(num >= smt_size);
+
+		SET_SMR_VALID(base, num, 1);
+		SET_SMR_MASK(base, num, 0);
+		SET_SMR_ID(base, num, sids[i]);
+
+		SET_S2CR_N(base, num, 0);
+		SET_S2CR_CBNDX(base, num, ctx);
+		SET_S2CR_MEMATTR(base, num, 0x0A);
+		/* Set security bit override to be Non-secure */
+		SET_S2CR_NSCFG(base, num, 3);
+	}
+
+	return 0;
+}
+
+static void program_all_m2v_tables(struct msm_iommu_drvdata *iommu_drvdata)
+{
+	device_for_each_child(iommu_drvdata->dev, iommu_drvdata->base,
+						program_m2v_table);
+}
+
 static void __program_context(struct msm_iommu_drvdata *iommu_drvdata,
 			      struct msm_iommu_ctx_drvdata *ctx_drvdata,
-			      struct msm_iommu_priv *priv, bool is_secure)
+			      struct msm_iommu_priv *priv, bool is_secure,
+			      bool program_m2v)
 {
 	unsigned int prrr, nmrr;
-	unsigned int pn;
-	int num = 0, i, smt_size;
+	phys_addr_t pn;
 	void __iomem *base = iommu_drvdata->base;
 	unsigned int ctx = ctx_drvdata->num;
-	u32 *sids = ctx_drvdata->sids;
-	int len = ctx_drvdata->nsid;
 	phys_addr_t pgtable = __pa(priv->pt.fl_table);
 
 	__reset_context(base, ctx);
@@ -578,24 +612,9 @@
 	}
 
 	if (!is_secure) {
-		smt_size = GET_IDR0_NUMSMRG(base);
-		/* Program the M2V tables for this context */
-		for (i = 0; i < len / sizeof(*sids); i++) {
-			for (; num < smt_size; num++)
-				if (GET_SMR_VALID(base, num) == 0)
-					break;
-			BUG_ON(num >= smt_size);
+		if (program_m2v)
+			program_all_m2v_tables(iommu_drvdata);
 
-			SET_SMR_VALID(base, num, 1);
-			SET_SMR_MASK(base, num, 0);
-			SET_SMR_ID(base, num, sids[i]);
-
-			SET_S2CR_N(base, num, 0);
-			SET_S2CR_CBNDX(base, num, ctx);
-			SET_S2CR_MEMATTR(base, num, 0x0A);
-			/* Set security bit override to be Non-secure */
-			SET_S2CR_NSCFG(base, num, 3);
-		}
 		SET_CBAR_N(base, ctx, 0);
 
 		/* Stage 1 Context with Stage 2 bypass */
@@ -669,6 +688,7 @@
 	struct msm_iommu_ctx_drvdata *tmp_drvdata;
 	int ret = 0;
 	int is_secure;
+	bool set_m2v = false;
 
 	mutex_lock(&msm_iommu_lock);
 
@@ -734,11 +754,12 @@
 		}
 		program_iommu_bfb_settings(iommu_drvdata->base,
 					   iommu_drvdata->bfb_settings);
+		set_m2v = true;
 	}
 
 	iommu_halt(iommu_drvdata);
 
-	__program_context(iommu_drvdata, ctx_drvdata, priv, is_secure);
+	__program_context(iommu_drvdata, ctx_drvdata, priv, is_secure, set_m2v);
 
 	iommu_resume(iommu_drvdata);
 
@@ -800,8 +821,11 @@
 	iommu_halt(iommu_drvdata);
 
 	__reset_context(iommu_drvdata->base, ctx_drvdata->num);
-	if (!is_secure)
-		__release_smg(iommu_drvdata->base, ctx_drvdata->num);
+
+	/*
+	 * Only reset the M2V tables on the very last detach */
+	if (!is_secure && iommu_drvdata->ctx_attach_count == 1)
+		__release_smg(iommu_drvdata->base);
 
 	iommu_resume(iommu_drvdata);