Merge pull request #2343 from simo5/ANS_X963

Implement ANSI X9.63 KDF
diff --git a/docs/hazmat/primitives/key-derivation-functions.rst b/docs/hazmat/primitives/key-derivation-functions.rst
index 35e2dd8..10b806e 100644
--- a/docs/hazmat/primitives/key-derivation-functions.rst
+++ b/docs/hazmat/primitives/key-derivation-functions.rst
@@ -506,6 +506,99 @@
         ``key_material`` generates the same key as the ``expected_key``, and
         raises an exception if they do not match.
 
+.. currentmodule:: cryptography.hazmat.primitives.kdf.x963kdf
+
+.. class:: X963KDF(algorithm, length, otherinfo, backend)
+
+    .. versionadded:: 1.1
+
+    X963KDF (ANSI X9.63 Key Derivation Function) is defined by ANSI
+    in the `ANSI X9.63:2001`_ document, to be used to derive keys for use
+    after a Key Exchange negotiation operation.
+
+    SECG in `SEC 1 v2.0`_ recommends that
+    :class:`~cryptography.hazmat.primitives.kdf.concatkdf.ConcatKDFHash` be
+    used for new projects. This KDF should only be used for backwards
+    compatibility with pre-existing implementations.
+
+
+    .. warning::
+
+        X963KDF should not be used for password storage.
+
+    .. doctest::
+
+        >>> import os
+        >>> from cryptography.hazmat.primitives import hashes
+        >>> from cryptography.hazmat.primitives.kdf.x963kdf import X963KDF
+        >>> from cryptography.hazmat.backends import default_backend
+        >>> backend = default_backend()
+        >>> sharedinfo = b"ANSI X9.63 Example"
+        >>> xkdf = X963KDF(
+        ...     algorithm=hashes.SHA256(),
+        ...     length=256,
+        ...     sharedinfo=sharedinfo,
+        ...     backend=backend
+        ... )
+        >>> key = xkdf.derive(b"input key")
+        >>> xkdf = X963KDF(
+        ...     algorithm=hashes.SHA256(),
+        ...     length=256,
+        ...     sharedinfo=sharedinfo,
+        ...     backend=backend
+        ... )
+        >>> xkdf.verify(b"input key", key)
+
+    :param algorithm: An instance of a
+        :class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm`
+        provider
+
+    :param int length: The desired length of the derived key in bytes.
+        Maximum is ``hashlen * (2^32 -1)``.
+
+    :param bytes sharedinfo: Application specific context information.
+        If ``None`` is explicitly passed an empty byte string will be used.
+
+    :param backend: A cryptography backend
+        :class:`~cryptography.hazmat.backends.interfaces.HashBackend`
+        provider.
+
+    :raises cryptography.exceptions.UnsupportedAlgorithm: This is raised
+        if the provided ``backend`` does not implement
+        :class:`~cryptography.hazmat.backends.interfaces.HashBackend`
+
+    :raises TypeError: This exception is raised if ``sharedinfo`` is not
+        ``bytes``.
+
+    .. method:: derive(key_material)
+
+        :param bytes key_material: The input key material.
+        :return bytes: The derived key.
+        :raises TypeError: This exception is raised if ``key_material`` is
+                            not ``bytes``.
+
+        Derives a new key from the input key material.
+
+    .. method:: verify(key_material, expected_key)
+
+        :param bytes key_material: The input key material. This is the same as
+                                   ``key_material`` in :meth:`derive`.
+        :param bytes expected_key: The expected result of deriving a new key,
+                                   this is the same as the return value of
+                                   :meth:`derive`.
+        :raises cryptography.exceptions.InvalidKey: This is raised when the
+                                                    derived key does not match
+                                                    the expected key.
+        :raises cryptography.exceptions.AlreadyFinalized: This is raised when
+                                                          :meth:`derive` or
+                                                          :meth:`verify` is
+                                                          called more than
+                                                          once.
+
+        This checks whether deriving a new key from the supplied
+        ``key_material`` generates the same key as the ``expected_key``, and
+        raises an exception if they do not match.
+
 
 Interface
 ~~~~~~~~~
@@ -556,6 +649,8 @@
 
 .. _`NIST SP 800-132`: http://csrc.nist.gov/publications/nistpubs/800-132/nist-sp800-132.pdf
 .. _`NIST SP 800-56Ar2`: http://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-56Ar2.pdf
+.. _`ANSI X9.63:2001`: https://webstore.ansi.org
+.. _`SEC 1 v2.0`: http://www.secg.org/sec1-v2.pdf
 .. _`Password Storage Cheat Sheet`: https://www.owasp.org/index.php/Password_Storage_Cheat_Sheet
 .. _`PBKDF2`: https://en.wikipedia.org/wiki/PBKDF2
 .. _`scrypt`: https://en.wikipedia.org/wiki/Scrypt
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 7549784..16aa5ef 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -47,6 +47,7 @@
 paddings
 pickleable
 plaintext
+pre
 preprocessor
 preprocessors
 pseudorandom
diff --git a/src/cryptography/hazmat/primitives/kdf/x963kdf.py b/src/cryptography/hazmat/primitives/kdf/x963kdf.py
new file mode 100644
index 0000000..83789b3
--- /dev/null
+++ b/src/cryptography/hazmat/primitives/kdf/x963kdf.py
@@ -0,0 +1,70 @@
+# This file is dual licensed under the terms of the Apache License, Version
+# 2.0, and the BSD License. See the LICENSE file in the root of this repository
+# for complete details.
+
+from __future__ import absolute_import, division, print_function
+
+import struct
+
+from cryptography import utils
+from cryptography.exceptions import (
+    AlreadyFinalized, InvalidKey, UnsupportedAlgorithm, _Reasons
+)
+from cryptography.hazmat.backends.interfaces import HashBackend
+from cryptography.hazmat.primitives import constant_time, hashes
+from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
+
+
+def _int_to_u32be(n):
+    return struct.pack('>I', n)
+
+
+@utils.register_interface(KeyDerivationFunction)
+class X963KDF(object):
+    def __init__(self, algorithm, length, sharedinfo, backend):
+
+        max_len = algorithm.digest_size * (2 ** 32 - 1)
+        if length > max_len:
+            raise ValueError(
+                "Can not derive keys larger than {0} bits.".format(max_len))
+        if not (sharedinfo is None or isinstance(sharedinfo, bytes)):
+            raise TypeError("sharedinfo must be bytes.")
+        self._algorithm = algorithm
+        self._length = length
+        self._sharedinfo = sharedinfo
+
+        if not isinstance(backend, HashBackend):
+            raise UnsupportedAlgorithm(
+                "Backend object does not implement HashBackend.",
+                _Reasons.BACKEND_MISSING_INTERFACE
+            )
+        self._backend = backend
+        self._used = False
+
+    def derive(self, key_material):
+        if self._used:
+            raise AlreadyFinalized
+        self._used = True
+
+        if not isinstance(key_material, bytes):
+            raise TypeError("key_material must be bytes.")
+
+        output = [b""]
+        outlen = 0
+        counter = 1
+
+        while self._length > outlen:
+            h = hashes.Hash(self._algorithm, self._backend)
+            h.update(key_material)
+            h.update(_int_to_u32be(counter))
+            if self._sharedinfo is not None:
+                h.update(self._sharedinfo)
+            output.append(h.finalize())
+            outlen += len(output[-1])
+            counter += 1
+
+        return b"".join(output)[:self._length]
+
+    def verify(self, key_material, expected_key):
+        if not constant_time.bytes_eq(self.derive(key_material), expected_key):
+            raise InvalidKey
diff --git a/tests/hazmat/primitives/test_X963_vectors.py b/tests/hazmat/primitives/test_X963_vectors.py
new file mode 100644
index 0000000..14bcff4
--- /dev/null
+++ b/tests/hazmat/primitives/test_X963_vectors.py
@@ -0,0 +1,72 @@
+# This file is dual licensed under the terms of the Apache License, Version
+# 2.0, and the BSD License. See the LICENSE file in the root of this repository
+# for complete details.
+
+from __future__ import absolute_import, division, print_function
+
+import binascii
+import os
+
+import pytest
+
+from cryptography import utils
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.backends.interfaces import HashBackend
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.kdf.x963kdf import X963KDF
+
+from ...utils import load_vectors_from_file, load_x963_vectors
+
+
+@utils.register_interface(hashes.HashAlgorithm)
+class UnsupportedDummyHash(object):
+    name = "unsupported-dummy-hash"
+    block_size = None
+    digest_size = None
+
+
+def _skip_hashfn_unsupported(backend, hashfn):
+    if not backend.hash_supported(hashfn):
+        pytest.skip(
+            "Hash {0} is not supported by this backend {1}".format(
+                hashfn.name, backend
+            )
+        )
+
+
+@pytest.mark.requires_backend_interface(interface=HashBackend)
+class TestX963(object):
+    _algorithms_dict = {
+        'SHA-1': hashes.SHA1,
+        'SHA-224': hashes.SHA224,
+        'SHA-256': hashes.SHA256,
+        'SHA-384': hashes.SHA384,
+        'SHA-512': hashes.SHA512
+    }
+
+    @pytest.mark.parametrize(
+        ("vector"),
+        load_vectors_from_file(
+            os.path.join("KDF", "ansx963_2001.txt"),
+            load_x963_vectors
+        )
+    )
+    def test_x963(self, backend, vector):
+        hashfn = self._algorithms_dict[vector["hash"]]
+        _skip_hashfn_unsupported(backend, hashfn())
+
+        key = binascii.unhexlify(vector["Z"])
+        sharedinfo = None
+        if vector["SharedInfo length"] != 0:
+            sharedinfo = binascii.unhexlify(vector["SharedInfo"])
+        key_data_len = vector["key data length"] // 8
+        key_data = binascii.unhexlify(vector["key_data"])
+
+        xkdf = X963KDF(algorithm=hashfn(),
+                       length=key_data_len,
+                       sharedinfo=sharedinfo,
+                       backend=default_backend())
+        xkdf.verify(key, key_data)
+
+    def test_unsupported_hash(self, backend):
+        _skip_hashfn_unsupported(backend, UnsupportedDummyHash())
diff --git a/tests/hazmat/primitives/test_x963kdf.py b/tests/hazmat/primitives/test_x963kdf.py
new file mode 100644
index 0000000..d87a46b
--- /dev/null
+++ b/tests/hazmat/primitives/test_x963kdf.py
@@ -0,0 +1,120 @@
+# This file is dual licensed under the terms of the Apache License, Version
+# 2.0, and the BSD License. See the LICENSE file in the root of this repository
+# for complete details.
+
+from __future__ import absolute_import, division, print_function
+
+import binascii
+
+import pytest
+
+from cryptography.exceptions import (
+    AlreadyFinalized, InvalidKey, _Reasons
+)
+from cryptography.hazmat.backends.interfaces import HashBackend
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.kdf.x963kdf import X963KDF
+
+from ...utils import raises_unsupported_algorithm
+
+
+@pytest.mark.requires_backend_interface(interface=HashBackend)
+class TestX963KDF(object):
+    def test_length_limit(self, backend):
+        big_length = hashes.SHA256().digest_size * (2 ** 32 - 1) + 1
+
+        with pytest.raises(ValueError):
+            X963KDF(hashes.SHA256(), big_length, None, backend)
+
+    def test_already_finalized(self, backend):
+        xkdf = X963KDF(hashes.SHA256(), 16, None, backend)
+
+        xkdf.derive(b"\x01" * 16)
+
+        with pytest.raises(AlreadyFinalized):
+            xkdf.derive(b"\x02" * 16)
+
+    def test_derive(self, backend):
+        key = binascii.unhexlify(
+            b"96c05619d56c328ab95fe84b18264b08725b85e33fd34f08"
+        )
+
+        derivedkey = binascii.unhexlify(b"443024c3dae66b95e6f5670601558f71")
+
+        xkdf = X963KDF(hashes.SHA256(), 16, None, backend)
+
+        assert xkdf.derive(key) == derivedkey
+
+    def test_verify(self, backend):
+        key = binascii.unhexlify(
+            b"22518b10e70f2a3f243810ae3254139efbee04aa57c7af7d"
+        )
+
+        sharedinfo = binascii.unhexlify(b"75eef81aa3041e33b80971203d2c0c52")
+
+        derivedkey = binascii.unhexlify(
+            b"c498af77161cc59f2962b9a713e2b215152d139766ce34a776df11866a69bf2e"
+            b"52a13d9c7c6fc878c50c5ea0bc7b00e0da2447cfd874f6cf92f30d0097111485"
+            b"500c90c3af8b487872d04685d14c8d1dc8d7fa08beb0ce0ababc11f0bd496269"
+            b"142d43525a78e5bc79a17f59676a5706dc54d54d4d1f0bd7e386128ec26afc21"
+        )
+
+        xkdf = X963KDF(hashes.SHA256(), 128, sharedinfo, backend)
+
+        assert xkdf.verify(key, derivedkey) is None
+
+    def test_invalid_verify(self, backend):
+        key = binascii.unhexlify(
+            b"96c05619d56c328ab95fe84b18264b08725b85e33fd34f08"
+        )
+
+        xkdf = X963KDF(hashes.SHA256(), 16, None, backend)
+
+        with pytest.raises(InvalidKey):
+            xkdf.verify(key, b"wrong derived key")
+
+    def test_unicode_typeerror(self, backend):
+        with pytest.raises(TypeError):
+            X963KDF(
+                hashes.SHA256(),
+                16,
+                sharedinfo=u"foo",
+                backend=backend
+            )
+
+        with pytest.raises(TypeError):
+            xkdf = X963KDF(
+                hashes.SHA256(),
+                16,
+                sharedinfo=None,
+                backend=backend
+            )
+
+            xkdf.derive(u"foo")
+
+        with pytest.raises(TypeError):
+            xkdf = X963KDF(
+                hashes.SHA256(),
+                16,
+                sharedinfo=None,
+                backend=backend
+            )
+
+            xkdf.verify(u"foo", b"bar")
+
+        with pytest.raises(TypeError):
+            xkdf = X963KDF(
+                hashes.SHA256(),
+                16,
+                sharedinfo=None,
+                backend=backend
+            )
+
+            xkdf.verify(b"foo", u"bar")
+
+
+def test_invalid_backend():
+    pretend_backend = object()
+
+    with raises_unsupported_algorithm(_Reasons.BACKEND_MISSING_INTERFACE):
+        X963KDF(hashes.SHA256(), 16, None, pretend_backend)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 210e929..023f57c 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -24,6 +24,7 @@
     load_hash_vectors, load_kasvs_dh_vectors,
     load_kasvs_ecdh_vectors, load_nist_vectors,
     load_pkcs1_vectors, load_rsa_nist_vectors, load_vectors_from_file,
+    load_x963_vectors,
     raises_unsupported_algorithm, select_backends, skip_if_empty
 )
 
@@ -3410,6 +3411,86 @@
     assert expected == load_kasvs_ecdh_vectors(vector_data)
 
 
+def test_load_x963_vectors():
+    vector_data = textwrap.dedent("""
+    # CAVS 12.0
+    # 'ANS X9.63-2001' information for sample
+
+    [SHA-1]
+    [shared secret length = 192]
+    [SharedInfo length = 0]
+    [key data length = 128]
+
+    COUNT = 0
+    Z = 1c7d7b5f0597b03d06a018466ed1a93e30ed4b04dc64ccdd
+    SharedInfo =
+        Counter = 00000001
+        Hash input 1 = 1c7d7b5f0597b03d06a018466ed1a93e30ed4b04dc64ccdd00000001
+        K1 = bf71dffd8f4d99223936beb46fee8ccc60439b7e
+    key_data = bf71dffd8f4d99223936beb46fee8ccc
+
+    COUNT = 1
+    Z = 5ed096510e3fcf782ceea98e9737993e2b21370f6cda2ab1
+    SharedInfo =
+        Counter = 00000001
+        Hash input 1 = 5ed096510e3fcf782ceea98e9737993e2b21370f6cda2ab100000001
+        K1 = ec3e224446bfd7b3be1df404104af953c1b2d0f5
+    key_data = ec3e224446bfd7b3be1df404104af953
+
+    [SHA-512]
+    [shared secret length = 521]
+    [SharedInfo length = 128]
+    [key data length = 1024]
+
+    COUNT = 0
+    Z = 00aa5bb79b33e389fa58ceadc047197f14e73712f452caa9fc4c9adb369348b8150739\
+2f1a86ddfdb7c4ff8231c4bd0f44e44a1b55b1404747a9e2e753f55ef05a2d
+    SharedInfo = e3b5b4c1b0d5cf1d2b3a2f9937895d31
+        Counter = 00000001
+        Hash input 1 = 00aa5bb79b33e389fa58ceadc047197f14e73712f452caa9fc4c9ad\
+b369348b81507392f1a86ddfdb7c4ff8231c4bd0f44e44a1b55b1404747a9e2e753f55ef05a2d0\
+0000001e3b5b4c1b0d5cf1d2b3a2f9937895d31
+        K1 = 4463f869f3cc18769b52264b0112b5858f7ad32a5a2d96d8cffabf7fa733633d6\
+e4dd2a599acceb3ea54a6217ce0b50eef4f6b40a5c30250a5a8eeee20800226
+        Counter = 00000002
+        Hash input 2 = 00aa5bb79b33e389fa58ceadc047197f14e73712f452caa9fc4c9ad\
+b369348b81507392f1a86ddfdb7c4ff8231c4bd0f44e44a1b55b1404747a9e2e753f55ef05a2d0\
+0000002e3b5b4c1b0d5cf1d2b3a2f9937895d31
+        K2 = 7089dbf351f3f5022aa9638bf1ee419dea9c4ff745a25ac27bda33ca08bd56dd1\
+a59b4106cf2dbbc0ab2aa8e2efa7b17902d34276951ceccab87f9661c3e8816
+    key_data = 4463f869f3cc18769b52264b0112b5858f7ad32a5a2d96d8cffabf7fa733633\
+d6e4dd2a599acceb3ea54a6217ce0b50eef4f6b40a5c30250a5a8eeee208002267089dbf351f3f\
+5022aa9638bf1ee419dea9c4ff745a25ac27bda33ca08bd56dd1a59b4106cf2dbbc0ab2aa8e2ef\
+a7b17902d34276951ceccab87f9661c3e8816
+    """).splitlines()
+
+    assert load_x963_vectors(vector_data) == [
+        {"hash": "SHA-1", "count": 0,
+         "shared secret length": 192,
+         "Z": "1c7d7b5f0597b03d06a018466ed1a93e30ed4b04dc64ccdd",
+         "SharedInfo length": 0,
+         "key data length": 128,
+         "key_data": "bf71dffd8f4d99223936beb46fee8ccc"},
+        {"hash": "SHA-1", "count": 1,
+         "shared secret length": 192,
+         "Z": "5ed096510e3fcf782ceea98e9737993e2b21370f6cda2ab1",
+         "SharedInfo length": 0,
+         "key data length": 128,
+         "key_data": "ec3e224446bfd7b3be1df404104af953"},
+        {"hash": "SHA-512", "count": 0,
+         "shared secret length": 521,
+         "Z": "00aa5bb79b33e389fa58ceadc047197f14e73712f452caa9fc4c9adb369348b\
+81507392f1a86ddfdb7c4ff8231c4bd0f44e44a1b55b1404747a9e2e753f55ef05a2d",
+         "SharedInfo length": 128,
+         "SharedInfo": "e3b5b4c1b0d5cf1d2b3a2f9937895d31",
+         "key data length": 1024,
+         "key_data": "4463f869f3cc18769b52264b0112b5858f7ad32a5a2d96d8cffabf7f\
+a733633d6e4dd2a599acceb3ea54a6217ce0b50eef4f6b40a5c30250a5a8eeee208002267089db\
+f351f3f5022aa9638bf1ee419dea9c4ff745a25ac27bda33ca08bd56dd1a59b4106cf2dbbc0ab2\
+aa8e2efa7b17902d34276951ceccab87f9661c3e8816"},
+    ]
+
+
 def test_vector_version():
     assert cryptography.__version__ == cryptography_vectors.__version__
 
diff --git a/tests/utils.py b/tests/utils.py
index 7e7abdf..c0052c9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -6,6 +6,7 @@
 
 import binascii
 import collections
+import math
 import re
 from contextlib import contextmanager
 
@@ -760,3 +761,51 @@
             }
 
     return vectors
+
+
+def load_x963_vectors(vector_data):
+    """
+    Loads data out of the X9.63 vector data
+    """
+
+    vectors = []
+
+    # Sets Metadata
+    hashname = None
+    vector = dict()
+    for line in vector_data:
+        line = line.strip()
+
+        if line.startswith("[SHA"):
+            hashname = line[1:-1]
+            shared_secret_len = 0
+            shared_info_len = 0
+            key_data_len = 0
+        elif line.startswith("[shared secret length"):
+            shared_secret_len = int(line[1:-1].split("=")[1].strip())
+        elif line.startswith("[SharedInfo length"):
+            shared_info_len = int(line[1:-1].split("=")[1].strip())
+        elif line.startswith("[key data length"):
+            key_data_len = int(line[1:-1].split("=")[1].strip())
+        elif line.startswith("COUNT"):
+            count = int(line.split("=")[1].strip())
+            vector["hash"] = hashname
+            vector["count"] = count
+            vector["shared secret length"] = shared_secret_len
+            vector["SharedInfo length"] = shared_info_len
+            vector["key data length"] = key_data_len
+        elif line.startswith("Z"):
+            vector["Z"] = line.split("=")[1].strip()
+            assert math.ceil(shared_secret_len / 8) * 2 == len(vector["Z"])
+        elif line.startswith("SharedInfo"):
+            if shared_info_len != 0:
+                vector["SharedInfo"] = line.split("=")[1].strip()
+                silen = len(vector["SharedInfo"])
+                assert math.ceil(shared_info_len / 8) * 2 == silen
+        elif line.startswith("key_data"):
+            vector["key_data"] = line.split("=")[1].strip()
+            assert math.ceil(key_data_len / 8) * 2 == len(vector["key_data"])
+            vectors.append(vector)
+            vector = dict()
+
+    return vectors