Merge pull request #1576 from alex/openssh-elliptic-curve

Fixes #1533 -- Initial work at parsing ECDSA public keys in OpenSSH format
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index b777d45..1ceb39d 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -6,6 +6,9 @@
 
 .. note:: This version is not yet released and is under active development.
 
+* :func:`~cryptography.hazmat.primitives.serialization.load_ssh_public_key` can
+  now load elliptic curve public keys.
+
 0.7 - 2014-12-17
 ~~~~~~~~~~~~~~~~
 
diff --git a/docs/hazmat/primitives/asymmetric/serialization.rst b/docs/hazmat/primitives/asymmetric/serialization.rst
index c184cdf..1456b0d 100644
--- a/docs/hazmat/primitives/asymmetric/serialization.rst
+++ b/docs/hazmat/primitives/asymmetric/serialization.rst
@@ -119,9 +119,6 @@
 
 The format used by OpenSSH to store public keys, as specified in :rfc:`4253`.
 
-Currently, only RSA and DSA public keys are supported. Any other type of key
-will result in an exception being thrown.
-
 An example RSA key in OpenSSH format (line breaks added for formatting
 purposes)::
 
@@ -134,7 +131,8 @@
     2MzHvnbv testkey@localhost
 
 DSA keys look almost identical but begin with ``ssh-dss`` rather than
-``ssh-rsa``.
+``ssh-rsa``. ECDSA keys have a slightly different format, they begin with
+``ecdsa-sha2-{curve}``.
 
 .. function:: load_ssh_public_key(data, backend)
 
@@ -143,12 +141,17 @@
     Deserialize a public key from OpenSSH (:rfc:`4253`) encoded data to an
     instance of the public key type for the specified backend.
 
+    .. note::
+
+        Currently Ed25519 keys are not supported.
+
     :param bytes data: The OpenSSH encoded key data.
 
     :param backend: A backend providing
-        :class:`~cryptography.hazmat.backends.interfaces.RSABackend` or
-        :class:`~cryptography.hazmat.backends.interfaces.DSABackend` depending
-        on key type.
+        :class:`~cryptography.hazmat.backends.interfaces.RSABackend`,
+        :class:`~cryptography.hazmat.backends.interfaces.DSABackend`, or
+        :class:`~cryptography.hazmat.backends.interfaces.EllipticCurveBackend`
+        depending on the key's type.
 
     :returns: A new instance of a public key type.
 
diff --git a/src/cryptography/hazmat/primitives/serialization.py b/src/cryptography/hazmat/primitives/serialization.py
index 083f17e..dad419f 100644
--- a/src/cryptography/hazmat/primitives/serialization.py
+++ b/src/cryptography/hazmat/primitives/serialization.py
@@ -7,11 +7,10 @@
 import base64
 import struct
 
+import six
+
 from cryptography.exceptions import UnsupportedAlgorithm
-from cryptography.hazmat.primitives.asymmetric.dsa import (
-    DSAParameterNumbers, DSAPublicNumbers
-)
-from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
+from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa
 
 
 def load_pem_private_key(data, password, backend):
@@ -30,6 +29,18 @@
             'Key is not in the proper format or contains extra data.')
 
     key_type = key_parts[0]
+
+    if key_type == b'ssh-rsa':
+        loader = _load_ssh_rsa_public_key
+    elif key_type == b'ssh-dss':
+        loader = _load_ssh_dss_public_key
+    elif key_type in [
+        b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521',
+    ]:
+        loader = _load_ssh_ecdsa_public_key
+    else:
+        raise UnsupportedAlgorithm('Key type is not supported.')
+
     key_body = key_parts[1]
 
     try:
@@ -37,53 +48,81 @@
     except TypeError:
         raise ValueError('Key is not in the proper format.')
 
-    if key_type == b'ssh-rsa':
-        return _load_ssh_rsa_public_key(decoded_data, backend)
-    elif key_type == b'ssh-dss':
-        return _load_ssh_dss_public_key(decoded_data, backend)
-    else:
-        raise UnsupportedAlgorithm(
-            'Only RSA and DSA keys are currently supported.'
+    inner_key_type, rest = _read_next_string(decoded_data)
+
+    if inner_key_type != key_type:
+        raise ValueError(
+            'Key header and key body contain different key type values.'
         )
 
+    return loader(key_type, rest, backend)
 
-def _load_ssh_rsa_public_key(decoded_data, backend):
-    key_type, rest = _read_next_string(decoded_data)
-    e, rest = _read_next_mpint(rest)
+
+def _load_ssh_rsa_public_key(key_type, decoded_data, backend):
+    e, rest = _read_next_mpint(decoded_data)
     n, rest = _read_next_mpint(rest)
 
-    if key_type != b'ssh-rsa':
-        raise ValueError(
-            'Key header and key body contain different key type values.')
-
     if rest:
         raise ValueError('Key body contains extra bytes.')
 
-    return RSAPublicNumbers(e, n).public_key(backend)
+    return rsa.RSAPublicNumbers(e, n).public_key(backend)
 
 
-def _load_ssh_dss_public_key(decoded_data, backend):
-    key_type, rest = _read_next_string(decoded_data)
-    p, rest = _read_next_mpint(rest)
+def _load_ssh_dss_public_key(key_type, decoded_data, backend):
+    p, rest = _read_next_mpint(decoded_data)
     q, rest = _read_next_mpint(rest)
     g, rest = _read_next_mpint(rest)
     y, rest = _read_next_mpint(rest)
 
-    if key_type != b'ssh-dss':
-        raise ValueError(
-            'Key header and key body contain different key type values.')
-
     if rest:
         raise ValueError('Key body contains extra bytes.')
 
-    parameter_numbers = DSAParameterNumbers(p, q, g)
-    public_numbers = DSAPublicNumbers(y, parameter_numbers)
+    parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
+    public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
 
     return public_numbers.public_key(backend)
 
 
+def _load_ssh_ecdsa_public_key(expected_key_type, decoded_data, backend):
+    curve_name, rest = _read_next_string(decoded_data)
+    data, rest = _read_next_string(rest)
+
+    if expected_key_type != b"ecdsa-sha2-" + curve_name:
+        raise ValueError(
+            'Key header and key body contain different key type values.'
+        )
+
+    if rest:
+        raise ValueError('Key body contains extra bytes.')
+
+    if curve_name == b"nistp256":
+        curve = ec.SECP256R1()
+    elif curve_name == b"nistp384":
+        curve = ec.SECP384R1()
+    elif curve_name == b"nistp521":
+        curve = ec.SECP521R1()
+
+    if six.indexbytes(data, 0) != 4:
+        raise NotImplementedError(
+            "Compressed elliptic curve points are not supported"
+        )
+
+    # key_size is in bits, and sometimes it's not evenly divisible by 8, so we
+    # add 7 to round up the number of bytes.
+    if len(data) != 1 + 2 * ((curve.key_size + 7) // 8):
+        raise ValueError("Malformed key bytes")
+
+    x = _int_from_bytes(data[1:1 + (curve.key_size + 7) // 8], byteorder='big')
+    y = _int_from_bytes(data[1 + (curve.key_size + 7) // 8:], byteorder='big')
+    return ec.EllipticCurvePublicNumbers(x, y, curve).public_key(backend)
+
+
 def _read_next_string(data):
-    """Retrieves the next RFC 4251 string value from the data."""
+    """
+    Retrieves the next RFC 4251 string value from the data.
+
+    While the RFC calls these strings, in Python they are bytes objects.
+    """
     str_len, = struct.unpack('>I', data[:4])
     return data[4:4 + str_len], data[4 + str_len:]
 
diff --git a/tests/hazmat/primitives/test_serialization.py b/tests/hazmat/primitives/test_serialization.py
index f3166d7..8c79f64 100644
--- a/tests/hazmat/primitives/test_serialization.py
+++ b/tests/hazmat/primitives/test_serialization.py
@@ -576,7 +576,7 @@
 @pytest.mark.requires_backend_interface(interface=RSABackend)
 class TestRSASSHSerialization(object):
     def test_load_ssh_public_key_unsupported(self, backend):
-        ssh_key = b'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTY='
+        ssh_key = b'ecdsa-sha2-junk AAAAE2VjZHNhLXNoYTItbmlzdHAyNTY='
 
         with pytest.raises(UnsupportedAlgorithm):
             load_ssh_public_key(ssh_key, backend)
@@ -784,3 +784,118 @@
         )
 
         assert numbers == expected
+
+
+@pytest.mark.requires_backend_interface(interface=EllipticCurveBackend)
+class TestECDSASSHSerialization(object):
+    def test_load_ssh_public_key_ecdsa_nist_p256(self, backend):
+        _skip_curve_unsupported(backend, ec.SECP256R1())
+
+        ssh_key = (
+            b"ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAy"
+            b"NTYAAABBBGG2MfkHXp0UkxUyllDzWNBAImsvt5t7pFtTXegZK2WbGxml8zMrgWi5"
+            b"teIg1TO03/FD9hbpBFgBeix3NrCFPls= root@cloud-server-01"
+        )
+        key = load_ssh_public_key(ssh_key, backend)
+        assert isinstance(key, interfaces.EllipticCurvePublicKey)
+
+        expected_x = int(
+            "44196257377740326295529888716212621920056478823906609851236662550"
+            "785814128027", 10
+        )
+        expected_y = int(
+            "12257763433170736656417248739355923610241609728032203358057767672"
+            "925775019611", 10
+        )
+
+        assert key.public_numbers() == ec.EllipticCurvePublicNumbers(
+            expected_x, expected_y, ec.SECP256R1()
+        )
+
+    def test_load_ssh_public_key_ecdsa_nist_p384(self, backend):
+        _skip_curve_unsupported(backend, ec.SECP384R1())
+        ssh_key = (
+            b"ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAz"
+            b"ODQAAABhBMzucOm9wbwg4iMr5QL0ya0XNQGXpw4wM5f12E3tWhdcrzyGHyel71t1"
+            b"4bvF9JZ2/WIuSxUr33XDl8jYo+lMQ5N7Vanc7f7i3AR1YydatL3wQfZStQ1I3rBa"
+            b"qQtRSEU8Tg== root@cloud-server-01"
+        )
+        key = load_ssh_public_key(ssh_key, backend)
+
+        expected_x = int(
+            "31541830871345183397582554827482786756220448716666815789487537666"
+            "592636882822352575507883817901562613492450642523901", 10
+        )
+        expected_y = int(
+            "15111413269431823234030344298767984698884955023183354737123929430"
+            "995703524272335782455051101616329050844273733614670", 10
+        )
+
+        assert key.public_numbers() == ec.EllipticCurvePublicNumbers(
+            expected_x, expected_y, ec.SECP384R1()
+        )
+
+    def test_load_ssh_public_key_ecdsa_nist_p521(self, backend):
+        _skip_curve_unsupported(backend, ec.SECP521R1())
+        ssh_key = (
+            b"ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1"
+            b"MjEAAACFBAGTrRhMSEgF6Ni+PXNz+5fjS4lw3ypUILVVQ0Av+0hQxOx+MyozELon"
+            b"I8NKbrbBjijEs1GuImsmkTmWsMXS1j2A7wB4Kseh7W9KA9IZJ1+TMrzWUEwvOOXi"
+            b"wT23pbaWWXG4NaM7vssWfZBnvz3S174TCXnJ+DSccvWBFnKP0KchzLKxbg== "
+            b"root@cloud-server-01"
+        )
+        key = load_ssh_public_key(ssh_key, backend)
+
+        expected_x = int(
+            "54124123120178189598842622575230904027376313369742467279346415219"
+            "77809037378785192537810367028427387173980786968395921877911964629"
+            "142163122798974160187785455", 10
+        )
+        expected_y = int(
+            "16111775122845033200938694062381820957441843014849125660011303579"
+            "15284560361402515564433711416776946492019498546572162801954089916"
+            "006665939539407104638103918", 10
+        )
+
+        assert key.public_numbers() == ec.EllipticCurvePublicNumbers(
+            expected_x, expected_y, ec.SECP521R1()
+        )
+
+    def test_load_ssh_public_key_ecdsa_nist_p256_trailing_data(self, backend):
+        ssh_key = (
+            b"ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAy"
+            b"NTYAAABBBGG2MfkHXp0UkxUyllDzWNBAImsvt5t7pFtTXegZK2WbGxml8zMrgWi5"
+            b"teIg1TO03/FD9hbpBFgBeix3NrCFPltB= root@cloud-server-01"
+        )
+        with pytest.raises(ValueError):
+            load_ssh_public_key(ssh_key, backend)
+
+    def test_load_ssh_public_key_ecdsa_nist_p256_missing_data(self, backend):
+        ssh_key = (
+            b"ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAy"
+            b"NTYAAABBBGG2MfkHXp0UkxUyllDzWNBAImsvt5t7pFtTXegZK2WbGxml8zMrgWi5"
+            b"teIg1TO03/FD9hbpBFgBeix3NrCF= root@cloud-server-01"
+        )
+        with pytest.raises(ValueError):
+            load_ssh_public_key(ssh_key, backend)
+
+    def test_load_ssh_public_key_ecdsa_nist_p256_compressed(self, backend):
+        # If we ever implement compressed points, note that this is not a valid
+        # one, it just has the compressed marker in the right place.
+        ssh_key = (
+            b"ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAy"
+            b"NTYAAABBAWG2MfkHXp0UkxUyllDzWNBAImsvt5t7pFtTXegZK2WbGxml8zMrgWi5"
+            b"teIg1TO03/FD9hbpBFgBeix3NrCFPls= root@cloud-server-01"
+        )
+        with pytest.raises(NotImplementedError):
+            load_ssh_public_key(ssh_key, backend)
+
+    def test_load_ssh_public_key_ecdsa_nist_p256_bad_curve_name(self, backend):
+        ssh_key = (
+            # The curve name in here is changed to be "nistp255".
+            b"ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAy"
+            b"NTUAAABBBGG2MfkHXp0UkxUyllDzWNBAImsvt5t7pFtTXegZK2WbGxml8zMrgWi5"
+            b"teIg1TO03/FD9hbpBFgBeix3NrCFPls= root@cloud-server-01"
+        )
+        with pytest.raises(ValueError):
+            load_ssh_public_key(ssh_key, backend)