refactor rsa signature/verification logic to remove duplication (#3903)

diff --git a/src/cryptography/hazmat/backends/openssl/rsa.py b/src/cryptography/hazmat/backends/openssl/rsa.py
index 05b4e9d..1b6ebfd 100644
--- a/src/cryptography/hazmat/backends/openssl/rsa.py
+++ b/src/cryptography/hazmat/backends/openssl/rsa.py
@@ -187,47 +187,40 @@
     return padding_enum
 
 
-def _rsa_sig_sign(backend, padding, padding_enum, algorithm, private_key,
-                  data):
-    evp_md = backend._lib.EVP_get_digestbyname(
-        algorithm.name.encode("ascii"))
+def _rsa_sig_setup(backend, padding, algorithm, key, data, init_func):
+    padding_enum = _rsa_sig_determine_padding(backend, key, padding, algorithm)
+    evp_md = backend._lib.EVP_get_digestbyname(algorithm.name.encode("ascii"))
     backend.openssl_assert(evp_md != backend._ffi.NULL)
-
-    pkey_ctx = backend._lib.EVP_PKEY_CTX_new(
-        private_key._evp_pkey, backend._ffi.NULL
-    )
+    pkey_ctx = backend._lib.EVP_PKEY_CTX_new(key._evp_pkey, backend._ffi.NULL)
     backend.openssl_assert(pkey_ctx != backend._ffi.NULL)
     pkey_ctx = backend._ffi.gc(pkey_ctx, backend._lib.EVP_PKEY_CTX_free)
-    res = backend._lib.EVP_PKEY_sign_init(pkey_ctx)
+    res = init_func(pkey_ctx)
     backend.openssl_assert(res == 1)
-    res = backend._lib.EVP_PKEY_CTX_set_signature_md(
-        pkey_ctx, evp_md)
+    res = backend._lib.EVP_PKEY_CTX_set_signature_md(pkey_ctx, evp_md)
     backend.openssl_assert(res > 0)
-
-    res = backend._lib.EVP_PKEY_CTX_set_rsa_padding(
-        pkey_ctx, padding_enum)
+    res = backend._lib.EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, padding_enum)
     backend.openssl_assert(res > 0)
     if isinstance(padding, PSS):
         res = backend._lib.EVP_PKEY_CTX_set_rsa_pss_saltlen(
-            pkey_ctx,
-            _get_rsa_pss_salt_length(
-                padding,
-                private_key,
-                algorithm,
-            )
+            pkey_ctx, _get_rsa_pss_salt_length(padding, key, algorithm)
         )
         backend.openssl_assert(res > 0)
 
         mgf1_md = backend._lib.EVP_get_digestbyname(
-            padding._mgf._algorithm.name.encode("ascii"))
-        backend.openssl_assert(
-            mgf1_md != backend._ffi.NULL
+            padding._mgf._algorithm.name.encode("ascii")
         )
-        res = backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(
-            pkey_ctx, mgf1_md
-        )
+        backend.openssl_assert(mgf1_md != backend._ffi.NULL)
+        res = backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, mgf1_md)
         backend.openssl_assert(res > 0)
 
+    return pkey_ctx
+
+
+def _rsa_sig_sign(backend, padding, algorithm, private_key, data):
+    pkey_ctx = _rsa_sig_setup(
+        backend, padding, algorithm, private_key, data,
+        backend._lib.EVP_PKEY_sign_init
+    )
     buflen = backend._ffi.new("size_t *")
     res = backend._lib.EVP_PKEY_sign(
         pkey_ctx,
@@ -258,52 +251,13 @@
     return backend._ffi.buffer(buf)[:]
 
 
-def _rsa_sig_verify(backend, padding, padding_enum, algorithm, public_key,
-                    signature, data):
-    evp_md = backend._lib.EVP_get_digestbyname(
-        algorithm.name.encode("ascii"))
-    backend.openssl_assert(evp_md != backend._ffi.NULL)
-
-    pkey_ctx = backend._lib.EVP_PKEY_CTX_new(
-        public_key._evp_pkey, backend._ffi.NULL
+def _rsa_sig_verify(backend, padding, algorithm, public_key, signature, data):
+    pkey_ctx = _rsa_sig_setup(
+        backend, padding, algorithm, public_key, data,
+        backend._lib.EVP_PKEY_verify_init
     )
-    backend.openssl_assert(pkey_ctx != backend._ffi.NULL)
-    pkey_ctx = backend._ffi.gc(pkey_ctx, backend._lib.EVP_PKEY_CTX_free)
-    res = backend._lib.EVP_PKEY_verify_init(pkey_ctx)
-    backend.openssl_assert(res == 1)
-    res = backend._lib.EVP_PKEY_CTX_set_signature_md(
-        pkey_ctx, evp_md)
-    backend.openssl_assert(res > 0)
-
-    res = backend._lib.EVP_PKEY_CTX_set_rsa_padding(
-        pkey_ctx, padding_enum)
-    backend.openssl_assert(res > 0)
-    if isinstance(padding, PSS):
-        res = backend._lib.EVP_PKEY_CTX_set_rsa_pss_saltlen(
-            pkey_ctx,
-            _get_rsa_pss_salt_length(
-                padding,
-                public_key,
-                algorithm,
-            )
-        )
-        backend.openssl_assert(res > 0)
-        mgf1_md = backend._lib.EVP_get_digestbyname(
-            padding._mgf._algorithm.name.encode("ascii"))
-        backend.openssl_assert(
-            mgf1_md != backend._ffi.NULL
-        )
-        res = backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(
-            pkey_ctx, mgf1_md
-        )
-        backend.openssl_assert(res > 0)
-
     res = backend._lib.EVP_PKEY_verify(
-        pkey_ctx,
-        signature,
-        len(signature),
-        data,
-        len(data)
+        pkey_ctx, signature, len(signature), data, len(data)
     )
     # The previous call can return negative numbers in the event of an
     # error. This is not a signature failure but we need to fail if it
@@ -321,9 +275,10 @@
         self._backend = backend
         self._private_key = private_key
 
-        self._padding_enum = _rsa_sig_determine_padding(
-            backend, private_key, padding, algorithm
-        )
+        # We now call _rsa_sig_determine_padding in _rsa_sig_setup. However
+        # we need to make a pointless call to it here so we maintain the
+        # API of erroring on init with this context if the values are invalid.
+        _rsa_sig_determine_padding(backend, private_key, padding, algorithm)
         self._padding = padding
         self._algorithm = algorithm
         self._hash_ctx = hashes.Hash(self._algorithm, self._backend)
@@ -335,7 +290,6 @@
         return _rsa_sig_sign(
             self._backend,
             self._padding,
-            self._padding_enum,
             self._algorithm,
             self._private_key,
             self._hash_ctx.finalize()
@@ -349,10 +303,10 @@
         self._public_key = public_key
         self._signature = signature
         self._padding = padding
-
-        self._padding_enum = _rsa_sig_determine_padding(
-            backend, public_key, padding, algorithm
-        )
+        # We now call _rsa_sig_determine_padding in _rsa_sig_setup. However
+        # we need to make a pointless call to it here so we maintain the
+        # API of erroring on init with this context if the values are invalid.
+        _rsa_sig_determine_padding(backend, public_key, padding, algorithm)
 
         padding = padding
         self._algorithm = algorithm
@@ -365,7 +319,6 @@
         return _rsa_sig_verify(
             self._backend,
             self._padding,
-            self._padding_enum,
             self._algorithm,
             self._public_key,
             self._signature,
@@ -456,16 +409,10 @@
         )
 
     def sign(self, data, padding, algorithm):
-        padding_enum = _rsa_sig_determine_padding(
-            self._backend, self, padding, algorithm
-        )
         data, algorithm = _calculate_digest_and_algorithm(
             self._backend, data, algorithm
         )
-        return _rsa_sig_sign(
-            self._backend, padding, padding_enum,
-            algorithm, self, data
-        )
+        return _rsa_sig_sign(self._backend, padding, algorithm, self, data)
 
 
 @utils.register_interface(RSAPublicKeyWithSerialization)
@@ -521,13 +468,9 @@
         )
 
     def verify(self, signature, data, padding, algorithm):
-        padding_enum = _rsa_sig_determine_padding(
-            self._backend, self, padding, algorithm
-        )
         data, algorithm = _calculate_digest_and_algorithm(
             self._backend, data, algorithm
         )
         return _rsa_sig_verify(
-            self._backend, padding, padding_enum, algorithm, self,
-            signature, data
+            self._backend, padding, algorithm, self, signature, data
         )