Modified HKDF to use HKDFExpand
diff --git a/cryptography/hazmat/primitives/kdf/hkdf.py b/cryptography/hazmat/primitives/kdf/hkdf.py
index 44e1481..d49cc5b 100644
--- a/cryptography/hazmat/primitives/kdf/hkdf.py
+++ b/cryptography/hazmat/primitives/kdf/hkdf.py
@@ -34,6 +34,57 @@
self._algorithm = algorithm
+ if isinstance(salt, six.text_type):
+ raise TypeError(
+ "Unicode-objects must be encoded before using them as a salt.")
+
+ if salt is None:
+ salt = b"\x00" * (self._algorithm.digest_size // 8)
+
+ self._salt = salt
+
+ self._backend = backend
+
+ self._used = False
+
+ self._hkdf_expand = HKDFExpand(self._algorithm, length, info, backend)
+
+ def _extract(self, key_material):
+ h = hmac.HMAC(self._salt, self._algorithm, backend=self._backend)
+ h.update(key_material)
+ return h.finalize()
+
+ def derive(self, key_material):
+ if isinstance(key_material, six.text_type):
+ raise TypeError(
+ "Unicode-objects must be encoded before using them as key "
+ "material."
+ )
+
+ if self._used:
+ raise AlreadyFinalized
+
+ self._used = True
+ return self._hkdf_expand.derive(self._extract(key_material))
+
+ def verify(self, key_material, expected_key):
+ if not constant_time.bytes_eq(self.derive(key_material), expected_key):
+ raise InvalidKey
+
+
+@utils.register_interface(interfaces.KeyDerivationFunction)
+class HKDFExpand(HKDF):
+ def __init__(self, algorithm, length, info, backend):
+ if not isinstance(backend, HMACBackend):
+ raise UnsupportedAlgorithm(
+ "Backend object does not implement HMACBackend",
+ _Reasons.BACKEND_MISSING_INTERFACE
+ )
+
+ self._algorithm = algorithm
+
+ self._backend = backend
+
max_length = 255 * (algorithm.digest_size // 8)
if length > max_length:
@@ -44,15 +95,6 @@
self._length = length
- if isinstance(salt, six.text_type):
- raise TypeError(
- "Unicode-objects must be encoded before using them as a salt.")
-
- if salt is None:
- salt = b"\x00" * (self._algorithm.digest_size // 8)
-
- self._salt = salt
-
if isinstance(info, six.text_type):
raise TypeError(
"Unicode-objects must be encoded before using them as info.")
@@ -61,15 +103,9 @@
info = b""
self._info = info
- self._backend = backend
self._used = False
- def _extract(self, key_material):
- h = hmac.HMAC(self._salt, self._algorithm, backend=self._backend)
- h.update(key_material)
- return h.finalize()
-
def _expand(self, key_material):
output = [b""]
counter = 1
@@ -87,29 +123,6 @@
def derive(self, key_material):
if isinstance(key_material, six.text_type):
raise TypeError(
- "Unicode-objects must be encoded before using them as key "
- "material."
- )
-
- if self._used:
- raise AlreadyFinalized
-
- self._used = True
- return self._expand(self._extract(key_material))
-
- def verify(self, key_material, expected_key):
- if not constant_time.bytes_eq(self.derive(key_material), expected_key):
- raise InvalidKey
-
-
-@utils.register_interface(interfaces.KeyDerivationFunction)
-class HKDFExpand(HKDF):
- def __init__(self, algorithm, length, info, backend):
- HKDF.__init__(self, algorithm, length, None, info, backend)
-
- def derive(self, key_material):
- if isinstance(key_material, six.text_type):
- raise TypeError(
"Unicode-objects must be encoded before using them as key"
"material."
)
diff --git a/tests/hazmat/primitives/test_hkdf.py b/tests/hazmat/primitives/test_hkdf.py
index bee4217..598f09f 100644
--- a/tests/hazmat/primitives/test_hkdf.py
+++ b/tests/hazmat/primitives/test_hkdf.py
@@ -214,3 +214,6 @@
with raises_unsupported_algorithm(_Reasons.BACKEND_MISSING_INTERFACE):
HKDF(hashes.SHA256(), 16, None, None, pretend_backend)
+
+ with raises_unsupported_algorithm(_Reasons.BACKEND_MISSING_INTERFACE):
+ HKDFExpand(hashes.SHA256(), 16, None, pretend_backend)
diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py
index 6c3f4c9..7cf5efd 100644
--- a/tests/hazmat/primitives/utils.py
+++ b/tests/hazmat/primitives/utils.py
@@ -26,7 +26,7 @@
from cryptography.hazmat.primitives import hashes, hmac
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher
-from cryptography.hazmat.primitives.kdf.hkdf import HKDF
+from cryptography.hazmat.primitives.kdf.hkdf import HKDF, HKDFExpand
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from ...utils import load_vectors_from_file
@@ -347,10 +347,9 @@
def hkdf_expand_test(backend, algorithm, params):
- hkdf = HKDF(
+ hkdf = HKDFExpand(
algorithm,
int(params["l"]),
- salt=binascii.unhexlify(params["salt"]) or None,
info=binascii.unhexlify(params["info"]) or None,
backend=backend
)