refactor RSA signing to prep for prehash support (#3240)

diff --git a/src/cryptography/hazmat/backends/openssl/rsa.py b/src/cryptography/hazmat/backends/openssl/rsa.py
index ba830dd..4c554be 100644
--- a/src/cryptography/hazmat/backends/openssl/rsa.py
+++ b/src/cryptography/hazmat/backends/openssl/rsa.py
@@ -12,7 +12,7 @@
 )
 from cryptography.hazmat.primitives import hashes
 from cryptography.hazmat.primitives.asymmetric import (
-    AsymmetricSignatureContext, AsymmetricVerificationContext, rsa
+    AsymmetricSignatureContext, AsymmetricVerificationContext, rsa,
 )
 from cryptography.hazmat.primitives.asymmetric.padding import (
     AsymmetricPadding, MGF1, OAEP, PKCS1v15, PSS, calculate_max_pss_salt_length
@@ -139,51 +139,127 @@
         raise ValueError("Decryption failed.")
 
 
+def _rsa_sig_determine_padding(backend, private_key, padding, algorithm):
+    if not isinstance(padding, AsymmetricPadding):
+        raise TypeError("Expected provider of AsymmetricPadding.")
+
+    pkey_size = backend._lib.EVP_PKEY_size(private_key._evp_pkey)
+    backend.openssl_assert(pkey_size > 0)
+
+    if isinstance(padding, PKCS1v15):
+        padding_enum = backend._lib.RSA_PKCS1_PADDING
+    elif isinstance(padding, PSS):
+        if not isinstance(padding._mgf, MGF1):
+            raise UnsupportedAlgorithm(
+                "Only MGF1 is supported by this backend.",
+                _Reasons.UNSUPPORTED_MGF
+            )
+
+        # Size of key in bytes - 2 is the maximum
+        # PSS signature length (salt length is checked later)
+        if pkey_size - algorithm.digest_size - 2 < 0:
+            raise ValueError("Digest too large for key size. Use a larger "
+                             "key.")
+
+        if not backend._pss_mgf1_hash_supported(padding._mgf._algorithm):
+            raise UnsupportedAlgorithm(
+                "When OpenSSL is older than 1.0.1 then only SHA1 is "
+                "supported with MGF1.",
+                _Reasons.UNSUPPORTED_HASH
+            )
+
+        padding_enum = backend._lib.RSA_PKCS1_PSS_PADDING
+    else:
+        raise UnsupportedAlgorithm(
+            "{0} is not supported by this backend.".format(padding.name),
+            _Reasons.UNSUPPORTED_PADDING
+        )
+
+    return padding_enum
+
+
+def _rsa_sig_sign(backend, padding, padding_enum, algorithm, private_key,
+                  data):
+    evp_md = backend._lib.EVP_get_digestbyname(
+        algorithm.name.encode("ascii"))
+    backend.openssl_assert(evp_md != backend._ffi.NULL)
+
+    pkey_ctx = backend._lib.EVP_PKEY_CTX_new(
+        private_key._evp_pkey, backend._ffi.NULL
+    )
+    backend.openssl_assert(pkey_ctx != backend._ffi.NULL)
+    pkey_ctx = backend._ffi.gc(pkey_ctx, backend._lib.EVP_PKEY_CTX_free)
+    res = backend._lib.EVP_PKEY_sign_init(pkey_ctx)
+    backend.openssl_assert(res == 1)
+    res = backend._lib.EVP_PKEY_CTX_set_signature_md(
+        pkey_ctx, evp_md)
+    backend.openssl_assert(res > 0)
+
+    res = backend._lib.EVP_PKEY_CTX_set_rsa_padding(
+        pkey_ctx, padding_enum)
+    backend.openssl_assert(res > 0)
+    if isinstance(padding, PSS):
+        res = backend._lib.EVP_PKEY_CTX_set_rsa_pss_saltlen(
+            pkey_ctx,
+            _get_rsa_pss_salt_length(
+                padding,
+                private_key,
+                algorithm,
+            )
+        )
+        backend.openssl_assert(res > 0)
+
+        if backend._lib.Cryptography_HAS_MGF1_MD:
+            # MGF1 MD is configurable in OpenSSL 1.0.1+
+            mgf1_md = backend._lib.EVP_get_digestbyname(
+                padding._mgf._algorithm.name.encode("ascii"))
+            backend.openssl_assert(
+                mgf1_md != backend._ffi.NULL
+            )
+            res = backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(
+                pkey_ctx, mgf1_md
+            )
+            backend.openssl_assert(res > 0)
+
+    buflen = backend._ffi.new("size_t *")
+    res = backend._lib.EVP_PKEY_sign(
+        pkey_ctx,
+        backend._ffi.NULL,
+        buflen,
+        data,
+        len(data)
+    )
+    backend.openssl_assert(res == 1)
+    buf = backend._ffi.new("unsigned char[]", buflen[0])
+    res = backend._lib.EVP_PKEY_sign(
+        pkey_ctx, buf, buflen, data, len(data))
+    if res != 1:
+        errors = backend._consume_errors()
+        assert errors[0].lib == backend._lib.ERR_LIB_RSA
+        reason = None
+        if (errors[0].reason ==
+                backend._lib.RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE):
+            reason = ("Salt length too long for key size. Try using "
+                      "MAX_LENGTH instead.")
+        else:
+            assert (errors[0].reason ==
+                    backend._lib.RSA_R_DIGEST_TOO_BIG_FOR_RSA_KEY)
+            reason = "Digest too large for key size. Use a larger key."
+        assert reason is not None
+        raise ValueError(reason)
+
+    return backend._ffi.buffer(buf)[:]
+
+
 @utils.register_interface(AsymmetricSignatureContext)
 class _RSASignatureContext(object):
     def __init__(self, backend, private_key, padding, algorithm):
         self._backend = backend
         self._private_key = private_key
 
-        if not isinstance(padding, AsymmetricPadding):
-            raise TypeError("Expected provider of AsymmetricPadding.")
-
-        self._pkey_size = self._backend._lib.EVP_PKEY_size(
-            self._private_key._evp_pkey
+        self._padding_enum = _rsa_sig_determine_padding(
+            backend, private_key, padding, algorithm
         )
-        self._backend.openssl_assert(self._pkey_size > 0)
-
-        if isinstance(padding, PKCS1v15):
-            self._padding_enum = self._backend._lib.RSA_PKCS1_PADDING
-        elif isinstance(padding, PSS):
-            if not isinstance(padding._mgf, MGF1):
-                raise UnsupportedAlgorithm(
-                    "Only MGF1 is supported by this backend.",
-                    _Reasons.UNSUPPORTED_MGF
-                )
-
-            # Size of key in bytes - 2 is the maximum
-            # PSS signature length (salt length is checked later)
-            if self._pkey_size - algorithm.digest_size - 2 < 0:
-                raise ValueError("Digest too large for key size. Use a larger "
-                                 "key.")
-
-            if not self._backend._pss_mgf1_hash_supported(
-                padding._mgf._algorithm
-            ):
-                raise UnsupportedAlgorithm(
-                    "When OpenSSL is older than 1.0.1 then only SHA1 is "
-                    "supported with MGF1.",
-                    _Reasons.UNSUPPORTED_HASH
-                )
-
-            self._padding_enum = self._backend._lib.RSA_PKCS1_PSS_PADDING
-        else:
-            raise UnsupportedAlgorithm(
-                "{0} is not supported by this backend.".format(padding.name),
-                _Reasons.UNSUPPORTED_PADDING
-            )
-
         self._padding = padding
         self._algorithm = algorithm
         self._hash_ctx = hashes.Hash(self._algorithm, self._backend)
@@ -192,76 +268,14 @@
         self._hash_ctx.update(data)
 
     def finalize(self):
-        evp_md = self._backend._lib.EVP_get_digestbyname(
-            self._algorithm.name.encode("ascii"))
-        self._backend.openssl_assert(evp_md != self._backend._ffi.NULL)
-
-        pkey_ctx = self._backend._lib.EVP_PKEY_CTX_new(
-            self._private_key._evp_pkey, self._backend._ffi.NULL
+        return _rsa_sig_sign(
+            self._backend,
+            self._padding,
+            self._padding_enum,
+            self._algorithm,
+            self._private_key,
+            self._hash_ctx.finalize()
         )
-        self._backend.openssl_assert(pkey_ctx != self._backend._ffi.NULL)
-        pkey_ctx = self._backend._ffi.gc(pkey_ctx,
-                                         self._backend._lib.EVP_PKEY_CTX_free)
-        res = self._backend._lib.EVP_PKEY_sign_init(pkey_ctx)
-        self._backend.openssl_assert(res == 1)
-        res = self._backend._lib.EVP_PKEY_CTX_set_signature_md(
-            pkey_ctx, evp_md)
-        self._backend.openssl_assert(res > 0)
-
-        res = self._backend._lib.EVP_PKEY_CTX_set_rsa_padding(
-            pkey_ctx, self._padding_enum)
-        self._backend.openssl_assert(res > 0)
-        if isinstance(self._padding, PSS):
-            res = self._backend._lib.EVP_PKEY_CTX_set_rsa_pss_saltlen(
-                pkey_ctx,
-                _get_rsa_pss_salt_length(
-                    self._padding,
-                    self._private_key,
-                    self._hash_ctx.algorithm,
-                )
-            )
-            self._backend.openssl_assert(res > 0)
-
-            if self._backend._lib.Cryptography_HAS_MGF1_MD:
-                # MGF1 MD is configurable in OpenSSL 1.0.1+
-                mgf1_md = self._backend._lib.EVP_get_digestbyname(
-                    self._padding._mgf._algorithm.name.encode("ascii"))
-                self._backend.openssl_assert(
-                    mgf1_md != self._backend._ffi.NULL
-                )
-                res = self._backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(
-                    pkey_ctx, mgf1_md
-                )
-                self._backend.openssl_assert(res > 0)
-        data_to_sign = self._hash_ctx.finalize()
-        buflen = self._backend._ffi.new("size_t *")
-        res = self._backend._lib.EVP_PKEY_sign(
-            pkey_ctx,
-            self._backend._ffi.NULL,
-            buflen,
-            data_to_sign,
-            len(data_to_sign)
-        )
-        self._backend.openssl_assert(res == 1)
-        buf = self._backend._ffi.new("unsigned char[]", buflen[0])
-        res = self._backend._lib.EVP_PKEY_sign(
-            pkey_ctx, buf, buflen, data_to_sign, len(data_to_sign))
-        if res != 1:
-            errors = self._backend._consume_errors()
-            assert errors[0].lib == self._backend._lib.ERR_LIB_RSA
-            reason = None
-            if (errors[0].reason ==
-                    self._backend._lib.RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE):
-                reason = ("Salt length too long for key size. Try using "
-                          "MAX_LENGTH instead.")
-            else:
-                assert (errors[0].reason ==
-                        self._backend._lib.RSA_R_DIGEST_TOO_BIG_FOR_RSA_KEY)
-                reason = "Digest too large for key size. Use a larger key."
-            assert reason is not None
-            raise ValueError(reason)
-
-        return self._backend._ffi.buffer(buf)[:]
 
 
 @utils.register_interface(AsymmetricVerificationContext)
@@ -460,10 +474,17 @@
         )
 
     def sign(self, data, padding, algorithm):
-        signer = self.signer(padding, algorithm)
-        signer.update(data)
-        signature = signer.finalize()
-        return signature
+        padding_enum = _rsa_sig_determine_padding(
+            self._backend, self, padding, algorithm
+        )
+        hash_ctx = hashes.Hash(algorithm, self._backend)
+        hash_ctx.update(data)
+        data = hash_ctx.finalize()
+
+        return _rsa_sig_sign(
+            self._backend, padding, padding_enum,
+            algorithm, self, data
+        )
 
 
 @utils.register_interface(RSAPublicKeyWithSerialization)