shake128/256 support (#4611)

* shake128/256 support

* remove block_size

* doc an exception

* change how we detect XOF by adding _xof attribute

* interface!

* review feedback
diff --git a/tests/hazmat/primitives/test_hash_vectors.py b/tests/hazmat/primitives/test_hash_vectors.py
index f8561fc..5225a00 100644
--- a/tests/hazmat/primitives/test_hash_vectors.py
+++ b/tests/hazmat/primitives/test_hash_vectors.py
@@ -4,6 +4,7 @@
 
 from __future__ import absolute_import, division, print_function
 
+import binascii
 import os
 
 import pytest
@@ -11,8 +12,8 @@
 from cryptography.hazmat.backends.interfaces import HashBackend
 from cryptography.hazmat.primitives import hashes
 
-from .utils import generate_hash_test
-from ...utils import load_hash_vectors
+from .utils import _load_all_params, generate_hash_test
+from ...utils import load_hash_vectors, load_nist_vectors
 
 
 @pytest.mark.supported(
@@ -250,3 +251,75 @@
         ],
         hashes.SHA3_512(),
     )
+
+
+@pytest.mark.supported(
+    only_if=lambda backend: backend.hash_supported(
+        hashes.SHAKE128(digest_size=16)),
+    skip_message="Does not support SHAKE128",
+)
+@pytest.mark.requires_backend_interface(interface=HashBackend)
+class TestSHAKE128(object):
+    test_shake128 = generate_hash_test(
+        load_hash_vectors,
+        os.path.join("hashes", "SHAKE"),
+        [
+            "SHAKE128LongMsg.rsp",
+            "SHAKE128ShortMsg.rsp",
+        ],
+        hashes.SHAKE128(digest_size=16),
+    )
+
+    @pytest.mark.parametrize(
+        "vector",
+        _load_all_params(
+            os.path.join("hashes", "SHAKE"),
+            [
+                "SHAKE128VariableOut.rsp",
+            ],
+            load_nist_vectors,
+        )
+    )
+    def test_shake128_variable(self, vector, backend):
+        output_length = int(vector['outputlen']) // 8
+        msg = binascii.unhexlify(vector['msg'])
+        shake = hashes.SHAKE128(digest_size=output_length)
+        m = hashes.Hash(shake, backend=backend)
+        m.update(msg)
+        assert m.finalize() == binascii.unhexlify(vector['output'])
+
+
+@pytest.mark.supported(
+    only_if=lambda backend: backend.hash_supported(
+        hashes.SHAKE256(digest_size=32)),
+    skip_message="Does not support SHAKE256",
+)
+@pytest.mark.requires_backend_interface(interface=HashBackend)
+class TestSHAKE256(object):
+    test_shake256 = generate_hash_test(
+        load_hash_vectors,
+        os.path.join("hashes", "SHAKE"),
+        [
+            "SHAKE256LongMsg.rsp",
+            "SHAKE256ShortMsg.rsp",
+        ],
+        hashes.SHAKE256(digest_size=32),
+    )
+
+    @pytest.mark.parametrize(
+        "vector",
+        _load_all_params(
+            os.path.join("hashes", "SHAKE"),
+            [
+                "SHAKE256VariableOut.rsp",
+            ],
+            load_nist_vectors,
+        )
+    )
+    def test_shake256_variable(self, vector, backend):
+        output_length = int(vector['outputlen']) // 8
+        msg = binascii.unhexlify(vector['msg'])
+        shake = hashes.SHAKE256(digest_size=output_length)
+        m = hashes.Hash(shake, backend=backend)
+        m.update(msg)
+        assert m.finalize() == binascii.unhexlify(vector['output'])
diff --git a/tests/hazmat/primitives/test_hashes.py b/tests/hazmat/primitives/test_hashes.py
index 6cba84b..b10fadc 100644
--- a/tests/hazmat/primitives/test_hashes.py
+++ b/tests/hazmat/primitives/test_hashes.py
@@ -179,3 +179,24 @@
     assert h.finalize() == binascii.unhexlify(
         b"dff2e73091f6c05e528896c4c831b9448653dc2ff043528f6769437bc7b975c2"
     )
+
+
+class TestSHAKE(object):
+    @pytest.mark.parametrize(
+        "xof",
+        [hashes.SHAKE128, hashes.SHAKE256]
+    )
+    def test_invalid_digest_type(self, xof):
+        with pytest.raises(TypeError):
+            xof(digest_size=object())
+
+    @pytest.mark.parametrize(
+        "xof",
+        [hashes.SHAKE128, hashes.SHAKE256]
+    )
+    def test_invalid_digest_size(self, xof):
+        with pytest.raises(ValueError):
+            xof(digest_size=-5)
+
+        with pytest.raises(ValueError):
+            xof(digest_size=0)
diff --git a/tests/utils.py b/tests/utils.py
index 364a349..b481280 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -134,7 +134,7 @@
             # string as hex 00, which is of course not actually an empty
             # string. So we parse the provided length and catch this edge case.
             msg = line.split(" = ")[1].encode("ascii") if length > 0 else b""
-        elif line.startswith("MD"):
+        elif line.startswith("MD") or line.startswith("Output"):
             md = line.split(" = ")[1]
             # after MD is found the Msg+MD (+ potential key) tuple is complete
             if key is not None: