Merge pull request #962 from reaperhulk/rsa-enc

RSA encryption support
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 2fdf8c9..e09fa5d 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -12,6 +12,8 @@
   removed from ``MGF1`` in two releases per our :doc:`/api-stability` policy.
 * Added :class:`~cryptography.hazmat.primitives.ciphers.algorithms.SEED` support.
 * Added :class:`~cryptography.hazmat.primitives.cmac.CMAC`.
+* Added decryption support to :class:`~cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey`
+  and encryption support to :class:`~cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey`.
 
 0.3 - 2014-03-27
 ~~~~~~~~~~~~~~~~
diff --git a/cryptography/hazmat/backends/interfaces.py b/cryptography/hazmat/backends/interfaces.py
index 1ddf078..aaaca5e 100644
--- a/cryptography/hazmat/backends/interfaces.py
+++ b/cryptography/hazmat/backends/interfaces.py
@@ -123,6 +123,12 @@
         Returns decrypted bytes.
         """
 
+    @abc.abstractmethod
+    def encrypt_rsa(self, public_key, plaintext, padding):
+        """
+        Returns encrypted bytes.
+        """
+
 
 @six.add_metaclass(abc.ABCMeta)
 class DSABackend(object):
diff --git a/cryptography/hazmat/backends/openssl/backend.py b/cryptography/hazmat/backends/openssl/backend.py
index 2d7e230..f9154f3 100644
--- a/cryptography/hazmat/backends/openssl/backend.py
+++ b/cryptography/hazmat/backends/openssl/backend.py
@@ -475,6 +475,16 @@
         )
 
     def decrypt_rsa(self, private_key, ciphertext, padding):
+        key_size_bytes = int(math.ceil(private_key.key_size / 8.0))
+        if key_size_bytes != len(ciphertext):
+            raise ValueError("Ciphertext length must be equal to key size.")
+
+        return self._enc_dec_rsa(private_key, ciphertext, padding)
+
+    def encrypt_rsa(self, public_key, plaintext, padding):
+        return self._enc_dec_rsa(public_key, plaintext, padding)
+
+    def _enc_dec_rsa(self, key, data, padding):
         if isinstance(padding, PKCS1v15):
             padding_enum = self._lib.RSA_PKCS1_PADDING
         elif isinstance(padding, OAEP):
@@ -508,24 +518,27 @@
                 _Reasons.UNSUPPORTED_PADDING
             )
 
-        key_size_bytes = int(math.ceil(private_key.key_size / 8.0))
-        if key_size_bytes < len(ciphertext):
-            raise ValueError("Ciphertext too large for key size")
-
         if self._lib.Cryptography_HAS_PKEY_CTX:
-            return self._decrypt_rsa_pkey_ctx(private_key, ciphertext,
-                                              padding_enum)
+            return self._enc_dec_rsa_pkey_ctx(key, data, padding_enum)
         else:
-            return self._decrypt_rsa_098(private_key, ciphertext, padding_enum)
+            return self._enc_dec_rsa_098(key, data, padding_enum)
 
-    def _decrypt_rsa_pkey_ctx(self, private_key, ciphertext, padding_enum):
-        evp_pkey = self._rsa_private_key_to_evp_pkey(private_key)
+    def _enc_dec_rsa_pkey_ctx(self, key, data, padding_enum):
+        if isinstance(key, rsa.RSAPublicKey):
+            init = self._lib.EVP_PKEY_encrypt_init
+            crypt = self._lib.Cryptography_EVP_PKEY_encrypt
+            evp_pkey = self._rsa_public_key_to_evp_pkey(key)
+        else:
+            init = self._lib.EVP_PKEY_decrypt_init
+            crypt = self._lib.Cryptography_EVP_PKEY_decrypt
+            evp_pkey = self._rsa_private_key_to_evp_pkey(key)
+
         pkey_ctx = self._lib.EVP_PKEY_CTX_new(
             evp_pkey, self._ffi.NULL
         )
         assert pkey_ctx != self._ffi.NULL
         pkey_ctx = self._ffi.gc(pkey_ctx, self._lib.EVP_PKEY_CTX_free)
-        res = self._lib.EVP_PKEY_decrypt_init(pkey_ctx)
+        res = init(pkey_ctx)
         assert res == 1
         res = self._lib.EVP_PKEY_CTX_set_rsa_padding(
             pkey_ctx, padding_enum)
@@ -534,50 +547,60 @@
         assert buf_size > 0
         outlen = self._ffi.new("size_t *", buf_size)
         buf = self._ffi.new("char[]", buf_size)
-        res = self._lib.Cryptography_EVP_PKEY_decrypt(
+        res = crypt(
             pkey_ctx,
             buf,
             outlen,
-            ciphertext,
-            len(ciphertext)
+            data,
+            len(data)
         )
         if res <= 0:
-            errors = self._consume_errors()
-            assert errors
-            assert errors[0].lib == self._lib.ERR_LIB_RSA
-            assert (
-                errors[0].reason == self._lib.RSA_R_BLOCK_TYPE_IS_NOT_01 or
-                errors[0].reason == self._lib.RSA_R_BLOCK_TYPE_IS_NOT_02
-            )
-            raise ValueError("Decryption failed")
+            self._handle_rsa_enc_dec_error(key)
 
         return self._ffi.buffer(buf)[:outlen[0]]
 
-    def _decrypt_rsa_098(self, private_key, ciphertext, padding_enum):
-        rsa_cdata = self._rsa_cdata_from_private_key(private_key)
+    def _enc_dec_rsa_098(self, key, data, padding_enum):
+        if isinstance(key, rsa.RSAPublicKey):
+            crypt = self._lib.RSA_public_encrypt
+            rsa_cdata = self._rsa_cdata_from_public_key(key)
+        else:
+            crypt = self._lib.RSA_private_decrypt
+            rsa_cdata = self._rsa_cdata_from_private_key(key)
+
         rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
         key_size = self._lib.RSA_size(rsa_cdata)
         assert key_size > 0
         buf = self._ffi.new("unsigned char[]", key_size)
-        res = self._lib.RSA_private_decrypt(
-            len(ciphertext),
-            ciphertext,
+        res = crypt(
+            len(data),
+            data,
             buf,
             rsa_cdata,
             padding_enum
         )
         if res < 0:
-            errors = self._consume_errors()
-            assert errors
-            assert errors[0].lib == self._lib.ERR_LIB_RSA
+            self._handle_rsa_enc_dec_error(key)
+
+        return self._ffi.buffer(buf)[:res]
+
+    def _handle_rsa_enc_dec_error(self, key):
+        errors = self._consume_errors()
+        assert errors
+        assert errors[0].lib == self._lib.ERR_LIB_RSA
+        if isinstance(key, rsa.RSAPublicKey):
+            assert (errors[0].reason ==
+                    self._lib.RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE)
+            raise ValueError(
+                "Data too long for key size. Encrypt less data or use a "
+                "larger key size"
+            )
+        else:
             assert (
                 errors[0].reason == self._lib.RSA_R_BLOCK_TYPE_IS_NOT_01 or
                 errors[0].reason == self._lib.RSA_R_BLOCK_TYPE_IS_NOT_02
             )
             raise ValueError("Decryption failed")
 
-        return self._ffi.buffer(buf)[:res]
-
     def cmac_algorithm_supported(self, algorithm):
         return (
             self._lib.Cryptography_HAS_CMAC == 1
diff --git a/cryptography/hazmat/primitives/asymmetric/rsa.py b/cryptography/hazmat/primitives/asymmetric/rsa.py
index cffd4e9..5d3bb36 100644
--- a/cryptography/hazmat/primitives/asymmetric/rsa.py
+++ b/cryptography/hazmat/primitives/asymmetric/rsa.py
@@ -52,6 +52,15 @@
         return backend.create_rsa_verification_ctx(self, signature, padding,
                                                    algorithm)
 
+    def encrypt(self, plaintext, padding, backend):
+        if not isinstance(backend, RSABackend):
+            raise UnsupportedAlgorithm(
+                "Backend object does not implement RSABackend",
+                _Reasons.BACKEND_MISSING_INTERFACE
+            )
+
+        return backend.encrypt_rsa(self, plaintext, padding)
+
     @property
     def key_size(self):
         return utils.bit_length(self.modulus)
diff --git a/docs/hazmat/backends/interfaces.rst b/docs/hazmat/backends/interfaces.rst
index 11ff930..2f63f3e 100644
--- a/docs/hazmat/backends/interfaces.rst
+++ b/docs/hazmat/backends/interfaces.rst
@@ -275,6 +275,18 @@
             :class:`~cryptography.hazmat.primitives.interfaces.AsymmetricPadding`
             provider.
 
+    .. method:: encrypt_rsa(public_key, plaintext, padding)
+
+        :param public_key: An instance of an
+            :class:`~cryptography.hazmat.primitives.interfaces.RSAPublicKey`
+            provider.
+
+        :param bytes plaintext: The plaintext to encrypt.
+
+        :param padding: An instance of an
+            :class:`~cryptography.hazmat.primitives.interfaces.AsymmetricPadding`
+            provider.
+
 
 .. class:: TraditionalOpenSSLSerializationBackend
 
diff --git a/docs/hazmat/primitives/asymmetric/rsa.rst b/docs/hazmat/primitives/asymmetric/rsa.rst
index 862df63..68ad089 100644
--- a/docs/hazmat/primitives/asymmetric/rsa.rst
+++ b/docs/hazmat/primitives/asymmetric/rsa.rst
@@ -267,6 +267,67 @@
         :raises ValueError: This is raised when the chosen hash algorithm is
             too large for the key size.
 
+    .. method:: encrypt(plaintext, padding, backend)
+
+        .. versionadded:: 0.4
+
+        Encrypt data using the public key. The resulting ciphertext can only
+        be decrypted with the private key.
+
+        :param bytes plaintext: The plaintext to encrypt.
+
+        :param padding: An instance of a
+            :class:`~cryptography.hazmat.primitives.interfaces.AsymmetricPadding`
+            provider.
+
+        :param backend: A
+            :class:`~cryptography.hazmat.backends.interfaces.RSABackend`
+            provider.
+
+        :return bytes: Encrypted data.
+
+        :raises cryptography.exceptions.UnsupportedAlgorithm: This is raised if
+            the provided ``backend`` does not implement
+            :class:`~cryptography.hazmat.backends.interfaces.RSABackend` or if
+            the backend does not support the chosen hash or padding algorithm.
+            If the padding is
+            :class:`~cryptography.hazmat.primitives.asymmetric.padding.OAEP`
+            with the
+            :class:`~cryptography.hazmat.primitives.asymmetric.padding.MGF1`
+            mask generation function it may also refer to the ``MGF1`` hash
+            algorithm.
+
+        :raises TypeError: This is raised when the padding is not an
+            :class:`~cryptography.hazmat.primitives.interfaces.AsymmetricPadding`
+            provider.
+
+        :raises ValueError: This is raised if the data is too large for the
+            key size. If the padding is
+            :class:`~cryptography.hazmat.primitives.asymmetric.padding.OAEP`
+            it may also be raised for invalid label values.
+
+        .. code-block:: python
+
+            from cryptography.hazmat.backends import default_backend
+            from cryptography.hazmat.primitives import hashes
+            from cryptography.hazmat.primitives.asymmetric import padding, rsa
+
+            private_key = rsa.RSAPrivateKey.generate(
+                public_exponent=65537,
+                key_size=2048,
+                backend=default_backend()
+            )
+            public_key = private_key.public_key()
+            ciphertext = public_key.encrypt(
+                plaintext,
+                padding.OAEP(
+                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
+                    algorithm=hashes.SHA1(),
+                    label=None
+                ),
+                default_backend()
+            )
+
 
 Handling partial RSA private keys
 ---------------------------------
diff --git a/docs/hazmat/primitives/interfaces.rst b/docs/hazmat/primitives/interfaces.rst
index 3b837a0..c76582c 100644
--- a/docs/hazmat/primitives/interfaces.rst
+++ b/docs/hazmat/primitives/interfaces.rst
@@ -263,6 +263,23 @@
         :returns:
             :class:`~cryptography.hazmat.primitives.interfaces.AsymmetricVerificationContext`
 
+    .. method:: encrypt(plaintext, padding, backend)
+
+        .. versionadded:: 0.4
+
+        Encrypt data with the public key.
+
+        :param bytes plaintext: The plaintext to encrypt.
+
+        :param padding: An instance of a
+            :class:`~cryptography.hazmat.primitives.interfaces.AsymmetricPadding`
+            provider.
+
+        :param backend: A
+            :class:`~cryptography.hazmat.backends.interfaces.RSABackend`
+            provider.
+
+        :return bytes: Encrypted data.
 
     .. attribute:: modulus
 
diff --git a/tests/hazmat/backends/test_openssl.py b/tests/hazmat/backends/test_openssl.py
index 5851166..bba7d75 100644
--- a/tests/hazmat/backends/test_openssl.py
+++ b/tests/hazmat/backends/test_openssl.py
@@ -301,7 +301,7 @@
         )
         with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_HASH):
             private_key.decrypt(
-                b"ciphertext",
+                b"0" * 64,
                 padding.OAEP(
                     mgf=padding.MGF1(algorithm=hashes.SHA256()),
                     algorithm=hashes.SHA1(),
@@ -318,7 +318,7 @@
         )
         with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_HASH):
             private_key.decrypt(
-                b"ciphertext",
+                b"0" * 64,
                 padding.OAEP(
                     mgf=padding.MGF1(algorithm=hashes.SHA1()),
                     algorithm=hashes.SHA256(),
@@ -335,7 +335,7 @@
         )
         with pytest.raises(ValueError):
             private_key.decrypt(
-                b"ciphertext",
+                b"0" * 64,
                 padding.OAEP(
                     mgf=padding.MGF1(algorithm=hashes.SHA1()),
                     algorithm=hashes.SHA1(),
diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py
index 34b80cc..38a5d0a 100644
--- a/tests/hazmat/primitives/test_rsa.py
+++ b/tests/hazmat/primitives/test_rsa.py
@@ -1260,7 +1260,7 @@
             backend=backend
         )
         with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_PADDING):
-            private_key.decrypt(b"somedata", DummyPadding(), backend)
+            private_key.decrypt(b"0" * 64, DummyPadding(), backend)
 
     def test_decrypt_invalid_decrypt(self, backend):
         private_key = rsa.RSAPrivateKey.generate(
@@ -1355,6 +1355,124 @@
         )
         with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_MGF):
             private_key.decrypt(
+                b"0" * 64,
+                padding.OAEP(
+                    mgf=DummyMGF(),
+                    algorithm=hashes.SHA1(),
+                    label=None
+                ),
+                backend
+            )
+
+
+@pytest.mark.rsa
+class TestRSAEncryption(object):
+    @pytest.mark.parametrize(
+        ("key_size", "pad"),
+        itertools.product(
+            (1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1536, 2048),
+            (
+                padding.OAEP(
+                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
+                    algorithm=hashes.SHA1(),
+                    label=None
+                ),
+                padding.PKCS1v15()
+            )
+        )
+    )
+    def test_rsa_encrypt(self, key_size, pad, backend):
+        private_key = rsa.RSAPrivateKey.generate(
+            public_exponent=65537,
+            key_size=key_size,
+            backend=backend
+        )
+        pt = b"encrypt me!"
+        public_key = private_key.public_key()
+        ct = public_key.encrypt(
+            pt,
+            pad,
+            backend
+        )
+        assert ct != pt
+        assert len(ct) == math.ceil(public_key.key_size / 8.0)
+        recovered_pt = private_key.decrypt(
+            ct,
+            pad,
+            backend
+        )
+        assert recovered_pt == pt
+
+    @pytest.mark.parametrize(
+        ("key_size", "pad"),
+        itertools.product(
+            (1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1536, 2048),
+            (
+                padding.OAEP(
+                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
+                    algorithm=hashes.SHA1(),
+                    label=None
+                ),
+                padding.PKCS1v15()
+            )
+        )
+    )
+    def test_rsa_encrypt_key_too_small(self, key_size, pad, backend):
+        private_key = rsa.RSAPrivateKey.generate(
+            public_exponent=65537,
+            key_size=key_size,
+            backend=backend
+        )
+        public_key = private_key.public_key()
+        # Slightly smaller than the key size but not enough for padding.
+        with pytest.raises(ValueError):
+            public_key.encrypt(
+                b"\x00" * (key_size // 8 - 1),
+                pad,
+                backend
+            )
+
+        # Larger than the key size.
+        with pytest.raises(ValueError):
+            public_key.encrypt(
+                b"\x00" * (key_size // 8 + 5),
+                pad,
+                backend
+            )
+
+    def test_rsa_encrypt_invalid_backend(self, backend):
+        pretend_backend = object()
+        private_key = rsa.RSAPrivateKey.generate(65537, 512, backend)
+        public_key = private_key.public_key()
+
+        with raises_unsupported_algorithm(_Reasons.BACKEND_MISSING_INTERFACE):
+            public_key.encrypt(
+                b"irrelevant",
+                padding.PKCS1v15(),
+                pretend_backend
+            )
+
+    def test_unsupported_padding(self, backend):
+        private_key = rsa.RSAPrivateKey.generate(
+            public_exponent=65537,
+            key_size=512,
+            backend=backend
+        )
+        public_key = private_key.public_key()
+
+        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_PADDING):
+            public_key.encrypt(b"somedata", DummyPadding(), backend)
+
+    def test_unsupported_oaep_mgf(self, backend):
+        private_key = rsa.RSAPrivateKey.generate(
+            public_exponent=65537,
+            key_size=512,
+            backend=backend
+        )
+        public_key = private_key.public_key()
+
+        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_MGF):
+            public_key.encrypt(
                 b"ciphertext",
                 padding.OAEP(
                     mgf=DummyMGF(),