NIST SP 800-108 Counter Mode KDF (#2748)

* NIST SP 800-108 Counter Mode and Feedback Mode KDF

* CounterKDF unit tests

* Refactor to support multiple key based KDF modes.

* Extracting supported algorithms for KBKDF Counter Mode test vectors

* Adding support for different rlen and counter location in KBKDF

* support for multiple L lengths and 24 bit counter length.

* Adding KBKDF Documentation.

* Refactoring KBKDF to KBKDFHMAC to describe hash algorithm used.
diff --git a/docs/hazmat/primitives/key-derivation-functions.rst b/docs/hazmat/primitives/key-derivation-functions.rst
index 4d95629..0415ccf 100644
--- a/docs/hazmat/primitives/key-derivation-functions.rst
+++ b/docs/hazmat/primitives/key-derivation-functions.rst
@@ -600,6 +600,155 @@
         raises an exception if they do not match.
 
 
+.. currentmodule:: cryptography.hazmat.primitives.kdf.kbkdf
+
+.. class:: KBKDFHMAC(algorithm, mode, length, rlen, llen, location,\
+           label, context, fixed, backend)
+
+    .. versionadded:: 1.4
+
+    KBKDF (Key Based Key Derivation Function) is defined by the
+    `NIST SP 800-108`_ document, to be used to derive additional
+    keys from a key that has been established through an automated
+    key-establishment scheme.
+
+    .. warning::
+
+        KBKDFHMAC should not be used for password storage.
+
+    .. doctest::
+
+        >>> import os
+        >>> from cryptography.hazmat.primitives import hashes
+        >>> from cryptography.hazmat.primitives.kdf.kbkdf import (
+        ...    CounterLocation, KBKDFHMAC, Mode
+        ... )
+        >>> from cryptography.hazmat.backends import default_backend
+        >>> backend = default_backend()
+        >>> label = b"KBKDF HMAC Label"
+        >>> context = b"KBKDF HMAC Context"
+        >>> kdf = KBKDFHMAC(
+        ...     algorithm=hashes.SHA256(),
+        ...     mode=Mode.CounterMode,
+        ...     length=256,
+        ...     rlen=4,
+        ...     llen=4,
+        ...     location=CounterLocation.BeforeFixed,
+        ...     label=label,
+        ...     context=context,
+        ...     fixed=None,
+        ...     backend=backend
+        ... )
+        >>> key = kdf.derive(b"input key")
+        >>> kdf = KBKDFHMAC(
+        ...     algorithm=hashes.SHA256(),
+        ...     mode=Mode.CounterMode,
+        ...     length=256,
+        ...     rlen=4,
+        ...     llen=4,
+        ...     location=CounterLocation.BeforeFixed,
+        ...     label=label,
+        ...     context=context,
+        ...     fixed=None,
+        ...     backend=backend
+        ... )
+        >>> kdf.verify(b"input key", key)
+
+    :param algorithm: An instance of a
+        :class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm`
+        provider
+
+    :param mode: The desired mode of the PRF. A value from the
+      :class:`~cryptography.hazmat.primitives.kdf.kbkdf.Mode` enum.
+
+    :param int length: The desired length of the derived key in bytes.
+
+    :param int rlen: An integer that indicates the length of the binary
+        representation of the counter in bytes.
+
+    :param int llen: An integer that indicates the binary
+        representation of the ``length`` in bytes.
+
+    :param location: The desired location of the counter. A value from the
+      :class:`~cryptography.hazmat.primitives.kdf.kbkdf.CounterLocation` enum.
+
+    :param bytes label: Application specific label information. If ``None``
+        is explicitly passed an empty byte string will be used.
+
+    :param bytes context: Application specific context information. If ``None``
+        is explicitly passed an empty byte string will be used.
+
+    :param bytes fixed: Instead of specifying ``label`` and ``context`` you
+        may supply your own fixed data. If ``fixed`` is specified, ``label``
+        and ``context`` is ignored.
+
+    :param backend: A cryptography backend
+        :class:`~cryptography.hazmat.backends.interfaces.HashBackend`
+        provider.
+
+    :raises cryptography.exceptions.UnsupportedAlgorithm: This is raised
+        if the provided ``backend`` does not implement
+        :class:`~cryptography.hazmat.backends.interfaces.HashBackend`
+
+    :raises TypeError: This exception is raised if ``label`` or ``context``
+        is not ``bytes``. Also raised if ``rlen`` or ``llen`` is not ``int``.
+
+    :raises ValueError: This exception is raised if ``rlen`` or ``llen``
+        is greater than 4 or less than 1. This exception is also raised if
+        you specify a ``label`` or ``context`` and ``fixed``.
+
+    .. method:: derive(key_material)
+
+        :param bytes key_material: The input key material.
+        :return bytes: The derived key.
+        :raises TypeError: This exception is raised if ``key_material`` is
+                            not ``bytes``.
+
+        Derives a new key from the input key material.
+
+    .. method:: verify(key_material, expected_key)
+
+        :param bytes key_material: The input key material. This is the same as
+                                   ``key_material`` in :meth:`derive`.
+        :param bytes expected_key: The expected result of deriving a new key,
+                                   this is the same as the return value of
+                                   :meth:`derive`.
+        :raises cryptography.exceptions.InvalidKey: This is raised when the
+                                                    derived key does not match
+                                                    the expected key.
+        :raises cryptography.exceptions.AlreadyFinalized: This is raised when
+                                                          :meth:`derive` or
+                                                          :meth:`verify` is
+                                                          called more than
+                                                          once.
+
+        This checks whether deriving a new key from the supplied
+        ``key_material`` generates the same key as the ``expected_key``, and
+        raises an exception if they do not match.
+
+.. class:: Mode
+
+    An enumeration for the key based key derivative modes.
+
+    .. attribute:: CounterMode
+
+        The output of the PRF is computed with a counter
+        as the iteration variable.
+
+.. class:: CounterLocation
+
+    An enumeration for the key based key derivative counter location.
+
+    .. attribute:: BeforeFixed
+
+        The counter iteration variable will be concatenated before
+        the fixed input data.
+
+    .. attribute:: AfterFixed
+
+        The counter iteration variable will be concatenated after
+        the fixed input data.
+
 Interface
 ~~~~~~~~~
 
@@ -648,6 +797,7 @@
 
 
 .. _`NIST SP 800-132`: http://csrc.nist.gov/publications/nistpubs/800-132/nist-sp800-132.pdf
+.. _`NIST SP 800-108`: http://csrc.nist.gov/publications/nistpubs/800-108/sp800-108.pdf
 .. _`NIST SP 800-56Ar2`: http://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-56Ar2.pdf
 .. _`ANSI X9.63:2001`: https://webstore.ansi.org
 .. _`SEC 1 v2.0`: http://www.secg.org/sec1-v2.pdf
diff --git a/src/cryptography/hazmat/primitives/kdf/kbkdf.py b/src/cryptography/hazmat/primitives/kdf/kbkdf.py
new file mode 100644
index 0000000..70a0fdc
--- /dev/null
+++ b/src/cryptography/hazmat/primitives/kdf/kbkdf.py
@@ -0,0 +1,148 @@
+# This file is dual licensed under the terms of the Apache License, Version
+# 2.0, and the BSD License. See the LICENSE file in the root of this repository
+# for complete details.
+
+from __future__ import absolute_import, division, print_function
+
+from enum import Enum
+
+from six.moves import range
+
+from cryptography import utils
+from cryptography.exceptions import (
+    AlreadyFinalized, InvalidKey, UnsupportedAlgorithm, _Reasons
+)
+from cryptography.hazmat.backends.interfaces import HMACBackend
+from cryptography.hazmat.primitives import constant_time, hashes, hmac
+from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
+
+
+class Mode(Enum):
+    CounterMode = "ctr"
+
+
+class CounterLocation(Enum):
+    BeforeFixed = "before_fixed"
+    AfterFixed = "after_fixed"
+
+
+@utils.register_interface(KeyDerivationFunction)
+class KBKDFHMAC(object):
+    def __init__(self, algorithm, mode, length, rlen, llen,
+                 location, label, context, fixed, backend):
+        if not isinstance(backend, HMACBackend):
+            raise UnsupportedAlgorithm(
+                "Backend object does not implement HMACBackend.",
+                _Reasons.BACKEND_MISSING_INTERFACE
+            )
+
+        if not isinstance(algorithm, hashes.HashAlgorithm):
+            raise UnsupportedAlgorithm(
+                "Algorithm supplied is not a supported hash algorithm.",
+                _Reasons.UNSUPPORTED_HASH
+            )
+
+        if not backend.hmac_supported(algorithm):
+            raise UnsupportedAlgorithm(
+                "Algorithm supplied is not a supported hmac algorithm.",
+                _Reasons.UNSUPPORTED_HASH
+            )
+
+        if not isinstance(mode, Mode):
+            raise TypeError("mode must be of type Mode")
+
+        if not isinstance(location, CounterLocation):
+            raise TypeError("location must be of type CounterLocation")
+
+        if (label or context) and fixed:
+            raise ValueError("When supplying fixed data, "
+                             "label and context are ignored.")
+
+        if rlen is None or not self._valid_byte_length(rlen):
+            raise ValueError("rlen must be between 1 and 4")
+
+        if llen is None and fixed is None:
+            raise ValueError("Please specify an llen")
+
+        if llen is not None and not isinstance(llen, int):
+            raise TypeError("llen must be an integer")
+
+        if label is None:
+            label = b''
+
+        if context is None:
+            context = b''
+
+        if (not isinstance(label, bytes) or
+                not isinstance(context, bytes)):
+            raise TypeError('label and context must be of type bytes')
+
+        self._algorithm = algorithm
+        self._mode = mode
+        self._length = length
+        self._rlen = rlen
+        self._llen = llen
+        self._location = location
+        self._label = label
+        self._context = context
+        self._backend = backend
+        self._used = False
+        self._fixed_data = fixed
+
+    def _valid_byte_length(self, value):
+        if not isinstance(value, int):
+            raise TypeError('value must be of type int')
+
+        value_bin = utils.int_to_bytes(1, value)
+        if not 1 <= len(value_bin) <= 4:
+            return False
+        return True
+
+    def derive(self, key_material):
+        if self._used:
+            raise AlreadyFinalized
+
+        if not isinstance(key_material, bytes):
+            raise TypeError('key_material must be bytes')
+        self._used = True
+
+        # inverse floor division (equivalent to ceiling)
+        rounds = -(-self._length // self._algorithm.digest_size)
+
+        output = [b'']
+
+        # For counter mode, the number of iterations shall not be
+        # larger than 2^r-1, where r ≤ 32 is the binary length of the counter
+        # This ensures that the counter values used as an input to the
+        # PRF will not repeat during a particular call to the KDF function.
+        r_bin = utils.int_to_bytes(1, self._rlen)
+        if rounds > pow(2, len(r_bin) * 8) - 1:
+            raise ValueError('There are too many iterations.')
+
+        for i in range(1, rounds + 1):
+            h = hmac.HMAC(key_material, self._algorithm, backend=self._backend)
+
+            counter = utils.int_to_bytes(i, self._rlen)
+            if self._location == CounterLocation.BeforeFixed:
+                h.update(counter)
+
+            h.update(self._generate_fixed_input())
+
+            if self._location == CounterLocation.AfterFixed:
+                h.update(counter)
+
+            output.append(h.finalize())
+
+        return b''.join(output)[:self._length]
+
+    def _generate_fixed_input(self):
+        if self._fixed_data and isinstance(self._fixed_data, bytes):
+            return self._fixed_data
+
+        l = utils.int_to_bytes(self._length * 8, self._llen)
+
+        return b"".join([self._label, b"\x00", self._context, l])
+
+    def verify(self, key_material, expected_key):
+        if not constant_time.bytes_eq(self.derive(key_material), expected_key):
+            raise InvalidKey
diff --git a/tests/hazmat/primitives/test_kbkdf.py b/tests/hazmat/primitives/test_kbkdf.py
new file mode 100644
index 0000000..45a53ac
--- /dev/null
+++ b/tests/hazmat/primitives/test_kbkdf.py
@@ -0,0 +1,151 @@
+# This file is dual licensed under the terms of the Apache License, Version
+# 2.0, and the BSD License. See the LICENSE file in the root of this repository
+# for complete details.
+
+from __future__ import absolute_import, division, print_function
+
+import pytest
+
+from cryptography.exceptions import (
+    AlreadyFinalized, InvalidKey, _Reasons
+)
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.kdf.kbkdf import (
+    CounterLocation, KBKDFHMAC, Mode
+)
+
+from ...doubles import DummyHashAlgorithm
+from ...utils import raises_unsupported_algorithm
+
+
+class TestKBKDFHMAC(object):
+    def test_invalid_key(self):
+        kdf = KBKDFHMAC(hashes.SHA256(), Mode.CounterMode, 32, 4, 4,
+                        CounterLocation.BeforeFixed, b'label', b'context',
+                        None, backend=default_backend())
+
+        key = kdf.derive(b"material")
+
+        kdf = KBKDFHMAC(hashes.SHA256(), Mode.CounterMode, 32, 4, 4,
+                        CounterLocation.BeforeFixed, b'label', b'context',
+                        None, backend=default_backend())
+
+        with pytest.raises(InvalidKey):
+            kdf.verify(b"material2", key)
+
+    def test_already_finalized(self):
+        kdf = KBKDFHMAC(hashes.SHA256(), Mode.CounterMode, 32, 4, 4,
+                        CounterLocation.BeforeFixed, b'label', b'context',
+                        None, backend=default_backend())
+
+        kdf.derive(b'material')
+
+        with pytest.raises(AlreadyFinalized):
+            kdf.derive(b'material2')
+
+        kdf = KBKDFHMAC(hashes.SHA256(), Mode.CounterMode, 32, 4, 4,
+                        CounterLocation.BeforeFixed, b'label', b'context',
+                        None, backend=default_backend())
+
+        key = kdf.derive(b'material')
+
+        with pytest.raises(AlreadyFinalized):
+            kdf.verify(b'material', key)
+
+        kdf = KBKDFHMAC(hashes.SHA256(), Mode.CounterMode, 32, 4, 4,
+                        CounterLocation.BeforeFixed, b'label', b'context',
+                        None, backend=default_backend())
+        kdf.verify(b'material', key)
+
+        with pytest.raises(AlreadyFinalized):
+            kdf.verify(b"material", key)
+
+    def test_key_length(self):
+        kdf = KBKDFHMAC(hashes.SHA1(), Mode.CounterMode, 85899345920, 4, 4,
+                        CounterLocation.BeforeFixed, b'label', b'context',
+                        None, backend=default_backend())
+
+        with pytest.raises(ValueError):
+            kdf.derive(b'material')
+
+    def test_rlen(self):
+        with pytest.raises(ValueError):
+            KBKDFHMAC(hashes.SHA256(), Mode.CounterMode, 32, 5, 4,
+                      CounterLocation.BeforeFixed, b'label', b'context',
+                      None, backend=default_backend())
+
+    def test_r_type(self):
+        with pytest.raises(TypeError):
+            KBKDFHMAC(hashes.SHA1(), Mode.CounterMode, 32, b'r', 4,
+                      CounterLocation.BeforeFixed, b'label', b'context',
+                      None, backend=default_backend())
+
+    def test_l_type(self):
+        with pytest.raises(TypeError):
+            KBKDFHMAC(hashes.SHA1(), Mode.CounterMode, 32, 4, b'l',
+                      CounterLocation.BeforeFixed, b'label', b'context',
+                      None, backend=default_backend())
+
+    def test_l(self):
+        with pytest.raises(ValueError):
+            KBKDFHMAC(hashes.SHA1(), Mode.CounterMode, 32, 4, None,
+                      CounterLocation.BeforeFixed, b'label', b'context',
+                      None, backend=default_backend())
+
+    def test_unsupported_mode(self):
+        with pytest.raises(TypeError):
+            KBKDFHMAC(hashes.SHA256(), None, 32, 4, 4,
+                      CounterLocation.BeforeFixed, b'label', b'context',
+                      None, backend=default_backend())
+
+    def test_unsupported_location(self):
+        with pytest.raises(TypeError):
+            KBKDFHMAC(hashes.SHA256(), Mode.CounterMode, 32, 4, 4,
+                      None, b'label', b'context', None,
+                      backend=default_backend())
+
+    def test_unsupported_parameters(self):
+        with pytest.raises(ValueError):
+            KBKDFHMAC(hashes.SHA256(), Mode.CounterMode, 32, 4, 4,
+                      CounterLocation.BeforeFixed, b'label', b'context',
+                      b'fixed', backend=default_backend())
+
+    def test_unsupported_hash(self):
+        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_HASH):
+            KBKDFHMAC(object(), Mode.CounterMode, 32, 4, 4,
+                      CounterLocation.BeforeFixed, b'label', b'context',
+                      None, backend=default_backend())
+
+    def test_unsupported_algorithm(self):
+        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_HASH):
+            KBKDFHMAC(DummyHashAlgorithm(), Mode.CounterMode, 32, 4, 4,
+                      CounterLocation.BeforeFixed, b'label', b'context',
+                      None, backend=default_backend())
+
+    def test_invalid_backend(self):
+        mock_backend = object
+
+        with raises_unsupported_algorithm(_Reasons.BACKEND_MISSING_INTERFACE):
+            KBKDFHMAC(hashes.SHA256(), Mode.CounterMode, 32, 4, 4,
+                      CounterLocation.BeforeFixed, b'label', b'context',
+                      None, backend=mock_backend())
+
+    def test_unicode_error_label(self):
+        with pytest.raises(TypeError):
+            KBKDFHMAC(hashes.SHA256(), Mode.CounterMode, 32, 4, 4,
+                      CounterLocation.BeforeFixed, u'label', b'context',
+                      backend=default_backend())
+
+    def test_unicode_error_context(self):
+        with pytest.raises(TypeError):
+            KBKDFHMAC(hashes.SHA256(), Mode.CounterMode, 32, 4, 4,
+                      CounterLocation.BeforeFixed, b'label', u'context',
+                      None, backend=default_backend())
+
+    def test_unicode_error_key_material(self):
+        with pytest.raises(TypeError):
+            kdf = KBKDFHMAC(hashes.SHA256(), Mode.CounterMode, 32, 4, 4,
+                            CounterLocation.BeforeFixed, b'label',
+                            b'context', None, backend=default_backend())
+            kdf.derive(u'material')
diff --git a/tests/hazmat/primitives/test_kbkdf_vectors.py b/tests/hazmat/primitives/test_kbkdf_vectors.py
new file mode 100644
index 0000000..c8263e2
--- /dev/null
+++ b/tests/hazmat/primitives/test_kbkdf_vectors.py
@@ -0,0 +1,23 @@
+# This file is dual licensed under the terms of the Apache License, Version
+# 2.0, and the BSD License. See the LICENSE file in the root of this repository
+# for complete details.
+
+from __future__ import absolute_import, division, print_function
+
+import os
+
+import pytest
+
+from cryptography.hazmat.backends.interfaces import HMACBackend
+
+from .utils import generate_kbkdf_counter_mode_test
+from ...utils import load_nist_kbkdf_vectors
+
+
+@pytest.mark.requires_backend_interface(interface=HMACBackend)
+class TestCounterKDFCounterMode(object):
+    test_HKDFSHA1 = generate_kbkdf_counter_mode_test(
+        load_nist_kbkdf_vectors,
+        os.path.join("KDF"),
+        ["nist-800-108-KBKDF-CTR.txt"]
+    )
diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py
index e148bc6..e45466d 100644
--- a/tests/hazmat/primitives/utils.py
+++ b/tests/hazmat/primitives/utils.py
@@ -18,6 +18,9 @@
 from cryptography.hazmat.primitives.asymmetric import rsa
 from cryptography.hazmat.primitives.ciphers import Cipher
 from cryptography.hazmat.primitives.kdf.hkdf import HKDF, HKDFExpand
+from cryptography.hazmat.primitives.kdf.kbkdf import (
+    CounterLocation, KBKDFHMAC, Mode
+)
 from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
 
 from ...utils import load_vectors_from_file
@@ -370,6 +373,55 @@
     return test_hkdf
 
 
+def generate_kbkdf_counter_mode_test(param_loader, path, file_names):
+    all_params = _load_all_params(path, file_names, param_loader)
+
+    @pytest.mark.parametrize("params", all_params)
+    def test_kbkdf(self, backend, params):
+        kbkdf_counter_mode_test(backend, params)
+    return test_kbkdf
+
+
+def kbkdf_counter_mode_test(backend, params):
+    supported_algorithms = {
+        'hmac_sha1': hashes.SHA1,
+        'hmac_sha224': hashes.SHA224,
+        'hmac_sha256': hashes.SHA256,
+        'hmac_sha384': hashes.SHA384,
+        'hmac_sha512': hashes.SHA512,
+    }
+
+    supportd_counter_locations = {
+        "before_fixed": CounterLocation.BeforeFixed,
+        "after_fixed": CounterLocation.AfterFixed,
+    }
+
+    algorithm = supported_algorithms.get(params.get('prf'))
+    if algorithm is None or not backend.hmac_supported(algorithm()):
+        pytest.skip('Does not support algorithm')
+
+    ctr_loc = supportd_counter_locations.get(params.get("ctrlocation"))
+    if ctr_loc is None or not isinstance(ctr_loc, CounterLocation):
+        pytest.skip("Does not support counter location".format(
+            location=params.get('ctrlocation')
+        ))
+
+    ctrkdf = KBKDFHMAC(
+        algorithm(),
+        Mode.CounterMode,
+        params['l'] // 8,
+        params['rlen'] // 8,
+        None,
+        ctr_loc,
+        None,
+        None,
+        binascii.unhexlify(params['fixedinputdata']),
+        backend=backend)
+
+    ko = ctrkdf.derive(binascii.unhexlify(params['ki']))
+    assert binascii.hexlify(ko) == params["ko"]
+
+
 def generate_rsa_verification_test(param_loader, path, file_names, hash_alg,
                                    pad_factory):
     all_params = _load_all_params(path, file_names, param_loader)