feat: Implement ES256 for JWT verification (#340)

feat: Implement EC256 for JWT verification
diff --git a/tests/test_jwt.py b/tests/test_jwt.py
index b0c6e48..488aee4 100644
--- a/tests/test_jwt.py
+++ b/tests/test_jwt.py
@@ -37,6 +37,12 @@
 with open(os.path.join(DATA_DIR, "other_cert.pem"), "rb") as fh:
     OTHER_CERT_BYTES = fh.read()
 
+with open(os.path.join(DATA_DIR, "es256_privatekey.pem"), "rb") as fh:
+    EC_PRIVATE_KEY_BYTES = fh.read()
+
+with open(os.path.join(DATA_DIR, "es256_public_cert.pem"), "rb") as fh:
+    EC_PUBLIC_CERT_BYTES = fh.read()
+
 SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json")
 
 with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh:
@@ -68,8 +74,21 @@
 
 
 @pytest.fixture
-def token_factory(signer):
-    def factory(claims=None, key_id=None):
+def es256_signer():
+    return crypt.ES256Signer.from_string(EC_PRIVATE_KEY_BYTES, "1")
+
+
+def test_encode_basic_es256(es256_signer):
+    test_payload = {"test": "value"}
+    encoded = jwt.encode(es256_signer, test_payload)
+    header, payload, _, _ = jwt._unverified_decode(encoded)
+    assert payload == test_payload
+    assert header == {"typ": "JWT", "alg": "ES256", "kid": es256_signer.key_id}
+
+
+@pytest.fixture
+def token_factory(signer, es256_signer):
+    def factory(claims=None, key_id=None, use_es256_signer=False):
         now = _helpers.datetime_to_secs(_helpers.utcnow())
         payload = {
             "aud": "audience@example.com",
@@ -86,7 +105,10 @@
             signer._key_id = None
             key_id = None
 
-        return jwt.encode(signer, payload, key_id=key_id)
+        if use_es256_signer:
+            return jwt.encode(es256_signer, payload, key_id=key_id)
+        else:
+            return jwt.encode(signer, payload, key_id=key_id)
 
     return factory
 
@@ -98,6 +120,15 @@
     assert payload["metadata"]["meta"] == "data"
 
 
+def test_decode_valid_es256(token_factory):
+    payload = jwt.decode(
+        token_factory(use_es256_signer=True), certs=EC_PUBLIC_CERT_BYTES
+    )
+    assert payload["aud"] == "audience@example.com"
+    assert payload["user"] == "billy bob"
+    assert payload["metadata"]["meta"] == "data"
+
+
 def test_decode_valid_with_audience(token_factory):
     payload = jwt.decode(
         token_factory(), certs=PUBLIC_CERT_BYTES, audience="audience@example.com"
@@ -201,6 +232,29 @@
     assert payload["user"] == "billy bob"
 
 
+def test_decode_unknown_alg():
+    headers = json.dumps({u"kid": u"1", u"alg": u"fakealg"})
+    token = b".".join(
+        map(lambda seg: base64.b64encode(seg.encode("utf-8")), [headers, u"{}", u"sig"])
+    )
+
+    with pytest.raises(ValueError) as excinfo:
+        jwt.decode(token)
+    assert excinfo.match(r"fakealg")
+
+
+def test_decode_missing_crytography_alg(monkeypatch):
+    monkeypatch.delitem(jwt._ALGORITHM_TO_VERIFIER_CLASS, "ES256")
+    headers = json.dumps({u"kid": u"1", u"alg": u"ES256"})
+    token = b".".join(
+        map(lambda seg: base64.b64encode(seg.encode("utf-8")), [headers, u"{}", u"sig"])
+    )
+
+    with pytest.raises(ValueError) as excinfo:
+        jwt.decode(token)
+    assert excinfo.match(r"cryptography")
+
+
 def test_roundtrip_explicit_key_id(token_factory):
     token = token_factory(key_id="3")
     certs = {"2": OTHER_CERT_BYTES, "3": PUBLIC_CERT_BYTES}