blob: 74342ce03f68d4fc1d46949c1483e94219ae3f2c [file] [log] [blame]
salrashid1231fbc6792018-11-09 11:05:34 -08001# Copyright 2018 Google Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import datetime
16import json
17import os
18
19import mock
20import pytest
21from six.moves import http_client
22
23from google.auth import _helpers
24from google.auth import crypt
25from google.auth import exceptions
26from google.auth import impersonated_credentials
27from google.auth import transport
28from google.auth.impersonated_credentials import Credentials
29from google.oauth2 import service_account
30
31DATA_DIR = os.path.join(os.path.dirname(__file__), '', 'data')
32
33with open(os.path.join(DATA_DIR, 'privatekey.pem'), 'rb') as fh:
34 PRIVATE_KEY_BYTES = fh.read()
35
36SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, 'service_account.json')
37
38with open(SERVICE_ACCOUNT_JSON_FILE, 'r') as fh:
39 SERVICE_ACCOUNT_INFO = json.load(fh)
40
41SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, '1')
42TOKEN_URI = 'https://example.com/oauth2/token'
43
44
45@pytest.fixture
46def mock_donor_credentials():
47 with mock.patch('google.oauth2._client.jwt_grant', autospec=True) as grant:
48 grant.return_value = (
49 "source token",
50 _helpers.utcnow() + datetime.timedelta(seconds=500),
51 {})
52 yield grant
53
54
55class TestImpersonatedCredentials(object):
56
57 SERVICE_ACCOUNT_EMAIL = 'service-account@example.com'
58 TARGET_PRINCIPAL = 'impersonated@project.iam.gserviceaccount.com'
59 TARGET_SCOPES = ['https://www.googleapis.com/auth/devstorage.read_only']
60 DELEGATES = []
61 LIFETIME = 3600
62 SOURCE_CREDENTIALS = service_account.Credentials(
63 SIGNER, SERVICE_ACCOUNT_EMAIL, TOKEN_URI)
64
65 def make_credentials(self, lifetime=LIFETIME):
66 return Credentials(
67 source_credentials=self.SOURCE_CREDENTIALS,
68 target_principal=self.TARGET_PRINCIPAL,
69 target_scopes=self.TARGET_SCOPES,
70 delegates=self.DELEGATES,
71 lifetime=lifetime)
72
73 def test_default_state(self):
74 credentials = self.make_credentials()
75 assert not credentials.valid
76 assert credentials.expired
77
78 def make_request(self, data, status=http_client.OK,
79 headers=None, side_effect=None):
80 response = mock.create_autospec(transport.Response, instance=False)
81 response.status = status
82 response.data = _helpers.to_bytes(data)
83 response.headers = headers or {}
84
85 request = mock.create_autospec(transport.Request, instance=False)
86 request.side_effect = side_effect
87 request.return_value = response
88
89 return request
90
91 def test_refresh_success(self, mock_donor_credentials):
92 credentials = self.make_credentials(lifetime=None)
93 token = 'token'
94
95 expire_time = (
96 _helpers.utcnow().replace(microsecond=0) +
97 datetime.timedelta(seconds=500)).isoformat('T') + 'Z'
98 response_body = {
99 "accessToken": token,
100 "expireTime": expire_time
101 }
102
103 request = self.make_request(
104 data=json.dumps(response_body),
105 status=http_client.OK)
106
107 credentials.refresh(request)
108
109 assert credentials.valid
110 assert not credentials.expired
111
112 def test_refresh_failure_malformed_expire_time(
113 self, mock_donor_credentials):
114 credentials = self.make_credentials(lifetime=None)
115 token = 'token'
116
117 expire_time = (
118 _helpers.utcnow() + datetime.timedelta(seconds=500)).isoformat('T')
119 response_body = {
120 "accessToken": token,
121 "expireTime": expire_time
122 }
123
124 request = self.make_request(
125 data=json.dumps(response_body),
126 status=http_client.OK)
127
128 with pytest.raises(exceptions.RefreshError) as excinfo:
129 credentials.refresh(request)
130
131 assert excinfo.match(impersonated_credentials._REFRESH_ERROR)
132
133 assert not credentials.valid
134 assert credentials.expired
135
136 def test_refresh_failure_lifetime_specified(self, mock_donor_credentials):
137 credentials = self.make_credentials(lifetime=500)
138 token = 'token'
139
140 expire_time = (
141 _helpers.utcnow().replace(microsecond=0) +
142 datetime.timedelta(seconds=500)).isoformat('T') + 'Z'
143 response_body = {
144 "accessToken": token,
145 "expireTime": expire_time
146 }
147
148 request = self.make_request(
149 data=json.dumps(response_body),
150 status=http_client.OK)
151
152 credentials.refresh(request)
153
154 with pytest.raises(exceptions.RefreshError) as excinfo:
155 credentials.refresh(request)
156
157 assert excinfo.match(impersonated_credentials._LIFETIME_ERROR)
158
159 assert not credentials.valid
160 assert credentials.expired
161
162 def test_refresh_failure_unauthorzed(self, mock_donor_credentials):
163 credentials = self.make_credentials(lifetime=None)
164
165 response_body = {
166 "error": {
167 "code": 403,
168 "message": "The caller does not have permission",
169 "status": "PERMISSION_DENIED"
170 }
171 }
172
173 request = self.make_request(
174 data=json.dumps(response_body),
175 status=http_client.UNAUTHORIZED)
176
177 with pytest.raises(exceptions.RefreshError) as excinfo:
178 credentials.refresh(request)
179
180 assert excinfo.match(impersonated_credentials._REFRESH_ERROR)
181
182 assert not credentials.valid
183 assert credentials.expired
184
185 def test_refresh_failure_http_error(self, mock_donor_credentials):
186 credentials = self.make_credentials(lifetime=None)
187
188 response_body = {}
189
190 request = self.make_request(
191 data=json.dumps(response_body),
192 status=http_client.HTTPException)
193
194 with pytest.raises(exceptions.RefreshError) as excinfo:
195 credentials.refresh(request)
196
197 assert excinfo.match(impersonated_credentials._REFRESH_ERROR)
198
199 assert not credentials.valid
200 assert credentials.expired
201
202 def test_expired(self):
203 credentials = self.make_credentials(lifetime=None)
204 assert credentials.expired