MIPS: Basic MSA context switching support

This patch adds support for context switching the MSA vector registers.
These 128 bit vector registers are aliased with the FP registers - an
FP register accesses the least significant bits of the vector register
with which it is aliased (ie. the register with the same index). Due to
both this & the requirement that the scalar FPU must be 64-bit (FR=1) if
enabled at the same time as MSA the kernel will enable MSA & scalar FP
at the same time for tasks which use MSA. If we restore the MSA vector
context then we might as well enable the scalar FPU since the reason it
was left disabled was to allow for lazy FP context restoring - but we
just restored the FP context as it's a subset of the vector context. If
we restore the FP context and have previously used MSA then we have to
restore the whole vector context anyway (see comment in
enable_restore_fp_context for details) so similarly we might as well
enable MSA.

Thus if a task does not use MSA then it will continue to behave as
without this patch - the scalar FP context will be saved & restored as
usual. But if a task executes an MSA instruction then it will save &
restore the vector context forever more.

Signed-off-by: Paul Burton <paul.burton@imgtec.com>
Cc: linux-mips@linux-mips.org
Patchwork: https://patchwork.linux-mips.org/patch/6431/
Signed-off-by: Ralf Baechle <ralf@linux-mips.org>
diff --git a/arch/mips/include/asm/asmmacro.h b/arch/mips/include/asm/asmmacro.h
index f571db3..fe3b03c 100644
--- a/arch/mips/include/asm/asmmacro.h
+++ b/arch/mips/include/asm/asmmacro.h
@@ -328,4 +328,74 @@
 	.endm
 #endif
 
+	.macro	msa_save_all	thread
+	st_d	0, THREAD_FPR0, \thread
+	st_d	1, THREAD_FPR1, \thread
+	st_d	2, THREAD_FPR2, \thread
+	st_d	3, THREAD_FPR3, \thread
+	st_d	4, THREAD_FPR4, \thread
+	st_d	5, THREAD_FPR5, \thread
+	st_d	6, THREAD_FPR6, \thread
+	st_d	7, THREAD_FPR7, \thread
+	st_d	8, THREAD_FPR8, \thread
+	st_d	9, THREAD_FPR9, \thread
+	st_d	10, THREAD_FPR10, \thread
+	st_d	11, THREAD_FPR11, \thread
+	st_d	12, THREAD_FPR12, \thread
+	st_d	13, THREAD_FPR13, \thread
+	st_d	14, THREAD_FPR14, \thread
+	st_d	15, THREAD_FPR15, \thread
+	st_d	16, THREAD_FPR16, \thread
+	st_d	17, THREAD_FPR17, \thread
+	st_d	18, THREAD_FPR18, \thread
+	st_d	19, THREAD_FPR19, \thread
+	st_d	20, THREAD_FPR20, \thread
+	st_d	21, THREAD_FPR21, \thread
+	st_d	22, THREAD_FPR22, \thread
+	st_d	23, THREAD_FPR23, \thread
+	st_d	24, THREAD_FPR24, \thread
+	st_d	25, THREAD_FPR25, \thread
+	st_d	26, THREAD_FPR26, \thread
+	st_d	27, THREAD_FPR27, \thread
+	st_d	28, THREAD_FPR28, \thread
+	st_d	29, THREAD_FPR29, \thread
+	st_d	30, THREAD_FPR30, \thread
+	st_d	31, THREAD_FPR31, \thread
+	.endm
+
+	.macro	msa_restore_all	thread
+	ld_d	0, THREAD_FPR0, \thread
+	ld_d	1, THREAD_FPR1, \thread
+	ld_d	2, THREAD_FPR2, \thread
+	ld_d	3, THREAD_FPR3, \thread
+	ld_d	4, THREAD_FPR4, \thread
+	ld_d	5, THREAD_FPR5, \thread
+	ld_d	6, THREAD_FPR6, \thread
+	ld_d	7, THREAD_FPR7, \thread
+	ld_d	8, THREAD_FPR8, \thread
+	ld_d	9, THREAD_FPR9, \thread
+	ld_d	10, THREAD_FPR10, \thread
+	ld_d	11, THREAD_FPR11, \thread
+	ld_d	12, THREAD_FPR12, \thread
+	ld_d	13, THREAD_FPR13, \thread
+	ld_d	14, THREAD_FPR14, \thread
+	ld_d	15, THREAD_FPR15, \thread
+	ld_d	16, THREAD_FPR16, \thread
+	ld_d	17, THREAD_FPR17, \thread
+	ld_d	18, THREAD_FPR18, \thread
+	ld_d	19, THREAD_FPR19, \thread
+	ld_d	20, THREAD_FPR20, \thread
+	ld_d	21, THREAD_FPR21, \thread
+	ld_d	22, THREAD_FPR22, \thread
+	ld_d	23, THREAD_FPR23, \thread
+	ld_d	24, THREAD_FPR24, \thread
+	ld_d	25, THREAD_FPR25, \thread
+	ld_d	26, THREAD_FPR26, \thread
+	ld_d	27, THREAD_FPR27, \thread
+	ld_d	28, THREAD_FPR28, \thread
+	ld_d	29, THREAD_FPR29, \thread
+	ld_d	30, THREAD_FPR30, \thread
+	ld_d	31, THREAD_FPR31, \thread
+	.endm
+
 #endif /* _ASM_ASMMACRO_H */
diff --git a/arch/mips/include/asm/msa.h b/arch/mips/include/asm/msa.h
index 0a614fb..a2aba6c 100644
--- a/arch/mips/include/asm/msa.h
+++ b/arch/mips/include/asm/msa.h
@@ -12,6 +12,9 @@
 
 #include <asm/mipsregs.h>
 
+extern void _save_msa(struct task_struct *);
+extern void _restore_msa(struct task_struct *);
+
 static inline void enable_msa(void)
 {
 	if (cpu_has_msa) {
@@ -36,6 +39,31 @@
 	return read_c0_config5() & MIPS_CONF5_MSAEN;
 }
 
+static inline int thread_msa_context_live(void)
+{
+	/*
+	 * Check cpu_has_msa only if it's a constant. This will allow the
+	 * compiler to optimise out code for CPUs without MSA without adding
+	 * an extra redundant check for CPUs with MSA.
+	 */
+	if (__builtin_constant_p(cpu_has_msa) && !cpu_has_msa)
+		return 0;
+
+	return test_thread_flag(TIF_MSA_CTX_LIVE);
+}
+
+static inline void save_msa(struct task_struct *t)
+{
+	if (cpu_has_msa)
+		_save_msa(t);
+}
+
+static inline void restore_msa(struct task_struct *t)
+{
+	if (cpu_has_msa)
+		_restore_msa(t);
+}
+
 #ifdef TOOLCHAIN_SUPPORTS_MSA
 
 #define __BUILD_MSA_CTL_REG(name, cs)				\
diff --git a/arch/mips/include/asm/processor.h b/arch/mips/include/asm/processor.h
index 50cf4c3..ad70cba 100644
--- a/arch/mips/include/asm/processor.h
+++ b/arch/mips/include/asm/processor.h
@@ -96,7 +96,12 @@
 
 
 #define NUM_FPU_REGS	32
-#define FPU_REG_WIDTH	64
+
+#ifdef CONFIG_CPU_HAS_MSA
+# define FPU_REG_WIDTH	128
+#else
+# define FPU_REG_WIDTH	64
+#endif
 
 union fpureg {
 	__u32	val32[FPU_REG_WIDTH / 32];
@@ -133,6 +138,7 @@
 struct mips_fpu_struct {
 	union fpureg	fpr[NUM_FPU_REGS];
 	unsigned int	fcr31;
+	unsigned int	msacsr;
 };
 
 #define NUM_DSP_REGS   6
@@ -310,6 +316,7 @@
 	.fpu			= {				\
 		.fpr		= {{{0,},},},			\
 		.fcr31		= 0,				\
+		.msacsr		= 0,				\
 	},							\
 	/*							\
 	 * FPU affinity state (null if not FPAFF)		\
diff --git a/arch/mips/include/asm/switch_to.h b/arch/mips/include/asm/switch_to.h
index 278d45a..495c104 100644
--- a/arch/mips/include/asm/switch_to.h
+++ b/arch/mips/include/asm/switch_to.h
@@ -16,22 +16,29 @@
 #include <asm/watch.h>
 #include <asm/dsp.h>
 #include <asm/cop2.h>
+#include <asm/msa.h>
 
 struct task_struct;
 
+enum {
+	FP_SAVE_NONE	= 0,
+	FP_SAVE_VECTOR	= -1,
+	FP_SAVE_SCALAR	= 1,
+};
+
 /**
  * resume - resume execution of a task
  * @prev:	The task previously executed.
  * @next:	The task to begin executing.
  * @next_ti:	task_thread_info(next).
- * @usedfpu:	Non-zero if prev's FP context should be saved.
+ * @fp_save:	Which, if any, FP context to save for prev.
  *
  * This function is used whilst scheduling to save the context of prev & load
  * the context of next. Returns prev.
  */
 extern asmlinkage struct task_struct *resume(struct task_struct *prev,
 		struct task_struct *next, struct thread_info *next_ti,
-		u32 usedfpu);
+		s32 fp_save);
 
 extern unsigned int ll_bit;
 extern struct task_struct *ll_task;
@@ -75,7 +82,8 @@
 
 #define switch_to(prev, next, last)					\
 do {									\
-	u32 __usedfpu, __c0_stat;					\
+	u32 __c0_stat;							\
+	s32 __fpsave = FP_SAVE_NONE;					\
 	__mips_mt_fpaff_switch_to(prev);				\
 	if (cpu_has_dsp)						\
 		__save_dsp(prev);					\
@@ -88,8 +96,12 @@
 		write_c0_status(__c0_stat & ~ST0_CU2);			\
 	}								\
 	__clear_software_ll_bit();					\
-	__usedfpu = test_and_clear_tsk_thread_flag(prev, TIF_USEDFPU);	\
-	(last) = resume(prev, next, task_thread_info(next), __usedfpu); \
+	if (test_and_clear_tsk_thread_flag(prev, TIF_USEDFPU))		\
+		__fpsave = FP_SAVE_SCALAR;				\
+	if (test_and_clear_tsk_thread_flag(prev, TIF_USEDMSA))		\
+		__fpsave = FP_SAVE_VECTOR;				\
+	(last) = resume(prev, next, task_thread_info(next), __fpsave);	\
+	disable_msa();							\
 } while (0)
 
 #define finish_arch_switch(prev)					\
diff --git a/arch/mips/include/asm/thread_info.h b/arch/mips/include/asm/thread_info.h
index e80ae50..d2d961d 100644
--- a/arch/mips/include/asm/thread_info.h
+++ b/arch/mips/include/asm/thread_info.h
@@ -116,6 +116,8 @@
 #define TIF_LOAD_WATCH		25	/* If set, load watch registers */
 #define TIF_SYSCALL_TRACEPOINT	26	/* syscall tracepoint instrumentation */
 #define TIF_32BIT_FPREGS	27	/* 32-bit floating point registers */
+#define TIF_USEDMSA		29	/* MSA has been used this quantum */
+#define TIF_MSA_CTX_LIVE	30	/* MSA context must be preserved */
 #define TIF_SYSCALL_TRACE	31	/* syscall trace active */
 
 #define _TIF_SYSCALL_TRACE	(1<<TIF_SYSCALL_TRACE)
@@ -133,6 +135,8 @@
 #define _TIF_FPUBOUND		(1<<TIF_FPUBOUND)
 #define _TIF_LOAD_WATCH		(1<<TIF_LOAD_WATCH)
 #define _TIF_32BIT_FPREGS	(1<<TIF_32BIT_FPREGS)
+#define _TIF_USEDMSA		(1<<TIF_USEDMSA)
+#define _TIF_MSA_CTX_LIVE	(1<<TIF_MSA_CTX_LIVE)
 #define _TIF_SYSCALL_TRACEPOINT	(1<<TIF_SYSCALL_TRACEPOINT)
 
 #define _TIF_WORK_SYSCALL_ENTRY	(_TIF_NOHZ | _TIF_SYSCALL_TRACE |	\