Merge pull request #460 from Lukasa/issue/458
Raise NotImplementedError when SNI not present.
diff --git a/src/OpenSSL/SSL.py b/src/OpenSSL/SSL.py
index 800ae1e..794199c 100644
--- a/src/OpenSSL/SSL.py
+++ b/src/OpenSSL/SSL.py
@@ -406,34 +406,41 @@
return _ffi.string(_lib.SSLeay_version(type))
-def _requires_npn(func):
+def _make_requires(flag, error):
"""
- Wraps any function that requires NPN support in OpenSSL, ensuring that
- NotImplementedError is raised if NPN is not present.
+ Builds a decorator that ensures that functions that rely on OpenSSL
+ functions that are not present in this build raise NotImplementedError,
+ rather than AttributeError coming out of cryptography.
+
+ :param flag: A cryptography flag that guards the functions, e.g.
+ ``Cryptography_HAS_NEXTPROTONEG``.
+ :param error: The string to be used in the exception if the flag is false.
"""
- @wraps(func)
- def wrapper(*args, **kwargs):
- if not _lib.Cryptography_HAS_NEXTPROTONEG:
- raise NotImplementedError("NPN not available.")
+ def _requires_decorator(func):
+ if not flag:
+ @wraps(func)
+ def explode(*args, **kwargs):
+ raise NotImplementedError(error)
+ return explode
+ else:
+ return func
- return func(*args, **kwargs)
-
- return wrapper
+ return _requires_decorator
-def _requires_alpn(func):
- """
- Wraps any function that requires ALPN support in OpenSSL, ensuring that
- NotImplementedError is raised if ALPN support is not present.
- """
- @wraps(func)
- def wrapper(*args, **kwargs):
- if not _lib.Cryptography_HAS_ALPN:
- raise NotImplementedError("ALPN not available.")
+_requires_npn = _make_requires(
+ _lib.Cryptography_HAS_NEXTPROTONEG, "NPN not available"
+)
- return func(*args, **kwargs)
- return wrapper
+_requires_alpn = _make_requires(
+ _lib.Cryptography_HAS_ALPN, "ALPN not available"
+)
+
+
+_requires_sni = _make_requires(
+ _lib.Cryptography_HAS_TLSEXT_HOSTNAME, "SNI not available"
+)
class Session(object):
@@ -991,6 +998,7 @@
return _lib.SSL_CTX_set_mode(self._context, mode)
+ @_requires_sni
def set_tlsext_servername_callback(self, callback):
"""
Specify a callback function to be called when clients specify a server
@@ -1209,6 +1217,7 @@
_lib.SSL_set_SSL_CTX(self._ssl, context._context)
self._context = context
+ @_requires_sni
def get_servername(self):
"""
Retrieve the servername extension value if provided in the client hello
@@ -1224,6 +1233,7 @@
return _ffi.string(name)
+ @_requires_sni
def set_tlsext_host_name(self, name):
"""
Set the value of the servername extension to send in the client hello.
diff --git a/tests/test_ssl.py b/tests/test_ssl.py
index ab316fc..b1592af 100644
--- a/tests/test_ssl.py
+++ b/tests/test_ssl.py
@@ -55,6 +55,7 @@
Error, SysCallError, WantReadError, WantWriteError, ZeroReturnError)
from OpenSSL.SSL import (
Context, ContextType, Session, Connection, ConnectionType, SSLeay_version)
+from OpenSSL.SSL import _make_requires
from OpenSSL._util import lib as _lib
@@ -3855,5 +3856,46 @@
self.assertTrue(isinstance(const, int))
+class TestRequires(object):
+ """
+ Tests for the decorator factory used to conditionally raise
+ NotImplementedError when older OpenSSLs are used.
+ """
+ def test_available(self):
+ """
+ When the OpenSSL functionality is available the decorated functions
+ work appropriately.
+ """
+ feature_guard = _make_requires(True, "Error text")
+ results = []
+
+ @feature_guard
+ def inner():
+ results.append(True)
+ return True
+
+ assert inner() is True
+ assert [True] == results
+
+ def test_unavailable(self):
+ """
+ When the OpenSSL functionality is not available the decorated function
+ does not execute and NotImplementedError is raised.
+ """
+ feature_guard = _make_requires(False, "Error text")
+ results = []
+
+ @feature_guard
+ def inner():
+ results.append(True)
+ return True
+
+ with pytest.raises(NotImplementedError) as e:
+ inner()
+
+ assert "Error text" in str(e.value)
+ assert results == []
+
+
if __name__ == '__main__':
main()