ext4 crypto: use per-inode tfm structure

As suggested by Herbert Xu, we shouldn't allocate a new tfm each time
we read or write a page.  Instead we can use a single tfm hanging off
the inode's crypt_info structure for all of our encryption needs for
that inode, since the tfm can be used by multiple crypto requests in
parallel.

Also use cmpxchg() to avoid races that could result in crypt_info
structure getting doubly allocated or doubly freed.

Signed-off-by: Theodore Ts'o <tytso@mit.edu>
diff --git a/fs/ext4/crypto.c b/fs/ext4/crypto.c
index 28a0e4bd..c3a9b08 100644
--- a/fs/ext4/crypto.c
+++ b/fs/ext4/crypto.c
@@ -80,8 +80,6 @@
 	ctx->w.bounce_page = NULL;
 	ctx->w.control_page = NULL;
 	if (ctx->flags & EXT4_CTX_REQUIRES_FREE_ENCRYPT_FL) {
-		if (ctx->tfm)
-			crypto_free_tfm(ctx->tfm);
 		kmem_cache_free(ext4_crypto_ctx_cachep, ctx);
 	} else {
 		spin_lock_irqsave(&ext4_crypto_ctx_lock, flags);
@@ -136,36 +134,6 @@
 	}
 	ctx->flags &= ~EXT4_WRITE_PATH_FL;
 
-	/* Allocate a new Crypto API context if we don't already have
-	 * one or if it isn't the right mode. */
-	if (ctx->tfm && (ctx->mode != ci->ci_data_mode)) {
-		crypto_free_tfm(ctx->tfm);
-		ctx->tfm = NULL;
-		ctx->mode = EXT4_ENCRYPTION_MODE_INVALID;
-	}
-	if (!ctx->tfm) {
-		switch (ci->ci_data_mode) {
-		case EXT4_ENCRYPTION_MODE_AES_256_XTS:
-			ctx->tfm = crypto_ablkcipher_tfm(
-				crypto_alloc_ablkcipher("xts(aes)", 0, 0));
-			break;
-		case EXT4_ENCRYPTION_MODE_AES_256_GCM:
-			/* TODO(mhalcrow): AEAD w/ gcm(aes);
-			 * crypto_aead_setauthsize() */
-			ctx->tfm = ERR_PTR(-ENOTSUPP);
-			break;
-		default:
-			BUG();
-		}
-		if (IS_ERR_OR_NULL(ctx->tfm)) {
-			res = PTR_ERR(ctx->tfm);
-			ctx->tfm = NULL;
-			goto out;
-		}
-		ctx->mode = ci->ci_data_mode;
-	}
-	BUG_ON(ci->ci_size != ext4_encryption_key_size(ci->ci_data_mode));
-
 out:
 	if (res) {
 		if (!IS_ERR_OR_NULL(ctx))
@@ -185,11 +153,8 @@
 {
 	struct ext4_crypto_ctx *pos, *n;
 
-	list_for_each_entry_safe(pos, n, &ext4_free_crypto_ctxs, free_list) {
-		if (pos->tfm)
-			crypto_free_tfm(pos->tfm);
+	list_for_each_entry_safe(pos, n, &ext4_free_crypto_ctxs, free_list)
 		kmem_cache_free(ext4_crypto_ctx_cachep, pos);
-	}
 	INIT_LIST_HEAD(&ext4_free_crypto_ctxs);
 	if (ext4_bounce_page_pool)
 		mempool_destroy(ext4_bounce_page_pool);
@@ -303,32 +268,11 @@
 	struct ablkcipher_request *req = NULL;
 	DECLARE_EXT4_COMPLETION_RESULT(ecr);
 	struct scatterlist dst, src;
-	struct ext4_inode_info *ei = EXT4_I(inode);
-	struct crypto_ablkcipher *atfm = __crypto_ablkcipher_cast(ctx->tfm);
+	struct ext4_crypt_info *ci = EXT4_I(inode)->i_crypt_info;
+	struct crypto_ablkcipher *tfm = ci->ci_ctfm;
 	int res = 0;
 
-	BUG_ON(!ctx->tfm);
-	BUG_ON(ctx->mode != ei->i_crypt_info->ci_data_mode);
-
-	if (ctx->mode != EXT4_ENCRYPTION_MODE_AES_256_XTS) {
-		printk_ratelimited(KERN_ERR
-				   "%s: unsupported crypto algorithm: %d\n",
-				   __func__, ctx->mode);
-		return -ENOTSUPP;
-	}
-
-	crypto_ablkcipher_clear_flags(atfm, ~0);
-	crypto_tfm_set_flags(ctx->tfm, CRYPTO_TFM_REQ_WEAK_KEY);
-
-	res = crypto_ablkcipher_setkey(atfm, ei->i_crypt_info->ci_raw,
-				       ei->i_crypt_info->ci_size);
-	if (res) {
-		printk_ratelimited(KERN_ERR
-				   "%s: crypto_ablkcipher_setkey() failed\n",
-				   __func__);
-		return res;
-	}
-	req = ablkcipher_request_alloc(atfm, GFP_NOFS);
+	req = ablkcipher_request_alloc(tfm, GFP_NOFS);
 	if (!req) {
 		printk_ratelimited(KERN_ERR
 				   "%s: crypto_request_alloc() failed\n",
diff --git a/fs/ext4/crypto_fname.c b/fs/ext4/crypto_fname.c
index e63dd29..29a2dc9 100644
--- a/fs/ext4/crypto_fname.c
+++ b/fs/ext4/crypto_fname.c
@@ -252,52 +252,6 @@
 	return cp - dst;
 }
 
-int ext4_setup_fname_crypto(struct inode *inode)
-{
-	struct ext4_inode_info *ei = EXT4_I(inode);
-	struct ext4_crypt_info *ci = ei->i_crypt_info;
-	struct crypto_ablkcipher *ctfm;
-	int res;
-
-	/* Check if the crypto policy is set on the inode */
-	res = ext4_encrypted_inode(inode);
-	if (res == 0)
-		return 0;
-
-	res = ext4_get_encryption_info(inode);
-	if (res < 0)
-		return res;
-	ci = ei->i_crypt_info;
-
-	if (!ci || ci->ci_ctfm)
-		return 0;
-
-	if (ci->ci_filename_mode != EXT4_ENCRYPTION_MODE_AES_256_CTS) {
-		printk_once(KERN_WARNING "ext4: unsupported key mode %d\n",
-			    ci->ci_filename_mode);
-		return -ENOKEY;
-	}
-
-	ctfm = crypto_alloc_ablkcipher("cts(cbc(aes))", 0, 0);
-	if (!ctfm || IS_ERR(ctfm)) {
-		res = ctfm ? PTR_ERR(ctfm) : -ENOMEM;
-		printk(KERN_DEBUG "%s: error (%d) allocating crypto tfm\n",
-		       __func__, res);
-		return res;
-	}
-	crypto_ablkcipher_clear_flags(ctfm, ~0);
-	crypto_tfm_set_flags(crypto_ablkcipher_tfm(ctfm),
-			     CRYPTO_TFM_REQ_WEAK_KEY);
-
-	res = crypto_ablkcipher_setkey(ctfm, ci->ci_raw, ci->ci_size);
-	if (res) {
-		crypto_free_ablkcipher(ctfm);
-		return -EIO;
-	}
-	ci->ci_ctfm = ctfm;
-	return 0;
-}
-
 /**
  * ext4_fname_crypto_round_up() -
  *
@@ -449,7 +403,7 @@
 		fname->disk_name.len = iname->len;
 		goto out;
 	}
-	ret = ext4_setup_fname_crypto(dir);
+	ret = ext4_get_encryption_info(dir);
 	if (ret)
 		return ret;
 	ci = EXT4_I(dir)->i_crypt_info;
diff --git a/fs/ext4/crypto_key.c b/fs/ext4/crypto_key.c
index 858d7d6..442d24e 100644
--- a/fs/ext4/crypto_key.c
+++ b/fs/ext4/crypto_key.c
@@ -84,20 +84,32 @@
 	return res;
 }
 
-void ext4_free_encryption_info(struct inode *inode)
+void ext4_free_crypt_info(struct ext4_crypt_info *ci)
 {
-	struct ext4_inode_info *ei = EXT4_I(inode);
-	struct ext4_crypt_info *ci = ei->i_crypt_info;
-
 	if (!ci)
 		return;
 
 	if (ci->ci_keyring_key)
 		key_put(ci->ci_keyring_key);
 	crypto_free_ablkcipher(ci->ci_ctfm);
-	memzero_explicit(&ci->ci_raw, sizeof(ci->ci_raw));
 	kmem_cache_free(ext4_crypt_info_cachep, ci);
-	ei->i_crypt_info = NULL;
+}
+
+void ext4_free_encryption_info(struct inode *inode,
+			       struct ext4_crypt_info *ci)
+{
+	struct ext4_inode_info *ei = EXT4_I(inode);
+	struct ext4_crypt_info *prev;
+
+	if (ci == NULL)
+		ci = ACCESS_ONCE(ei->i_crypt_info);
+	if (ci == NULL)
+		return;
+	prev = cmpxchg(&ei->i_crypt_info, ci, NULL);
+	if (prev != ci)
+		return;
+
+	ext4_free_crypt_info(ci);
 }
 
 int _ext4_get_encryption_info(struct inode *inode)
@@ -111,6 +123,10 @@
 	struct ext4_encryption_context ctx;
 	struct user_key_payload *ukp;
 	struct ext4_sb_info *sbi = EXT4_SB(inode->i_sb);
+	struct crypto_ablkcipher *ctfm;
+	const char *cipher_str;
+	char raw_key[EXT4_MAX_KEY_SIZE];
+	char mode;
 	int res;
 
 	if (!ext4_read_workqueue) {
@@ -119,11 +135,14 @@
 			return res;
 	}
 
-	if (ei->i_crypt_info) {
-		if (!ei->i_crypt_info->ci_keyring_key ||
-		    key_validate(ei->i_crypt_info->ci_keyring_key) == 0)
+retry:
+	crypt_info = ACCESS_ONCE(ei->i_crypt_info);
+	if (crypt_info) {
+		if (!crypt_info->ci_keyring_key ||
+		    key_validate(crypt_info->ci_keyring_key) == 0)
 			return 0;
-		ext4_free_encryption_info(inode);
+		ext4_free_encryption_info(inode, crypt_info);
+		goto retry;
 	}
 
 	res = ext4_xattr_get(inode, EXT4_XATTR_INDEX_ENCRYPTION,
@@ -144,26 +163,37 @@
 	if (!crypt_info)
 		return -ENOMEM;
 
-	ei->i_crypt_policy_flags = ctx.flags;
 	crypt_info->ci_flags = ctx.flags;
 	crypt_info->ci_data_mode = ctx.contents_encryption_mode;
 	crypt_info->ci_filename_mode = ctx.filenames_encryption_mode;
 	crypt_info->ci_ctfm = NULL;
+	crypt_info->ci_keyring_key = NULL;
 	memcpy(crypt_info->ci_master_key, ctx.master_key_descriptor,
 	       sizeof(crypt_info->ci_master_key));
 	if (S_ISREG(inode->i_mode))
-		crypt_info->ci_size =
-			ext4_encryption_key_size(crypt_info->ci_data_mode);
+		mode = crypt_info->ci_data_mode;
 	else if (S_ISDIR(inode->i_mode) || S_ISLNK(inode->i_mode))
-		crypt_info->ci_size =
-			ext4_encryption_key_size(crypt_info->ci_filename_mode);
+		mode = crypt_info->ci_filename_mode;
 	else
 		BUG();
-	BUG_ON(!crypt_info->ci_size);
-	if (DUMMY_ENCRYPTION_ENABLED(sbi)) {
-		memset(crypt_info->ci_raw, 0x42, EXT4_AES_256_XTS_KEY_SIZE);
+	switch (mode) {
+	case EXT4_ENCRYPTION_MODE_AES_256_XTS:
+		cipher_str = "xts(aes)";
+		break;
+	case EXT4_ENCRYPTION_MODE_AES_256_CTS:
+		cipher_str = "cts(cbc(aes))";
+		break;
+	default:
+		printk_once(KERN_WARNING
+			    "ext4: unsupported key mode %d (ino %u)\n",
+			    mode, (unsigned) inode->i_ino);
+		res = -ENOKEY;
 		goto out;
 	}
+	if (DUMMY_ENCRYPTION_ENABLED(sbi)) {
+		memset(raw_key, 0x42, EXT4_AES_256_XTS_KEY_SIZE);
+		goto got_key;
+	}
 	memcpy(full_key_descriptor, EXT4_KEY_DESC_PREFIX,
 	       EXT4_KEY_DESC_PREFIX_SIZE);
 	sprintf(full_key_descriptor + EXT4_KEY_DESC_PREFIX_SIZE,
@@ -177,6 +207,7 @@
 		keyring_key = NULL;
 		goto out;
 	}
+	crypt_info->ci_keyring_key = keyring_key;
 	BUG_ON(keyring_key->type != &key_type_logon);
 	ukp = ((struct user_key_payload *)keyring_key->payload.data);
 	if (ukp->datalen != sizeof(struct ext4_encryption_key)) {
@@ -188,19 +219,36 @@
 		     EXT4_KEY_DERIVATION_NONCE_SIZE);
 	BUG_ON(master_key->size != EXT4_AES_256_XTS_KEY_SIZE);
 	res = ext4_derive_key_aes(ctx.nonce, master_key->raw,
-				  crypt_info->ci_raw);
-out:
-	if (res < 0) {
-		if (res == -ENOKEY)
-			res = 0;
-		kmem_cache_free(ext4_crypt_info_cachep, crypt_info);
-	} else {
-		ei->i_crypt_info = crypt_info;
-		crypt_info->ci_keyring_key = keyring_key;
-		keyring_key = NULL;
+				  raw_key);
+got_key:
+	ctfm = crypto_alloc_ablkcipher(cipher_str, 0, 0);
+	if (!ctfm || IS_ERR(ctfm)) {
+		res = ctfm ? PTR_ERR(ctfm) : -ENOMEM;
+		printk(KERN_DEBUG
+		       "%s: error %d (inode %u) allocating crypto tfm\n",
+		       __func__, res, (unsigned) inode->i_ino);
+		goto out;
 	}
-	if (keyring_key)
-		key_put(keyring_key);
+	crypt_info->ci_ctfm = ctfm;
+	crypto_ablkcipher_clear_flags(ctfm, ~0);
+	crypto_tfm_set_flags(crypto_ablkcipher_tfm(ctfm),
+			     CRYPTO_TFM_REQ_WEAK_KEY);
+	res = crypto_ablkcipher_setkey(ctfm, raw_key,
+				       ext4_encryption_key_size(mode));
+	if (res)
+		goto out;
+	memzero_explicit(raw_key, sizeof(raw_key));
+	if (cmpxchg(&ei->i_crypt_info, NULL, crypt_info) != NULL) {
+		ext4_free_crypt_info(crypt_info);
+		goto retry;
+	}
+	return 0;
+
+out:
+	if (res == -ENOKEY)
+		res = 0;
+	ext4_free_crypt_info(crypt_info);
+	memzero_explicit(raw_key, sizeof(raw_key));
 	return res;
 }
 
diff --git a/fs/ext4/dir.c b/fs/ext4/dir.c
index 28cb94f..e11e6ae 100644
--- a/fs/ext4/dir.c
+++ b/fs/ext4/dir.c
@@ -133,9 +133,6 @@
 			return err;
 	}
 
-	err = ext4_setup_fname_crypto(inode);
-	if (err)
-		return err;
 	if (ext4_encrypted_inode(inode)) {
 		err = ext4_fname_crypto_alloc_buffer(inode, EXT4_NAME_LEN,
 						     &fname_crypto_str);
diff --git a/fs/ext4/ext4.h b/fs/ext4/ext4.h
index 23e33fb..7435ff2 100644
--- a/fs/ext4/ext4.h
+++ b/fs/ext4/ext4.h
@@ -911,7 +911,6 @@
 
 	/* on-disk additional length */
 	__u16 i_extra_isize;
-	char i_crypt_policy_flags;
 
 	/* Indicate the inline data space. */
 	u16 i_inline_off;
@@ -2105,7 +2104,6 @@
 			   const struct qstr *iname,
 			   struct ext4_str *oname);
 #ifdef CONFIG_EXT4_FS_ENCRYPTION
-int ext4_setup_fname_crypto(struct inode *inode);
 void ext4_fname_crypto_free_buffer(struct ext4_str *crypto_str);
 int ext4_fname_setup_filename(struct inode *dir, const struct qstr *iname,
 			      int lookup, struct ext4_filename *fname);
@@ -2131,7 +2129,8 @@
 
 
 /* crypto_key.c */
-void ext4_free_encryption_info(struct inode *inode);
+void ext4_free_crypt_info(struct ext4_crypt_info *ci);
+void ext4_free_encryption_info(struct inode *inode, struct ext4_crypt_info *ci);
 int _ext4_get_encryption_info(struct inode *inode);
 
 #ifdef CONFIG_EXT4_FS_ENCRYPTION
diff --git a/fs/ext4/ext4_crypto.h b/fs/ext4/ext4_crypto.h
index c5258f2..34e0d24 100644
--- a/fs/ext4/ext4_crypto.h
+++ b/fs/ext4/ext4_crypto.h
@@ -74,13 +74,11 @@
 } __attribute__((__packed__));
 
 struct ext4_crypt_info {
-	unsigned char	ci_size;
 	char		ci_data_mode;
 	char		ci_filename_mode;
 	char		ci_flags;
 	struct crypto_ablkcipher *ci_ctfm;
 	struct key	*ci_keyring_key;
-	char		ci_raw[EXT4_MAX_KEY_SIZE];
 	char		ci_master_key[EXT4_KEY_DESCRIPTOR_SIZE];
 };
 
@@ -89,7 +87,6 @@
 #define EXT4_WRITE_PATH_FL			      0x00000004
 
 struct ext4_crypto_ctx {
-	struct crypto_tfm *tfm;         /* Crypto API context */
 	union {
 		struct {
 			struct page *bounce_page;       /* Ciphertext page */
diff --git a/fs/ext4/namei.c b/fs/ext4/namei.c
index 9bed99f..6ab50f8 100644
--- a/fs/ext4/namei.c
+++ b/fs/ext4/namei.c
@@ -607,11 +607,12 @@
 				char *name;
 				struct ext4_str fname_crypto_str
 					= {.name = NULL, .len = 0};
-				int res;
+				int res = 0;
 
 				name  = de->name;
 				len = de->name_len;
-				res = ext4_setup_fname_crypto(dir);
+				if (ext4_encrypted_inode(inode))
+					res = ext4_get_encryption_info(dir);
 				if (res) {
 					printk(KERN_WARNING "Error setting up"
 					       " fname crypto: %d\n", res);
@@ -953,12 +954,12 @@
 					   EXT4_DIR_REC_LEN(0));
 #ifdef CONFIG_EXT4_FS_ENCRYPTION
 	/* Check if the directory is encrypted */
-	err = ext4_setup_fname_crypto(dir);
-	if (err) {
-		brelse(bh);
-		return err;
-	}
 	if (ext4_encrypted_inode(dir)) {
+		err = ext4_get_encryption_info(dir);
+		if (err < 0) {
+			brelse(bh);
+			return err;
+		}
 		err = ext4_fname_crypto_alloc_buffer(dir, EXT4_NAME_LEN,
 						     &fname_crypto_str);
 		if (err < 0) {
@@ -3108,7 +3109,7 @@
 		err = ext4_inherit_context(dir, inode);
 		if (err)
 			goto err_drop_inode;
-		err = ext4_setup_fname_crypto(inode);
+		err = ext4_get_encryption_info(inode);
 		if (err)
 			goto err_drop_inode;
 		istr.name = (const unsigned char *) symname;
diff --git a/fs/ext4/super.c b/fs/ext4/super.c
index b0bd1c1..56bfc2f 100644
--- a/fs/ext4/super.c
+++ b/fs/ext4/super.c
@@ -959,7 +959,7 @@
 	}
 #ifdef CONFIG_EXT4_FS_ENCRYPTION
 	if (EXT4_I(inode)->i_crypt_info)
-		ext4_free_encryption_info(inode);
+		ext4_free_encryption_info(inode, EXT4_I(inode)->i_crypt_info);
 #endif
 }
 
diff --git a/fs/ext4/symlink.c b/fs/ext4/symlink.c
index 3287088..68e915a 100644
--- a/fs/ext4/symlink.c
+++ b/fs/ext4/symlink.c
@@ -37,7 +37,7 @@
 	if (!ext4_encrypted_inode(inode))
 		return page_follow_link_light(dentry, nd);
 
-	res = ext4_setup_fname_crypto(inode);
+	res = ext4_get_encryption_info(inode);
 	if (res)
 		return ERR_PTR(res);