arm64: mm: use fixmap when creating page tables

As a preparatory step to allow us to allocate early page tables from
unmapped memory using memblock_alloc, modify the __create_mapping
callees to map and unmap the tables they modify using fixmap entries.

All but the top-level pgd initialisation is performed via the fixmap.
Subsequent patches will inject the pgd physical address, and migrate to
using the FIX_PGD slot.

Signed-off-by: Mark Rutland <mark.rutland@arm.com>
Reviewed-by: Catalin Marinas <catalin.marinas@arm.com>
Tested-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Reviewed-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Tested-by: Jeremy Linton <jeremy.linton@arm.com>
Cc: Laura Abbott <labbott@fedoraproject.org>
Cc: Will Deacon <will.deacon@arm.com>
Signed-off-by: Catalin Marinas <catalin.marinas@arm.com>
diff --git a/arch/arm64/mm/mmu.c b/arch/arm64/mm/mmu.c
index 8879aed..5248b02 100644
--- a/arch/arm64/mm/mmu.c
+++ b/arch/arm64/mm/mmu.c
@@ -63,19 +63,30 @@
 }
 EXPORT_SYMBOL(phys_mem_access_prot);
 
-static void __init *early_pgtable_alloc(void)
+static phys_addr_t __init early_pgtable_alloc(void)
 {
 	phys_addr_t phys;
 	void *ptr;
 
 	phys = memblock_alloc(PAGE_SIZE, PAGE_SIZE);
 	BUG_ON(!phys);
-	ptr = __va(phys);
+
+	/*
+	 * The FIX_{PGD,PUD,PMD} slots may be in active use, but the FIX_PTE
+	 * slot will be free, so we can (ab)use the FIX_PTE slot to initialise
+	 * any level of table.
+	 */
+	ptr = pte_set_fixmap(phys);
+
 	memset(ptr, 0, PAGE_SIZE);
 
-	/* Ensure the zeroed page is visible to the page table walker */
-	dsb(ishst);
-	return ptr;
+	/*
+	 * Implicit barriers also ensure the zeroed page is visible to the page
+	 * table walker
+	 */
+	pte_clear_fixmap();
+
+	return phys;
 }
 
 /*
@@ -99,24 +110,28 @@
 static void alloc_init_pte(pmd_t *pmd, unsigned long addr,
 				  unsigned long end, unsigned long pfn,
 				  pgprot_t prot,
-				  void *(*pgtable_alloc)(void))
+				  phys_addr_t (*pgtable_alloc)(void))
 {
 	pte_t *pte;
 
 	if (pmd_none(*pmd) || pmd_sect(*pmd)) {
-		pte = pgtable_alloc();
+		phys_addr_t pte_phys = pgtable_alloc();
+		pte = pte_set_fixmap(pte_phys);
 		if (pmd_sect(*pmd))
 			split_pmd(pmd, pte);
-		__pmd_populate(pmd, __pa(pte), PMD_TYPE_TABLE);
+		__pmd_populate(pmd, pte_phys, PMD_TYPE_TABLE);
 		flush_tlb_all();
+		pte_clear_fixmap();
 	}
 	BUG_ON(pmd_bad(*pmd));
 
-	pte = pte_offset_kernel(pmd, addr);
+	pte = pte_set_fixmap_offset(pmd, addr);
 	do {
 		set_pte(pte, pfn_pte(pfn, prot));
 		pfn++;
 	} while (pte++, addr += PAGE_SIZE, addr != end);
+
+	pte_clear_fixmap();
 }
 
 static void split_pud(pud_t *old_pud, pmd_t *pmd)
@@ -134,7 +149,7 @@
 static void alloc_init_pmd(struct mm_struct *mm, pud_t *pud,
 				  unsigned long addr, unsigned long end,
 				  phys_addr_t phys, pgprot_t prot,
-				  void *(*pgtable_alloc)(void))
+				  phys_addr_t (*pgtable_alloc)(void))
 {
 	pmd_t *pmd;
 	unsigned long next;
@@ -143,7 +158,8 @@
 	 * Check for initial section mappings in the pgd/pud and remove them.
 	 */
 	if (pud_none(*pud) || pud_sect(*pud)) {
-		pmd = pgtable_alloc();
+		phys_addr_t pmd_phys = pgtable_alloc();
+		pmd = pmd_set_fixmap(pmd_phys);
 		if (pud_sect(*pud)) {
 			/*
 			 * need to have the 1G of mappings continue to be
@@ -151,12 +167,13 @@
 			 */
 			split_pud(pud, pmd);
 		}
-		pud_populate(mm, pud, pmd);
+		__pud_populate(pud, pmd_phys, PUD_TYPE_TABLE);
 		flush_tlb_all();
+		pmd_clear_fixmap();
 	}
 	BUG_ON(pud_bad(*pud));
 
-	pmd = pmd_offset(pud, addr);
+	pmd = pmd_set_fixmap_offset(pud, addr);
 	do {
 		next = pmd_addr_end(addr, end);
 		/* try section mapping first */
@@ -182,6 +199,8 @@
 		}
 		phys += next - addr;
 	} while (pmd++, addr = next, addr != end);
+
+	pmd_clear_fixmap();
 }
 
 static inline bool use_1G_block(unsigned long addr, unsigned long next,
@@ -199,18 +218,18 @@
 static void alloc_init_pud(struct mm_struct *mm, pgd_t *pgd,
 				  unsigned long addr, unsigned long end,
 				  phys_addr_t phys, pgprot_t prot,
-				  void *(*pgtable_alloc)(void))
+				  phys_addr_t (*pgtable_alloc)(void))
 {
 	pud_t *pud;
 	unsigned long next;
 
 	if (pgd_none(*pgd)) {
-		pud = pgtable_alloc();
-		pgd_populate(mm, pgd, pud);
+		phys_addr_t pud_phys = pgtable_alloc();
+		__pgd_populate(pgd, pud_phys, PUD_TYPE_TABLE);
 	}
 	BUG_ON(pgd_bad(*pgd));
 
-	pud = pud_offset(pgd, addr);
+	pud = pud_set_fixmap_offset(pgd, addr);
 	do {
 		next = pud_addr_end(addr, end);
 
@@ -243,6 +262,8 @@
 		}
 		phys += next - addr;
 	} while (pud++, addr = next, addr != end);
+
+	pud_clear_fixmap();
 }
 
 /*
@@ -252,7 +273,7 @@
 static void  __create_mapping(struct mm_struct *mm, pgd_t *pgd,
 				    phys_addr_t phys, unsigned long virt,
 				    phys_addr_t size, pgprot_t prot,
-				    void *(*pgtable_alloc)(void))
+				    phys_addr_t (*pgtable_alloc)(void))
 {
 	unsigned long addr, length, end, next;
 
@@ -275,14 +296,14 @@
 	} while (pgd++, addr = next, addr != end);
 }
 
-static void *late_pgtable_alloc(void)
+static phys_addr_t late_pgtable_alloc(void)
 {
 	void *ptr = (void *)__get_free_page(PGALLOC_GFP);
 	BUG_ON(!ptr);
 
 	/* Ensure the zeroed page is visible to the page table walker */
 	dsb(ishst);
-	return ptr;
+	return __pa(ptr);
 }
 
 static void __init create_mapping(phys_addr_t phys, unsigned long virt,