Merge pull request #101 from pyca/ecdhe

Add basic support for using ECDHE.
diff --git a/ChangeLog b/ChangeLog
index 53fb1df..6416e18 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,12 @@
+2014-04-19  Jean-Paul Calderone  <exarkun@twistedmatrix.com>
+
+	* OpenSSL/crypto.py: Based on work from Alex Gaynor, Andrew
+	  Lutomirski, Tobias Oberstein, Laurens Van Houtven, and Hynek
+	  Schlawack, add ``get_elliptic_curve`` and ``get_elliptic_curves``
+	  to support TLS ECDHE modes.
+	* OpenSSL/SSL.py: Add ``Context.set_tmp_ecdh`` to configure a TLS
+	  context with a particular elliptic curve for ECDHE modes.
+
 2014-04-19  Markus Unterwaditzer <markus@unterwaditzer.net>
 
 	* OpenSSL/SSL.py: ``Connection.send`` and ``Connection.sendall``
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py
index 593d89f..58553d6 100644
--- a/OpenSSL/SSL.py
+++ b/OpenSSL/SSL.py
@@ -130,7 +130,6 @@
 SSL_CB_HANDSHAKE_START = _lib.SSL_CB_HANDSHAKE_START
 SSL_CB_HANDSHAKE_DONE = _lib.SSL_CB_HANDSHAKE_DONE
 
-
 class Error(Exception):
     """
     An error occurred in an `OpenSSL.SSL` API.
@@ -604,6 +603,19 @@
         _lib.SSL_CTX_set_tmp_dh(self._context, dh)
 
 
+    def set_tmp_ecdh(self, curve):
+        """
+        Select a curve to use for ECDHE key exchange.
+
+        :param curve: A curve object to use as returned by either
+            :py:meth:`OpenSSL.crypto.get_elliptic_curve` or
+            :py:meth:`OpenSSL.crypto.get_elliptic_curves`.
+
+        :return: None
+        """
+        _lib.SSL_CTX_set_tmp_ecdh(self._context, curve._to_EC_KEY())
+
+
     def set_cipher_list(self, cipher_list):
         """
         Change the cipher list
@@ -1224,7 +1236,7 @@
         The makefile() method is not implemented, since there is no dup semantics
         for SSL connections
 
-        :raise NotImplementedError
+        :raise: NotImplementedError
         """
         raise NotImplementedError("Cannot make file object of OpenSSL.SSL.Connection")
 
diff --git a/OpenSSL/crypto.py b/OpenSSL/crypto.py
index 65e28d7..03fe853 100644
--- a/OpenSSL/crypto.py
+++ b/OpenSSL/crypto.py
@@ -5,7 +5,8 @@
 
 from six import (
     integer_types as _integer_types,
-    text_type as _text_type)
+    text_type as _text_type,
+    PY3 as _PY3)
 
 from OpenSSL._util import (
     ffi as _ffi,
@@ -263,6 +264,156 @@
 
 
 
+class _EllipticCurve(object):
+    """
+    A representation of a supported elliptic curve.
+
+    @cvar _curves: :py:obj:`None` until an attempt is made to load the curves.
+        Thereafter, a :py:type:`set` containing :py:type:`_EllipticCurve`
+        instances each of which represents one curve supported by the system.
+    @type _curves: :py:type:`NoneType` or :py:type:`set`
+    """
+    _curves = None
+
+    if _PY3:
+        # This only necessary on Python 3.  Morever, it is broken on Python 2.
+        def __ne__(self, other):
+            """
+            Implement cooperation with the right-hand side argument of ``!=``.
+
+            Python 3 seems to have dropped this cooperation in this very narrow
+            circumstance.
+            """
+            if isinstance(other, _EllipticCurve):
+                return super(_EllipticCurve, self).__ne__(other)
+            return NotImplemented
+
+
+    @classmethod
+    def _load_elliptic_curves(cls, lib):
+        """
+        Get the curves supported by OpenSSL.
+
+        :param lib: The OpenSSL library binding object.
+
+        :return: A :py:type:`set` of ``cls`` instances giving the names of the
+            elliptic curves the underlying library supports.
+        """
+        if lib.Cryptography_HAS_EC:
+            num_curves = lib.EC_get_builtin_curves(_ffi.NULL, 0)
+            builtin_curves = _ffi.new('EC_builtin_curve[]', num_curves)
+            # The return value on this call should be num_curves again.  We could
+            # check it to make sure but if it *isn't* then.. what could we do?
+            # Abort the whole process, I suppose...?  -exarkun
+            lib.EC_get_builtin_curves(builtin_curves, num_curves)
+            return set(
+                cls.from_nid(lib, c.nid)
+                for c in builtin_curves)
+        return set()
+
+
+    @classmethod
+    def _get_elliptic_curves(cls, lib):
+        """
+        Get, cache, and return the curves supported by OpenSSL.
+
+        :param lib: The OpenSSL library binding object.
+
+        :return: A :py:type:`set` of ``cls`` instances giving the names of the
+            elliptic curves the underlying library supports.
+        """
+        if cls._curves is None:
+            cls._curves = cls._load_elliptic_curves(lib)
+        return cls._curves
+
+
+    @classmethod
+    def from_nid(cls, lib, nid):
+        """
+        Instantiate a new :py:class:`_EllipticCurve` associated with the given
+        OpenSSL NID.
+
+        :param lib: The OpenSSL library binding object.
+
+        :param nid: The OpenSSL NID the resulting curve object will represent.
+            This must be a curve NID (and not, for example, a hash NID) or
+            subsequent operations will fail in unpredictable ways.
+        :type nid: :py:class:`int`
+
+        :return: The curve object.
+        """
+        return cls(lib, nid, _ffi.string(lib.OBJ_nid2sn(nid)).decode("ascii"))
+
+
+    def __init__(self, lib, nid, name):
+        """
+        :param _lib: The :py:mod:`cryptography` binding instance used to
+            interface with OpenSSL.
+
+        :param _nid: The OpenSSL NID identifying the curve this object
+            represents.
+        :type _nid: :py:class:`int`
+
+        :param name: The OpenSSL short name identifying the curve this object
+            represents.
+        :type name: :py:class:`unicode`
+        """
+        self._lib = lib
+        self._nid = nid
+        self.name = name
+
+
+    def __repr__(self):
+        return "<Curve %r>" % (self.name,)
+
+
+    def _to_EC_KEY(self):
+        """
+        Create a new OpenSSL EC_KEY structure initialized to use this curve.
+
+        The structure is automatically garbage collected when the Python object
+        is garbage collected.
+        """
+        key = self._lib.EC_KEY_new_by_curve_name(self._nid)
+        return _ffi.gc(key, _lib.EC_KEY_free)
+
+
+
+def get_elliptic_curves():
+    """
+    Return a set of objects representing the elliptic curves supported in the
+    OpenSSL build in use.
+
+    The curve objects have a :py:class:`unicode` ``name`` attribute by which
+    they identify themselves.
+
+    The curve objects are useful as values for the argument accepted by
+    :py:meth:`Context.set_tmp_ecdh` to specify which elliptical curve should be
+    used for ECDHE key exchange.
+    """
+    return _EllipticCurve._get_elliptic_curves(_lib)
+
+
+
+def get_elliptic_curve(name):
+    """
+    Return a single curve object selected by name.
+
+    See :py:func:`get_elliptic_curves` for information about curve objects.
+
+    :param name: The OpenSSL short name identifying the curve object to
+        retrieve.
+    :type name: :py:class:`unicode`
+
+    If the named curve is not supported then :py:class:`ValueError` is raised.
+    """
+    for curve in get_elliptic_curves():
+        if curve.name == name:
+            return curve
+    raise ValueError("unknown curve name", name)
+
+
+
 class X509Name(object):
     def __init__(self, name):
         """
diff --git a/OpenSSL/test/test_crypto.py b/OpenSSL/test/test_crypto.py
index a3685a9..34e60a3 100644
--- a/OpenSSL/test/test_crypto.py
+++ b/OpenSSL/test/test_crypto.py
@@ -11,7 +11,7 @@
 from subprocess import PIPE, Popen
 from datetime import datetime, timedelta
 
-from six import binary_type
+from six import u, b, binary_type
 
 from OpenSSL.crypto import TYPE_RSA, TYPE_DSA, Error, PKey, PKeyType
 from OpenSSL.crypto import X509, X509Type, X509Name, X509NameType
@@ -25,9 +25,10 @@
 from OpenSSL.crypto import PKCS12, PKCS12Type, load_pkcs12
 from OpenSSL.crypto import CRL, Revoked, load_crl
 from OpenSSL.crypto import NetscapeSPKI, NetscapeSPKIType
-from OpenSSL.crypto import sign, verify
-from OpenSSL.test.util import TestCase, b
-from OpenSSL._util import native
+from OpenSSL.crypto import (
+    sign, verify, get_elliptic_curve, get_elliptic_curves)
+from OpenSSL.test.util import EqualityTestsMixin, TestCase
+from OpenSSL._util import native, lib
 
 def normalize_certificate_pem(pem):
     return dump_certificate(FILETYPE_PEM, load_certificate(FILETYPE_PEM, pem))
@@ -3058,5 +3059,154 @@
         verify(good_cert, sig, content, "sha1")
 
 
+
+class EllipticCurveTests(TestCase):
+    """
+    Tests for :py:class:`_EllipticCurve`, :py:obj:`get_elliptic_curve`, and
+    :py:obj:`get_elliptic_curves`.
+    """
+    def test_set(self):
+        """
+        :py:obj:`get_elliptic_curves` returns a :py:obj:`set`.
+        """
+        self.assertIsInstance(get_elliptic_curves(), set)
+
+
+    def test_some_curves(self):
+        """
+        If :py:mod:`cryptography` has elliptic curve support then the set
+        returned by :py:obj:`get_elliptic_curves` has some elliptic curves in
+        it.
+
+        There could be an OpenSSL that violates this assumption.  If so, this
+        test will fail and we'll find out.
+        """
+        curves = get_elliptic_curves()
+        if lib.Cryptography_HAS_EC:
+            self.assertTrue(curves)
+        else:
+            self.assertFalse(curves)
+
+
+    def test_a_curve(self):
+        """
+        :py:obj:`get_elliptic_curve` can be used to retrieve a particular
+        supported curve.
+        """
+        curves = get_elliptic_curves()
+        if curves:
+            curve = next(iter(curves))
+            self.assertEqual(curve.name, get_elliptic_curve(curve.name).name)
+        else:
+            self.assertRaises(ValueError, get_elliptic_curve, u("prime256v1"))
+
+
+    def test_not_a_curve(self):
+        """
+        :py:obj:`get_elliptic_curve` raises :py:class:`ValueError` if called
+        with a name which does not identify a supported curve.
+        """
+        self.assertRaises(
+            ValueError, get_elliptic_curve, u("this curve was just invented"))
+
+
+    def test_repr(self):
+        """
+        The string representation of a curve object includes simply states the
+        object is a curve and what its name is.
+        """
+        curves = get_elliptic_curves()
+        if curves:
+            curve = next(iter(curves))
+            self.assertEqual("<Curve %r>" % (curve.name,), repr(curve))
+
+
+    def test_to_EC_KEY(self):
+        """
+        The curve object can export a version of itself as an EC_KEY* via the
+        private :py:meth:`_EllipticCurve._to_EC_KEY`.
+        """
+        curves = get_elliptic_curves()
+        if curves:
+            curve = next(iter(curves))
+            # It's not easy to assert anything about this object.  However, see
+            # leakcheck/crypto.py for a test that demonstrates it at least does
+            # not leak memory.
+            curve._to_EC_KEY()
+
+
+
+class EllipticCurveFactory(object):
+    """
+    A helper to get the names of two curves.
+    """
+    def __init__(self):
+        curves = iter(get_elliptic_curves())
+        try:
+            self.curve_name = next(curves).name
+            self.another_curve_name = next(curves).name
+        except StopIteration:
+            self.curve_name = self.another_curve_name = None
+
+
+
+class EllipticCurveEqualityTests(TestCase, EqualityTestsMixin):
+    """
+    Tests :py:type:`_EllipticCurve`\ 's implementation of ``==`` and ``!=``.
+    """
+    curve_factory = EllipticCurveFactory()
+
+    if curve_factory.curve_name is None:
+        skip = "There are no curves available there can be no curve objects."
+
+
+    def anInstance(self):
+        """
+        Get the curve object for an arbitrary curve supported by the system.
+        """
+        return get_elliptic_curve(self.curve_factory.curve_name)
+
+
+    def anotherInstance(self):
+        """
+        Get the curve object for an arbitrary curve supported by the system -
+        but not the one returned by C{anInstance}.
+        """
+        return get_elliptic_curve(self.curve_factory.another_curve_name)
+
+
+
+class EllipticCurveHashTests(TestCase):
+    """
+    Tests for :py:type:`_EllipticCurve`\ 's implementation of hashing (thus use
+    as an item in a :py:type:`dict` or :py:type:`set`).
+    """
+    curve_factory = EllipticCurveFactory()
+
+    if curve_factory.curve_name is None:
+        skip = "There are no curves available there can be no curve objects."
+
+
+    def test_contains(self):
+        """
+        The ``in`` operator reports that a :py:type:`set` containing a curve
+        does contain that curve.
+        """
+        curve = get_elliptic_curve(self.curve_factory.curve_name)
+        curves = set([curve])
+        self.assertIn(curve, curves)
+
+
+    def test_does_not_contain(self):
+        """
+        The ``in`` operator reports that a :py:type:`set` not containing a
+        curve does not contain that curve.
+        """
+        curve = get_elliptic_curve(self.curve_factory.curve_name)
+        curves = set([get_elliptic_curve(self.curve_factory.another_curve_name)])
+        self.assertNotIn(curve, curves)
+
+
+
 if __name__ == '__main__':
     main()
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index 2dc0912..1d18fd0 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -20,7 +20,9 @@
 from OpenSSL.crypto import PKey, X509, X509Extension, X509Store
 from OpenSSL.crypto import dump_privatekey, load_privatekey
 from OpenSSL.crypto import dump_certificate, load_certificate
+from OpenSSL.crypto import get_elliptic_curves
 
+from OpenSSL.SSL import _lib
 from OpenSSL.SSL import OPENSSL_VERSION_NUMBER, SSLEAY_VERSION, SSLEAY_CFLAGS
 from OpenSSL.SSL import SSLEAY_PLATFORM, SSLEAY_DIR, SSLEAY_BUILT_ON
 from OpenSSL.SSL import SENT_SHUTDOWN, RECEIVED_SHUTDOWN
@@ -1172,6 +1174,18 @@
         # XXX What should I assert here? -exarkun
 
 
+    def test_set_tmp_ecdh(self):
+        """
+        :py:obj:`Context.set_tmp_ecdh` sets the elliptic curve for
+        Diffie-Hellman to the specified curve.
+        """
+        context = Context(TLSv1_METHOD)
+        for curve in get_elliptic_curves():
+            # The only easily "assertable" thing is that it does not raise an
+            # exception.
+            context.set_tmp_ecdh(curve)
+
+
     def test_set_cipher_list_bytes(self):
         """
         :py:obj:`Context.set_cipher_list` accepts a :py:obj:`bytes` naming the
diff --git a/OpenSSL/test/util.py b/OpenSSL/test/util.py
index 4e4d812..21bbdc4 100644
--- a/OpenSSL/test/util.py
+++ b/OpenSSL/test/util.py
@@ -210,6 +210,23 @@
         return containee
     assertIn = failUnlessIn
 
+    def assertNotIn(self, containee, container, msg=None):
+        """
+        Fail the test if C{containee} is found in C{container}.
+
+        @param containee: the value that should not be in C{container}
+        @param container: a sequence type, or in the case of a mapping type,
+                          will follow semantics of 'if key in dict.keys()'
+        @param msg: if msg is None, then the failure message will be
+                    '%r in %r' % (first, second)
+        """
+        if containee in container:
+            raise self.failureException(msg or "%r in %r"
+                                        % (containee, container))
+        return containee
+    failIfIn = assertNotIn
+
+
     def failUnlessIdentical(self, first, second, msg=None):
         """
         Fail the test if :py:data:`first` is not :py:data:`second`.  This is an
@@ -300,3 +317,133 @@
         self.assertTrue(isinstance(theType, type))
         instance = theType(*constructionArgs)
         self.assertIdentical(type(instance), theType)
+
+
+
+class EqualityTestsMixin(object):
+    """
+    A mixin defining tests for the standard implementation of C{==} and C{!=}.
+    """
+    def anInstance(self):
+        """
+        Return an instance of the class under test.  Each call to this method
+        must return a different object.  All objects returned must be equal to
+        each other.
+        """
+        raise NotImplementedError()
+
+
+    def anotherInstance(self):
+        """
+        Return an instance of the class under test.  Each call to this method
+        must return a different object.  The objects must not be equal to the
+        objects returned by C{anInstance}.  They may or may not be equal to
+        each other (they will not be compared against each other).
+        """
+        raise NotImplementedError()
+
+
+    def test_identicalEq(self):
+        """
+        An object compares equal to itself using the C{==} operator.
+        """
+        o = self.anInstance()
+        self.assertTrue(o == o)
+
+
+    def test_identicalNe(self):
+        """
+        An object doesn't compare not equal to itself using the C{!=} operator.
+        """
+        o = self.anInstance()
+        self.assertFalse(o != o)
+
+
+    def test_sameEq(self):
+        """
+        Two objects that are equal to each other compare equal to each other
+        using the C{==} operator.
+        """
+        a = self.anInstance()
+        b = self.anInstance()
+        self.assertTrue(a == b)
+
+
+    def test_sameNe(self):
+        """
+        Two objects that are equal to each other do not compare not equal to
+        each other using the C{!=} operator.
+        """
+        a = self.anInstance()
+        b = self.anInstance()
+        self.assertFalse(a != b)
+
+
+    def test_differentEq(self):
+        """
+        Two objects that are not equal to each other do not compare equal to
+        each other using the C{==} operator.
+        """
+        a = self.anInstance()
+        b = self.anotherInstance()
+        self.assertFalse(a == b)
+
+
+    def test_differentNe(self):
+        """
+        Two objects that are not equal to each other compare not equal to each
+        other using the C{!=} operator.
+        """
+        a = self.anInstance()
+        b = self.anotherInstance()
+        self.assertTrue(a != b)
+
+
+    def test_anotherTypeEq(self):
+        """
+        The object does not compare equal to an object of an unrelated type
+        (which does not implement the comparison) using the C{==} operator.
+        """
+        a = self.anInstance()
+        b = object()
+        self.assertFalse(a == b)
+
+
+    def test_anotherTypeNe(self):
+        """
+        The object compares not equal to an object of an unrelated type (which
+        does not implement the comparison) using the C{!=} operator.
+        """
+        a = self.anInstance()
+        b = object()
+        self.assertTrue(a != b)
+
+
+    def test_delegatedEq(self):
+        """
+        The result of comparison using C{==} is delegated to the right-hand
+        operand if it is of an unrelated type.
+        """
+        class Delegate(object):
+            def __eq__(self, other):
+                # Do something crazy and obvious.
+                return [self]
+
+        a = self.anInstance()
+        b = Delegate()
+        self.assertEqual(a == b, [b])
+
+
+    def test_delegateNe(self):
+        """
+        The result of comparison using C{!=} is delegated to the right-hand
+        operand if it is of an unrelated type.
+        """
+        class Delegate(object):
+            def __ne__(self, other):
+                # Do something crazy and obvious.
+                return [self]
+
+        a = self.anInstance()
+        b = Delegate()
+        self.assertEqual(a != b, [b])
diff --git a/doc/api/crypto.rst b/doc/api/crypto.rst
index ee93cfb..b360e89 100644
--- a/doc/api/crypto.rst
+++ b/doc/api/crypto.rst
@@ -119,6 +119,28 @@
     Generic exception used in the :py:mod:`.crypto` module.
 
 
+.. py:function:: get_elliptic_curves
+
+    Return a set of objects representing the elliptic curves supported in the
+    OpenSSL build in use.
+
+    The curve objects have a :py:class:`unicode` ``name`` attribute by which
+    they identify themselves.
+
+    The curve objects are useful as values for the argument accepted by
+    :py:meth:`Context.set_tmp_ecdh` to specify which elliptical curve should be
+    used for ECDHE key exchange.
+
+
+.. py:function:: get_elliptic_curve
+
+    Return a single curve object selected by name.
+
+    See :py:func:`get_elliptic_curves` for information about curve objects.
+
+    If the named curve is not supported then :py:class:`ValueError` is raised.
+
+
 .. py:function:: dump_certificate(type, cert)
 
     Dump the certificate *cert* into a buffer string encoded with the type
diff --git a/doc/api/ssl.rst b/doc/api/ssl.rst
index e1c1d8a..a75af1f 100644
--- a/doc/api/ssl.rst
+++ b/doc/api/ssl.rst
@@ -317,6 +317,15 @@
     Load parameters for Ephemeral Diffie-Hellman from *dhfile*.
 
 
+.. py:method:: Context.set_tmp_ecdh(curve)
+
+   Select a curve to use for ECDHE key exchange.
+
+   The valid values of *curve* are the objects returned by
+   :py:func:`OpenSSL.crypto.get_elliptic_curves` or
+   :py:func:`OpenSSL.crypto.get_elliptic_curve`.
+
+
 .. py:method:: Context.set_app_data(data)
 
     Associate *data* with this Context object. *data* can be retrieved
diff --git a/leakcheck/crypto.py b/leakcheck/crypto.py
index f5fe2f8..ca79b7c 100644
--- a/leakcheck/crypto.py
+++ b/leakcheck/crypto.py
@@ -5,7 +5,7 @@
 
 from OpenSSL.crypto import (
     FILETYPE_PEM, TYPE_DSA, Error, PKey, X509, load_privatekey, CRL, Revoked,
-    _X509_REVOKED_dup)
+    get_elliptic_curves, _X509_REVOKED_dup)
 
 from OpenSSL._util import lib as _lib
 
@@ -145,6 +145,22 @@
 
 
 
+class Checker_EllipticCurve(BaseChecker):
+    """
+    Leak checks for :py:obj:`_EllipticCurve`.
+    """
+    def check_to_EC_KEY(self):
+        """
+        Repeatedly create an EC_KEY* from an :py:obj:`_EllipticCurve`.  The
+        structure should be automatically garbage collected.
+        """
+        curves = get_elliptic_curves()
+        if curves:
+            curve = next(iter(curves))
+            for i in xrange(self.iterations * 1000):
+                curve._to_EC_KEY()
+
+
 def vmsize():
     return [x for x in file('/proc/self/status').readlines() if 'VmSize' in x]