Don't expose extract and expand on this class yet because we don't know how best to expose verify functionality, continue testing the stages using the private methods.
diff --git a/cryptography/hazmat/primitives/kdf/hkdf.py b/cryptography/hazmat/primitives/kdf/hkdf.py
index 2b5ba81..ae24f67 100644
--- a/cryptography/hazmat/primitives/kdf/hkdf.py
+++ b/cryptography/hazmat/primitives/kdf/hkdf.py
@@ -57,27 +57,11 @@
self._used = False
- def extract(self, key_material):
- if self._used:
- raise exceptions.AlreadyFinalized
-
- self._used = True
-
- return self._extract(key_material)
-
def _extract(self, key_material):
h = hmac.HMAC(self._salt, self._algorithm, backend=self._backend)
h.update(key_material)
return h.finalize()
- def expand(self, key_material):
- if self._used:
- raise exceptions.AlreadyFinalized
-
- self._used = True
-
- return self._expand(key_material)
-
def _expand(self, key_material):
output = [b""]
counter = 1
diff --git a/tests/hazmat/primitives/test_hkdf.py b/tests/hazmat/primitives/test_hkdf.py
index f3345b0..66993f0 100644
--- a/tests/hazmat/primitives/test_hkdf.py
+++ b/tests/hazmat/primitives/test_hkdf.py
@@ -71,24 +71,6 @@
backend=backend
)
- hkdf.extract(b"\x01" * 16)
-
- with pytest.raises(exceptions.AlreadyFinalized):
- hkdf.extract(b"\x02" * 16)
-
- hkdf = HKDF(
- hashes.SHA256(),
- 16,
- salt=None,
- info=None,
- backend=backend
- )
-
- hkdf.expand(b"\x01" * 16)
-
- with pytest.raises(exceptions.AlreadyFinalized):
- hkdf.expand(b"\x02" * 16)
-
def test_verify(self, backend):
hkdf = HKDF(
hashes.SHA256(),
diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py
index 2584272..5a8dc3a 100644
--- a/tests/hazmat/primitives/utils.py
+++ b/tests/hazmat/primitives/utils.py
@@ -326,7 +326,7 @@
backend=backend
)
- prk = hkdf.extract(binascii.unhexlify(params["ikm"]))
+ prk = hkdf._extract(binascii.unhexlify(params["ikm"]))
assert prk == binascii.unhexlify(params["prk"])
@@ -340,7 +340,7 @@
backend=backend
)
- okm = hkdf.expand(binascii.unhexlify(params["prk"]))
+ okm = hkdf._expand(binascii.unhexlify(params["prk"]))
assert okm == binascii.unhexlify(params["okm"])