iommu: msm: Add support for probe deferral

Probe of the IOMMU devices might fail because clocks or
other depedencies might not be available at the time of probe.
In this case the IOMMU driver needs to handle this gracefully
and request probe deferral. This change ensures that the IOMMU
driver handles the probe deferral correctly by cleaning up properly
and returning the proper error code when a probe deferral occurs.

Change-Id: Ia9dd5997500742133f72295fcd90a546e5a70ea1
Signed-off-by: Olav Haugan <ohaugan@codeaurora.org>
diff --git a/arch/arm/mach-msm/iommu_domains.c b/arch/arm/mach-msm/iommu_domains.c
index f4ac37e..12f5a8e 100644
--- a/arch/arm/mach-msm/iommu_domains.c
+++ b/arch/arm/mach-msm/iommu_domains.c
@@ -570,12 +570,14 @@
 			goto out;
 		}
 		ctx = msm_iommu_get_ctx(name);
-		if (!ctx) {
-			pr_err("Unable to find context %s\n", name);
-			ret_val = -EINVAL;
+		if (IS_ERR(ctx)) {
+			ret_val = PTR_ERR(ctx);
 			goto out;
 		}
-		iommu_group_add_device(group, ctx);
+
+		ret_val = iommu_group_add_device(group, ctx);
+		if (ret_val)
+			goto out;
 	}
 out:
 	return ret_val;
@@ -590,7 +592,7 @@
 	struct msm_iova_layout l;
 	struct msm_iova_partition *part = 0;
 	struct iommu_domain *domain = 0;
-	unsigned int *addr_array;
+	unsigned int *addr_array = 0;
 	unsigned int array_size;
 	int domain_no;
 	int secure_domain;
@@ -661,11 +663,46 @@
 	iommu_group_set_iommudata(group, domain, NULL);
 
 free_mem:
+	kfree(addr_array);
 	kfree(part);
 out:
 	return ret_val;
 }
 
+static int __msm_group_get_domain(struct device *dev, void *data)
+{
+	struct msm_iommu_data_entry *list_entry;
+	struct list_head *dev_list = data;
+	int ret_val = 0;
+
+	list_entry = kmalloc(sizeof(*list_entry), GFP_KERNEL);
+	if (list_entry) {
+		list_entry->data = dev;
+		list_add(&list_entry->list, dev_list);
+	} else {
+		ret_val = -ENOMEM;
+	}
+
+	return ret_val;
+}
+
+static void __msm_iommu_group_remove_device(struct iommu_group *grp)
+{
+	struct msm_iommu_data_entry *tmp;
+	struct msm_iommu_data_entry *list_entry;
+	struct list_head dev_list;
+
+	INIT_LIST_HEAD(&dev_list);
+	iommu_group_for_each_dev(grp, &dev_list, __msm_group_get_domain);
+
+	list_for_each_entry_safe(list_entry, tmp, &dev_list, list) {
+		iommu_group_remove_device(list_entry->data);
+		list_del(&list_entry->list);
+		kfree(list_entry);
+	}
+}
+
+
 static int iommu_domain_parse_dt(const struct device_node *dt_node)
 {
 	struct device_node *node;
@@ -674,13 +711,30 @@
 	int ret_val = 0;
 	struct iommu_group *group = 0;
 	const char *name;
+	struct msm_iommu_data_entry *grp_list_entry;
+	struct msm_iommu_data_entry *tmp;
+	struct list_head iommu_group_list;
+	INIT_LIST_HEAD(&iommu_group_list);
 
 	for_each_child_of_node(dt_node, node) {
 		group = iommu_group_alloc();
 		if (IS_ERR(group)) {
 			ret_val = PTR_ERR(group);
-			goto out;
+			group = 0;
+			goto free_group;
 		}
+
+		/* This is only needed to clean up memory if something fails */
+		grp_list_entry = kmalloc(sizeof(*grp_list_entry),
+					   GFP_KERNEL);
+		if (grp_list_entry) {
+			grp_list_entry->data = group;
+			list_add(&grp_list_entry->list, &iommu_group_list);
+		} else {
+			ret_val = -ENOMEM;
+			goto free_group;
+		}
+
 		if (of_property_read_string(node, "label", &name)) {
 			ret_val = -EINVAL;
 			goto free_group;
@@ -696,7 +750,6 @@
 
 		ret_val = find_and_add_contexts(group, node, num_contexts);
 		if (ret_val) {
-			ret_val = -EINVAL;
 			goto free_group;
 		}
 		ret_val = create_and_add_domain(group, node, name);
@@ -704,9 +757,33 @@
 			ret_val = -EINVAL;
 			goto free_group;
 		}
+
+		/* Remove reference to the group that is taken when the group
+		 * is allocated. This will ensure that when all the devices in
+		 * the group are removed the group will be released.
+		 */
+		iommu_group_put(group);
 	}
+
+	list_for_each_entry_safe(grp_list_entry, tmp, &iommu_group_list, list) {
+		list_del(&grp_list_entry->list);
+		kfree(grp_list_entry);
+	}
+	goto out;
+
 free_group:
-	/* No iommu_group_free() function */
+	list_for_each_entry_safe(grp_list_entry, tmp, &iommu_group_list, list) {
+		struct iommu_domain *d;
+
+		d = iommu_group_get_iommudata(grp_list_entry->data);
+		if (d)
+			msm_unregister_domain(d);
+
+		__msm_iommu_group_remove_device(grp_list_entry->data);
+		list_del(&grp_list_entry->list);
+		kfree(grp_list_entry);
+	}
+	iommu_group_put(group);
 out:
 	return ret_val;
 }
diff --git a/drivers/iommu/msm_iommu_dev-v0.c b/drivers/iommu/msm_iommu_dev-v0.c
index 4fca473..4ee65d8 100644
--- a/drivers/iommu/msm_iommu_dev-v0.c
+++ b/drivers/iommu/msm_iommu_dev-v0.c
@@ -80,10 +80,13 @@
 	}
 	mutex_unlock(&iommu_list_lock);
 
-	if (!dev || !dev_get_drvdata(dev))
-		pr_err("Could not find context <%s>\n", ctx_name);
 	put_device(dev);
 
+	if (!dev || !dev_get_drvdata(dev)) {
+		pr_debug("Could not find context <%s>\n", ctx_name);
+		dev = ERR_PTR(-EPROBE_DEFER);
+	}
+
 	return dev;
 }
 EXPORT_SYMBOL(msm_iommu_get_ctx);
@@ -131,6 +134,52 @@
 	mb();
 }
 
+static int __get_clocks(struct platform_device *pdev,
+			struct msm_iommu_drvdata *drvdata,
+			int needs_alt_core_clk)
+{
+	int ret = 0;
+
+	drvdata->pclk = devm_clk_get(&pdev->dev, "iface_clk");
+	if (IS_ERR(drvdata->pclk)) {
+		ret = PTR_ERR(drvdata->pclk);
+		drvdata->pclk = NULL;
+		if (ret != -EPROBE_DEFER) {
+			pr_err("Unable to get %s clock for %s IOMMU device\n",
+				dev_name(&pdev->dev), drvdata->name);
+		}
+		goto fail;
+	}
+
+	drvdata->clk = devm_clk_get(&pdev->dev, "core_clk");
+
+	if (!IS_ERR(drvdata->clk)) {
+		if (clk_get_rate(drvdata->clk) == 0) {
+			ret = clk_round_rate(drvdata->clk, 1000);
+			clk_set_rate(drvdata->clk, ret);
+		}
+	} else {
+		drvdata->clk = NULL;
+	}
+
+	if (needs_alt_core_clk) {
+		drvdata->aclk = devm_clk_get(&pdev->dev, "alt_core_clk");
+		if (IS_ERR(drvdata->aclk)) {
+			ret = PTR_ERR(drvdata->aclk);
+			goto fail;
+		}
+	}
+
+	if (drvdata->aclk && clk_get_rate(drvdata->aclk) == 0) {
+		ret = clk_round_rate(drvdata->aclk, 1000);
+		clk_set_rate(drvdata->aclk, ret);
+	}
+
+	return 0;
+fail:
+	return ret;
+}
+
 #ifdef CONFIG_OF_DEVICE
 
 static int __get_bus_vote_client(struct platform_device *pdev,
@@ -165,13 +214,13 @@
 }
 
 static int msm_iommu_parse_dt(struct platform_device *pdev,
-				struct msm_iommu_drvdata *drvdata,
-				int *needs_alt_core_clk)
+				struct msm_iommu_drvdata *drvdata)
 {
 	struct device_node *child;
 	struct resource *r;
 	u32 glb_offset = 0;
 	int ret = 0;
+	int needs_alt_core_clk;
 
 	ret = __get_bus_vote_client(pdev, drvdata);
 
@@ -212,9 +261,14 @@
 		goto fail;
 	}
 
-	*needs_alt_core_clk = of_property_read_bool(pdev->dev.of_node,
+	needs_alt_core_clk = of_property_read_bool(pdev->dev.of_node,
 						   "qcom,needs-alt-core-clk");
 
+	ret = __get_clocks(pdev, drvdata, needs_alt_core_clk);
+
+	if (ret)
+		goto fail;
+
 	drvdata->sec_id = -1;
 	drvdata->ttbr_split = 0;
 
@@ -235,8 +289,7 @@
 
 #else
 static int msm_iommu_parse_dt(struct platform_device *pdev,
-				struct msm_iommu_drvdata *drvdata,
-				int *needs_alt_core_clk)
+				struct msm_iommu_drvdata *drvdata)
 {
 	return 0;
 }
@@ -246,60 +299,8 @@
 
 }
 
-
 #endif
 
-static int __get_clocks(struct platform_device *pdev,
-			struct msm_iommu_drvdata *drvdata,
-			int needs_alt_core_clk)
-{
-	int ret = 0;
-
-	drvdata->pclk = clk_get(&pdev->dev, "iface_clk");
-	if (IS_ERR(drvdata->pclk)) {
-		ret = PTR_ERR(drvdata->pclk);
-		drvdata->pclk = NULL;
-		pr_err("Unable to get %s clock for %s IOMMU device\n",
-			dev_name(&pdev->dev), drvdata->name);
-		goto fail;
-	}
-
-	drvdata->clk = clk_get(&pdev->dev, "core_clk");
-
-	if (!IS_ERR(drvdata->clk)) {
-		if (clk_get_rate(drvdata->clk) == 0) {
-			ret = clk_round_rate(drvdata->clk, 1000);
-			clk_set_rate(drvdata->clk, ret);
-		}
-	} else {
-		drvdata->clk = NULL;
-	}
-
-	if (needs_alt_core_clk) {
-		drvdata->aclk = devm_clk_get(&pdev->dev, "alt_core_clk");
-		if (IS_ERR(drvdata->aclk))
-			return PTR_ERR(drvdata->aclk);
-	}
-
-	if (drvdata->aclk && clk_get_rate(drvdata->aclk) == 0) {
-		ret = clk_round_rate(drvdata->aclk, 1000);
-		clk_set_rate(drvdata->aclk, ret);
-	}
-
-	return 0;
-fail:
-	return ret;
-}
-
-static void __put_clocks(struct msm_iommu_drvdata *drvdata)
-{
-	if (drvdata->aclk)
-		clk_put(drvdata->aclk);
-	if (drvdata->clk)
-		clk_put(drvdata->clk);
-	clk_put(drvdata->pclk);
-}
-
 /*
  * Do a basic check of the IOMMU by performing an ATS operation
  * on context bank 0.
@@ -391,7 +392,6 @@
 	struct msm_iommu_drvdata *drvdata;
 	struct msm_iommu_dev *iommu_dev = pdev->dev.platform_data;
 	int ret;
-	int needs_alt_core_clk = 0;
 
 	drvdata = devm_kzalloc(&pdev->dev, sizeof(*drvdata), GFP_KERNEL);
 
@@ -401,13 +401,18 @@
 	}
 
 	if (pdev->dev.of_node) {
-		ret = msm_iommu_parse_dt(pdev, drvdata, &needs_alt_core_clk);
+		ret = msm_iommu_parse_dt(pdev, drvdata);
 		if (ret)
 			goto fail;
 	} else if (pdev->dev.platform_data) {
 		struct resource *r, *r2;
 		resource_size_t	len;
 
+		ret = __get_clocks(pdev, drvdata, 0);
+
+		if (ret)
+			goto fail;
+
 		r = platform_get_resource_byname(pdev, IORESOURCE_MEM,
 						"physbase");
 
@@ -418,7 +423,8 @@
 
 		len = resource_size(r);
 
-		r2 = request_mem_region(r->start, len, r->name);
+		r2 = devm_request_mem_region(&pdev->dev, r->start,
+					     len, r->name);
 		if (!r2) {
 			pr_err("Could not request memory region: %pr\n", r);
 			ret = -EBUSY;
@@ -448,11 +454,6 @@
 
 	drvdata->dev = &pdev->dev;
 
-	ret = __get_clocks(pdev, drvdata, needs_alt_core_clk);
-
-	if (ret)
-		goto fail;
-
 	iommu_access_ops_v0.iommu_clk_on(drvdata);
 
 	msm_iommu_reset(drvdata->base, drvdata->glb_base, drvdata->ncb);
@@ -461,14 +462,14 @@
 	if (ret)
 		goto fail_clk;
 
+	iommu_access_ops_v0.iommu_clk_off(drvdata);
+
 	pr_info("device %s mapped at %p, with %d ctx banks\n",
 		drvdata->name, drvdata->base, drvdata->ncb);
 
 	msm_iommu_add_drv(drvdata);
 	platform_set_drvdata(pdev, drvdata);
 
-	iommu_access_ops_v0.iommu_clk_off(drvdata);
-
 	pmon_info = msm_iommu_pm_alloc(&pdev->dev);
 	if (pmon_info != NULL) {
 		ret = msm_iommu_pmon_parse_dt(pdev, pmon_info);
@@ -497,7 +498,6 @@
 
 fail_clk:
 	iommu_access_ops_v0.iommu_clk_off(drvdata);
-	__put_clocks(drvdata);
 fail:
 	__put_bus_vote_client(drvdata);
 fail_mem:
@@ -508,13 +508,13 @@
 {
 	struct msm_iommu_drvdata *drv = NULL;
 
+	msm_iommu_pm_iommu_unregister(&pdev->dev);
+	msm_iommu_pm_free(&pdev->dev);
+
 	drv = platform_get_drvdata(pdev);
 	if (drv) {
 		__put_bus_vote_client(drv);
 		msm_iommu_remove_drv(drv);
-		if (drv->clk)
-			clk_put(drv->clk);
-		clk_put(drv->pclk);
 		platform_set_drvdata(pdev, NULL);
 	}
 	return 0;
@@ -530,26 +530,28 @@
 
 	irq = platform_get_irq(pdev, 0);
 	if (irq > 0) {
-		ret = request_threaded_irq(irq, NULL,
+		ret = devm_request_threaded_irq(&pdev->dev, irq, NULL,
 				msm_iommu_fault_handler,
 				IRQF_ONESHOT | IRQF_SHARED,
 				"msm_iommu_nonsecure_irq", ctx_drvdata);
 		if (ret) {
 			pr_err("Request IRQ %d failed with ret=%d\n", irq, ret);
-			return ret;
+			goto out;
 		}
 	}
 
 	r = platform_get_resource(pdev, IORESOURCE_MEM, 0);
 	if (!r) {
 		pr_err("Could not find reg property for context bank\n");
-		return -EINVAL;
+		ret = -EINVAL;
+		goto out;
 	}
 
 	ret = of_address_to_resource(pdev->dev.parent->of_node, 0, &rp);
 	if (ret) {
 		pr_err("of_address_to_resource failed\n");
-		return -EINVAL;
+		ret = -EINVAL;
+		goto out;
 	}
 
 	/* Calculate the context bank number using the base addresses. CB0
@@ -560,29 +562,34 @@
 	if (of_property_read_string(pdev->dev.of_node, "label",
 					&ctx_drvdata->name)) {
 		pr_err("Could not find label property\n");
-		return -EINVAL;
+		ret = -EINVAL;
+		goto out;
 	}
 
 	if (!of_get_property(pdev->dev.of_node, "qcom,iommu-ctx-mids",
 			     &nmid_array_size)) {
 		pr_err("Could not find iommu-ctx-mids property\n");
-		return -EINVAL;
+		ret = -EINVAL;
+		goto out;
 	}
 	if (nmid_array_size >= sizeof(ctx_drvdata->sids)) {
 		pr_err("Too many mids defined - array size: %u, mids size: %u\n",
 			nmid_array_size, sizeof(ctx_drvdata->sids));
-		return -EINVAL;
+		ret = -EINVAL;
+		goto out;
 	}
 	nmid = nmid_array_size / sizeof(*ctx_drvdata->sids);
 
 	if (of_property_read_u32_array(pdev->dev.of_node, "qcom,iommu-ctx-mids",
 				       ctx_drvdata->sids, nmid)) {
 		pr_err("Could not find iommu-ctx-mids property\n");
-		return -EINVAL;
+		ret = -EINVAL;
+		goto out;
 	}
 	ctx_drvdata->nsid = nmid;
 
-	return 0;
+out:
+	return ret;
 }
 
 static void __program_m2v_tables(struct msm_iommu_drvdata *drvdata,
@@ -632,7 +639,7 @@
 	drvdata = dev_get_drvdata(pdev->dev.parent);
 
 	if (!drvdata) {
-		ret = -ENODEV;
+		ret = -EPROBE_DEFER;
 		goto fail;
 	}
 
@@ -650,8 +657,10 @@
 
 	if (pdev->dev.of_node) {
 		ret = msm_iommu_ctx_parse_dt(pdev, ctx_drvdata);
-		if (ret)
+		if (ret) {
+			platform_set_drvdata(pdev, NULL);
 			goto fail;
+		}
 	} else if (pdev->dev.platform_data) {
 		struct msm_iommu_ctx_dev *c = pdev->dev.platform_data;
 
@@ -673,7 +682,8 @@
 			goto fail;
 		}
 
-		ret = request_threaded_irq(irq, NULL, msm_iommu_fault_handler,
+		ret = devm_request_threaded_irq(&pdev->dev, irq, NULL,
+					msm_iommu_fault_handler,
 					IRQF_ONESHOT | IRQF_SHARED,
 					"msm_iommu_nonsecure_irq", ctx_drvdata);
 
diff --git a/drivers/iommu/msm_iommu_dev-v1.c b/drivers/iommu/msm_iommu_dev-v1.c
index cf05a36..c0e05f4 100644
--- a/drivers/iommu/msm_iommu_dev-v1.c
+++ b/drivers/iommu/msm_iommu_dev-v1.c
@@ -125,7 +125,6 @@
 	struct resource *r;
 
 	drvdata->dev = &pdev->dev;
-	msm_iommu_add_drv(drvdata);
 
 	ret = __get_bus_vote_client(pdev, drvdata);
 
@@ -178,6 +177,7 @@
 	if (ret)
 		pr_err("Failed to create iommu context device\n");
 
+	msm_iommu_add_drv(drvdata);
 fail:
 	__put_bus_vote_client(drvdata);
 	return ret;
@@ -265,7 +265,7 @@
 
 	drvdata->gdsc = devm_regulator_get(&pdev->dev, "vdd");
 	if (IS_ERR(drvdata->gdsc))
-		return -EINVAL;
+		return PTR_ERR(drvdata->gdsc);
 
 	drvdata->alt_gdsc = devm_regulator_get(&pdev->dev, "qcom,alt-vdd");
 	if (IS_ERR(drvdata->alt_gdsc))
@@ -344,9 +344,6 @@
 	if (drv) {
 		__put_bus_vote_client(drv);
 		msm_iommu_remove_drv(drv);
-		if (drv->clk)
-			clk_put(drv->clk);
-		clk_put(drv->pclk);
 		platform_set_drvdata(pdev, NULL);
 	}
 	return 0;
@@ -356,7 +353,7 @@
 				struct msm_iommu_ctx_drvdata *ctx_drvdata)
 {
 	struct resource *r, rp;
-	int irq, ret;
+	int irq = 0, ret = 0;
 	u32 nsid;
 
 	ctx_drvdata->secure_context = of_property_read_bool(pdev->dev.of_node,
@@ -365,25 +362,27 @@
 	if (!ctx_drvdata->secure_context) {
 		irq = platform_get_irq(pdev, 0);
 		if (irq > 0) {
-			ret = request_threaded_irq(irq, NULL,
+			ret = devm_request_threaded_irq(&pdev->dev, irq, NULL,
 					msm_iommu_fault_handler_v2,
 					IRQF_ONESHOT | IRQF_SHARED,
 					"msm_iommu_nonsecure_irq", pdev);
 			if (ret) {
 				pr_err("Request IRQ %d failed with ret=%d\n",
 					irq, ret);
-				return ret;
+				goto out;
 			}
 		}
 	}
 
 	r = platform_get_resource(pdev, IORESOURCE_MEM, 0);
-	if (!r)
-		return -EINVAL;
+	if (!r) {
+		ret = -EINVAL;
+		goto out;
+	}
 
 	ret = of_address_to_resource(pdev->dev.parent->of_node, 0, &rp);
 	if (ret)
-		return -EINVAL;
+		goto out;
 
 	/* Calculate the context bank number using the base addresses. The
 	 * first 8 pages belong to the global address space which is followed
@@ -396,21 +395,26 @@
 					&ctx_drvdata->name))
 		ctx_drvdata->name = dev_name(&pdev->dev);
 
-	if (!of_get_property(pdev->dev.of_node, "qcom,iommu-ctx-sids", &nsid))
-		return -EINVAL;
-
-	if (nsid >= sizeof(ctx_drvdata->sids))
-		return -EINVAL;
+	if (!of_get_property(pdev->dev.of_node, "qcom,iommu-ctx-sids", &nsid)) {
+		ret = -EINVAL;
+		goto out;
+	}
+	if (nsid >= sizeof(ctx_drvdata->sids)) {
+		ret = -EINVAL;
+		goto out;
+	}
 
 	if (of_property_read_u32_array(pdev->dev.of_node, "qcom,iommu-ctx-sids",
 				       ctx_drvdata->sids,
 				       nsid / sizeof(*ctx_drvdata->sids))) {
-		return -EINVAL;
+		ret = -EINVAL;
+		goto out;
 	}
 	ctx_drvdata->nsid = nsid;
 
 	ctx_drvdata->asid = -1;
-	return 0;
+out:
+	return ret;
 }
 
 static int __devinit msm_iommu_ctx_probe(struct platform_device *pdev)
@@ -428,12 +432,14 @@
 
 	ctx_drvdata->pdev = pdev;
 	INIT_LIST_HEAD(&ctx_drvdata->attached_elm);
-	platform_set_drvdata(pdev, ctx_drvdata);
 
 	ret = msm_iommu_ctx_parse_dt(pdev, ctx_drvdata);
-	if (!ret)
+	if (!ret) {
+		platform_set_drvdata(pdev, ctx_drvdata);
+
 		dev_info(&pdev->dev, "context %s using bank %d\n",
 			 ctx_drvdata->name, ctx_drvdata->num);
+	}
 
 	return ret;
 }