refactor RSA signature verification to prep for prehash support (#3261)

diff --git a/src/cryptography/hazmat/backends/openssl/rsa.py b/src/cryptography/hazmat/backends/openssl/rsa.py
index 4c554be..8bb8578 100644
--- a/src/cryptography/hazmat/backends/openssl/rsa.py
+++ b/src/cryptography/hazmat/backends/openssl/rsa.py
@@ -139,11 +139,11 @@
         raise ValueError("Decryption failed.")
 
 
-def _rsa_sig_determine_padding(backend, private_key, padding, algorithm):
+def _rsa_sig_determine_padding(backend, key, padding, algorithm):
     if not isinstance(padding, AsymmetricPadding):
         raise TypeError("Expected provider of AsymmetricPadding.")
 
-    pkey_size = backend._lib.EVP_PKEY_size(private_key._evp_pkey)
+    pkey_size = backend._lib.EVP_PKEY_size(key._evp_pkey)
     backend.openssl_assert(pkey_size > 0)
 
     if isinstance(padding, PKCS1v15):
@@ -159,7 +159,7 @@
         # PSS signature length (salt length is checked later)
         if pkey_size - algorithm.digest_size - 2 < 0:
             raise ValueError("Digest too large for key size. Use a larger "
-                             "key.")
+                             "key or different digest.")
 
         if not backend._pss_mgf1_hash_supported(padding._mgf._algorithm):
             raise UnsupportedAlgorithm(
@@ -251,6 +251,65 @@
     return backend._ffi.buffer(buf)[:]
 
 
+def _rsa_sig_verify(backend, padding, padding_enum, algorithm, public_key,
+                    signature, data):
+    evp_md = backend._lib.EVP_get_digestbyname(
+        algorithm.name.encode("ascii"))
+    backend.openssl_assert(evp_md != backend._ffi.NULL)
+
+    pkey_ctx = backend._lib.EVP_PKEY_CTX_new(
+        public_key._evp_pkey, backend._ffi.NULL
+    )
+    backend.openssl_assert(pkey_ctx != backend._ffi.NULL)
+    pkey_ctx = backend._ffi.gc(pkey_ctx, backend._lib.EVP_PKEY_CTX_free)
+    res = backend._lib.EVP_PKEY_verify_init(pkey_ctx)
+    backend.openssl_assert(res == 1)
+    res = backend._lib.EVP_PKEY_CTX_set_signature_md(
+        pkey_ctx, evp_md)
+    backend.openssl_assert(res > 0)
+
+    res = backend._lib.EVP_PKEY_CTX_set_rsa_padding(
+        pkey_ctx, padding_enum)
+    backend.openssl_assert(res > 0)
+    if isinstance(padding, PSS):
+        res = backend._lib.EVP_PKEY_CTX_set_rsa_pss_saltlen(
+            pkey_ctx,
+            _get_rsa_pss_salt_length(
+                padding,
+                public_key,
+                algorithm,
+            )
+        )
+        backend.openssl_assert(res > 0)
+        if backend._lib.Cryptography_HAS_MGF1_MD:
+            # MGF1 MD is configurable in OpenSSL 1.0.1+
+            mgf1_md = backend._lib.EVP_get_digestbyname(
+                padding._mgf._algorithm.name.encode("ascii"))
+            backend.openssl_assert(
+                mgf1_md != backend._ffi.NULL
+            )
+            res = backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(
+                pkey_ctx, mgf1_md
+            )
+            backend.openssl_assert(res > 0)
+
+    res = backend._lib.EVP_PKEY_verify(
+        pkey_ctx,
+        signature,
+        len(signature),
+        data,
+        len(data)
+    )
+    # The previous call can return negative numbers in the event of an
+    # error. This is not a signature failure but we need to fail if it
+    # occurs.
+    backend.openssl_assert(res >= 0)
+    if res == 0:
+        errors = backend._consume_errors()
+        assert errors
+        raise InvalidSignature
+
+
 @utils.register_interface(AsymmetricSignatureContext)
 class _RSASignatureContext(object):
     def __init__(self, backend, private_key, padding, algorithm):
@@ -284,49 +343,13 @@
         self._backend = backend
         self._public_key = public_key
         self._signature = signature
-
-        if not isinstance(padding, AsymmetricPadding):
-            raise TypeError("Expected provider of AsymmetricPadding.")
-
-        self._pkey_size = self._backend._lib.EVP_PKEY_size(
-            self._public_key._evp_pkey
-        )
-        self._backend.openssl_assert(self._pkey_size > 0)
-
-        if isinstance(padding, PKCS1v15):
-            self._padding_enum = self._backend._lib.RSA_PKCS1_PADDING
-        elif isinstance(padding, PSS):
-            if not isinstance(padding._mgf, MGF1):
-                raise UnsupportedAlgorithm(
-                    "Only MGF1 is supported by this backend.",
-                    _Reasons.UNSUPPORTED_MGF
-                )
-
-            # Size of key in bytes - 2 is the maximum
-            # PSS signature length (salt length is checked later)
-            if self._pkey_size - algorithm.digest_size - 2 < 0:
-                raise ValueError(
-                    "Digest too large for key size. Check that you have the "
-                    "correct key and digest algorithm."
-                )
-
-            if not self._backend._pss_mgf1_hash_supported(
-                padding._mgf._algorithm
-            ):
-                raise UnsupportedAlgorithm(
-                    "When OpenSSL is older than 1.0.1 then only SHA1 is "
-                    "supported with MGF1.",
-                    _Reasons.UNSUPPORTED_HASH
-                )
-
-            self._padding_enum = self._backend._lib.RSA_PKCS1_PSS_PADDING
-        else:
-            raise UnsupportedAlgorithm(
-                "{0} is not supported by this backend.".format(padding.name),
-                _Reasons.UNSUPPORTED_PADDING
-            )
-
         self._padding = padding
+
+        self._padding_enum = _rsa_sig_determine_padding(
+            backend, public_key, padding, algorithm
+        )
+
+        padding = padding
         self._algorithm = algorithm
         self._hash_ctx = hashes.Hash(self._algorithm, self._backend)
 
@@ -334,63 +357,15 @@
         self._hash_ctx.update(data)
 
     def verify(self):
-        evp_md = self._backend._lib.EVP_get_digestbyname(
-            self._algorithm.name.encode("ascii"))
-        self._backend.openssl_assert(evp_md != self._backend._ffi.NULL)
-
-        pkey_ctx = self._backend._lib.EVP_PKEY_CTX_new(
-            self._public_key._evp_pkey, self._backend._ffi.NULL
-        )
-        self._backend.openssl_assert(pkey_ctx != self._backend._ffi.NULL)
-        pkey_ctx = self._backend._ffi.gc(pkey_ctx,
-                                         self._backend._lib.EVP_PKEY_CTX_free)
-        res = self._backend._lib.EVP_PKEY_verify_init(pkey_ctx)
-        self._backend.openssl_assert(res == 1)
-        res = self._backend._lib.EVP_PKEY_CTX_set_signature_md(
-            pkey_ctx, evp_md)
-        self._backend.openssl_assert(res > 0)
-
-        res = self._backend._lib.EVP_PKEY_CTX_set_rsa_padding(
-            pkey_ctx, self._padding_enum)
-        self._backend.openssl_assert(res > 0)
-        if isinstance(self._padding, PSS):
-            res = self._backend._lib.EVP_PKEY_CTX_set_rsa_pss_saltlen(
-                pkey_ctx,
-                _get_rsa_pss_salt_length(
-                    self._padding,
-                    self._public_key,
-                    self._hash_ctx.algorithm,
-                )
-            )
-            self._backend.openssl_assert(res > 0)
-            if self._backend._lib.Cryptography_HAS_MGF1_MD:
-                # MGF1 MD is configurable in OpenSSL 1.0.1+
-                mgf1_md = self._backend._lib.EVP_get_digestbyname(
-                    self._padding._mgf._algorithm.name.encode("ascii"))
-                self._backend.openssl_assert(
-                    mgf1_md != self._backend._ffi.NULL
-                )
-                res = self._backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(
-                    pkey_ctx, mgf1_md
-                )
-                self._backend.openssl_assert(res > 0)
-
-        data_to_verify = self._hash_ctx.finalize()
-        res = self._backend._lib.EVP_PKEY_verify(
-            pkey_ctx,
+        return _rsa_sig_verify(
+            self._backend,
+            self._padding,
+            self._padding_enum,
+            self._algorithm,
+            self._public_key,
             self._signature,
-            len(self._signature),
-            data_to_verify,
-            len(data_to_verify)
+            self._hash_ctx.finalize()
         )
-        # The previous call can return negative numbers in the event of an
-        # error. This is not a signature failure but we need to fail if it
-        # occurs.
-        self._backend.openssl_assert(res >= 0)
-        if res == 0:
-            errors = self._backend._consume_errors()
-            assert errors
-            raise InvalidSignature
 
 
 @utils.register_interface(RSAPrivateKeyWithSerialization)