Merge pull request #460 from Lukasa/issue/458

Raise NotImplementedError when SNI not present.
diff --git a/src/OpenSSL/SSL.py b/src/OpenSSL/SSL.py
index 800ae1e..794199c 100644
--- a/src/OpenSSL/SSL.py
+++ b/src/OpenSSL/SSL.py
@@ -406,34 +406,41 @@
     return _ffi.string(_lib.SSLeay_version(type))
 
 
-def _requires_npn(func):
+def _make_requires(flag, error):
     """
-    Wraps any function that requires NPN support in OpenSSL, ensuring that
-    NotImplementedError is raised if NPN is not present.
+    Builds a decorator that ensures that functions that rely on OpenSSL
+    functions that are not present in this build raise NotImplementedError,
+    rather than AttributeError coming out of cryptography.
+
+    :param flag: A cryptography flag that guards the functions, e.g.
+        ``Cryptography_HAS_NEXTPROTONEG``.
+    :param error: The string to be used in the exception if the flag is false.
     """
-    @wraps(func)
-    def wrapper(*args, **kwargs):
-        if not _lib.Cryptography_HAS_NEXTPROTONEG:
-            raise NotImplementedError("NPN not available.")
+    def _requires_decorator(func):
+        if not flag:
+            @wraps(func)
+            def explode(*args, **kwargs):
+                raise NotImplementedError(error)
+            return explode
+        else:
+            return func
 
-        return func(*args, **kwargs)
-
-    return wrapper
+    return _requires_decorator
 
 
-def _requires_alpn(func):
-    """
-    Wraps any function that requires ALPN support in OpenSSL, ensuring that
-    NotImplementedError is raised if ALPN support is not present.
-    """
-    @wraps(func)
-    def wrapper(*args, **kwargs):
-        if not _lib.Cryptography_HAS_ALPN:
-            raise NotImplementedError("ALPN not available.")
+_requires_npn = _make_requires(
+    _lib.Cryptography_HAS_NEXTPROTONEG, "NPN not available"
+)
 
-        return func(*args, **kwargs)
 
-    return wrapper
+_requires_alpn = _make_requires(
+    _lib.Cryptography_HAS_ALPN, "ALPN not available"
+)
+
+
+_requires_sni = _make_requires(
+    _lib.Cryptography_HAS_TLSEXT_HOSTNAME, "SNI not available"
+)
 
 
 class Session(object):
@@ -991,6 +998,7 @@
 
         return _lib.SSL_CTX_set_mode(self._context, mode)
 
+    @_requires_sni
     def set_tlsext_servername_callback(self, callback):
         """
         Specify a callback function to be called when clients specify a server
@@ -1209,6 +1217,7 @@
         _lib.SSL_set_SSL_CTX(self._ssl, context._context)
         self._context = context
 
+    @_requires_sni
     def get_servername(self):
         """
         Retrieve the servername extension value if provided in the client hello
@@ -1224,6 +1233,7 @@
 
         return _ffi.string(name)
 
+    @_requires_sni
     def set_tlsext_host_name(self, name):
         """
         Set the value of the servername extension to send in the client hello.
diff --git a/tests/test_ssl.py b/tests/test_ssl.py
index ab316fc..b1592af 100644
--- a/tests/test_ssl.py
+++ b/tests/test_ssl.py
@@ -55,6 +55,7 @@
     Error, SysCallError, WantReadError, WantWriteError, ZeroReturnError)
 from OpenSSL.SSL import (
     Context, ContextType, Session, Connection, ConnectionType, SSLeay_version)
+from OpenSSL.SSL import _make_requires
 
 from OpenSSL._util import lib as _lib
 
@@ -3855,5 +3856,46 @@
             self.assertTrue(isinstance(const, int))
 
 
+class TestRequires(object):
+    """
+    Tests for the decorator factory used to conditionally raise
+    NotImplementedError when older OpenSSLs are used.
+    """
+    def test_available(self):
+        """
+        When the OpenSSL functionality is available the decorated functions
+        work appropriately.
+        """
+        feature_guard = _make_requires(True, "Error text")
+        results = []
+
+        @feature_guard
+        def inner():
+            results.append(True)
+            return True
+
+        assert inner() is True
+        assert [True] == results
+
+    def test_unavailable(self):
+        """
+        When the OpenSSL functionality is not available the decorated function
+        does not execute and NotImplementedError is raised.
+        """
+        feature_guard = _make_requires(False, "Error text")
+        results = []
+
+        @feature_guard
+        def inner():
+            results.append(True)
+            return True
+
+        with pytest.raises(NotImplementedError) as e:
+            inner()
+
+        assert "Error text" in str(e.value)
+        assert results == []
+
+
 if __name__ == '__main__':
     main()