Merge pull request #112 from reaperhulk/block-cipher-decrypt

Block Cipher Decryption
diff --git a/.coveragerc b/.coveragerc
index 398ff08..b891cb7 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -1,2 +1,6 @@
 [run]
 branch = True
+
+[report]
+exclude_lines =
+    @abc.abstractmethod
diff --git a/cryptography/bindings/openssl/api.py b/cryptography/bindings/openssl/api.py
index 3d92c14..fedaf9c 100644
--- a/cryptography/bindings/openssl/api.py
+++ b/cryptography/bindings/openssl/api.py
@@ -142,7 +142,27 @@
                 GetCipherByName("des-ede3-{mode.name}")
             )
 
-    def create_block_cipher_context(self, cipher, mode):
+    def create_block_cipher_encrypt_context(self, cipher, mode):
+        ctx, evp, iv_nonce = self._create_block_cipher_context(cipher, mode)
+        res = self.lib.EVP_EncryptInit_ex(ctx, evp, api.ffi.NULL, cipher.key,
+                                          iv_nonce)
+        assert res != 0
+        # We purposely disable padding here as it's handled higher up in the
+        # API.
+        self.lib.EVP_CIPHER_CTX_set_padding(ctx, 0)
+        return ctx
+
+    def create_block_cipher_decrypt_context(self, cipher, mode):
+        ctx, evp, iv_nonce = self._create_block_cipher_context(cipher, mode)
+        res = self.lib.EVP_DecryptInit_ex(ctx, evp, api.ffi.NULL, cipher.key,
+                                          iv_nonce)
+        assert res != 0
+        # We purposely disable padding here as it's handled higher up in the
+        # API.
+        self.lib.EVP_CIPHER_CTX_set_padding(ctx, 0)
+        return ctx
+
+    def _create_block_cipher_context(self, cipher, mode):
         ctx = self.lib.EVP_CIPHER_CTX_new()
         ctx = self.ffi.gc(ctx, self.lib.EVP_CIPHER_CTX_free)
         evp_cipher = self._cipher_registry[type(cipher), type(mode)](
@@ -156,24 +176,21 @@
         else:
             iv_nonce = self.ffi.NULL
 
-        # TODO: Sometimes this needs to be a DecryptInit, when?
-        res = self.lib.EVP_EncryptInit_ex(
-            ctx, evp_cipher, self.ffi.NULL, cipher.key, iv_nonce
-        )
-        assert res != 0
+        return (ctx, evp_cipher, iv_nonce)
 
-        # We purposely disable padding here as it's handled higher up in the
-        # API.
-        self.lib.EVP_CIPHER_CTX_set_padding(ctx, 0)
-        return ctx
-
-    def update_encrypt_context(self, ctx, plaintext):
+    def update_encrypt_context(self, ctx, data):
         block_size = self.lib.EVP_CIPHER_CTX_block_size(ctx)
-        buf = self.ffi.new("unsigned char[]", len(plaintext) + block_size - 1)
+        buf = self.ffi.new("unsigned char[]", len(data) + block_size - 1)
         outlen = self.ffi.new("int *")
-        res = self.lib.EVP_EncryptUpdate(
-            ctx, buf, outlen, plaintext, len(plaintext)
-        )
+        res = self.lib.EVP_EncryptUpdate(ctx, buf, outlen, data, len(data))
+        assert res != 0
+        return self.ffi.buffer(buf)[:outlen[0]]
+
+    def update_decrypt_context(self, ctx, data):
+        block_size = self.lib.EVP_CIPHER_CTX_block_size(ctx)
+        buf = self.ffi.new("unsigned char[]", len(data) + block_size - 1)
+        outlen = self.ffi.new("int *")
+        res = self.lib.EVP_DecryptUpdate(ctx, buf, outlen, data, len(data))
         assert res != 0
         return self.ffi.buffer(buf)[:outlen[0]]
 
@@ -187,6 +204,16 @@
         assert res == 1
         return self.ffi.buffer(buf)[:outlen[0]]
 
+    def finalize_decrypt_context(self, ctx):
+        block_size = self.lib.EVP_CIPHER_CTX_block_size(ctx)
+        buf = self.ffi.new("unsigned char[]", block_size)
+        outlen = self.ffi.new("int *")
+        res = self.lib.EVP_DecryptFinal_ex(ctx, buf, outlen)
+        assert res != 0
+        res = self.lib.EVP_CIPHER_CTX_cleanup(ctx)
+        assert res == 1
+        return self.ffi.buffer(buf)[:outlen[0]]
+
     def supports_hash(self, hash_cls):
         return (self.ffi.NULL !=
                 self.lib.EVP_get_digestbyname(hash_cls.name.encode("ascii")))
diff --git a/cryptography/bindings/openssl/evp.py b/cryptography/bindings/openssl/evp.py
index 2bb5b0f..41df105 100644
--- a/cryptography/bindings/openssl/evp.py
+++ b/cryptography/bindings/openssl/evp.py
@@ -41,6 +41,11 @@
 int EVP_EncryptUpdate(EVP_CIPHER_CTX *, unsigned char *, int *,
                       const unsigned char *, int);
 int EVP_EncryptFinal_ex(EVP_CIPHER_CTX *, unsigned char *, int *);
+int EVP_DecryptInit_ex(EVP_CIPHER_CTX *, const EVP_CIPHER *, ENGINE *,
+                       const unsigned char *, const unsigned char *);
+int EVP_DecryptUpdate(EVP_CIPHER_CTX *, unsigned char *, int *,
+                      const unsigned char *, int);
+int EVP_DecryptFinal_ex(EVP_CIPHER_CTX *, unsigned char *, int *);
 int EVP_CIPHER_CTX_cleanup(EVP_CIPHER_CTX *);
 const EVP_CIPHER *EVP_CIPHER_CTX_cipher(const EVP_CIPHER_CTX *);
 int EVP_CIPHER_block_size(const EVP_CIPHER *);
diff --git a/cryptography/primitives/block/base.py b/cryptography/primitives/block/base.py
index 42c1f79..12b6f62 100644
--- a/cryptography/primitives/block/base.py
+++ b/cryptography/primitives/block/base.py
@@ -13,12 +13,7 @@
 
 from __future__ import absolute_import, division, print_function
 
-from enum import Enum
-
-
-class _Operation(Enum):
-    encrypt = 0
-    decrypt = 1
+from cryptography.primitives import interfaces
 
 
 class BlockCipher(object):
@@ -31,30 +26,49 @@
         self.cipher = cipher
         self.mode = mode
         self._api = api
-        self._ctx = api.create_block_cipher_context(cipher, mode)
-        self._operation = None
 
-    def encrypt(self, plaintext):
+    def encryptor(self):
+        return _CipherEncryptionContext(self.cipher, self.mode, self._api)
+
+    def decryptor(self):
+        return _CipherDecryptionContext(self.cipher, self.mode, self._api)
+
+
+@interfaces.register(interfaces.CipherContext)
+class _CipherEncryptionContext(object):
+    def __init__(self, cipher, mode, api):
+        super(_CipherEncryptionContext, self).__init__()
+        self._api = api
+        self._ctx = self._api.create_block_cipher_encrypt_context(cipher, mode)
+
+    def update(self, data):
         if self._ctx is None:
-            raise ValueError("BlockCipher was already finalized")
-
-        if self._operation is None:
-            self._operation = _Operation.encrypt
-        elif self._operation is not _Operation.encrypt:
-            raise ValueError("BlockCipher cannot encrypt when the operation is"
-                             " set to %s" % self._operation.name)
-
-        return self._api.update_encrypt_context(self._ctx, plaintext)
+            raise ValueError("Context was already finalized")
+        return self._api.update_encrypt_context(self._ctx, data)
 
     def finalize(self):
         if self._ctx is None:
-            raise ValueError("BlockCipher was already finalized")
-
-        if self._operation is _Operation.encrypt:
-            result = self._api.finalize_encrypt_context(self._ctx)
-        else:
-            raise ValueError("BlockCipher cannot finalize the unknown "
-                             "operation %s" % self._operation.name)
-
+            raise ValueError("Context was already finalized")
+        data = self._api.finalize_encrypt_context(self._ctx)
         self._ctx = None
-        return result
+        return data
+
+
+@interfaces.register(interfaces.CipherContext)
+class _CipherDecryptionContext(object):
+    def __init__(self, cipher, mode, api):
+        super(_CipherDecryptionContext, self).__init__()
+        self._api = api
+        self._ctx = self._api.create_block_cipher_decrypt_context(cipher, mode)
+
+    def update(self, data):
+        if self._ctx is None:
+            raise ValueError("Context was already finalized")
+        return self._api.update_decrypt_context(self._ctx, data)
+
+    def finalize(self):
+        if self._ctx is None:
+            raise ValueError("Context was already finalized")
+        data = self._api.finalize_decrypt_context(self._ctx)
+        self._ctx = None
+        return data
diff --git a/cryptography/primitives/block/modes.py b/cryptography/primitives/block/modes.py
index 4363180..a933c18 100644
--- a/cryptography/primitives/block/modes.py
+++ b/cryptography/primitives/block/modes.py
@@ -16,14 +16,7 @@
 from cryptography.primitives import interfaces
 
 
-def register(iface):
-    def register_decorator(klass):
-        iface.register(klass)
-        return klass
-    return register_decorator
-
-
-@register(interfaces.ModeWithInitializationVector)
+@interfaces.register(interfaces.ModeWithInitializationVector)
 class CBC(object):
     name = "CBC"
 
@@ -36,7 +29,7 @@
     name = "ECB"
 
 
-@register(interfaces.ModeWithInitializationVector)
+@interfaces.register(interfaces.ModeWithInitializationVector)
 class OFB(object):
     name = "OFB"
 
@@ -45,7 +38,7 @@
         self.initialization_vector = initialization_vector
 
 
-@register(interfaces.ModeWithInitializationVector)
+@interfaces.register(interfaces.ModeWithInitializationVector)
 class CFB(object):
     name = "CFB"
 
@@ -54,7 +47,7 @@
         self.initialization_vector = initialization_vector
 
 
-@register(interfaces.ModeWithNonce)
+@interfaces.register(interfaces.ModeWithNonce)
 class CTR(object):
     name = "CTR"
 
diff --git a/cryptography/primitives/interfaces.py b/cryptography/primitives/interfaces.py
index c1fc991..49c19d0 100644
--- a/cryptography/primitives/interfaces.py
+++ b/cryptography/primitives/interfaces.py
@@ -18,9 +18,30 @@
 import six
 
 
+def register(iface):
+    def register_decorator(klass):
+        iface.register(klass)
+        return klass
+    return register_decorator
+
+
 class ModeWithInitializationVector(six.with_metaclass(abc.ABCMeta)):
     pass
 
 
 class ModeWithNonce(six.with_metaclass(abc.ABCMeta)):
     pass
+
+
+class CipherContext(six.with_metaclass(abc.ABCMeta)):
+    @abc.abstractmethod
+    def update(self, data):
+        """
+        update takes bytes and return bytes
+        """
+
+    @abc.abstractmethod
+    def finalize(self):
+        """
+        finalize return bytes
+        """
diff --git a/docs/primitives/symmetric-encryption.rst b/docs/primitives/symmetric-encryption.rst
index 96bd68f..73d8ad3 100644
--- a/docs/primitives/symmetric-encryption.rst
+++ b/docs/primitives/symmetric-encryption.rst
@@ -15,29 +15,47 @@
 
     Block ciphers work by encrypting content in chunks, often 64- or 128-bits.
     They combine an underlying algorithm (such as AES), with a mode (such as
-    CBC, CTR, or GCM). A simple example of encrypting content with AES is:
+    CBC, CTR, or GCM). A simple example of encrypting (and then decrypting)
+    content with AES is:
 
     .. doctest::
 
         >>> from cryptography.primitives.block import BlockCipher, ciphers, modes
         >>> cipher = BlockCipher(ciphers.AES(key), modes.CBC(iv))
-        >>> cipher.encrypt(b"a secret message") + cipher.finalize()
-        '...'
+        >>> encryptor = cipher.encryptor()
+        >>> ct = encryptor.update(b"a secret message") + encryptor.finalize()
+        >>> decryptor = cipher.decryptor()
+        >>> decryptor.update(ct) + decryptor.finalize()
+        'a secret message'
 
     :param cipher: One of the ciphers described below.
     :param mode: One of the modes described below.
 
-    ``encrypt()`` should be called repeatedly with new plaintext, and once the
-    full plaintext is fed in, ``finalize()`` should be called.
+    .. method:: encryptor()
 
-    .. method:: encrypt(plaintext)
+        :return :class:`CipherContext`: encryption instance
 
-        :param bytes plaintext: The text you wish to encrypt.
-        :return bytes: Returns the ciphertext that was added.
+    .. method:: decryptor()
+
+        :return :class:`CipherContext`: decryption instance
+
+.. class:: cryptography.primitives.interfaces.CipherContext()
+
+    When calling ``encryptor()`` or ``decryptor()`` on a BlockCipher object you
+    will receive a return object conforming to the CipherContext interface. You
+    can then call ``update(data)`` with data until you have fed everything into
+    the context. Once that is done call ``finalize()`` to finish the operation and
+    obtain the remainder of the data.
+
+
+    .. method:: update(data)
+
+        :param bytes data: The text you wish to pass into the context.
+        :return bytes: Returns the data that was encrypted or decrypted.
 
     .. method:: finalize()
 
-        :return bytes: Returns the remainder of the ciphertext.
+        :return bytes: Returns the remainder of the data.
 
 Ciphers
 ~~~~~~~
diff --git a/setup.py b/setup.py
index cbbf100..1856cad 100644
--- a/setup.py
+++ b/setup.py
@@ -10,8 +10,6 @@
 # implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import sys
-
 from setuptools import setup, find_packages
 
 
@@ -32,9 +30,6 @@
     CFFI_DEPENDENCY,
 ]
 
-if sys.version_info[:2] < (3, 4):
-    install_requires += ["enum34"]
-
 setup(
     name=about["__title__"],
     version=about["__version__"],
diff --git a/tests/primitives/test_block.py b/tests/primitives/test_block.py
index 9f5905b..5e147a7 100644
--- a/tests/primitives/test_block.py
+++ b/tests/primitives/test_block.py
@@ -15,11 +15,10 @@
 
 import binascii
 
-import pretend
 import pytest
 
+from cryptography.primitives import interfaces
 from cryptography.primitives.block import BlockCipher, ciphers, modes
-from cryptography.primitives.block.base import _Operation
 
 
 class TestBlockCipher(object):
@@ -29,40 +28,42 @@
             modes.CBC(binascii.unhexlify(b"0" * 32))
         )
 
+    def test_creates_encryptor(self):
+        cipher = BlockCipher(
+            ciphers.AES(binascii.unhexlify(b"0" * 32)),
+            modes.CBC(binascii.unhexlify(b"0" * 32))
+        )
+        assert isinstance(cipher.encryptor(), interfaces.CipherContext)
+
+    def test_creates_decryptor(self):
+        cipher = BlockCipher(
+            ciphers.AES(binascii.unhexlify(b"0" * 32)),
+            modes.CBC(binascii.unhexlify(b"0" * 32))
+        )
+        assert isinstance(cipher.decryptor(), interfaces.CipherContext)
+
+
+class TestBlockCipherContext(object):
     def test_use_after_finalize(self, api):
         cipher = BlockCipher(
             ciphers.AES(binascii.unhexlify(b"0" * 32)),
             modes.CBC(binascii.unhexlify(b"0" * 32)),
             api
         )
-        cipher.encrypt(b"a" * 16)
-        cipher.finalize()
+        encryptor = cipher.encryptor()
+        encryptor.update(b"a" * 16)
+        encryptor.finalize()
         with pytest.raises(ValueError):
-            cipher.encrypt(b"b" * 16)
+            encryptor.update(b"b" * 16)
         with pytest.raises(ValueError):
-            cipher.finalize()
-
-    def test_encrypt_with_invalid_operation(self, api):
-        cipher = BlockCipher(
-            ciphers.AES(binascii.unhexlify(b"0" * 32)),
-            modes.CBC(binascii.unhexlify(b"0" * 32)),
-            api
-        )
-        cipher._operation = _Operation.decrypt
-
+            encryptor.finalize()
+        decryptor = cipher.decryptor()
+        decryptor.update(b"a" * 16)
+        decryptor.finalize()
         with pytest.raises(ValueError):
-            cipher.encrypt(b"b" * 16)
-
-    def test_finalize_with_invalid_operation(self, api):
-        cipher = BlockCipher(
-            ciphers.AES(binascii.unhexlify(b"0" * 32)),
-            modes.CBC(binascii.unhexlify(b"0" * 32)),
-            api
-        )
-        cipher._operation = pretend.stub(name="wat")
-
+            decryptor.update(b"b" * 16)
         with pytest.raises(ValueError):
-            cipher.finalize()
+            decryptor.finalize()
 
     def test_unaligned_block_encryption(self, api):
         cipher = BlockCipher(
@@ -70,7 +71,16 @@
             modes.ECB(),
             api
         )
-        ct = cipher.encrypt(b"a" * 15)
+        encryptor = cipher.encryptor()
+        ct = encryptor.update(b"a" * 15)
         assert ct == b""
-        ct += cipher.encrypt(b"a" * 65)
+        ct += encryptor.update(b"a" * 65)
         assert len(ct) == 80
+        ct += encryptor.finalize()
+        decryptor = cipher.decryptor()
+        pt = decryptor.update(ct[:3])
+        assert pt == b""
+        pt += decryptor.update(ct[3:])
+        assert len(pt) == 80
+        assert pt == b"a" * 80
+        decryptor.finalize()
diff --git a/tests/primitives/utils.py b/tests/primitives/utils.py
index a15e773..91ca36d 100644
--- a/tests/primitives/utils.py
+++ b/tests/primitives/utils.py
@@ -37,9 +37,14 @@
         mode_factory(**params),
         api
     )
-    actual_ciphertext = cipher.encrypt(binascii.unhexlify(plaintext))
-    actual_ciphertext += cipher.finalize()
+    encryptor = cipher.encryptor()
+    actual_ciphertext = encryptor.update(binascii.unhexlify(plaintext))
+    actual_ciphertext += encryptor.finalize()
     assert actual_ciphertext == binascii.unhexlify(ciphertext)
+    decryptor = cipher.decryptor()
+    actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext))
+    actual_plaintext += decryptor.finalize()
+    assert actual_plaintext == binascii.unhexlify(plaintext)
 
 
 def generate_hash_test(param_loader, path, file_names, hash_cls,