Merge pull request #1016 from Ayrx/hkdf-expand-only

HKDF Expand Only implementation
diff --git a/cryptography/hazmat/primitives/kdf/hkdf.py b/cryptography/hazmat/primitives/kdf/hkdf.py
index 03500aa..daa8fcc 100644
--- a/cryptography/hazmat/primitives/kdf/hkdf.py
+++ b/cryptography/hazmat/primitives/kdf/hkdf.py
@@ -34,6 +34,51 @@
 
         self._algorithm = algorithm
 
+        if isinstance(salt, six.text_type):
+            raise TypeError(
+                "Unicode-objects must be encoded before using them as a salt.")
+
+        if salt is None:
+            salt = b"\x00" * (self._algorithm.digest_size // 8)
+
+        self._salt = salt
+
+        self._backend = backend
+
+        self._hkdf_expand = HKDFExpand(self._algorithm, length, info, backend)
+
+    def _extract(self, key_material):
+        h = hmac.HMAC(self._salt, self._algorithm, backend=self._backend)
+        h.update(key_material)
+        return h.finalize()
+
+    def derive(self, key_material):
+        if isinstance(key_material, six.text_type):
+            raise TypeError(
+                "Unicode-objects must be encoded before using them as key "
+                "material."
+            )
+
+        return self._hkdf_expand.derive(self._extract(key_material))
+
+    def verify(self, key_material, expected_key):
+        if not constant_time.bytes_eq(self.derive(key_material), expected_key):
+            raise InvalidKey
+
+
+@utils.register_interface(interfaces.KeyDerivationFunction)
+class HKDFExpand(object):
+    def __init__(self, algorithm, length, info, backend):
+        if not isinstance(backend, HMACBackend):
+            raise UnsupportedAlgorithm(
+                "Backend object does not implement HMACBackend",
+                _Reasons.BACKEND_MISSING_INTERFACE
+            )
+
+        self._algorithm = algorithm
+
+        self._backend = backend
+
         max_length = 255 * (algorithm.digest_size // 8)
 
         if length > max_length:
@@ -44,15 +89,6 @@
 
         self._length = length
 
-        if isinstance(salt, six.text_type):
-            raise TypeError(
-                "Unicode-objects must be encoded before using them as a salt.")
-
-        if salt is None:
-            salt = b"\x00" * (self._algorithm.digest_size // 8)
-
-        self._salt = salt
-
         if isinstance(info, six.text_type):
             raise TypeError(
                 "Unicode-objects must be encoded before using them as info.")
@@ -61,15 +97,9 @@
             info = b""
 
         self._info = info
-        self._backend = backend
 
         self._used = False
 
-    def _extract(self, key_material):
-        h = hmac.HMAC(self._salt, self._algorithm, backend=self._backend)
-        h.update(key_material)
-        return h.finalize()
-
     def _expand(self, key_material):
         output = [b""]
         counter = 1
@@ -87,7 +117,7 @@
     def derive(self, key_material):
         if isinstance(key_material, six.text_type):
             raise TypeError(
-                "Unicode-objects must be encoded before using them as key "
+                "Unicode-objects must be encoded before using them as key"
                 "material."
             )
 
@@ -95,7 +125,7 @@
             raise AlreadyFinalized
 
         self._used = True
-        return self._expand(self._extract(key_material))
+        return self._expand(key_material)
 
     def verify(self, key_material, expected_key):
         if not constant_time.bytes_eq(self.derive(key_material), expected_key):
diff --git a/docs/hazmat/primitives/key-derivation-functions.rst b/docs/hazmat/primitives/key-derivation-functions.rst
index ee8f8ab..de6bf5f 100644
--- a/docs/hazmat/primitives/key-derivation-functions.rst
+++ b/docs/hazmat/primitives/key-derivation-functions.rst
@@ -219,6 +219,98 @@
         ``key_material`` generates the same key as the ``expected_key``, and
         raises an exception if they do not match.
 
+
+.. class:: HKDFExpand(algorithm, length, info, backend)
+
+    .. versionadded:: 0.5
+
+    HKDF consists of two stages, extract and expand. This class exposes an
+    expand only version of HKDF that is suitable when the key material is
+    already cryptographically strong.
+
+    .. warning::
+
+        HKDFExpand should only be used if the key material is
+        cryptographically strong. You should use
+        :class:`~cryptography.hazmat.primitives.kdf.hkdf.HKDF` if
+        you are unsure.
+
+    .. doctest::
+
+        >>> import os
+        >>> from cryptography.hazmat.primitives import hashes
+        >>> from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand
+        >>> from cryptography.hazmat.backends import default_backend
+        >>> backend = default_backend()
+        >>> info = b"hkdf-example"
+        >>> key_material = os.urandom(16)
+        >>> hkdf = HKDFExpand(
+        ...     algorithm=hashes.SHA256(),
+        ...     length=32,
+        ...     info=info,
+        ...     backend=backend
+        ... )
+        >>> key = hkdf.derive(key_material)
+        >>> hkdf = HKDFExpand(
+        ...     algorithm=hashes.SHA256(),
+        ...     length=32,
+        ...     info=info,
+        ...     backend=backend
+        ... )
+        >>> hkdf.verify(key_material, key)
+
+    :param algorithm: An instance of a
+        :class:`~cryptography.hazmat.primitives.interfaces.HashAlgorithm`
+        provider.
+
+    :param int length: The desired length of the derived key. Maximum is
+        ``255 * (algorithm.digest_size // 8)``.
+
+    :param bytes info: Application specific context information.  If ``None``
+        is explicitly passed an empty byte string will be used.
+
+    :param backend: A
+        :class:`~cryptography.hazmat.backends.interfaces.HMACBackend`
+        provider.
+
+    :raises cryptography.exceptions.UnsupportedAlgorithm: This is raised if the
+        provided ``backend`` does not implement
+        :class:`~cryptography.hazmat.backends.interfaces.HMACBackend`
+    :raises TypeError: This is raised if the provided ``info`` is a unicode object
+
+    .. method:: derive(key_material)
+
+        :param bytes key_material: The input key material.
+        :return bytes: The derived key.
+
+        :raises TypeError: This is raised if the provided ``key_material`` is
+            a unicode object
+
+        Derives a new key from the input key material by performing both the
+        extract and expand operations.
+
+    .. method:: verify(key_material, expected_key)
+
+        :param key_material bytes: The input key material. This is the same as
+                                   ``key_material`` in :meth:`derive`.
+        :param expected_key bytes: The expected result of deriving a new key,
+                                   this is the same as the return value of
+                                   :meth:`derive`.
+        :raises cryptography.exceptions.InvalidKey: This is raised when the
+                                                    derived key does not match
+                                                    the expected key.
+        :raises cryptography.exceptions.AlreadyFinalized: This is raised when
+                                                          :meth:`derive` or
+                                                          :meth:`verify` is
+                                                          called more than
+                                                          once.
+        :raises TypeError: This is raised if the provided ``key_material`` is
+            a unicode object
+
+        This checks whether deriving a new key from the supplied
+        ``key_material`` generates the same key as the ``expected_key``, and
+        raises an exception if they do not match.
+
 .. _`NIST SP 800-132`: http://csrc.nist.gov/publications/nistpubs/800-132/nist-sp800-132.pdf
 .. _`Password Storage Cheat Sheet`: https://www.owasp.org/index.php/Password_Storage_Cheat_Sheet
 .. _`PBKDF2`: https://en.wikipedia.org/wiki/PBKDF2
diff --git a/tests/hazmat/primitives/test_hkdf.py b/tests/hazmat/primitives/test_hkdf.py
index 2e3c0c3..598f09f 100644
--- a/tests/hazmat/primitives/test_hkdf.py
+++ b/tests/hazmat/primitives/test_hkdf.py
@@ -13,6 +13,8 @@
 
 from __future__ import absolute_import, division, print_function
 
+import binascii
+
 import pytest
 
 import six
@@ -21,7 +23,7 @@
     AlreadyFinalized, InvalidKey, _Reasons
 )
 from cryptography.hazmat.primitives import hashes
-from cryptography.hazmat.primitives.kdf.hkdf import HKDF
+from cryptography.hazmat.primitives.kdf.hkdf import HKDF, HKDFExpand
 
 from ...utils import raises_unsupported_algorithm
 
@@ -151,8 +153,67 @@
             hkdf.verify(b"foo", six.u("bar"))
 
 
+@pytest.mark.hmac
+class TestHKDFExpand(object):
+    def test_derive(self, backend):
+        prk = binascii.unhexlify(
+            b"077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5"
+        )
+
+        okm = (b"3cb25f25faacd57a90434f64d0362f2a2d2d0a90cf1a5a4c5db02d56ecc4c"
+               b"5bf34007208d5b887185865")
+
+        info = binascii.unhexlify(b"f0f1f2f3f4f5f6f7f8f9")
+        hkdf = HKDFExpand(hashes.SHA256(), 42, info, backend)
+
+        assert binascii.hexlify(hkdf.derive(prk)) == okm
+
+    def test_verify(self, backend):
+        prk = binascii.unhexlify(
+            b"077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5"
+        )
+
+        okm = (b"3cb25f25faacd57a90434f64d0362f2a2d2d0a90cf1a5a4c5db02d56ecc4c"
+               b"5bf34007208d5b887185865")
+
+        info = binascii.unhexlify(b"f0f1f2f3f4f5f6f7f8f9")
+        hkdf = HKDFExpand(hashes.SHA256(), 42, info, backend)
+
+        assert hkdf.verify(prk, binascii.unhexlify(okm)) is None
+
+    def test_invalid_verify(self, backend):
+        prk = binascii.unhexlify(
+            b"077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5"
+        )
+
+        info = binascii.unhexlify(b"f0f1f2f3f4f5f6f7f8f9")
+        hkdf = HKDFExpand(hashes.SHA256(), 42, info, backend)
+
+        with pytest.raises(InvalidKey):
+            hkdf.verify(prk, b"wrong key")
+
+    def test_already_finalized(self, backend):
+        info = binascii.unhexlify(b"f0f1f2f3f4f5f6f7f8f9")
+        hkdf = HKDFExpand(hashes.SHA256(), 42, info, backend)
+
+        hkdf.derive(b"first")
+
+        with pytest.raises(AlreadyFinalized):
+            hkdf.derive(b"second")
+
+    def test_unicode_error(self, backend):
+        info = binascii.unhexlify(b"f0f1f2f3f4f5f6f7f8f9")
+        hkdf = HKDFExpand(hashes.SHA256(), 42, info, backend)
+
+        with pytest.raises(TypeError):
+            hkdf.derive(six.u("first"))
+
+
 def test_invalid_backend():
     pretend_backend = object()
 
     with raises_unsupported_algorithm(_Reasons.BACKEND_MISSING_INTERFACE):
         HKDF(hashes.SHA256(), 16, None, None, pretend_backend)
+
+    with raises_unsupported_algorithm(_Reasons.BACKEND_MISSING_INTERFACE):
+        HKDFExpand(hashes.SHA256(), 16, None, pretend_backend)
diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py
index 6c3f4c9..a496459 100644
--- a/tests/hazmat/primitives/utils.py
+++ b/tests/hazmat/primitives/utils.py
@@ -26,7 +26,7 @@
 from cryptography.hazmat.primitives import hashes, hmac
 from cryptography.hazmat.primitives.asymmetric import rsa
 from cryptography.hazmat.primitives.ciphers import Cipher
-from cryptography.hazmat.primitives.kdf.hkdf import HKDF
+from cryptography.hazmat.primitives.kdf.hkdf import HKDF, HKDFExpand
 from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
 
 from ...utils import load_vectors_from_file
@@ -347,15 +347,14 @@
 
 
 def hkdf_expand_test(backend, algorithm, params):
-    hkdf = HKDF(
+    hkdf = HKDFExpand(
         algorithm,
         int(params["l"]),
-        salt=binascii.unhexlify(params["salt"]) or None,
         info=binascii.unhexlify(params["info"]) or None,
         backend=backend
     )
 
-    okm = hkdf._expand(binascii.unhexlify(params["prk"]))
+    okm = hkdf.derive(binascii.unhexlify(params["prk"]))
 
     assert okm == binascii.unhexlify(params["okm"])