Add google.auth.jwt (#7)
diff --git a/docs/reference/google.auth.jwt.rst b/docs/reference/google.auth.jwt.rst
new file mode 100644
index 0000000..79e3080
--- /dev/null
+++ b/docs/reference/google.auth.jwt.rst
@@ -0,0 +1,7 @@
+google.auth.jwt module
+======================
+
+.. automodule:: google.auth.jwt
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/reference/google.auth.rst b/docs/reference/google.auth.rst
index 4c39cb9..57a16b4 100644
--- a/docs/reference/google.auth.rst
+++ b/docs/reference/google.auth.rst
@@ -12,4 +12,5 @@
.. toctree::
google.auth.crypt
+ google.auth.jwt
diff --git a/google/auth/_helpers.py b/google/auth/_helpers.py
index 2d3b653..0a62209 100644
--- a/google/auth/_helpers.py
+++ b/google/auth/_helpers.py
@@ -14,9 +14,34 @@
"""Helper functions for commonly used utilities."""
+
+import calendar
+import datetime
+
import six
+def utcnow():
+ """Returns the current UTC datetime.
+
+ Returns:
+ datetime: The current time in UTC.
+ """
+ return datetime.datetime.utcnow()
+
+
+def datetime_to_secs(value):
+ """Convert a datetime object to the number of seconds since the UNIX epoch.
+
+ Args:
+ value (datetime): The datetime to convert.
+
+ Returns:
+ int: The number of seconds since the UNIX epoch.
+ """
+ return calendar.timegm(value.utctimetuple())
+
+
def to_bytes(value, encoding='utf-8'):
"""Converts a string value to bytes, if necessary.
diff --git a/google/auth/jwt.py b/google/auth/jwt.py
new file mode 100644
index 0000000..5349e29
--- /dev/null
+++ b/google/auth/jwt.py
@@ -0,0 +1,235 @@
+# Copyright 2016 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""JSON Web Tokens
+
+Provides support for creating (encoding) and verifying (decoding) JWTs,
+especially JWTs generated and consumed by Google infrastructure.
+
+See `rfc7519`_ for more details on JWTs.
+
+To encode a JWT::
+
+ from google.auth import crypto
+ from google.auth import jwt
+
+ signer = crypt.Signer(private_key)
+ payload = {'some': 'payload'}
+ encoded = jwt.encode(signer, payload)
+
+To decode a JWT and verify claims::
+
+ claims = jwt.decode(encoded, certs=public_certs)
+
+You can also skip verification::
+
+ claims = jwt.decode(encoded, verify=False)
+
+.. _rfc7519: https://tools.ietf.org/html/rfc7519
+
+"""
+
+import base64
+import collections
+import json
+
+from google.auth import crypt
+from google.auth import _helpers
+
+
+_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in sections
+_CLOCK_SKEW_SECS = 300 # 5 minutes in seconds
+
+
+def encode(signer, payload, header=None, key_id=None):
+ """Make a signed JWT.
+
+ Args:
+ signer (google.auth.crypt.Signer): The signer used to sign the JWT.
+ payload (Mapping): The JWT payload.
+ header (Mapping): Additional JWT header payload.
+ key_id (str): The key id to add to the JWT header. If the
+ signer has a key id it will be used as the default. If this is
+ specified it will override the signer's key id.
+
+ Returns:
+ bytes: The encoded JWT.
+ """
+ if header is None:
+ header = {}
+
+ if key_id is None:
+ key_id = signer.key_id
+
+ header.update({'typ': 'JWT', 'alg': 'RS256'})
+
+ if key_id is not None:
+ header['kid'] = key_id
+
+ segments = [
+ base64.urlsafe_b64encode(json.dumps(header).encode('utf-8')),
+ base64.urlsafe_b64encode(json.dumps(payload).encode('utf-8')),
+ ]
+
+ signing_input = b'.'.join(segments)
+ signature = signer.sign(signing_input)
+ segments.append(base64.urlsafe_b64encode(signature))
+
+ return b'.'.join(segments)
+
+
+def _decode_jwt_segment(encoded_section):
+ """Decodes a single JWT segment."""
+ section_bytes = base64.urlsafe_b64decode(encoded_section)
+ try:
+ return json.loads(section_bytes.decode('utf-8'))
+ except ValueError:
+ raise ValueError('Can\'t parse segment: {0}'.format(section_bytes))
+
+
+def _unverified_decode(token):
+ """Decodes a token and does no verification.
+
+ Args:
+ token (Union[str, bytes]): The encoded JWT.
+
+ Returns:
+ Tuple(str, str, str, str): header, payload, signed_section, and
+ signature.
+
+ Raises:
+ ValueError: if there are an incorrect amount of segments in the token.
+ """
+ token = _helpers.to_bytes(token)
+
+ if token.count(b'.') != 2:
+ raise ValueError(
+ 'Wrong number of segments in token: {0}'.format(token))
+
+ encoded_header, encoded_payload, signature = token.split(b'.')
+ signed_section = encoded_header + b'.' + encoded_payload
+ signature = base64.urlsafe_b64decode(signature)
+
+ # Parse segments
+ header = _decode_jwt_segment(encoded_header)
+ payload = _decode_jwt_segment(encoded_payload)
+
+ return header, payload, signed_section, signature
+
+
+def decode_header(token):
+ """Return the decoded header of a token.
+
+ No verification is done. This is useful to extract the key id from
+ the header in order to acquire the appropriate certificate to verify
+ the token.
+
+ Args:
+ token (Union[str, bytes]): the encoded JWT.
+
+ Returns:
+ Mapping: The decoded JWT header.
+ """
+ header, _, _, _ = _unverified_decode(token)
+ return header
+
+
+def _verify_iat_and_exp(payload):
+ """Verifies the iat (Issued At) and exp (Expires) claims in a token
+ payload.
+
+ Args:
+ payload (mapping): The JWT payload.
+
+ Raises:
+ ValueError: if any checks failed.
+ """
+ now = _helpers.datetime_to_secs(_helpers.utcnow())
+
+ # Make sure the iat and exp claims are present
+ for key in ('iat', 'exp'):
+ if key not in payload:
+ raise ValueError(
+ 'Token does not contain required claim {}'.format(key))
+
+ # Make sure the token wasn't issued in the future
+ iat = payload['iat']
+ earliest = iat - _CLOCK_SKEW_SECS
+ if now < earliest:
+ raise ValueError('Token used too early, {} < {}'.format(now, iat))
+
+ # Make sure the token wasn't issue in the past
+ exp = payload['exp']
+ latest = exp + _CLOCK_SKEW_SECS
+ if latest < now:
+ raise ValueError('Token expired, {} < {}'.format(latest, now))
+
+
+def decode(token, certs=None, verify=True, audience=None):
+ """Decode and verify a JWT.
+
+ Args:
+ token (string): The encoded JWT.
+ certs (Union[str, bytes, Mapping]): The certificate used to
+ validate. If bytes or string, it must the the public key
+ certificate in PEM format. If a mapping, it must be a mapping of
+ key IDs to public key certificates in PEM format. The mapping must
+ contain the same key ID that's specified in the token's header.
+ verify (bool): Whether to perform signature and claim validation.
+ Verification is done by default.
+ audience (str): The audience claim, 'aud', that this JWT should
+ contain. If None then the JWT's 'aud' parameter is not verified.
+
+ Returns:
+ Mapping: The deserialized JSON payload in the JWT.
+
+ Raises:
+ ValueError: if any verification checks failed.
+ """
+ header, payload, signed_section, signature = _unverified_decode(token)
+
+ if not verify:
+ return payload
+
+ # If certs is specified as a dictionary of key IDs to certificates, then
+ # use the certificate identified by the key ID in the token header.
+ if isinstance(certs, collections.Mapping):
+ key_id = header.get('kid')
+ if key_id:
+ if key_id not in certs:
+ raise ValueError(
+ 'Certificate for key id {} not found.'.format(key_id))
+ certs_to_check = [certs[key_id]]
+ # If there's no key id in the header, check against all of the certs.
+ else:
+ certs_to_check = certs.values()
+ else:
+ certs_to_check = certs
+
+ # Verify that the signature matches the message.
+ if not crypt.verify_signature(signed_section, signature, certs_to_check):
+ raise ValueError('Could not verify token signature.')
+
+ # Verify the issued at and created times in the payload.
+ _verify_iat_and_exp(payload)
+
+ # Check audience.
+ if audience is not None:
+ claim_audience = payload.get('aud')
+ if audience != claim_audience:
+ raise ValueError(
+ 'Token has wrong audience {}, expected {}'.format(
+ claim_audience, audience))
+
+ return payload
diff --git a/tests/test__helpers.py b/tests/test__helpers.py
index b7e0bab..c2bc4a7 100644
--- a/tests/test__helpers.py
+++ b/tests/test__helpers.py
@@ -12,12 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import datetime
import pytest
from google.auth import _helpers
+def test_utcnow():
+ assert isinstance(_helpers.utcnow(), datetime.datetime)
+
+
+def test_datetime_to_secs():
+ assert _helpers.datetime_to_secs(
+ datetime.datetime(1970, 1, 1)) == 0
+ assert _helpers.datetime_to_secs(
+ datetime.datetime(1990, 5, 29)) == 643939200
+
+
def test_to_bytes_with_bytes():
value = b'bytes-val'
assert _helpers.to_bytes(value) == value
diff --git a/tests/test_jwt.py b/tests/test_jwt.py
new file mode 100644
index 0000000..69628e5
--- /dev/null
+++ b/tests/test_jwt.py
@@ -0,0 +1,189 @@
+# Copyright 2014 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import datetime
+import os
+
+import pytest
+
+from google.auth import _helpers
+from google.auth import crypt
+from google.auth import jwt
+
+
+DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
+
+with open(os.path.join(DATA_DIR, 'privatekey.pem'), 'rb') as fh:
+ PRIVATE_KEY_BYTES = fh.read()
+
+with open(os.path.join(DATA_DIR, 'public_cert.pem'), 'rb') as fh:
+ PUBLIC_CERT_BYTES = fh.read()
+
+with open(os.path.join(DATA_DIR, 'other_cert.pem'), 'rb') as fh:
+ OTHER_CERT_BYTES = fh.read()
+
+
+@pytest.fixture
+def signer():
+ return crypt.Signer.from_string(PRIVATE_KEY_BYTES, '1')
+
+
+def test_encode_basic(signer):
+ test_payload = {'test': 'value'}
+ encoded = jwt.encode(signer, test_payload)
+ header, payload, _, _ = jwt._unverified_decode(encoded)
+ assert payload == test_payload
+ assert header == {'typ': 'JWT', 'alg': 'RS256', 'kid': signer.key_id}
+
+
+def test_encode_extra_headers(signer):
+ encoded = jwt.encode(signer, {}, header={'extra': 'value'})
+ header = jwt.decode_header(encoded)
+ assert header == {
+ 'typ': 'JWT', 'alg': 'RS256', 'kid': signer.key_id, 'extra': 'value'}
+
+
+@pytest.fixture
+def token_factory(signer):
+ def factory(claims=None, key_id=None):
+ now = _helpers.datetime_to_secs(_helpers.utcnow())
+ payload = {
+ 'aud': 'audience@example.com',
+ 'iat': now,
+ 'exp': now + 300,
+ 'user': 'billy bob',
+ 'metadata': {'meta': 'data'}
+ }
+ payload.update(claims or {})
+
+ # False is specified to remove the signer's key id for testing
+ # headers without key ids.
+ if key_id is False:
+ signer.key_id = None
+ key_id = None
+
+ return jwt.encode(signer, payload, key_id=key_id)
+ return factory
+
+
+def test_decode_valid(token_factory):
+ payload = jwt.decode(token_factory(), certs=PUBLIC_CERT_BYTES)
+ assert payload['aud'] == 'audience@example.com'
+ assert payload['user'] == 'billy bob'
+ assert payload['metadata']['meta'] == 'data'
+
+
+def test_decode_valid_with_audience(token_factory):
+ payload = jwt.decode(
+ token_factory(), certs=PUBLIC_CERT_BYTES,
+ audience='audience@example.com')
+ assert payload['aud'] == 'audience@example.com'
+ assert payload['user'] == 'billy bob'
+ assert payload['metadata']['meta'] == 'data'
+
+
+def test_decode_valid_unverified(token_factory):
+ payload = jwt.decode(token_factory(), certs=OTHER_CERT_BYTES, verify=False)
+ assert payload['aud'] == 'audience@example.com'
+ assert payload['user'] == 'billy bob'
+ assert payload['metadata']['meta'] == 'data'
+
+
+def test_decode_bad_token_wrong_number_of_segments():
+ with pytest.raises(ValueError) as excinfo:
+ jwt.decode('1.2', PUBLIC_CERT_BYTES)
+ assert excinfo.match(r'Wrong number of segments')
+
+
+def test_decode_bad_token_not_base64():
+ with pytest.raises((ValueError, TypeError)) as excinfo:
+ jwt.decode('1.2.3', PUBLIC_CERT_BYTES)
+ assert excinfo.match(r'Incorrect padding')
+
+
+def test_decode_bad_token_not_json():
+ token = b'.'.join([base64.urlsafe_b64encode(b'123!')] * 3)
+ with pytest.raises(ValueError) as excinfo:
+ jwt.decode(token, PUBLIC_CERT_BYTES)
+ assert excinfo.match(r'Can\'t parse segment')
+
+
+def test_decode_bad_token_no_iat_or_exp(signer):
+ token = jwt.encode(signer, {'test': 'value'})
+ with pytest.raises(ValueError) as excinfo:
+ jwt.decode(token, PUBLIC_CERT_BYTES)
+ assert excinfo.match(r'Token does not contain required claim')
+
+
+def test_decode_bad_token_too_early(token_factory):
+ token = token_factory(claims={
+ 'iat': _helpers.datetime_to_secs(
+ _helpers.utcnow() + datetime.timedelta(hours=1))
+ })
+ with pytest.raises(ValueError) as excinfo:
+ jwt.decode(token, PUBLIC_CERT_BYTES)
+ assert excinfo.match(r'Token used too early')
+
+
+def test_decode_bad_token_expired(token_factory):
+ token = token_factory(claims={
+ 'exp': _helpers.datetime_to_secs(
+ _helpers.utcnow() - datetime.timedelta(hours=1))
+ })
+ with pytest.raises(ValueError) as excinfo:
+ jwt.decode(token, PUBLIC_CERT_BYTES)
+ assert excinfo.match(r'Token expired')
+
+
+def test_decode_bad_token_wrong_audience(token_factory):
+ token = token_factory()
+ audience = 'audience2@example.com'
+ with pytest.raises(ValueError) as excinfo:
+ jwt.decode(token, PUBLIC_CERT_BYTES, audience=audience)
+ assert excinfo.match(r'Token has wrong audience')
+
+
+def test_decode_wrong_cert(token_factory):
+ with pytest.raises(ValueError) as excinfo:
+ jwt.decode(token_factory(), OTHER_CERT_BYTES)
+ assert excinfo.match(r'Could not verify token signature')
+
+
+def test_decode_multicert_bad_cert(token_factory):
+ certs = {'1': OTHER_CERT_BYTES, '2': PUBLIC_CERT_BYTES}
+ with pytest.raises(ValueError) as excinfo:
+ jwt.decode(token_factory(), certs)
+ assert excinfo.match(r'Could not verify token signature')
+
+
+def test_decode_no_cert(token_factory):
+ certs = {'2': PUBLIC_CERT_BYTES}
+ with pytest.raises(ValueError) as excinfo:
+ jwt.decode(token_factory(), certs)
+ assert excinfo.match(r'Certificate for key id 1 not found')
+
+
+def test_decode_no_key_id(token_factory):
+ token = token_factory(key_id=False)
+ certs = {'2': PUBLIC_CERT_BYTES}
+ payload = jwt.decode(token, certs)
+ assert payload['user'] == 'billy bob'
+
+
+def test_roundtrip_explicit_key_id(token_factory):
+ token = token_factory(key_id='3')
+ certs = {'2': OTHER_CERT_BYTES, '3': PUBLIC_CERT_BYTES}
+ payload = jwt.decode(token, certs)
+ assert payload['user'] == 'billy bob'