Merge pull request #1346 from reaperhulk/fix-pkcs8-ec-load

Process curve name when loading EC keys. Fixes #1336
diff --git a/cryptography/hazmat/backends/openssl/backend.py b/cryptography/hazmat/backends/openssl/backend.py
index 9d767ae..389ef0b 100644
--- a/cryptography/hazmat/backends/openssl/backend.py
+++ b/cryptography/hazmat/backends/openssl/backend.py
@@ -474,12 +474,14 @@
             assert dsa_cdata != self._ffi.NULL
             dsa_cdata = self._ffi.gc(dsa_cdata, self._lib.DSA_free)
             return _DSAPrivateKey(self, dsa_cdata)
-        elif self._lib.Cryptography_HAS_EC == 1 \
-                and type == self._lib.EVP_PKEY_EC:
+        elif (self._lib.Cryptography_HAS_EC == 1 and
+              type == self._lib.EVP_PKEY_EC):
             ec_cdata = self._lib.EVP_PKEY_get1_EC_KEY(evp_pkey)
             assert ec_cdata != self._ffi.NULL
             ec_cdata = self._ffi.gc(ec_cdata, self._lib.EC_KEY_free)
-            return _EllipticCurvePrivateKey(self, ec_cdata, None)
+            sn = self._ec_key_curve_sn(ec_cdata)
+            curve = self._sn_to_elliptic_curve(sn)
+            return _EllipticCurvePrivateKey(self, ec_cdata, curve)
         else:
             raise UnsupportedAlgorithm("Unsupported key type.")
 
@@ -501,15 +503,30 @@
             assert dsa_cdata != self._ffi.NULL
             dsa_cdata = self._ffi.gc(dsa_cdata, self._lib.DSA_free)
             return _DSAPublicKey(self, dsa_cdata)
-        elif self._lib.Cryptography_HAS_EC == 1 \
-                and type == self._lib.EVP_PKEY_EC:
+        elif (self._lib.Cryptography_HAS_EC == 1 and
+              type == self._lib.EVP_PKEY_EC):
             ec_cdata = self._lib.EVP_PKEY_get1_EC_KEY(evp_pkey)
             assert ec_cdata != self._ffi.NULL
             ec_cdata = self._ffi.gc(ec_cdata, self._lib.EC_KEY_free)
-            return _EllipticCurvePublicKey(self, ec_cdata, None)
+            sn = self._ec_key_curve_sn(ec_cdata)
+            curve = self._sn_to_elliptic_curve(sn)
+            return _EllipticCurvePublicKey(self, ec_cdata, curve)
         else:
             raise UnsupportedAlgorithm("Unsupported key type.")
 
+    def _ec_key_curve_sn(self, ec_key):
+        group = self._lib.EC_KEY_get0_group(ec_key)
+        assert group != self._ffi.NULL
+
+        nid = self._lib.EC_GROUP_get_curve_name(group)
+        assert nid != self._lib.NID_undef
+
+        curve_name = self._lib.OBJ_nid2sn(nid)
+        assert curve_name != self._ffi.NULL
+
+        sn = self._ffi.string(curve_name).decode('ascii')
+        return sn
+
     def _pem_password_cb(self, password):
         """
         Generate a pem_password_cb function pointer that copied the password to
@@ -1048,6 +1065,15 @@
             )
         return curve_nid
 
+    def _sn_to_elliptic_curve(self, sn):
+        try:
+            return ec._CURVE_TYPES[sn]()
+        except KeyError:
+            raise UnsupportedAlgorithm(
+                "{0} is not a supported elliptic curve".format(sn),
+                _Reasons.UNSUPPORTED_ELLIPTIC_CURVE
+            )
+
     @contextmanager
     def _tmp_bn_ctx(self):
         bn_ctx = self._lib.BN_CTX_new()
diff --git a/cryptography/hazmat/primitives/asymmetric/ec.py b/cryptography/hazmat/primitives/asymmetric/ec.py
index 220a419..98eca27 100644
--- a/cryptography/hazmat/primitives/asymmetric/ec.py
+++ b/cryptography/hazmat/primitives/asymmetric/ec.py
@@ -184,6 +184,30 @@
         return 192
 
 
+_CURVE_TYPES = {
+    "prime192v1": SECP192R1,
+    "prime256v1": SECP256R1,
+
+    "secp192r1": SECP192R1,
+    "secp224r1": SECP224R1,
+    "secp256r1": SECP256R1,
+    "secp384r1": SECP384R1,
+    "secp521r1": SECP521R1,
+
+    "sect163k1": SECT163K1,
+    "sect233k1": SECT233K1,
+    "sect283k1": SECT283K1,
+    "sect409k1": SECT409K1,
+    "sect571k1": SECT571K1,
+
+    "sect163r2": SECT163R2,
+    "sect233r1": SECT233R1,
+    "sect283r1": SECT283R1,
+    "sect409r1": SECT409R1,
+    "sect571r1": SECT571R1,
+}
+
+
 @utils.register_interface(interfaces.EllipticCurveSignatureAlgorithm)
 class ECDSA(object):
     def __init__(self, algorithm):
diff --git a/tests/hazmat/backends/test_openssl.py b/tests/hazmat/backends/test_openssl.py
index d4c5e2e..110bbdb 100644
--- a/tests/hazmat/backends/test_openssl.py
+++ b/tests/hazmat/backends/test_openssl.py
@@ -493,7 +493,7 @@
             )
 
 
-class TestOpenSSLNoEllipticCurve(object):
+class TestOpenSSLEllipticCurve(object):
     def test_elliptic_curve_supported(self, monkeypatch):
         monkeypatch.setattr(backend._lib, "Cryptography_HAS_EC", 0)
 
@@ -506,6 +506,10 @@
             None, None
         ) is False
 
+    def test_sn_to_elliptic_curve_not_supported(self):
+        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_ELLIPTIC_CURVE):
+            backend._sn_to_elliptic_curve(b"fake")
+
 
 class TestDeprecatedRSABackendMethods(object):
     def test_create_rsa_signature_ctx(self):
diff --git a/tests/hazmat/primitives/test_ec.py b/tests/hazmat/primitives/test_ec.py
index 2690e79..65461f7 100644
--- a/tests/hazmat/primitives/test_ec.py
+++ b/tests/hazmat/primitives/test_ec.py
@@ -29,26 +29,6 @@
     raises_unsupported_algorithm
 )
 
-_CURVE_TYPES = {
-    "secp192r1": ec.SECP192R1,
-    "secp224r1": ec.SECP224R1,
-    "secp256r1": ec.SECP256R1,
-    "secp384r1": ec.SECP384R1,
-    "secp521r1": ec.SECP521R1,
-
-    "sect163k1": ec.SECT163K1,
-    "sect233k1": ec.SECT233K1,
-    "sect283k1": ec.SECT283K1,
-    "sect409k1": ec.SECT409K1,
-    "sect571k1": ec.SECT571K1,
-
-    "sect163r2": ec.SECT163R2,
-    "sect233r1": ec.SECT233R1,
-    "sect283r1": ec.SECT283R1,
-    "sect409r1": ec.SECT409R1,
-    "sect571r1": ec.SECT571R1,
-}
-
 _HASH_TYPES = {
     "SHA-1": hashes.SHA1,
     "SHA-224": hashes.SHA224,
@@ -162,7 +142,7 @@
         ))
     )
     def test_signing_with_example_keys(self, backend, vector, hash_type):
-        curve_type = _CURVE_TYPES[vector['curve']]
+        curve_type = ec._CURVE_TYPES[vector['curve']]
 
         _skip_ecdsa_vector(backend, curve_type, hash_type)
 
@@ -188,7 +168,7 @@
         verifier.verify()
 
     @pytest.mark.parametrize(
-        "curve", _CURVE_TYPES.values()
+        "curve", ec._CURVE_TYPES.values()
     )
     def test_generate_vector_curves(self, backend, curve):
         _skip_curve_unsupported(backend, curve())
@@ -244,7 +224,7 @@
     )
     def test_signatures(self, backend, vector):
         hash_type = _HASH_TYPES[vector['digest_algorithm']]
-        curve_type = _CURVE_TYPES[vector['curve']]
+        curve_type = ec._CURVE_TYPES[vector['curve']]
 
         _skip_ecdsa_vector(backend, curve_type, hash_type)
 
@@ -276,7 +256,7 @@
     )
     def test_signature_failures(self, backend, vector):
         hash_type = _HASH_TYPES[vector['digest_algorithm']]
-        curve_type = _CURVE_TYPES[vector['curve']]
+        curve_type = ec._CURVE_TYPES[vector['curve']]
 
         _skip_ecdsa_vector(backend, curve_type, hash_type)
 
diff --git a/tests/hazmat/primitives/test_serialization.py b/tests/hazmat/primitives/test_serialization.py
index 8405f4b..5ee68b2 100644
--- a/tests/hazmat/primitives/test_serialization.py
+++ b/tests/hazmat/primitives/test_serialization.py
@@ -135,6 +135,8 @@
         )
         assert key
         assert isinstance(key, interfaces.EllipticCurvePublicKey)
+        assert key.curve.name == "secp256r1"
+        assert key.curve.key_size == 256
 
 
 @pytest.mark.traditional_openssl_serialization
@@ -412,6 +414,8 @@
         )
         assert key
         assert isinstance(key, interfaces.EllipticCurvePrivateKey)
+        assert key.curve.name == "secp256r1"
+        assert key.curve.key_size == 256
 
     def test_unused_password(self, backend):
         key_file = os.path.join(