Merge pull request #808 from public/tidy-rsa

RSAPrivateKey to evp_pkey utility method
diff --git a/cryptography/hazmat/backends/openssl/backend.py b/cryptography/hazmat/backends/openssl/backend.py
index 251dd09..a68bc08 100644
--- a/cryptography/hazmat/backends/openssl/backend.py
+++ b/cryptography/hazmat/backends/openssl/backend.py
@@ -301,21 +301,54 @@
         )
         assert res == 1
 
+        return self._rsa_cdata_to_private_key(ctx)
+
+    def _new_evp_pkey(self):
+        evp_pkey = self._lib.EVP_PKEY_new()
+        assert evp_pkey != self._ffi.NULL
+        return self._ffi.gc(evp_pkey, backend._lib.EVP_PKEY_free)
+
+    def _rsa_private_key_to_evp_pkey(self, private_key):
+        evp_pkey = self._new_evp_pkey()
+        rsa_cdata = self._rsa_cdata_from_private_key(private_key)
+
+        res = self._lib.RSA_blinding_on(rsa_cdata, self._ffi.NULL)
+        assert res == 1
+
+        res = self._lib.EVP_PKEY_assign_RSA(evp_pkey, rsa_cdata)
+        assert res == 1
+
+        return evp_pkey
+
+    def _rsa_public_key_to_evp_pkey(self, public_key):
+        evp_pkey = self._new_evp_pkey()
+        rsa_cdata = self._rsa_cdata_from_public_key(public_key)
+
+        res = self._lib.RSA_blinding_on(rsa_cdata, self._ffi.NULL)
+        assert res == 1
+
+        res = self._lib.EVP_PKEY_assign_RSA(evp_pkey, rsa_cdata)
+        assert res == 1
+
+        return evp_pkey
+
+    def _rsa_cdata_to_private_key(self, cdata):
         return rsa.RSAPrivateKey(
-            p=self._bn_to_int(ctx.p),
-            q=self._bn_to_int(ctx.q),
-            dmp1=self._bn_to_int(ctx.dmp1),
-            dmq1=self._bn_to_int(ctx.dmq1),
-            iqmp=self._bn_to_int(ctx.iqmp),
-            private_exponent=self._bn_to_int(ctx.d),
-            public_exponent=self._bn_to_int(ctx.e),
-            modulus=self._bn_to_int(ctx.n),
+            p=self._bn_to_int(cdata.p),
+            q=self._bn_to_int(cdata.q),
+            dmp1=self._bn_to_int(cdata.dmp1),
+            dmq1=self._bn_to_int(cdata.dmq1),
+            iqmp=self._bn_to_int(cdata.iqmp),
+            private_exponent=self._bn_to_int(cdata.d),
+            public_exponent=self._bn_to_int(cdata.e),
+            modulus=self._bn_to_int(cdata.n),
         )
 
     def _rsa_cdata_from_private_key(self, private_key):
+        # Does not GC the RSA cdata. You *must* make sure it's freed
+        # correctly yourself!
         ctx = self._lib.RSA_new()
         assert ctx != self._ffi.NULL
-        ctx = self._ffi.gc(ctx, self._lib.RSA_free)
         ctx.p = self._int_to_bn(private_key.p)
         ctx.q = self._int_to_bn(private_key.q)
         ctx.d = self._int_to_bn(private_key.d)
@@ -327,9 +360,11 @@
         return ctx
 
     def _rsa_cdata_from_public_key(self, public_key):
+        # Does not GC the RSA cdata. You *must* make sure it's freed
+        # correctly yourself!
+
         ctx = self._lib.RSA_new()
         assert ctx != self._ffi.NULL
-        ctx = self._ffi.gc(ctx, self._lib.RSA_free)
         ctx.e = self._int_to_bn(public_key.e)
         ctx.n = self._int_to_bn(public_key.n)
         return ctx
@@ -657,24 +692,19 @@
     def finalize(self):
         if self._hash_ctx is None:
             raise AlreadyFinalized("Context has already been finalized")
-        evp_pkey = self._backend._lib.EVP_PKEY_new()
-        assert evp_pkey != self._backend._ffi.NULL
-        evp_pkey = backend._ffi.gc(evp_pkey, backend._lib.EVP_PKEY_free)
-        rsa_cdata = backend._rsa_cdata_from_private_key(self._private_key)
-        res = self._backend._lib.RSA_blinding_on(
-            rsa_cdata, self._backend._ffi.NULL)
-        assert res == 1
-        res = self._backend._lib.EVP_PKEY_set1_RSA(evp_pkey, rsa_cdata)
-        assert res == 1
+
+        evp_pkey = self._backend._rsa_private_key_to_evp_pkey(
+            self._private_key)
+
         evp_md = self._backend._lib.EVP_get_digestbyname(
             self._algorithm.name.encode("ascii"))
         assert evp_md != self._backend._ffi.NULL
         pkey_size = self._backend._lib.EVP_PKEY_size(evp_pkey)
         assert pkey_size > 0
 
-        return self._finalize_method(evp_pkey, pkey_size, rsa_cdata, evp_md)
+        return self._finalize_method(evp_pkey, pkey_size, evp_md)
 
-    def _finalize_pkey_ctx(self, evp_pkey, pkey_size, rsa_cdata, evp_md):
+    def _finalize_pkey_ctx(self, evp_pkey, pkey_size, evp_md):
         pkey_ctx = self._backend._lib.EVP_PKEY_CTX_new(
             evp_pkey, self._backend._ffi.NULL
         )
@@ -705,7 +735,7 @@
         assert res == 1
         return self._backend._ffi.buffer(buf)[:]
 
-    def _finalize_pkcs1(self, evp_pkey, pkey_size, rsa_cdata, evp_md):
+    def _finalize_pkcs1(self, evp_pkey, pkey_size, evp_md):
         sig_buf = self._backend._ffi.new("char[]", pkey_size)
         sig_len = self._backend._ffi.new("unsigned int *")
         res = self._backend._lib.EVP_SignFinal(
@@ -753,22 +783,16 @@
         if self._hash_ctx is None:
             raise AlreadyFinalized("Context has already been finalized")
 
-        evp_pkey = self._backend._lib.EVP_PKEY_new()
-        assert evp_pkey != self._backend._ffi.NULL
-        evp_pkey = backend._ffi.gc(evp_pkey, backend._lib.EVP_PKEY_free)
-        rsa_cdata = backend._rsa_cdata_from_public_key(self._public_key)
-        res = self._backend._lib.RSA_blinding_on(
-            rsa_cdata, self._backend._ffi.NULL)
-        assert res == 1
-        res = self._backend._lib.EVP_PKEY_set1_RSA(evp_pkey, rsa_cdata)
-        assert res == 1
+        evp_pkey = self._backend._rsa_public_key_to_evp_pkey(
+            self._public_key)
+
         evp_md = self._backend._lib.EVP_get_digestbyname(
             self._algorithm.name.encode("ascii"))
         assert evp_md != self._backend._ffi.NULL
 
-        self._verify_method(rsa_cdata, evp_pkey, evp_md)
+        self._verify_method(evp_pkey, evp_md)
 
-    def _verify_pkey_ctx(self, rsa_cdata, evp_pkey, evp_md):
+    def _verify_pkey_ctx(self, evp_pkey, evp_md):
         pkey_ctx = self._backend._lib.EVP_PKEY_CTX_new(
             evp_pkey, self._backend._ffi.NULL
         )
@@ -800,7 +824,7 @@
             assert errors
             raise InvalidSignature
 
-    def _verify_pkcs1(self, rsa_cdata, evp_pkey, evp_md):
+    def _verify_pkcs1(self, evp_pkey, evp_md):
         res = self._backend._lib.EVP_VerifyFinal(
             self._hash_ctx._ctx,
             self._signature,