support prehashing in RSA sign (#3238)

* support prehashing in RSA sign

* check to make sure digest size matches prehashed data provided

* move doctest for prehashed
diff --git a/docs/hazmat/primitives/asymmetric/rsa.rst b/docs/hazmat/primitives/asymmetric/rsa.rst
index d37b40f..b6acab6 100644
--- a/docs/hazmat/primitives/asymmetric/rsa.rst
+++ b/docs/hazmat/primitives/asymmetric/rsa.rst
@@ -564,6 +564,9 @@
     .. method:: sign(data, padding, algorithm)
 
         .. versionadded:: 1.4
+        .. versionchanged:: 1.6
+            :class:`~cryptography.hazmat.primitives.asymmetric.utils.Prehashed`
+            can now be used as an ``algorithm``.
 
         Sign one block of data which can be verified later by others using the
         public key.
@@ -574,7 +577,9 @@
             :class:`~cryptography.hazmat.primitives.asymmetric.padding.AsymmetricPadding`.
 
         :param algorithm: An instance of
-            :class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm`.
+            :class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm` or
+            :class:`~cryptography.hazmat.primitives.asymmetric.utils.Prehashed`
+            if the ``data`` you want to sign has already been hashed.
 
         :return bytes: Signature.
 
diff --git a/docs/hazmat/primitives/asymmetric/utils.rst b/docs/hazmat/primitives/asymmetric/utils.rst
index 0788359..f29b3e9 100644
--- a/docs/hazmat/primitives/asymmetric/utils.rst
+++ b/docs/hazmat/primitives/asymmetric/utils.rst
@@ -28,3 +28,37 @@
     :param int s: The raw signature value ``s``.
 
     :return bytes: The encoded signature.
+
+.. class:: Prehashed(algorithm)
+
+    .. versionadded:: 1.6
+
+    ``Prehashed`` can be passed as the ``algorithm`` in
+    :meth:`~cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey.sign`
+    if the data to be signed has been hashed beforehand.
+
+    :param algorithm: An instance of
+        :class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm`.
+
+    .. doctest::
+
+        >>> import hashlib
+        >>> from cryptography.hazmat.backends import default_backend
+        >>> from cryptography.hazmat.primitives import hashes
+        >>> from cryptography.hazmat.primitives.asymmetric import (
+        ...    padding, rsa, utils
+        ... )
+        >>> private_key = rsa.generate_private_key(
+        ...     public_exponent=65537,
+        ...     key_size=2048,
+        ...     backend=default_backend()
+        ... )
+        >>> prehashed_msg = hashlib.sha256(b"A message I want to sign").digest()
+        >>> signature = private_key.sign(
+        ...     prehashed_msg,
+        ...     padding.PSS(
+        ...         mgf=padding.MGF1(hashes.SHA256()),
+        ...         salt_length=padding.PSS.MAX_LENGTH
+        ...     ),
+        ...     utils.Prehashed(hashes.SHA256())
+        ... )
diff --git a/src/cryptography/hazmat/backends/openssl/rsa.py b/src/cryptography/hazmat/backends/openssl/rsa.py
index 8bb8578..85d0652 100644
--- a/src/cryptography/hazmat/backends/openssl/rsa.py
+++ b/src/cryptography/hazmat/backends/openssl/rsa.py
@@ -13,6 +13,7 @@
 from cryptography.hazmat.primitives import hashes
 from cryptography.hazmat.primitives.asymmetric import (
     AsymmetricSignatureContext, AsymmetricVerificationContext, rsa,
+    utils as asym_utils
 )
 from cryptography.hazmat.primitives.asymmetric.padding import (
     AsymmetricPadding, MGF1, OAEP, PKCS1v15, PSS, calculate_max_pss_salt_length
@@ -452,9 +453,18 @@
         padding_enum = _rsa_sig_determine_padding(
             self._backend, self, padding, algorithm
         )
-        hash_ctx = hashes.Hash(algorithm, self._backend)
-        hash_ctx.update(data)
-        data = hash_ctx.finalize()
+        if not isinstance(algorithm, asym_utils.Prehashed):
+            hash_ctx = hashes.Hash(algorithm, self._backend)
+            hash_ctx.update(data)
+            data = hash_ctx.finalize()
+        else:
+            algorithm = algorithm._algorithm
+
+        if len(data) != algorithm.digest_size:
+            raise ValueError(
+                "The provided data must be the same length as the hash "
+                "algorithm's digest size."
+            )
 
         return _rsa_sig_sign(
             self._backend, padding, padding_enum,
diff --git a/src/cryptography/hazmat/primitives/asymmetric/utils.py b/src/cryptography/hazmat/primitives/asymmetric/utils.py
index 5b27654..44bf59d 100644
--- a/src/cryptography/hazmat/primitives/asymmetric/utils.py
+++ b/src/cryptography/hazmat/primitives/asymmetric/utils.py
@@ -13,6 +13,7 @@
 import six
 
 from cryptography import utils
+from cryptography.hazmat.primitives import hashes
 
 
 class _DSSSigValue(univ.Sequence):
@@ -69,3 +70,14 @@
     sig.setComponentByName('r', r)
     sig.setComponentByName('s', s)
     return encoder.encode(sig)
+
+
+class Prehashed(object):
+    def __init__(self, algorithm):
+        if not isinstance(algorithm, hashes.HashAlgorithm):
+            raise TypeError("Expected instance of HashAlgorithm.")
+
+        self._algorithm = algorithm
+        self._digest_size = algorithm.digest_size
+
+    digest_size = utils.read_only_property("_digest_size")
diff --git a/tests/hazmat/primitives/test_asym_utils.py b/tests/hazmat/primitives/test_asym_utils.py
index b997113..bd1fa35 100644
--- a/tests/hazmat/primitives/test_asym_utils.py
+++ b/tests/hazmat/primitives/test_asym_utils.py
@@ -7,8 +7,8 @@
 import pytest
 
 from cryptography.hazmat.primitives.asymmetric.utils import (
-    decode_dss_signature, decode_rfc6979_signature,
-    encode_dss_signature, encode_rfc6979_signature
+    Prehashed, decode_dss_signature, decode_rfc6979_signature,
+    encode_dss_signature, encode_rfc6979_signature,
 )
 
 
@@ -76,3 +76,8 @@
         # This is the BER "end-of-contents octets," which older versions of
         # pyasn1 are wrongly willing to return from top-level DER decoding.
         decode_dss_signature(b"\x00\x00")
+
+
+def test_pass_invalid_prehashed_arg():
+    with pytest.raises(TypeError):
+        Prehashed(object())
diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py
index 81e3f94..6ec1799 100644
--- a/tests/hazmat/primitives/test_rsa.py
+++ b/tests/hazmat/primitives/test_rsa.py
@@ -18,7 +18,9 @@
     PEMSerializationBackend, RSABackend
 )
 from cryptography.hazmat.primitives import hashes, serialization
-from cryptography.hazmat.primitives.asymmetric import padding, rsa
+from cryptography.hazmat.primitives.asymmetric import (
+    padding, rsa, utils as asym_utils
+)
 from cryptography.hazmat.primitives.asymmetric.rsa import (
     RSAPrivateNumbers, RSAPublicNumbers
 )
@@ -492,6 +494,43 @@
         verifier.update(message)
         verifier.verify()
 
+    @pytest.mark.supported(
+        only_if=lambda backend: backend.rsa_padding_supported(
+            padding.PSS(mgf=padding.MGF1(hashes.SHA1()), salt_length=0)
+        ),
+        skip_message="Does not support PSS."
+    )
+    def test_prehashed_sign(self, backend):
+        private_key = RSA_KEY_512.private_key(backend)
+        message = b"one little message"
+        h = hashes.Hash(hashes.SHA1(), backend)
+        h.update(message)
+        digest = h.finalize()
+        pss = padding.PSS(mgf=padding.MGF1(hashes.SHA1()), salt_length=0)
+        prehashed_alg = asym_utils.Prehashed(hashes.SHA1())
+        signature = private_key.sign(digest, pss, prehashed_alg)
+        public_key = private_key.public_key()
+        verifier = public_key.verifier(signature, pss, hashes.SHA1())
+        verifier.update(message)
+        verifier.verify()
+
+    @pytest.mark.supported(
+        only_if=lambda backend: backend.rsa_padding_supported(
+            padding.PSS(mgf=padding.MGF1(hashes.SHA1()), salt_length=0)
+        ),
+        skip_message="Does not support PSS."
+    )
+    def test_prehashed_digest_mismatch(self, backend):
+        private_key = RSA_KEY_512.private_key(backend)
+        message = b"one little message"
+        h = hashes.Hash(hashes.SHA512(), backend)
+        h.update(message)
+        digest = h.finalize()
+        pss = padding.PSS(mgf=padding.MGF1(hashes.SHA1()), salt_length=0)
+        prehashed_alg = asym_utils.Prehashed(hashes.SHA1())
+        with pytest.raises(ValueError):
+            private_key.sign(digest, pss, prehashed_alg)
+
 
 @pytest.mark.requires_backend_interface(interface=RSABackend)
 class TestRSAVerification(object):