Raise MemoryError when backend.derive_scrypt can't malloc enough (#4592)
* Raise MemoryError when backend.derive_scrypt can't malloc enough
* Expose ERR_R_MALLOC_FAILURE and use the reason_match pattern to catch it
* Add test_scrypt_malloc_failure in test_scrypt
* let's see if this passes
* add comment to filippo's blog post about scrypt's params
diff --git a/src/_cffi_src/openssl/err.py b/src/_cffi_src/openssl/err.py
index 5975135..b4d053c 100644
--- a/src/_cffi_src/openssl/err.py
+++ b/src/_cffi_src/openssl/err.py
@@ -22,6 +22,8 @@
static const int ERR_LIB_SSL;
static const int ERR_LIB_X509;
+static const int ERR_R_MALLOC_FAILURE;
+
static const int ASN1_R_BOOLEAN_IS_WRONG_LENGTH;
static const int ASN1_R_BUFFER_TOO_SMALL;
static const int ASN1_R_CIPHER_HAS_NO_OBJECT_IDENTIFIER;
diff --git a/src/cryptography/hazmat/backends/openssl/backend.py b/src/cryptography/hazmat/backends/openssl/backend.py
index ae966cd..fda6293 100644
--- a/src/cryptography/hazmat/backends/openssl/backend.py
+++ b/src/cryptography/hazmat/backends/openssl/backend.py
@@ -2120,7 +2120,24 @@
key_material, len(key_material), salt, len(salt), n, r, p,
scrypt._MEM_LIMIT, buf, length
)
- self.openssl_assert(res == 1)
+ if res != 1:
+ errors = self._consume_errors()
+ if not self._lib.CRYPTOGRAPHY_OPENSSL_LESS_THAN_111:
+ # This error is only added to the stack in 1.1.1+
+ self.openssl_assert(
+ errors[0]._lib_reason_match(
+ self._lib.ERR_LIB_EVP,
+ self._lib.ERR_R_MALLOC_FAILURE
+ )
+ )
+
+ # memory required formula explained here:
+ # https://blog.filippo.io/the-scrypt-parameters/
+ min_memory = 128 * n * r // (1024**2)
+ raise MemoryError(
+ "Not enough memory to derive key. These parameters require"
+ " {} MB of memory.".format(min_memory)
+ )
return self._ffi.buffer(buf)[:]
def aead_cipher_supported(self, cipher):
diff --git a/tests/hazmat/primitives/test_scrypt.py b/tests/hazmat/primitives/test_scrypt.py
index 64abfe7..25d2c61 100644
--- a/tests/hazmat/primitives/test_scrypt.py
+++ b/tests/hazmat/primitives/test_scrypt.py
@@ -80,6 +80,20 @@
Scrypt(salt, length, work_factor, block_size,
parallelization_factor, backend)
+ def test_scrypt_malloc_failure(self, backend):
+ password = b"NaCl"
+ work_factor = 1024 ** 3
+ block_size = 589824
+ parallelization_factor = 16
+ length = 64
+ salt = b"NaCl"
+
+ scrypt = Scrypt(salt, length, work_factor, block_size,
+ parallelization_factor, backend)
+
+ with pytest.raises(MemoryError):
+ scrypt.derive(password)
+
def test_password_not_bytes(self, backend):
password = 1
work_factor = 1024