powerpc: split hugepage when using subpage protection

We find all the overlapping vma and mark them such that we don't allocate
hugepage in that range. Also we split existing huge page so that the
normal page hash can be invalidated and new page faulted in with new
protection bits.

Signed-off-by: Aneesh Kumar K.V <aneesh.kumar@linux.vnet.ibm.com>
Signed-off-by: Benjamin Herrenschmidt <benh@kernel.crashing.org>
diff --git a/arch/powerpc/mm/subpage-prot.c b/arch/powerpc/mm/subpage-prot.c
index 7c415dd..aa74acb 100644
--- a/arch/powerpc/mm/subpage-prot.c
+++ b/arch/powerpc/mm/subpage-prot.c
@@ -130,6 +130,53 @@
 	up_write(&mm->mmap_sem);
 }
 
+#ifdef CONFIG_TRANSPARENT_HUGEPAGE
+static int subpage_walk_pmd_entry(pmd_t *pmd, unsigned long addr,
+				  unsigned long end, struct mm_walk *walk)
+{
+	struct vm_area_struct *vma = walk->private;
+	split_huge_page_pmd(vma, addr, pmd);
+	return 0;
+}
+
+static void subpage_mark_vma_nohuge(struct mm_struct *mm, unsigned long addr,
+				    unsigned long len)
+{
+	struct vm_area_struct *vma;
+	struct mm_walk subpage_proto_walk = {
+		.mm = mm,
+		.pmd_entry = subpage_walk_pmd_entry,
+	};
+
+	/*
+	 * We don't try too hard, we just mark all the vma in that range
+	 * VM_NOHUGEPAGE and split them.
+	 */
+	vma = find_vma(mm, addr);
+	/*
+	 * If the range is in unmapped range, just return
+	 */
+	if (vma && ((addr + len) <= vma->vm_start))
+		return;
+
+	while (vma) {
+		if (vma->vm_start >= (addr + len))
+			break;
+		vma->vm_flags |= VM_NOHUGEPAGE;
+		subpage_proto_walk.private = vma;
+		walk_page_range(vma->vm_start, vma->vm_end,
+				&subpage_proto_walk);
+		vma = vma->vm_next;
+	}
+}
+#else
+static void subpage_mark_vma_nohuge(struct mm_struct *mm, unsigned long addr,
+				    unsigned long len)
+{
+	return;
+}
+#endif
+
 /*
  * Copy in a subpage protection map for an address range.
  * The map has 2 bits per 4k subpage, so 32 bits per 64k page.
@@ -168,6 +215,7 @@
 		return -EFAULT;
 
 	down_write(&mm->mmap_sem);
+	subpage_mark_vma_nohuge(mm, addr, len);
 	for (limit = addr + len; addr < limit; addr = next) {
 		next = pmd_addr_end(addr, limit);
 		err = -ENOMEM;