random: add backtracking protection to the CRNG

Signed-off-by: Theodore Ts'o <tytso@mit.edu>
diff --git a/drivers/char/random.c b/drivers/char/random.c
index 2a30d97..783dee1 100644
--- a/drivers/char/random.c
+++ b/drivers/char/random.c
@@ -438,7 +438,8 @@
 #define CRNG_INIT_CNT_THRESH (2*CHACHA20_KEY_SIZE)
 static void _extract_crng(struct crng_state *crng,
 			  __u8 out[CHACHA20_BLOCK_SIZE]);
-static void extract_crng(__u8 out[CHACHA20_BLOCK_SIZE]);
+static void _crng_backtrack_protect(struct crng_state *crng,
+				    __u8 tmp[CHACHA20_BLOCK_SIZE], int used);
 static void process_random_ready_list(void);
 
 /**********************************************************************
@@ -826,8 +827,11 @@
 		num = extract_entropy(r, &buf, 32, 16, 0);
 		if (num == 0)
 			return;
-	} else
+	} else {
 		_extract_crng(&primary_crng, buf.block);
+		_crng_backtrack_protect(&primary_crng, buf.block,
+					CHACHA20_KEY_SIZE);
+	}
 	spin_lock_irqsave(&primary_crng.lock, flags);
 	for (i = 0; i < 8; i++) {
 		unsigned long	rv;
@@ -889,9 +893,46 @@
 	_extract_crng(crng, out);
 }
 
+/*
+ * Use the leftover bytes from the CRNG block output (if there is
+ * enough) to mutate the CRNG key to provide backtracking protection.
+ */
+static void _crng_backtrack_protect(struct crng_state *crng,
+				    __u8 tmp[CHACHA20_BLOCK_SIZE], int used)
+{
+	unsigned long	flags;
+	__u32		*s, *d;
+	int		i;
+
+	used = round_up(used, sizeof(__u32));
+	if (used + CHACHA20_KEY_SIZE > CHACHA20_BLOCK_SIZE) {
+		extract_crng(tmp);
+		used = 0;
+	}
+	spin_lock_irqsave(&crng->lock, flags);
+	s = (__u32 *) &tmp[used];
+	d = &crng->state[4];
+	for (i=0; i < 8; i++)
+		*d++ ^= *s++;
+	spin_unlock_irqrestore(&crng->lock, flags);
+}
+
+static void crng_backtrack_protect(__u8 tmp[CHACHA20_BLOCK_SIZE], int used)
+{
+	struct crng_state *crng = NULL;
+
+#ifdef CONFIG_NUMA
+	if (crng_node_pool)
+		crng = crng_node_pool[numa_node_id()];
+	if (crng == NULL)
+#endif
+		crng = &primary_crng;
+	_crng_backtrack_protect(crng, tmp, used);
+}
+
 static ssize_t extract_crng_user(void __user *buf, size_t nbytes)
 {
-	ssize_t ret = 0, i;
+	ssize_t ret = 0, i = CHACHA20_BLOCK_SIZE;
 	__u8 tmp[CHACHA20_BLOCK_SIZE];
 	int large_request = (nbytes > 256);
 
@@ -916,6 +957,7 @@
 		buf += i;
 		ret += i;
 	}
+	crng_backtrack_protect(tmp, i);
 
 	/* Wipe data just written to memory */
 	memzero_explicit(tmp, sizeof(tmp));
@@ -1473,8 +1515,10 @@
 	if (nbytes > 0) {
 		extract_crng(tmp);
 		memcpy(buf, tmp, nbytes);
-		memzero_explicit(tmp, nbytes);
-	}
+		crng_backtrack_protect(tmp, nbytes);
+	} else
+		crng_backtrack_protect(tmp, CHACHA20_BLOCK_SIZE);
+	memzero_explicit(tmp, sizeof(tmp));
 }
 EXPORT_SYMBOL(get_random_bytes);