Merge branch 'amd-iommu/pagetable' into amd-iommu/2.6.32

Conflicts:
	arch/x86/kernel/amd_iommu.c
diff --git a/arch/x86/kernel/amd_iommu.c b/arch/x86/kernel/amd_iommu.c
index dc19ed4..98f230f 100644
--- a/arch/x86/kernel/amd_iommu.c
+++ b/arch/x86/kernel/amd_iommu.c
@@ -47,7 +47,6 @@
  */
 static struct protection_domain *pt_domain;
 
-#ifdef CONFIG_IOMMU_API
 static struct iommu_ops amd_iommu_ops;
 
 /*
@@ -60,13 +59,16 @@
 static int dma_ops_unity_map(struct dma_ops_domain *dma_dom,
 			     struct unity_map_entry *e);
 static struct dma_ops_domain *find_protection_domain(u16 devid);
-static u64* alloc_pte(struct protection_domain *dom,
-		      unsigned long address, u64
-		      **pte_page, gfp_t gfp);
+static u64 *alloc_pte(struct protection_domain *domain,
+		      unsigned long address, int end_lvl,
+		      u64 **pte_page, gfp_t gfp);
 static void dma_ops_reserve_addresses(struct dma_ops_domain *dom,
 				      unsigned long start_page,
 				      unsigned int pages);
 static void reset_iommu_command_buffer(struct amd_iommu *iommu);
+static u64 *fetch_pte(struct protection_domain *domain,
+		      unsigned long address, int map_size);
+static void update_domain(struct protection_domain *domain);
 
 #ifdef CONFIG_AMD_IOMMU_STATS
 
@@ -535,12 +537,15 @@
 	}
 }
 
-void amd_iommu_flush_all_devices(void)
+static void flush_devices_by_domain(struct protection_domain *domain)
 {
 	struct amd_iommu *iommu;
 	int i;
 
 	for (i = 0; i <= amd_iommu_last_bdf; ++i) {
+		if ((domain == NULL && amd_iommu_pd_table[i] == NULL) ||
+		    (amd_iommu_pd_table[i] != domain))
+			continue;
 
 		iommu = amd_iommu_rlookup_table[i];
 		if (!iommu)
@@ -567,6 +572,11 @@
 	iommu->reset_in_progress = false;
 }
 
+void amd_iommu_flush_all_devices(void)
+{
+	flush_devices_by_domain(NULL);
+}
+
 /****************************************************************************
  *
  * The functions below are used the create the page table mappings for
@@ -584,18 +594,21 @@
 static int iommu_map_page(struct protection_domain *dom,
 			  unsigned long bus_addr,
 			  unsigned long phys_addr,
-			  int prot)
+			  int prot,
+			  int map_size)
 {
 	u64 __pte, *pte;
 
 	bus_addr  = PAGE_ALIGN(bus_addr);
 	phys_addr = PAGE_ALIGN(phys_addr);
 
-	/* only support 512GB address spaces for now */
-	if (bus_addr > IOMMU_MAP_SIZE_L3 || !(prot & IOMMU_PROT_MASK))
+	BUG_ON(!PM_ALIGNED(map_size, bus_addr));
+	BUG_ON(!PM_ALIGNED(map_size, phys_addr));
+
+	if (!(prot & IOMMU_PROT_MASK))
 		return -EINVAL;
 
-	pte = alloc_pte(dom, bus_addr, NULL, GFP_KERNEL);
+	pte = alloc_pte(dom, bus_addr, map_size, NULL, GFP_KERNEL);
 
 	if (IOMMU_PTE_PRESENT(*pte))
 		return -EBUSY;
@@ -608,29 +621,18 @@
 
 	*pte = __pte;
 
+	update_domain(dom);
+
 	return 0;
 }
 
 static void iommu_unmap_page(struct protection_domain *dom,
-			     unsigned long bus_addr)
+			     unsigned long bus_addr, int map_size)
 {
-	u64 *pte;
+	u64 *pte = fetch_pte(dom, bus_addr, map_size);
 
-	pte = &dom->pt_root[IOMMU_PTE_L2_INDEX(bus_addr)];
-
-	if (!IOMMU_PTE_PRESENT(*pte))
-		return;
-
-	pte = IOMMU_PTE_PAGE(*pte);
-	pte = &pte[IOMMU_PTE_L1_INDEX(bus_addr)];
-
-	if (!IOMMU_PTE_PRESENT(*pte))
-		return;
-
-	pte = IOMMU_PTE_PAGE(*pte);
-	pte = &pte[IOMMU_PTE_L1_INDEX(bus_addr)];
-
-	*pte = 0;
+	if (pte)
+		*pte = 0;
 }
 
 /*
@@ -685,7 +687,8 @@
 
 	for (addr = e->address_start; addr < e->address_end;
 	     addr += PAGE_SIZE) {
-		ret = iommu_map_page(&dma_dom->domain, addr, addr, e->prot);
+		ret = iommu_map_page(&dma_dom->domain, addr, addr, e->prot,
+				     PM_MAP_4k);
 		if (ret)
 			return ret;
 		/*
@@ -740,24 +743,29 @@
  * This function checks if there is a PTE for a given dma address. If
  * there is one, it returns the pointer to it.
  */
-static u64* fetch_pte(struct protection_domain *domain,
-		      unsigned long address)
+static u64 *fetch_pte(struct protection_domain *domain,
+		      unsigned long address, int map_size)
 {
+	int level;
 	u64 *pte;
 
-	pte = &domain->pt_root[IOMMU_PTE_L2_INDEX(address)];
+	level =  domain->mode - 1;
+	pte   = &domain->pt_root[PM_LEVEL_INDEX(level, address)];
 
-	if (!IOMMU_PTE_PRESENT(*pte))
-		return NULL;
+	while (level > map_size) {
+		if (!IOMMU_PTE_PRESENT(*pte))
+			return NULL;
 
-	pte = IOMMU_PTE_PAGE(*pte);
-	pte = &pte[IOMMU_PTE_L1_INDEX(address)];
+		level -= 1;
 
-	if (!IOMMU_PTE_PRESENT(*pte))
-		return NULL;
+		pte = IOMMU_PTE_PAGE(*pte);
+		pte = &pte[PM_LEVEL_INDEX(level, address)];
 
-	pte = IOMMU_PTE_PAGE(*pte);
-	pte = &pte[IOMMU_PTE_L0_INDEX(address)];
+		if ((PM_PTE_LEVEL(*pte) == 0) && level != map_size) {
+			pte = NULL;
+			break;
+		}
+	}
 
 	return pte;
 }
@@ -797,7 +805,7 @@
 		u64 *pte, *pte_page;
 
 		for (i = 0; i < num_ptes; ++i) {
-			pte = alloc_pte(&dma_dom->domain, address,
+			pte = alloc_pte(&dma_dom->domain, address, PM_MAP_4k,
 					&pte_page, gfp);
 			if (!pte)
 				goto out_free;
@@ -830,16 +838,20 @@
 	for (i = dma_dom->aperture[index]->offset;
 	     i < dma_dom->aperture_size;
 	     i += PAGE_SIZE) {
-		u64 *pte = fetch_pte(&dma_dom->domain, i);
+		u64 *pte = fetch_pte(&dma_dom->domain, i, PM_MAP_4k);
 		if (!pte || !IOMMU_PTE_PRESENT(*pte))
 			continue;
 
 		dma_ops_reserve_addresses(dma_dom, i << PAGE_SHIFT, 1);
 	}
 
+	update_domain(&dma_dom->domain);
+
 	return 0;
 
 out_free:
+	update_domain(&dma_dom->domain);
+
 	free_page((unsigned long)dma_dom->aperture[index]->bitmap);
 
 	kfree(dma_dom->aperture[index]);
@@ -1079,7 +1091,7 @@
 	dma_dom->domain.id = domain_id_alloc();
 	if (dma_dom->domain.id == 0)
 		goto free_dma_dom;
-	dma_dom->domain.mode = PAGE_MODE_3_LEVEL;
+	dma_dom->domain.mode = PAGE_MODE_2_LEVEL;
 	dma_dom->domain.pt_root = (void *)get_zeroed_page(GFP_KERNEL);
 	dma_dom->domain.flags = PD_DMA_OPS_MASK;
 	dma_dom->domain.priv = dma_dom;
@@ -1133,20 +1145,9 @@
 	return dom;
 }
 
-/*
- * If a device is not yet associated with a domain, this function does
- * assigns it visible for the hardware
- */
-static void __attach_device(struct amd_iommu *iommu,
-			    struct protection_domain *domain,
-			    u16 devid)
+static void set_dte_entry(u16 devid, struct protection_domain *domain)
 {
-	u64 pte_root;
-
-	/* lock domain */
-	spin_lock(&domain->lock);
-
-	pte_root = virt_to_phys(domain->pt_root);
+	u64 pte_root = virt_to_phys(domain->pt_root);
 
 	pte_root |= (domain->mode & DEV_ENTRY_MODE_MASK)
 		    << DEV_ENTRY_MODE_SHIFT;
@@ -1157,6 +1158,21 @@
 	amd_iommu_dev_table[devid].data[0] = lower_32_bits(pte_root);
 
 	amd_iommu_pd_table[devid] = domain;
+}
+
+/*
+ * If a device is not yet associated with a domain, this function does
+ * assigns it visible for the hardware
+ */
+static void __attach_device(struct amd_iommu *iommu,
+			    struct protection_domain *domain,
+			    u16 devid)
+{
+	/* lock domain */
+	spin_lock(&domain->lock);
+
+	/* update DTE entry */
+	set_dte_entry(devid, domain);
 
 	domain->dev_cnt += 1;
 
@@ -1164,6 +1180,10 @@
 	spin_unlock(&domain->lock);
 }
 
+/*
+ * If a device is not yet associated with a domain, this function does
+ * assigns it visible for the hardware
+ */
 static void attach_device(struct amd_iommu *iommu,
 			  struct protection_domain *domain,
 			  u16 devid)
@@ -1389,40 +1409,92 @@
 	return 1;
 }
 
+static void update_device_table(struct protection_domain *domain)
+{
+	unsigned long flags;
+	int i;
+
+	for (i = 0; i <= amd_iommu_last_bdf; ++i) {
+		if (amd_iommu_pd_table[i] != domain)
+			continue;
+		write_lock_irqsave(&amd_iommu_devtable_lock, flags);
+		set_dte_entry(i, domain);
+		write_unlock_irqrestore(&amd_iommu_devtable_lock, flags);
+	}
+}
+
+static void update_domain(struct protection_domain *domain)
+{
+	if (!domain->updated)
+		return;
+
+	update_device_table(domain);
+	flush_devices_by_domain(domain);
+	iommu_flush_domain(domain->id);
+
+	domain->updated = false;
+}
+
 /*
- * If the pte_page is not yet allocated this function is called
+ * This function is used to add another level to an IO page table. Adding
+ * another level increases the size of the address space by 9 bits to a size up
+ * to 64 bits.
  */
-static u64* alloc_pte(struct protection_domain *dom,
-		      unsigned long address, u64 **pte_page, gfp_t gfp)
+static bool increase_address_space(struct protection_domain *domain,
+				   gfp_t gfp)
+{
+	u64 *pte;
+
+	if (domain->mode == PAGE_MODE_6_LEVEL)
+		/* address space already 64 bit large */
+		return false;
+
+	pte = (void *)get_zeroed_page(gfp);
+	if (!pte)
+		return false;
+
+	*pte             = PM_LEVEL_PDE(domain->mode,
+					virt_to_phys(domain->pt_root));
+	domain->pt_root  = pte;
+	domain->mode    += 1;
+	domain->updated  = true;
+
+	return true;
+}
+
+static u64 *alloc_pte(struct protection_domain *domain,
+		      unsigned long address,
+		      int end_lvl,
+		      u64 **pte_page,
+		      gfp_t gfp)
 {
 	u64 *pte, *page;
+	int level;
 
-	pte = &dom->pt_root[IOMMU_PTE_L2_INDEX(address)];
+	while (address > PM_LEVEL_SIZE(domain->mode))
+		increase_address_space(domain, gfp);
 
-	if (!IOMMU_PTE_PRESENT(*pte)) {
-		page = (u64 *)get_zeroed_page(gfp);
-		if (!page)
-			return NULL;
-		*pte = IOMMU_L2_PDE(virt_to_phys(page));
+	level =  domain->mode - 1;
+	pte   = &domain->pt_root[PM_LEVEL_INDEX(level, address)];
+
+	while (level > end_lvl) {
+		if (!IOMMU_PTE_PRESENT(*pte)) {
+			page = (u64 *)get_zeroed_page(gfp);
+			if (!page)
+				return NULL;
+			*pte = PM_LEVEL_PDE(level, virt_to_phys(page));
+		}
+
+		level -= 1;
+
+		pte = IOMMU_PTE_PAGE(*pte);
+
+		if (pte_page && level == end_lvl)
+			*pte_page = pte;
+
+		pte = &pte[PM_LEVEL_INDEX(level, address)];
 	}
 
-	pte = IOMMU_PTE_PAGE(*pte);
-	pte = &pte[IOMMU_PTE_L1_INDEX(address)];
-
-	if (!IOMMU_PTE_PRESENT(*pte)) {
-		page = (u64 *)get_zeroed_page(gfp);
-		if (!page)
-			return NULL;
-		*pte = IOMMU_L1_PDE(virt_to_phys(page));
-	}
-
-	pte = IOMMU_PTE_PAGE(*pte);
-
-	if (pte_page)
-		*pte_page = pte;
-
-	pte = &pte[IOMMU_PTE_L0_INDEX(address)];
-
 	return pte;
 }
 
@@ -1441,10 +1513,13 @@
 
 	pte = aperture->pte_pages[APERTURE_PAGE_INDEX(address)];
 	if (!pte) {
-		pte = alloc_pte(&dom->domain, address, &pte_page, GFP_ATOMIC);
+		pte = alloc_pte(&dom->domain, address, PM_MAP_4k, &pte_page,
+				GFP_ATOMIC);
 		aperture->pte_pages[APERTURE_PAGE_INDEX(address)] = pte_page;
 	} else
-		pte += IOMMU_PTE_L0_INDEX(address);
+		pte += PM_LEVEL_INDEX(0, address);
+
+	update_domain(&dom->domain);
 
 	return pte;
 }
@@ -1506,7 +1581,7 @@
 	if (!pte)
 		return;
 
-	pte += IOMMU_PTE_L0_INDEX(address);
+	pte += PM_LEVEL_INDEX(0, address);
 
 	WARN_ON(!*pte);
 
@@ -2240,7 +2315,7 @@
 	paddr &= PAGE_MASK;
 
 	for (i = 0; i < npages; ++i) {
-		ret = iommu_map_page(domain, iova, paddr, prot);
+		ret = iommu_map_page(domain, iova, paddr, prot, PM_MAP_4k);
 		if (ret)
 			return ret;
 
@@ -2261,7 +2336,7 @@
 	iova  &= PAGE_MASK;
 
 	for (i = 0; i < npages; ++i) {
-		iommu_unmap_page(domain, iova);
+		iommu_unmap_page(domain, iova, PM_MAP_4k);
 		iova  += PAGE_SIZE;
 	}
 
@@ -2276,21 +2351,9 @@
 	phys_addr_t paddr;
 	u64 *pte;
 
-	pte = &domain->pt_root[IOMMU_PTE_L2_INDEX(iova)];
+	pte = fetch_pte(domain, iova, PM_MAP_4k);
 
-	if (!IOMMU_PTE_PRESENT(*pte))
-		return 0;
-
-	pte = IOMMU_PTE_PAGE(*pte);
-	pte = &pte[IOMMU_PTE_L1_INDEX(iova)];
-
-	if (!IOMMU_PTE_PRESENT(*pte))
-		return 0;
-
-	pte = IOMMU_PTE_PAGE(*pte);
-	pte = &pte[IOMMU_PTE_L0_INDEX(iova)];
-
-	if (!IOMMU_PTE_PRESENT(*pte))
+	if (!pte || !IOMMU_PTE_PRESENT(*pte))
 		return 0;
 
 	paddr  = *pte & IOMMU_PAGE_MASK;