Merge pull request #2047 from reaperhulk/evp-pkey

create evp_pkey in constructor for DSA/EC
diff --git a/src/cryptography/hazmat/backends/openssl/backend.py b/src/cryptography/hazmat/backends/openssl/backend.py
index 665771a..2fe8832 100644
--- a/src/cryptography/hazmat/backends/openssl/backend.py
+++ b/src/cryptography/hazmat/backends/openssl/backend.py
@@ -388,8 +388,9 @@
             rsa_cdata, key_size, bn, self._ffi.NULL
         )
         assert res == 1
+        evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
 
-        return _RSAPrivateKey(self, rsa_cdata)
+        return _RSAPrivateKey(self, rsa_cdata, evp_pkey)
 
     def generate_rsa_parameters_supported(self, public_exponent, key_size):
         return (public_exponent >= 3 and public_exponent & 1 != 0 and
@@ -419,8 +420,9 @@
         rsa_cdata.n = self._int_to_bn(numbers.public_numbers.n)
         res = self._lib.RSA_blinding_on(rsa_cdata, self._ffi.NULL)
         assert res == 1
+        evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
 
-        return _RSAPrivateKey(self, rsa_cdata)
+        return _RSAPrivateKey(self, rsa_cdata, evp_pkey)
 
     def load_rsa_public_numbers(self, numbers):
         rsa._check_public_key_components(numbers.e, numbers.n)
@@ -431,8 +433,17 @@
         rsa_cdata.n = self._int_to_bn(numbers.n)
         res = self._lib.RSA_blinding_on(rsa_cdata, self._ffi.NULL)
         assert res == 1
+        evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
 
-        return _RSAPublicKey(self, rsa_cdata)
+        return _RSAPublicKey(self, rsa_cdata, evp_pkey)
+
+    def _rsa_cdata_to_evp_pkey(self, rsa_cdata):
+        evp_pkey = self._lib.EVP_PKEY_new()
+        assert evp_pkey != self._ffi.NULL
+        evp_pkey = self._ffi.gc(evp_pkey, self._lib.EVP_PKEY_free)
+        res = self._lib.EVP_PKEY_set1_RSA(evp_pkey, rsa_cdata)
+        assert res == 1
+        return evp_pkey
 
     def _bytes_to_bio(self, data):
         """
@@ -483,18 +494,18 @@
             rsa_cdata = self._lib.EVP_PKEY_get1_RSA(evp_pkey)
             assert rsa_cdata != self._ffi.NULL
             rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
-            return _RSAPrivateKey(self, rsa_cdata)
+            return _RSAPrivateKey(self, rsa_cdata, evp_pkey)
         elif key_type == self._lib.EVP_PKEY_DSA:
             dsa_cdata = self._lib.EVP_PKEY_get1_DSA(evp_pkey)
             assert dsa_cdata != self._ffi.NULL
             dsa_cdata = self._ffi.gc(dsa_cdata, self._lib.DSA_free)
-            return _DSAPrivateKey(self, dsa_cdata)
+            return _DSAPrivateKey(self, dsa_cdata, evp_pkey)
         elif (self._lib.Cryptography_HAS_EC == 1 and
               key_type == self._lib.EVP_PKEY_EC):
             ec_cdata = self._lib.EVP_PKEY_get1_EC_KEY(evp_pkey)
             assert ec_cdata != self._ffi.NULL
             ec_cdata = self._ffi.gc(ec_cdata, self._lib.EC_KEY_free)
-            return _EllipticCurvePrivateKey(self, ec_cdata)
+            return _EllipticCurvePrivateKey(self, ec_cdata, evp_pkey)
         else:
             raise UnsupportedAlgorithm("Unsupported key type.")
 
@@ -510,18 +521,18 @@
             rsa_cdata = self._lib.EVP_PKEY_get1_RSA(evp_pkey)
             assert rsa_cdata != self._ffi.NULL
             rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
-            return _RSAPublicKey(self, rsa_cdata)
+            return _RSAPublicKey(self, rsa_cdata, evp_pkey)
         elif key_type == self._lib.EVP_PKEY_DSA:
             dsa_cdata = self._lib.EVP_PKEY_get1_DSA(evp_pkey)
             assert dsa_cdata != self._ffi.NULL
             dsa_cdata = self._ffi.gc(dsa_cdata, self._lib.DSA_free)
-            return _DSAPublicKey(self, dsa_cdata)
+            return _DSAPublicKey(self, dsa_cdata, evp_pkey)
         elif (self._lib.Cryptography_HAS_EC == 1 and
               key_type == self._lib.EVP_PKEY_EC):
             ec_cdata = self._lib.EVP_PKEY_get1_EC_KEY(evp_pkey)
             assert ec_cdata != self._ffi.NULL
             ec_cdata = self._ffi.gc(ec_cdata, self._lib.EC_KEY_free)
-            return _EllipticCurvePublicKey(self, ec_cdata)
+            return _EllipticCurvePublicKey(self, ec_cdata, evp_pkey)
         else:
             raise UnsupportedAlgorithm("Unsupported key type.")
 
@@ -615,8 +626,9 @@
         ctx.g = self._lib.BN_dup(parameters._dsa_cdata.g)
 
         self._lib.DSA_generate_key(ctx)
+        evp_pkey = self._dsa_cdata_to_evp_pkey(ctx)
 
-        return _DSAPrivateKey(self, ctx)
+        return _DSAPrivateKey(self, ctx, evp_pkey)
 
     def generate_dsa_private_key_and_parameters(self, key_size):
         parameters = self.generate_dsa_parameters(key_size)
@@ -636,7 +648,9 @@
         dsa_cdata.pub_key = self._int_to_bn(numbers.public_numbers.y)
         dsa_cdata.priv_key = self._int_to_bn(numbers.x)
 
-        return _DSAPrivateKey(self, dsa_cdata)
+        evp_pkey = self._dsa_cdata_to_evp_pkey(dsa_cdata)
+
+        return _DSAPrivateKey(self, dsa_cdata, evp_pkey)
 
     def load_dsa_public_numbers(self, numbers):
         dsa._check_dsa_parameters(numbers.parameter_numbers)
@@ -649,7 +663,9 @@
         dsa_cdata.g = self._int_to_bn(numbers.parameter_numbers.g)
         dsa_cdata.pub_key = self._int_to_bn(numbers.y)
 
-        return _DSAPublicKey(self, dsa_cdata)
+        evp_pkey = self._dsa_cdata_to_evp_pkey(dsa_cdata)
+
+        return _DSAPublicKey(self, dsa_cdata, evp_pkey)
 
     def load_dsa_parameter_numbers(self, numbers):
         dsa._check_dsa_parameters(numbers)
@@ -663,6 +679,14 @@
 
         return _DSAParameters(self, dsa_cdata)
 
+    def _dsa_cdata_to_evp_pkey(self, dsa_cdata):
+        evp_pkey = self._lib.EVP_PKEY_new()
+        assert evp_pkey != self._ffi.NULL
+        evp_pkey = self._ffi.gc(evp_pkey, self._lib.EVP_PKEY_free)
+        res = self._lib.EVP_PKEY_set1_DSA(evp_pkey, dsa_cdata)
+        assert res == 1
+        return evp_pkey
+
     def dsa_hash_supported(self, algorithm):
         if self._lib.OPENSSL_VERSION_NUMBER < 0x1000000f:
             return isinstance(algorithm, hashes.SHA1)
@@ -714,7 +738,8 @@
             )
             if rsa_cdata != self._ffi.NULL:
                 rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
-                return _RSAPublicKey(self, rsa_cdata)
+                evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
+                return _RSAPublicKey(self, rsa_cdata, evp_pkey)
             else:
                 self._handle_key_loading_error()
 
@@ -796,7 +821,8 @@
             )
             if rsa_cdata != self._ffi.NULL:
                 rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
-                return _RSAPublicKey(self, rsa_cdata)
+                evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
+                return _RSAPublicKey(self, rsa_cdata, evp_pkey)
             else:
                 self._handle_key_loading_error()
 
@@ -1000,7 +1026,9 @@
             res = self._lib.EC_KEY_check_key(ec_cdata)
             assert res == 1
 
-            return _EllipticCurvePrivateKey(self, ec_cdata)
+            evp_pkey = self._ec_cdata_to_evp_pkey(ec_cdata)
+
+            return _EllipticCurvePrivateKey(self, ec_cdata, evp_pkey)
         else:
             raise UnsupportedAlgorithm(
                 "Backend object does not support {0}.".format(curve.name),
@@ -1022,8 +1050,9 @@
         res = self._lib.EC_KEY_set_private_key(
             ec_cdata, self._int_to_bn(numbers.private_value))
         assert res == 1
+        evp_pkey = self._ec_cdata_to_evp_pkey(ec_cdata)
 
-        return _EllipticCurvePrivateKey(self, ec_cdata)
+        return _EllipticCurvePrivateKey(self, ec_cdata, evp_pkey)
 
     def load_elliptic_curve_public_numbers(self, numbers):
         curve_nid = self._elliptic_curve_to_nid(numbers.curve)
@@ -1034,8 +1063,16 @@
 
         ec_cdata = self._ec_key_set_public_key_affine_coordinates(
             ec_cdata, numbers.x, numbers.y)
+        evp_pkey = self._ec_cdata_to_evp_pkey(ec_cdata)
 
-        return _EllipticCurvePublicKey(self, ec_cdata)
+        return _EllipticCurvePublicKey(self, ec_cdata, evp_pkey)
+
+    def _ec_cdata_to_evp_pkey(self, ec_cdata):
+        evp_pkey = self._lib.EVP_PKEY_new()
+        assert evp_pkey != self._ffi.NULL
+        evp_pkey = self._ffi.gc(evp_pkey, self._lib.EVP_PKEY_free)
+        res = self._lib.EVP_PKEY_set1_EC_KEY(evp_pkey, ec_cdata)
+        assert res == 1
 
     def _elliptic_curve_to_nid(self, curve):
         """
diff --git a/src/cryptography/hazmat/backends/openssl/dsa.py b/src/cryptography/hazmat/backends/openssl/dsa.py
index 254d29e..f84857f 100644
--- a/src/cryptography/hazmat/backends/openssl/dsa.py
+++ b/src/cryptography/hazmat/backends/openssl/dsa.py
@@ -107,9 +107,10 @@
 
 @utils.register_interface(dsa.DSAPrivateKeyWithSerialization)
 class _DSAPrivateKey(object):
-    def __init__(self, backend, dsa_cdata):
+    def __init__(self, backend, dsa_cdata, evp_pkey):
         self._backend = backend
         self._dsa_cdata = dsa_cdata
+        self._evp_pkey = evp_pkey
         self._key_size = self._backend._lib.BN_num_bits(self._dsa_cdata.p)
 
     key_size = utils.read_only_property("_key_size")
@@ -140,7 +141,8 @@
         dsa_cdata.q = self._backend._lib.BN_dup(self._dsa_cdata.q)
         dsa_cdata.g = self._backend._lib.BN_dup(self._dsa_cdata.g)
         dsa_cdata.pub_key = self._backend._lib.BN_dup(self._dsa_cdata.pub_key)
-        return _DSAPublicKey(self._backend, dsa_cdata)
+        evp_pkey = self._backend._dsa_cdata_to_evp_pkey(dsa_cdata)
+        return _DSAPublicKey(self._backend, dsa_cdata, evp_pkey)
 
     def parameters(self):
         dsa_cdata = self._backend._lib.DSA_new()
@@ -154,27 +156,21 @@
         return _DSAParameters(self._backend, dsa_cdata)
 
     def private_bytes(self, encoding, format, encryption_algorithm):
-        evp_pkey = self._backend._lib.EVP_PKEY_new()
-        assert evp_pkey != self._backend._ffi.NULL
-        evp_pkey = self._backend._ffi.gc(
-            evp_pkey, self._backend._lib.EVP_PKEY_free
-        )
-        res = self._backend._lib.EVP_PKEY_set1_DSA(evp_pkey, self._dsa_cdata)
-        assert res == 1
         return self._backend._private_key_bytes(
             encoding,
             format,
             encryption_algorithm,
-            evp_pkey,
+            self._evp_pkey,
             self._dsa_cdata
         )
 
 
 @utils.register_interface(dsa.DSAPublicKeyWithSerialization)
 class _DSAPublicKey(object):
-    def __init__(self, backend, dsa_cdata):
+    def __init__(self, backend, dsa_cdata, evp_pkey):
         self._backend = backend
         self._dsa_cdata = dsa_cdata
+        self._evp_pkey = evp_pkey
         self._key_size = self._backend._lib.BN_num_bits(self._dsa_cdata.p)
 
     key_size = utils.read_only_property("_key_size")
@@ -211,16 +207,9 @@
                 "DSA public keys do not support PKCS1 serialization"
             )
 
-        evp_pkey = self._backend._lib.EVP_PKEY_new()
-        assert evp_pkey != self._backend._ffi.NULL
-        evp_pkey = self._backend._ffi.gc(
-            evp_pkey, self._backend._lib.EVP_PKEY_free
-        )
-        res = self._backend._lib.EVP_PKEY_set1_DSA(evp_pkey, self._dsa_cdata)
-        assert res == 1
         return self._backend._public_key_bytes(
             encoding,
             format,
-            evp_pkey,
+            self._evp_pkey,
             None
         )
diff --git a/src/cryptography/hazmat/backends/openssl/ec.py b/src/cryptography/hazmat/backends/openssl/ec.py
index c2af2be..7d3afb9 100644
--- a/src/cryptography/hazmat/backends/openssl/ec.py
+++ b/src/cryptography/hazmat/backends/openssl/ec.py
@@ -150,10 +150,11 @@
 
 @utils.register_interface(ec.EllipticCurvePrivateKeyWithSerialization)
 class _EllipticCurvePrivateKey(object):
-    def __init__(self, backend, ec_key_cdata):
+    def __init__(self, backend, ec_key_cdata, evp_pkey):
         self._backend = backend
         _mark_asn1_named_ec_curve(backend, ec_key_cdata)
         self._ec_key = ec_key_cdata
+        self._evp_pkey = evp_pkey
 
         sn = _ec_key_curve_sn(backend, ec_key_cdata)
         self._curve = _sn_to_elliptic_curve(backend, sn)
@@ -188,9 +189,9 @@
         res = self._backend._lib.EC_KEY_set_public_key(public_ec_key, point)
         assert res == 1
 
-        return _EllipticCurvePublicKey(
-            self._backend, public_ec_key
-        )
+        evp_pkey = self._backend._ec_cdata_to_evp_pkey(public_ec_key)
+
+        return _EllipticCurvePublicKey(self._backend, public_ec_key, evp_pkey)
 
     def private_numbers(self):
         bn = self._backend._lib.EC_KEY_get0_private_key(self._ec_key)
@@ -201,28 +202,22 @@
         )
 
     def private_bytes(self, encoding, format, encryption_algorithm):
-        evp_pkey = self._backend._lib.EVP_PKEY_new()
-        assert evp_pkey != self._backend._ffi.NULL
-        evp_pkey = self._backend._ffi.gc(
-            evp_pkey, self._backend._lib.EVP_PKEY_free
-        )
-        res = self._backend._lib.EVP_PKEY_set1_EC_KEY(evp_pkey, self._ec_key)
-        assert res == 1
         return self._backend._private_key_bytes(
             encoding,
             format,
             encryption_algorithm,
-            evp_pkey,
+            self._evp_pkey,
             self._ec_key
         )
 
 
 @utils.register_interface(ec.EllipticCurvePublicKeyWithSerialization)
 class _EllipticCurvePublicKey(object):
-    def __init__(self, backend, ec_key_cdata):
+    def __init__(self, backend, ec_key_cdata, evp_pkey):
         self._backend = backend
         _mark_asn1_named_ec_curve(backend, ec_key_cdata)
         self._ec_key = ec_key_cdata
+        self._evp_pkey = evp_pkey
 
         sn = _ec_key_curve_sn(backend, ec_key_cdata)
         self._curve = _sn_to_elliptic_curve(backend, sn)
@@ -268,16 +263,9 @@
                 "EC public keys do not support PKCS1 serialization"
             )
 
-        evp_pkey = self._backend._lib.EVP_PKEY_new()
-        assert evp_pkey != self._backend._ffi.NULL
-        evp_pkey = self._backend._ffi.gc(
-            evp_pkey, self._backend._lib.EVP_PKEY_free
-        )
-        res = self._backend._lib.EVP_PKEY_set1_EC_KEY(evp_pkey, self._ec_key)
-        assert res == 1
         return self._backend._public_key_bytes(
             encoding,
             format,
-            evp_pkey,
+            self._evp_pkey,
             None
         )
diff --git a/src/cryptography/hazmat/backends/openssl/rsa.py b/src/cryptography/hazmat/backends/openssl/rsa.py
index 1dbbb84..21414c0 100644
--- a/src/cryptography/hazmat/backends/openssl/rsa.py
+++ b/src/cryptography/hazmat/backends/openssl/rsa.py
@@ -508,17 +508,9 @@
 
 @utils.register_interface(RSAPrivateKeyWithSerialization)
 class _RSAPrivateKey(object):
-    def __init__(self, backend, rsa_cdata):
+    def __init__(self, backend, rsa_cdata, evp_pkey):
         self._backend = backend
         self._rsa_cdata = rsa_cdata
-
-        evp_pkey = self._backend._lib.EVP_PKEY_new()
-        assert evp_pkey != self._backend._ffi.NULL
-        evp_pkey = self._backend._ffi.gc(
-            evp_pkey, self._backend._lib.EVP_PKEY_free
-        )
-        res = self._backend._lib.EVP_PKEY_set1_RSA(evp_pkey, rsa_cdata)
-        assert res == 1
         self._evp_pkey = evp_pkey
 
         self._key_size = self._backend._lib.BN_num_bits(self._rsa_cdata.n)
@@ -543,7 +535,8 @@
         ctx.n = self._backend._lib.BN_dup(self._rsa_cdata.n)
         res = self._backend._lib.RSA_blinding_on(ctx, self._backend._ffi.NULL)
         assert res == 1
-        return _RSAPublicKey(self._backend, ctx)
+        evp_pkey = self._backend._rsa_cdata_to_evp_pkey(ctx)
+        return _RSAPublicKey(self._backend, ctx, evp_pkey)
 
     def private_numbers(self):
         return rsa.RSAPrivateNumbers(
@@ -571,17 +564,9 @@
 
 @utils.register_interface(RSAPublicKeyWithSerialization)
 class _RSAPublicKey(object):
-    def __init__(self, backend, rsa_cdata):
+    def __init__(self, backend, rsa_cdata, evp_pkey):
         self._backend = backend
         self._rsa_cdata = rsa_cdata
-
-        evp_pkey = self._backend._lib.EVP_PKEY_new()
-        assert evp_pkey != self._backend._ffi.NULL
-        evp_pkey = self._backend._ffi.gc(
-            evp_pkey, self._backend._lib.EVP_PKEY_free
-        )
-        res = self._backend._lib.EVP_PKEY_set1_RSA(evp_pkey, rsa_cdata)
-        assert res == 1
         self._evp_pkey = evp_pkey
 
         self._key_size = self._backend._lib.BN_num_bits(self._rsa_cdata.n)