blob: 69628e5c9c963fd4e1d5e158f31f527f8b2a213b [file] [log] [blame]
Jon Wayne Parrott5824ad82016-10-06 09:27:44 -07001# Copyright 2014 Google Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import base64
16import datetime
17import os
18
19import pytest
20
21from google.auth import _helpers
22from google.auth import crypt
23from google.auth import jwt
24
25
26DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
27
28with open(os.path.join(DATA_DIR, 'privatekey.pem'), 'rb') as fh:
29 PRIVATE_KEY_BYTES = fh.read()
30
31with open(os.path.join(DATA_DIR, 'public_cert.pem'), 'rb') as fh:
32 PUBLIC_CERT_BYTES = fh.read()
33
34with open(os.path.join(DATA_DIR, 'other_cert.pem'), 'rb') as fh:
35 OTHER_CERT_BYTES = fh.read()
36
37
38@pytest.fixture
39def signer():
40 return crypt.Signer.from_string(PRIVATE_KEY_BYTES, '1')
41
42
43def test_encode_basic(signer):
44 test_payload = {'test': 'value'}
45 encoded = jwt.encode(signer, test_payload)
46 header, payload, _, _ = jwt._unverified_decode(encoded)
47 assert payload == test_payload
48 assert header == {'typ': 'JWT', 'alg': 'RS256', 'kid': signer.key_id}
49
50
51def test_encode_extra_headers(signer):
52 encoded = jwt.encode(signer, {}, header={'extra': 'value'})
53 header = jwt.decode_header(encoded)
54 assert header == {
55 'typ': 'JWT', 'alg': 'RS256', 'kid': signer.key_id, 'extra': 'value'}
56
57
58@pytest.fixture
59def token_factory(signer):
60 def factory(claims=None, key_id=None):
61 now = _helpers.datetime_to_secs(_helpers.utcnow())
62 payload = {
63 'aud': 'audience@example.com',
64 'iat': now,
65 'exp': now + 300,
66 'user': 'billy bob',
67 'metadata': {'meta': 'data'}
68 }
69 payload.update(claims or {})
70
71 # False is specified to remove the signer's key id for testing
72 # headers without key ids.
73 if key_id is False:
74 signer.key_id = None
75 key_id = None
76
77 return jwt.encode(signer, payload, key_id=key_id)
78 return factory
79
80
81def test_decode_valid(token_factory):
82 payload = jwt.decode(token_factory(), certs=PUBLIC_CERT_BYTES)
83 assert payload['aud'] == 'audience@example.com'
84 assert payload['user'] == 'billy bob'
85 assert payload['metadata']['meta'] == 'data'
86
87
88def test_decode_valid_with_audience(token_factory):
89 payload = jwt.decode(
90 token_factory(), certs=PUBLIC_CERT_BYTES,
91 audience='audience@example.com')
92 assert payload['aud'] == 'audience@example.com'
93 assert payload['user'] == 'billy bob'
94 assert payload['metadata']['meta'] == 'data'
95
96
97def test_decode_valid_unverified(token_factory):
98 payload = jwt.decode(token_factory(), certs=OTHER_CERT_BYTES, verify=False)
99 assert payload['aud'] == 'audience@example.com'
100 assert payload['user'] == 'billy bob'
101 assert payload['metadata']['meta'] == 'data'
102
103
104def test_decode_bad_token_wrong_number_of_segments():
105 with pytest.raises(ValueError) as excinfo:
106 jwt.decode('1.2', PUBLIC_CERT_BYTES)
107 assert excinfo.match(r'Wrong number of segments')
108
109
110def test_decode_bad_token_not_base64():
111 with pytest.raises((ValueError, TypeError)) as excinfo:
112 jwt.decode('1.2.3', PUBLIC_CERT_BYTES)
113 assert excinfo.match(r'Incorrect padding')
114
115
116def test_decode_bad_token_not_json():
117 token = b'.'.join([base64.urlsafe_b64encode(b'123!')] * 3)
118 with pytest.raises(ValueError) as excinfo:
119 jwt.decode(token, PUBLIC_CERT_BYTES)
120 assert excinfo.match(r'Can\'t parse segment')
121
122
123def test_decode_bad_token_no_iat_or_exp(signer):
124 token = jwt.encode(signer, {'test': 'value'})
125 with pytest.raises(ValueError) as excinfo:
126 jwt.decode(token, PUBLIC_CERT_BYTES)
127 assert excinfo.match(r'Token does not contain required claim')
128
129
130def test_decode_bad_token_too_early(token_factory):
131 token = token_factory(claims={
132 'iat': _helpers.datetime_to_secs(
133 _helpers.utcnow() + datetime.timedelta(hours=1))
134 })
135 with pytest.raises(ValueError) as excinfo:
136 jwt.decode(token, PUBLIC_CERT_BYTES)
137 assert excinfo.match(r'Token used too early')
138
139
140def test_decode_bad_token_expired(token_factory):
141 token = token_factory(claims={
142 'exp': _helpers.datetime_to_secs(
143 _helpers.utcnow() - datetime.timedelta(hours=1))
144 })
145 with pytest.raises(ValueError) as excinfo:
146 jwt.decode(token, PUBLIC_CERT_BYTES)
147 assert excinfo.match(r'Token expired')
148
149
150def test_decode_bad_token_wrong_audience(token_factory):
151 token = token_factory()
152 audience = 'audience2@example.com'
153 with pytest.raises(ValueError) as excinfo:
154 jwt.decode(token, PUBLIC_CERT_BYTES, audience=audience)
155 assert excinfo.match(r'Token has wrong audience')
156
157
158def test_decode_wrong_cert(token_factory):
159 with pytest.raises(ValueError) as excinfo:
160 jwt.decode(token_factory(), OTHER_CERT_BYTES)
161 assert excinfo.match(r'Could not verify token signature')
162
163
164def test_decode_multicert_bad_cert(token_factory):
165 certs = {'1': OTHER_CERT_BYTES, '2': PUBLIC_CERT_BYTES}
166 with pytest.raises(ValueError) as excinfo:
167 jwt.decode(token_factory(), certs)
168 assert excinfo.match(r'Could not verify token signature')
169
170
171def test_decode_no_cert(token_factory):
172 certs = {'2': PUBLIC_CERT_BYTES}
173 with pytest.raises(ValueError) as excinfo:
174 jwt.decode(token_factory(), certs)
175 assert excinfo.match(r'Certificate for key id 1 not found')
176
177
178def test_decode_no_key_id(token_factory):
179 token = token_factory(key_id=False)
180 certs = {'2': PUBLIC_CERT_BYTES}
181 payload = jwt.decode(token, certs)
182 assert payload['user'] == 'billy bob'
183
184
185def test_roundtrip_explicit_key_id(token_factory):
186 token = token_factory(key_id='3')
187 certs = {'2': OTHER_CERT_BYTES, '3': PUBLIC_CERT_BYTES}
188 payload = jwt.decode(token, certs)
189 assert payload['user'] == 'billy bob'