shake128/256 support (#4611)

* shake128/256 support

* remove block_size

* doc an exception

* change how we detect XOF by adding _xof attribute

* interface!

* review feedback
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 5acc050..2f5c802 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -26,6 +26,9 @@
   1.1.1.
 * Added support for :doc:`/hazmat/primitives/asymmetric/x448` when using
   OpenSSL 1.1.1.
+* Added support for :class:`~cryptography.hazmat.primitives.hashes.SHAKE128`
+  and :class:`~cryptography.hazmat.primitives.hashes.SHAKE256` when using
+  OpenSSL 1.1.1.
 * Added initial support for parsing PKCS12 files with
   :func:`~cryptography.hazmat.primitives.serialization.pkcs12.load_key_and_certificates`.
 * Added support for :class:`~cryptography.x509.IssuingDistributionPoint`.
diff --git a/docs/hazmat/primitives/cryptographic-hashes.rst b/docs/hazmat/primitives/cryptographic-hashes.rst
index bc97936..24cc70b 100644
--- a/docs/hazmat/primitives/cryptographic-hashes.rst
+++ b/docs/hazmat/primitives/cryptographic-hashes.rst
@@ -185,6 +185,36 @@
     SHA3/512 is a cryptographic hash function from the SHA-3 family and is
     standardized by NIST. It produces a 512-bit message digest.
 
+.. class:: SHAKE128(digest_size)
+
+    .. versionadded:: 2.5
+
+    SHAKE128 is an extendable output function (XOF) based on the same core
+    permutations as SHA3. It allows the caller to obtain an arbitrarily long
+    digest length. Longer lengths, however, do not increase security or
+    collision resistance and lengths shorter than 128 bit (16 bytes) will
+    decrease it.
+
+    :param int digest_size: The length of output desired. Must be greater than
+        zero.
+
+    :raises ValueError: If the ``digest_size`` is invalid.
+
+.. class:: SHAKE256(digest_size)
+
+    .. versionadded:: 2.5
+
+    SHAKE256 is an extendable output function (XOF) based on the same core
+    permutations as SHA3. It allows the caller to obtain an arbitrarily long
+    digest length. Longer lengths, however, do not increase security or
+    collision resistance and lengths shorter than 256 bit (32 bytes) will
+    decrease it.
+
+    :param int digest_size: The length of output desired. Must be greater than
+        zero.
+
+    :raises ValueError: If the ``digest_size`` is invalid.
+
 SHA-1
 ~~~~~
 
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 4614421..7e08f12 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -41,6 +41,7 @@
 El
 Encodings
 endian
+extendable
 fallback
 Fernet
 fernet
diff --git a/src/cryptography/hazmat/backends/openssl/hashes.py b/src/cryptography/hazmat/backends/openssl/hashes.py
index e9a5070..549fa2b 100644
--- a/src/cryptography/hazmat/backends/openssl/hashes.py
+++ b/src/cryptography/hazmat/backends/openssl/hashes.py
@@ -54,10 +54,25 @@
         self._backend.openssl_assert(res != 0)
 
     def finalize(self):
+        if isinstance(self.algorithm, hashes.ExtendableOutputFunction):
+            # extendable output functions use a different finalize
+            return self._finalize_xof()
+        else:
+            buf = self._backend._ffi.new("unsigned char[]",
+                                         self._backend._lib.EVP_MAX_MD_SIZE)
+            outlen = self._backend._ffi.new("unsigned int *")
+            res = self._backend._lib.EVP_DigestFinal_ex(self._ctx, buf, outlen)
+            self._backend.openssl_assert(res != 0)
+            self._backend.openssl_assert(
+                outlen[0] == self.algorithm.digest_size
+            )
+            return self._backend._ffi.buffer(buf)[:outlen[0]]
+
+    def _finalize_xof(self):
         buf = self._backend._ffi.new("unsigned char[]",
-                                     self._backend._lib.EVP_MAX_MD_SIZE)
-        outlen = self._backend._ffi.new("unsigned int *")
-        res = self._backend._lib.EVP_DigestFinal_ex(self._ctx, buf, outlen)
+                                     self.algorithm.digest_size)
+        res = self._backend._lib.EVP_DigestFinalXOF(
+            self._ctx, buf, self.algorithm.digest_size
+        )
         self._backend.openssl_assert(res != 0)
-        self._backend.openssl_assert(outlen[0] == self.algorithm.digest_size)
-        return self._backend._ffi.buffer(buf)[:outlen[0]]
+        return self._backend._ffi.buffer(buf)[:self.algorithm.digest_size]
diff --git a/src/cryptography/hazmat/primitives/hashes.py b/src/cryptography/hazmat/primitives/hashes.py
index 0d6e47f..9be2b60 100644
--- a/src/cryptography/hazmat/primitives/hashes.py
+++ b/src/cryptography/hazmat/primitives/hashes.py
@@ -57,6 +57,13 @@
         """
 
 
+@six.add_metaclass(abc.ABCMeta)
+class ExtendableOutputFunction(object):
+    """
+    An interface for extendable output functions.
+    """
+
+
 @utils.register_interface(HashContext)
 class Hash(object):
     def __init__(self, algorithm, backend, ctx=None):
@@ -174,6 +181,40 @@
 
 
 @utils.register_interface(HashAlgorithm)
+@utils.register_interface(ExtendableOutputFunction)
+class SHAKE128(object):
+    name = "shake128"
+
+    def __init__(self, digest_size):
+        if not isinstance(digest_size, six.integer_types):
+            raise TypeError("digest_size must be an integer")
+
+        if digest_size < 1:
+            raise ValueError("digest_size must be a positive integer")
+
+        self._digest_size = digest_size
+
+    digest_size = utils.read_only_property("_digest_size")
+
+
+@utils.register_interface(HashAlgorithm)
+@utils.register_interface(ExtendableOutputFunction)
+class SHAKE256(object):
+    name = "shake256"
+
+    def __init__(self, digest_size):
+        if not isinstance(digest_size, six.integer_types):
+            raise TypeError("digest_size must be an integer")
+
+        if digest_size < 1:
+            raise ValueError("digest_size must be a positive integer")
+
+        self._digest_size = digest_size
+
+    digest_size = utils.read_only_property("_digest_size")
+
+
+@utils.register_interface(HashAlgorithm)
 class MD5(object):
     name = "md5"
     digest_size = 16
diff --git a/tests/hazmat/primitives/test_hash_vectors.py b/tests/hazmat/primitives/test_hash_vectors.py
index f8561fc..5225a00 100644
--- a/tests/hazmat/primitives/test_hash_vectors.py
+++ b/tests/hazmat/primitives/test_hash_vectors.py
@@ -4,6 +4,7 @@
 
 from __future__ import absolute_import, division, print_function
 
+import binascii
 import os
 
 import pytest
@@ -11,8 +12,8 @@
 from cryptography.hazmat.backends.interfaces import HashBackend
 from cryptography.hazmat.primitives import hashes
 
-from .utils import generate_hash_test
-from ...utils import load_hash_vectors
+from .utils import _load_all_params, generate_hash_test
+from ...utils import load_hash_vectors, load_nist_vectors
 
 
 @pytest.mark.supported(
@@ -250,3 +251,75 @@
         ],
         hashes.SHA3_512(),
     )
+
+
+@pytest.mark.supported(
+    only_if=lambda backend: backend.hash_supported(
+        hashes.SHAKE128(digest_size=16)),
+    skip_message="Does not support SHAKE128",
+)
+@pytest.mark.requires_backend_interface(interface=HashBackend)
+class TestSHAKE128(object):
+    test_shake128 = generate_hash_test(
+        load_hash_vectors,
+        os.path.join("hashes", "SHAKE"),
+        [
+            "SHAKE128LongMsg.rsp",
+            "SHAKE128ShortMsg.rsp",
+        ],
+        hashes.SHAKE128(digest_size=16),
+    )
+
+    @pytest.mark.parametrize(
+        "vector",
+        _load_all_params(
+            os.path.join("hashes", "SHAKE"),
+            [
+                "SHAKE128VariableOut.rsp",
+            ],
+            load_nist_vectors,
+        )
+    )
+    def test_shake128_variable(self, vector, backend):
+        output_length = int(vector['outputlen']) // 8
+        msg = binascii.unhexlify(vector['msg'])
+        shake = hashes.SHAKE128(digest_size=output_length)
+        m = hashes.Hash(shake, backend=backend)
+        m.update(msg)
+        assert m.finalize() == binascii.unhexlify(vector['output'])
+
+
+@pytest.mark.supported(
+    only_if=lambda backend: backend.hash_supported(
+        hashes.SHAKE256(digest_size=32)),
+    skip_message="Does not support SHAKE256",
+)
+@pytest.mark.requires_backend_interface(interface=HashBackend)
+class TestSHAKE256(object):
+    test_shake256 = generate_hash_test(
+        load_hash_vectors,
+        os.path.join("hashes", "SHAKE"),
+        [
+            "SHAKE256LongMsg.rsp",
+            "SHAKE256ShortMsg.rsp",
+        ],
+        hashes.SHAKE256(digest_size=32),
+    )
+
+    @pytest.mark.parametrize(
+        "vector",
+        _load_all_params(
+            os.path.join("hashes", "SHAKE"),
+            [
+                "SHAKE256VariableOut.rsp",
+            ],
+            load_nist_vectors,
+        )
+    )
+    def test_shake256_variable(self, vector, backend):
+        output_length = int(vector['outputlen']) // 8
+        msg = binascii.unhexlify(vector['msg'])
+        shake = hashes.SHAKE256(digest_size=output_length)
+        m = hashes.Hash(shake, backend=backend)
+        m.update(msg)
+        assert m.finalize() == binascii.unhexlify(vector['output'])
diff --git a/tests/hazmat/primitives/test_hashes.py b/tests/hazmat/primitives/test_hashes.py
index 6cba84b..b10fadc 100644
--- a/tests/hazmat/primitives/test_hashes.py
+++ b/tests/hazmat/primitives/test_hashes.py
@@ -179,3 +179,24 @@
     assert h.finalize() == binascii.unhexlify(
         b"dff2e73091f6c05e528896c4c831b9448653dc2ff043528f6769437bc7b975c2"
     )
+
+
+class TestSHAKE(object):
+    @pytest.mark.parametrize(
+        "xof",
+        [hashes.SHAKE128, hashes.SHAKE256]
+    )
+    def test_invalid_digest_type(self, xof):
+        with pytest.raises(TypeError):
+            xof(digest_size=object())
+
+    @pytest.mark.parametrize(
+        "xof",
+        [hashes.SHAKE128, hashes.SHAKE256]
+    )
+    def test_invalid_digest_size(self, xof):
+        with pytest.raises(ValueError):
+            xof(digest_size=-5)
+
+        with pytest.raises(ValueError):
+            xof(digest_size=0)
diff --git a/tests/utils.py b/tests/utils.py
index 364a349..b481280 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -134,7 +134,7 @@
             # string as hex 00, which is of course not actually an empty
             # string. So we parse the provided length and catch this edge case.
             msg = line.split(" = ")[1].encode("ascii") if length > 0 else b""
-        elif line.startswith("MD"):
+        elif line.startswith("MD") or line.startswith("Output"):
             md = line.split(" = ")[1]
             # after MD is found the Msg+MD (+ potential key) tuple is complete
             if key is not None: