Ensure that a BlockCipher can only be used for one operation
This prevents trying to call encrypt() and then decrypt() on a
block cipher. It also enables finalize() to know what type of
finalization to call.
diff --git a/cryptography/primitives/block/base.py b/cryptography/primitives/block/base.py
index 417b1ad..2a6a5c3 100644
--- a/cryptography/primitives/block/base.py
+++ b/cryptography/primitives/block/base.py
@@ -21,16 +21,29 @@
self.cipher = cipher
self.mode = mode
self._ctx = api.create_block_cipher_context(cipher, mode)
+ self._operation = None
def encrypt(self, plaintext):
if self._ctx is None:
raise ValueError("BlockCipher was already finalized")
+
+ if self._operation is None:
+ self._operation = "encrypt"
+ elif self._operation != "encrypt":
+ raise ValueError("BlockCipher cannot encrypt when the operation is"
+ " set to %s" % self._operation)
+
return api.update_encrypt_context(self._ctx, plaintext)
def finalize(self):
if self._ctx is None:
raise ValueError("BlockCipher was already finalized")
- # TODO: this might be a decrypt context
- result = api.finalize_encrypt_context(self._ctx)
+
+ if self._operation == "encrypt":
+ result = api.finalize_encrypt_context(self._ctx)
+ else:
+ raise ValueError("BlockCipher cannot finalize the unknown "
+ "operation %s" % self._operation)
+
self._ctx = None
return result
diff --git a/tests/primitives/test_block.py b/tests/primitives/test_block.py
index f569343..7dccda4 100644
--- a/tests/primitives/test_block.py
+++ b/tests/primitives/test_block.py
@@ -30,3 +30,23 @@
cipher.encrypt(b"b" * 16)
with pytest.raises(ValueError):
cipher.finalize()
+
+ def test_encrypt_with_invalid_operation(self):
+ cipher = BlockCipher(
+ ciphers.AES(binascii.unhexlify(b"0" * 32)),
+ modes.CBC(binascii.unhexlify(b"0" * 32), padding.NoPadding())
+ )
+ cipher._operation = "decrypt"
+
+ with pytest.raises(ValueError):
+ cipher.encrypt(b"b" * 16)
+
+ def test_finalize_with_invalid_operation(self):
+ cipher = BlockCipher(
+ ciphers.AES(binascii.unhexlify(b"0" * 32)),
+ modes.CBC(binascii.unhexlify(b"0" * 32), padding.NoPadding())
+ )
+ cipher._operation = "wat"
+
+ with pytest.raises(ValueError):
+ cipher.encrypt(b"b" * 16)