ext4 crypto: use per-inode tfm structure

As suggested by Herbert Xu, we shouldn't allocate a new tfm each time
we read or write a page.  Instead we can use a single tfm hanging off
the inode's crypt_info structure for all of our encryption needs for
that inode, since the tfm can be used by multiple crypto requests in
parallel.

Also use cmpxchg() to avoid races that could result in crypt_info
structure getting doubly allocated or doubly freed.

Signed-off-by: Theodore Ts'o <tytso@mit.edu>
diff --git a/fs/ext4/crypto_key.c b/fs/ext4/crypto_key.c
index 858d7d6..442d24e 100644
--- a/fs/ext4/crypto_key.c
+++ b/fs/ext4/crypto_key.c
@@ -84,20 +84,32 @@
 	return res;
 }
 
-void ext4_free_encryption_info(struct inode *inode)
+void ext4_free_crypt_info(struct ext4_crypt_info *ci)
 {
-	struct ext4_inode_info *ei = EXT4_I(inode);
-	struct ext4_crypt_info *ci = ei->i_crypt_info;
-
 	if (!ci)
 		return;
 
 	if (ci->ci_keyring_key)
 		key_put(ci->ci_keyring_key);
 	crypto_free_ablkcipher(ci->ci_ctfm);
-	memzero_explicit(&ci->ci_raw, sizeof(ci->ci_raw));
 	kmem_cache_free(ext4_crypt_info_cachep, ci);
-	ei->i_crypt_info = NULL;
+}
+
+void ext4_free_encryption_info(struct inode *inode,
+			       struct ext4_crypt_info *ci)
+{
+	struct ext4_inode_info *ei = EXT4_I(inode);
+	struct ext4_crypt_info *prev;
+
+	if (ci == NULL)
+		ci = ACCESS_ONCE(ei->i_crypt_info);
+	if (ci == NULL)
+		return;
+	prev = cmpxchg(&ei->i_crypt_info, ci, NULL);
+	if (prev != ci)
+		return;
+
+	ext4_free_crypt_info(ci);
 }
 
 int _ext4_get_encryption_info(struct inode *inode)
@@ -111,6 +123,10 @@
 	struct ext4_encryption_context ctx;
 	struct user_key_payload *ukp;
 	struct ext4_sb_info *sbi = EXT4_SB(inode->i_sb);
+	struct crypto_ablkcipher *ctfm;
+	const char *cipher_str;
+	char raw_key[EXT4_MAX_KEY_SIZE];
+	char mode;
 	int res;
 
 	if (!ext4_read_workqueue) {
@@ -119,11 +135,14 @@
 			return res;
 	}
 
-	if (ei->i_crypt_info) {
-		if (!ei->i_crypt_info->ci_keyring_key ||
-		    key_validate(ei->i_crypt_info->ci_keyring_key) == 0)
+retry:
+	crypt_info = ACCESS_ONCE(ei->i_crypt_info);
+	if (crypt_info) {
+		if (!crypt_info->ci_keyring_key ||
+		    key_validate(crypt_info->ci_keyring_key) == 0)
 			return 0;
-		ext4_free_encryption_info(inode);
+		ext4_free_encryption_info(inode, crypt_info);
+		goto retry;
 	}
 
 	res = ext4_xattr_get(inode, EXT4_XATTR_INDEX_ENCRYPTION,
@@ -144,26 +163,37 @@
 	if (!crypt_info)
 		return -ENOMEM;
 
-	ei->i_crypt_policy_flags = ctx.flags;
 	crypt_info->ci_flags = ctx.flags;
 	crypt_info->ci_data_mode = ctx.contents_encryption_mode;
 	crypt_info->ci_filename_mode = ctx.filenames_encryption_mode;
 	crypt_info->ci_ctfm = NULL;
+	crypt_info->ci_keyring_key = NULL;
 	memcpy(crypt_info->ci_master_key, ctx.master_key_descriptor,
 	       sizeof(crypt_info->ci_master_key));
 	if (S_ISREG(inode->i_mode))
-		crypt_info->ci_size =
-			ext4_encryption_key_size(crypt_info->ci_data_mode);
+		mode = crypt_info->ci_data_mode;
 	else if (S_ISDIR(inode->i_mode) || S_ISLNK(inode->i_mode))
-		crypt_info->ci_size =
-			ext4_encryption_key_size(crypt_info->ci_filename_mode);
+		mode = crypt_info->ci_filename_mode;
 	else
 		BUG();
-	BUG_ON(!crypt_info->ci_size);
-	if (DUMMY_ENCRYPTION_ENABLED(sbi)) {
-		memset(crypt_info->ci_raw, 0x42, EXT4_AES_256_XTS_KEY_SIZE);
+	switch (mode) {
+	case EXT4_ENCRYPTION_MODE_AES_256_XTS:
+		cipher_str = "xts(aes)";
+		break;
+	case EXT4_ENCRYPTION_MODE_AES_256_CTS:
+		cipher_str = "cts(cbc(aes))";
+		break;
+	default:
+		printk_once(KERN_WARNING
+			    "ext4: unsupported key mode %d (ino %u)\n",
+			    mode, (unsigned) inode->i_ino);
+		res = -ENOKEY;
 		goto out;
 	}
+	if (DUMMY_ENCRYPTION_ENABLED(sbi)) {
+		memset(raw_key, 0x42, EXT4_AES_256_XTS_KEY_SIZE);
+		goto got_key;
+	}
 	memcpy(full_key_descriptor, EXT4_KEY_DESC_PREFIX,
 	       EXT4_KEY_DESC_PREFIX_SIZE);
 	sprintf(full_key_descriptor + EXT4_KEY_DESC_PREFIX_SIZE,
@@ -177,6 +207,7 @@
 		keyring_key = NULL;
 		goto out;
 	}
+	crypt_info->ci_keyring_key = keyring_key;
 	BUG_ON(keyring_key->type != &key_type_logon);
 	ukp = ((struct user_key_payload *)keyring_key->payload.data);
 	if (ukp->datalen != sizeof(struct ext4_encryption_key)) {
@@ -188,19 +219,36 @@
 		     EXT4_KEY_DERIVATION_NONCE_SIZE);
 	BUG_ON(master_key->size != EXT4_AES_256_XTS_KEY_SIZE);
 	res = ext4_derive_key_aes(ctx.nonce, master_key->raw,
-				  crypt_info->ci_raw);
-out:
-	if (res < 0) {
-		if (res == -ENOKEY)
-			res = 0;
-		kmem_cache_free(ext4_crypt_info_cachep, crypt_info);
-	} else {
-		ei->i_crypt_info = crypt_info;
-		crypt_info->ci_keyring_key = keyring_key;
-		keyring_key = NULL;
+				  raw_key);
+got_key:
+	ctfm = crypto_alloc_ablkcipher(cipher_str, 0, 0);
+	if (!ctfm || IS_ERR(ctfm)) {
+		res = ctfm ? PTR_ERR(ctfm) : -ENOMEM;
+		printk(KERN_DEBUG
+		       "%s: error %d (inode %u) allocating crypto tfm\n",
+		       __func__, res, (unsigned) inode->i_ino);
+		goto out;
 	}
-	if (keyring_key)
-		key_put(keyring_key);
+	crypt_info->ci_ctfm = ctfm;
+	crypto_ablkcipher_clear_flags(ctfm, ~0);
+	crypto_tfm_set_flags(crypto_ablkcipher_tfm(ctfm),
+			     CRYPTO_TFM_REQ_WEAK_KEY);
+	res = crypto_ablkcipher_setkey(ctfm, raw_key,
+				       ext4_encryption_key_size(mode));
+	if (res)
+		goto out;
+	memzero_explicit(raw_key, sizeof(raw_key));
+	if (cmpxchg(&ei->i_crypt_info, NULL, crypt_info) != NULL) {
+		ext4_free_crypt_info(crypt_info);
+		goto retry;
+	}
+	return 0;
+
+out:
+	if (res == -ENOKEY)
+		res = 0;
+	ext4_free_crypt_info(crypt_info);
+	memzero_explicit(raw_key, sizeof(raw_key));
 	return res;
 }