arm64: kernel: Add support for User Access Override

'User Access Override' is a new ARMv8.2 feature which allows the
unprivileged load and store instructions to be overridden to behave in
the normal way.

This patch converts {get,put}_user() and friends to use ldtr*/sttr*
instructions - so that they can only access EL0 memory, then enables
UAO when fs==KERNEL_DS so that these functions can access kernel memory.

This allows user space's read/write permissions to be checked against the
page tables, instead of testing addr<USER_DS, then using the kernel's
read/write permissions.

Signed-off-by: James Morse <james.morse@arm.com>
[catalin.marinas@arm.com: move uao_thread_switch() above dsb()]
Signed-off-by: Catalin Marinas <catalin.marinas@arm.com>
diff --git a/arch/arm64/include/asm/uaccess.h b/arch/arm64/include/asm/uaccess.h
index b2ede967..f973bdc 100644
--- a/arch/arm64/include/asm/uaccess.h
+++ b/arch/arm64/include/asm/uaccess.h
@@ -64,6 +64,16 @@
 static inline void set_fs(mm_segment_t fs)
 {
 	current_thread_info()->addr_limit = fs;
+
+	/*
+	 * Enable/disable UAO so that copy_to_user() etc can access
+	 * kernel memory with the unprivileged instructions.
+	 */
+	if (IS_ENABLED(CONFIG_ARM64_UAO) && fs == KERNEL_DS)
+		asm(ALTERNATIVE("nop", SET_PSTATE_UAO(1), ARM64_HAS_UAO));
+	else
+		asm(ALTERNATIVE("nop", SET_PSTATE_UAO(0), ARM64_HAS_UAO,
+				CONFIG_ARM64_UAO));
 }
 
 #define segment_eq(a, b)	((a) == (b))
@@ -113,9 +123,10 @@
  * The "__xxx_error" versions set the third argument to -EFAULT if an error
  * occurs, and leave it unchanged on success.
  */
-#define __get_user_asm(instr, reg, x, addr, err)			\
+#define __get_user_asm(instr, alt_instr, reg, x, addr, err, feature)	\
 	asm volatile(							\
-	"1:	" instr "	" reg "1, [%2]\n"			\
+	"1:"ALTERNATIVE(instr "     " reg "1, [%2]\n",			\
+			alt_instr " " reg "1, [%2]\n", feature)		\
 	"2:\n"								\
 	"	.section .fixup, \"ax\"\n"				\
 	"	.align	2\n"						\
@@ -138,16 +149,20 @@
 			CONFIG_ARM64_PAN));				\
 	switch (sizeof(*(ptr))) {					\
 	case 1:								\
-		__get_user_asm("ldrb", "%w", __gu_val, (ptr), (err));	\
+		__get_user_asm("ldrb", "ldtrb", "%w", __gu_val, (ptr),  \
+			       (err), ARM64_HAS_UAO);			\
 		break;							\
 	case 2:								\
-		__get_user_asm("ldrh", "%w", __gu_val, (ptr), (err));	\
+		__get_user_asm("ldrh", "ldtrh", "%w", __gu_val, (ptr),  \
+			       (err), ARM64_HAS_UAO);			\
 		break;							\
 	case 4:								\
-		__get_user_asm("ldr", "%w", __gu_val, (ptr), (err));	\
+		__get_user_asm("ldr", "ldtr", "%w", __gu_val, (ptr),	\
+			       (err), ARM64_HAS_UAO);			\
 		break;							\
 	case 8:								\
-		__get_user_asm("ldr", "%",  __gu_val, (ptr), (err));	\
+		__get_user_asm("ldr", "ldtr", "%",  __gu_val, (ptr),	\
+			       (err), ARM64_HAS_UAO);			\
 		break;							\
 	default:							\
 		BUILD_BUG();						\
@@ -181,9 +196,10 @@
 		((x) = 0, -EFAULT);					\
 })
 
-#define __put_user_asm(instr, reg, x, addr, err)			\
+#define __put_user_asm(instr, alt_instr, reg, x, addr, err, feature)	\
 	asm volatile(							\
-	"1:	" instr "	" reg "1, [%2]\n"			\
+	"1:"ALTERNATIVE(instr "     " reg "1, [%2]\n",			\
+			alt_instr " " reg "1, [%2]\n", feature)		\
 	"2:\n"								\
 	"	.section .fixup,\"ax\"\n"				\
 	"	.align	2\n"						\
@@ -205,16 +221,20 @@
 			CONFIG_ARM64_PAN));				\
 	switch (sizeof(*(ptr))) {					\
 	case 1:								\
-		__put_user_asm("strb", "%w", __pu_val, (ptr), (err));	\
+		__put_user_asm("strb", "sttrb", "%w", __pu_val, (ptr),	\
+			       (err), ARM64_HAS_UAO);			\
 		break;							\
 	case 2:								\
-		__put_user_asm("strh", "%w", __pu_val, (ptr), (err));	\
+		__put_user_asm("strh", "sttrh", "%w", __pu_val, (ptr),	\
+			       (err), ARM64_HAS_UAO);			\
 		break;							\
 	case 4:								\
-		__put_user_asm("str",  "%w", __pu_val, (ptr), (err));	\
+		__put_user_asm("str", "sttr", "%w", __pu_val, (ptr),	\
+			       (err), ARM64_HAS_UAO);			\
 		break;							\
 	case 8:								\
-		__put_user_asm("str",  "%", __pu_val, (ptr), (err));	\
+		__put_user_asm("str", "sttr", "%", __pu_val, (ptr),	\
+			       (err), ARM64_HAS_UAO);			\
 		break;							\
 	default:							\
 		BUILD_BUG();						\