Merge pull request #949 from reaperhulk/rsa-oaep-decrypt

OAEP support for RSA decryption
diff --git a/cryptography/hazmat/backends/openssl/backend.py b/cryptography/hazmat/backends/openssl/backend.py
index 4c487e4..16b963a 100644
--- a/cryptography/hazmat/backends/openssl/backend.py
+++ b/cryptography/hazmat/backends/openssl/backend.py
@@ -32,7 +32,7 @@
 from cryptography.hazmat.primitives import hashes, interfaces
 from cryptography.hazmat.primitives.asymmetric import dsa, rsa
 from cryptography.hazmat.primitives.asymmetric.padding import (
-    MGF1, PKCS1v15, PSS
+    MGF1, OAEP, PKCS1v15, PSS
 )
 from cryptography.hazmat.primitives.ciphers.algorithms import (
     AES, ARC4, Blowfish, CAST5, Camellia, IDEA, SEED, TripleDES
@@ -477,6 +477,29 @@
     def decrypt_rsa(self, private_key, ciphertext, padding):
         if isinstance(padding, PKCS1v15):
             padding_enum = self._lib.RSA_PKCS1_PADDING
+        elif isinstance(padding, OAEP):
+            padding_enum = self._lib.RSA_PKCS1_OAEP_PADDING
+            if not isinstance(padding._mgf, MGF1):
+                raise UnsupportedAlgorithm(
+                    "Only MGF1 is supported by this backend",
+                    _Reasons.UNSUPPORTED_MGF
+                )
+
+            if not isinstance(padding._mgf._algorithm, hashes.SHA1):
+                raise UnsupportedAlgorithm(
+                    "This backend supports only SHA1 inside MGF1 when "
+                    "using OAEP",
+                    _Reasons.UNSUPPORTED_HASH
+                )
+
+            if padding._label is not None and padding._label != b"":
+                raise ValueError("This backend does not support OAEP labels")
+
+            if not isinstance(padding._algorithm, hashes.SHA1):
+                raise UnsupportedAlgorithm(
+                    "This backend only supports SHA1 when using OAEP",
+                    _Reasons.UNSUPPORTED_HASH
+                )
         else:
             raise UnsupportedAlgorithm(
                 "{0} is not supported by this backend".format(
diff --git a/cryptography/hazmat/primitives/asymmetric/padding.py b/cryptography/hazmat/primitives/asymmetric/padding.py
index 72806a6..dcc6fe0 100644
--- a/cryptography/hazmat/primitives/asymmetric/padding.py
+++ b/cryptography/hazmat/primitives/asymmetric/padding.py
@@ -54,6 +54,19 @@
         self._salt_length = salt_length
 
 
+@utils.register_interface(interfaces.AsymmetricPadding)
+class OAEP(object):
+    name = "EME-OAEP"
+
+    def __init__(self, mgf, algorithm, label):
+        if not isinstance(algorithm, interfaces.HashAlgorithm):
+            raise TypeError("Expected instance of interfaces.HashAlgorithm.")
+
+        self._mgf = mgf
+        self._algorithm = algorithm
+        self._label = label
+
+
 class MGF1(object):
     MAX_LENGTH = object()
 
diff --git a/docs/hazmat/primitives/asymmetric/padding.rst b/docs/hazmat/primitives/asymmetric/padding.rst
index f33ca4e..4008479 100644
--- a/docs/hazmat/primitives/asymmetric/padding.rst
+++ b/docs/hazmat/primitives/asymmetric/padding.rst
@@ -33,6 +33,21 @@
         Pass this attribute to ``salt_length`` to get the maximum salt length
         available.
 
+.. class:: OAEP(mgf, label)
+
+    .. versionadded:: 0.4
+
+    OAEP (Optimal Asymmetric Encryption Padding) is a padding scheme defined in
+    :rfc:`3447`. It provides probabilistic encryption and is `proven secure`_
+    against several attack types. This is the `recommended padding algorithm`_
+    for RSA encryption. It cannot be used with RSA signing.
+
+    :param mgf: A mask generation function object. At this time the only
+        supported MGF is :class:`MGF1`.
+
+    :param bytes label: A label to apply. This is a rarely used field and
+        should typically be set to ``None`` or ``b""``, which are equivalent.
+
 .. class:: PKCS1v15()
 
     .. versionadded:: 0.3
@@ -62,3 +77,4 @@
 .. _`Padding is critical`: http://rdist.root.org/2009/10/06/why-rsa-encryption-padding-is-critical/
 .. _`security proof`: http://eprint.iacr.org/2001/062.pdf
 .. _`recommended padding algorithm`: http://www.daemonology.net/blog/2009-06-11-cryptographic-right-answers.html
+.. _`proven secure`: http://cseweb.ucsd.edu/users/mihir/papers/oae.pdf
diff --git a/docs/hazmat/primitives/asymmetric/rsa.rst b/docs/hazmat/primitives/asymmetric/rsa.rst
index c282d9e..862df63 100644
--- a/docs/hazmat/primitives/asymmetric/rsa.rst
+++ b/docs/hazmat/primitives/asymmetric/rsa.rst
@@ -138,13 +138,37 @@
             the provided ``backend`` does not implement
             :class:`~cryptography.hazmat.backends.interfaces.RSABackend` or if
             the backend does not support the chosen hash or padding algorithm.
+            If the padding is
+            :class:`~cryptography.hazmat.primitives.asymmetric.padding.OAEP`
+            with the
+            :class:`~cryptography.hazmat.primitives.asymmetric.padding.MGF1`
+            mask generation function it may also refer to the ``MGF1`` hash
+            algorithm.
 
         :raises TypeError: This is raised when the padding is not an
             :class:`~cryptography.hazmat.primitives.interfaces.AsymmetricPadding`
             provider.
 
-        :raises ValueError: This is raised when decryption fails or the chosen
-            hash algorithm is too large for the key size.
+        :raises ValueError: This is raised when decryption fails or the data
+            is too large for the key size. If the padding is
+            :class:`~cryptography.hazmat.primitives.asymmetric.padding.OAEP`
+            it may also be raised for invalid label values.
+
+        .. code-block:: python
+
+            from cryptography.hazmat.backends import default_backend
+            from cryptography.hazmat.primitives import hashes
+            from cryptography.hazmat.primitives.asymmetric import padding
+
+            plaintext = private_key.decrypt(
+                ciphertext,
+                padding.OAEP(
+                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
+                    algorithm=hashes.SHA1(),
+                    label=None
+                ),
+                default_backend()
+            )
 
 
 .. class:: RSAPublicKey(public_exponent, modulus)
diff --git a/tests/hazmat/backends/test_openssl.py b/tests/hazmat/backends/test_openssl.py
index 9853736..5851166 100644
--- a/tests/hazmat/backends/test_openssl.py
+++ b/tests/hazmat/backends/test_openssl.py
@@ -293,6 +293,57 @@
     def test_unsupported_mgf1_hash_algorithm(self):
         assert backend.mgf1_hash_supported(DummyHash()) is False
 
+    def test_unsupported_mgf1_hash_algorithm_decrypt(self):
+        private_key = rsa.RSAPrivateKey.generate(
+            public_exponent=65537,
+            key_size=512,
+            backend=backend
+        )
+        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_HASH):
+            private_key.decrypt(
+                b"ciphertext",
+                padding.OAEP(
+                    mgf=padding.MGF1(algorithm=hashes.SHA256()),
+                    algorithm=hashes.SHA1(),
+                    label=None
+                ),
+                backend
+            )
+
+    def test_unsupported_oaep_hash_algorithm_decrypt(self):
+        private_key = rsa.RSAPrivateKey.generate(
+            public_exponent=65537,
+            key_size=512,
+            backend=backend
+        )
+        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_HASH):
+            private_key.decrypt(
+                b"ciphertext",
+                padding.OAEP(
+                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
+                    algorithm=hashes.SHA256(),
+                    label=None
+                ),
+                backend
+            )
+
+    def test_unsupported_oaep_label_decrypt(self):
+        private_key = rsa.RSAPrivateKey.generate(
+            public_exponent=65537,
+            key_size=512,
+            backend=backend
+        )
+        with pytest.raises(ValueError):
+            private_key.decrypt(
+                b"ciphertext",
+                padding.OAEP(
+                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
+                    algorithm=hashes.SHA1(),
+                    label=b"label"
+                ),
+                backend
+            )
+
 
 @pytest.mark.skipif(
     backend._lib.OPENSSL_VERSION_NUMBER <= 0x10001000,
diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py
index a5266d5..7cf6e2f 100644
--- a/tests/hazmat/primitives/test_rsa.py
+++ b/tests/hazmat/primitives/test_rsa.py
@@ -1227,6 +1227,17 @@
         assert mgf._salt_length == padding.MGF1.MAX_LENGTH
 
 
+class TestOAEP(object):
+    def test_invalid_algorithm(self):
+        mgf = padding.MGF1(hashes.SHA1())
+        with pytest.raises(TypeError):
+            padding.OAEP(
+                mgf=mgf,
+                algorithm=b"",
+                label=None
+            )
+
+
 @pytest.mark.rsa
 class TestRSADecryption(object):
     @pytest.mark.parametrize(
@@ -1320,3 +1331,51 @@
                 padding.PKCS1v15(),
                 pretend_backend
             )
+
+    @pytest.mark.parametrize(
+        "vector",
+        _flatten_pkcs1_examples(load_vectors_from_file(
+            os.path.join(
+                "asymmetric", "RSA", "pkcs-1v2-1d2-vec", "oaep-vect.txt"),
+            load_pkcs1_vectors
+        ))
+    )
+    def test_decrypt_oaep_vectors(self, vector, backend):
+        private, public, example = vector
+        skey = rsa.RSAPrivateKey(
+            p=private["p"],
+            q=private["q"],
+            private_exponent=private["private_exponent"],
+            dmp1=private["dmp1"],
+            dmq1=private["dmq1"],
+            iqmp=private["iqmp"],
+            public_exponent=private["public_exponent"],
+            modulus=private["modulus"]
+        )
+        message = skey.decrypt(
+            binascii.unhexlify(example["encryption"]),
+            padding.OAEP(
+                mgf=padding.MGF1(algorithm=hashes.SHA1()),
+                algorithm=hashes.SHA1(),
+                label=None
+            ),
+            backend
+        )
+        assert message == binascii.unhexlify(example["message"])
+
+    def test_unsupported_oaep_mgf(self, backend):
+        private_key = rsa.RSAPrivateKey.generate(
+            public_exponent=65537,
+            key_size=512,
+            backend=backend
+        )
+        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_MGF):
+            private_key.decrypt(
+                b"ciphertext",
+                padding.OAEP(
+                    mgf=DummyMGF(),
+                    algorithm=hashes.SHA1(),
+                    label=None
+                ),
+                backend
+            )