blob: eeaa45354b1d1cad30133cc31f866d3952c7a716 [file] [log] [blame]
Joe Gregorio695fdc12011-01-16 16:46:55 -05001# Copyright 2010 Google Inc. All Rights Reserved.
2
3"""An OAuth 2.0 client
4
5Tools for interacting with OAuth 2.0 protected
6resources.
7"""
8
9__author__ = 'jcgregorio@google.com (Joe Gregorio)'
10
11import copy
12import datetime
13import httplib2
14import logging
15import urllib
16import urlparse
17
18try: # pragma: no cover
19 import simplejson
20except ImportError: # pragma: no cover
21 try:
22 # Try to import from django, should work on App Engine
23 from django.utils import simplejson
24 except ImportError:
25 # Should work for Python2.6 and higher.
26 import json as simplejson
27
28try:
29 from urlparse import parse_qsl
30except ImportError:
31 from cgi import parse_qsl
32
33
34class Error(Exception):
35 """Base error for this module."""
36 pass
37
38
39class RequestError(Error):
40 """Error occurred during request."""
41 pass
42
43
44class MissingParameter(Error):
45 pass
46
47
48def _abstract():
49 raise NotImplementedError('You need to override this function')
50
51
52class Credentials(object):
53 """Base class for all Credentials objects.
54
55 Subclasses must define an authorize() method
56 that applies the credentials to an HTTP transport.
57 """
58
59 def authorize(self, http):
60 """Take an httplib2.Http instance (or equivalent) and
61 authorizes it for the set of credentials, usually by
62 replacing http.request() with a method that adds in
63 the appropriate headers and then delegates to the original
64 Http.request() method.
65 """
66 _abstract()
67
68class Flow(object):
69 """Base class for all Flow objects."""
70 pass
71
72
73class OAuth2Credentials(Credentials):
74 """Credentials object for OAuth 2.0
75
76 Credentials can be applied to an httplib2.Http object
77 using the authorize() method, which then signs each
78 request from that object with the OAuth 2.0 access token.
79
80 OAuth2Credentials objects may be safely pickled and unpickled.
81 """
82
83 def __init__(self, access_token, client_id, client_secret, refresh_token,
84 token_expiry, token_uri, user_agent):
85 """Create an instance of OAuth2Credentials
86
87 This constructor is not usually called by the user, instead
88 OAuth2Credentials objects are instantiated by
89 the OAuth2WebServerFlow.
90
91 Args:
92 token_uri: string, URI of token endpoint
93 client_id: string, client identifier
94 client_secret: string, client secret
95 access_token: string, access token
96 token_expiry: datetime, when the access_token expires
97 refresh_token: string, refresh token
98 user_agent: string, The HTTP User-Agent to provide for this application.
99
100
101 Notes:
102 store: callable, a callable that when passed a Credential
103 will store the credential back to where it came from.
104 This is needed to store the latest access_token if it
105 has expired and been refreshed.
106 """
107 self.access_token = access_token
108 self.client_id = client_id
109 self.client_secret = client_secret
110 self.refresh_token = refresh_token
111 self.store = None
112 self.token_expiry = token_expiry
113 self.token_uri = token_uri
114 self.user_agent = user_agent
115
116 def set_store(self, store):
117 """Set the storage for the credential.
118
119 Args:
120 store: callable, a callable that when passed a Credential
121 will store the credential back to where it came from.
122 This is needed to store the latest access_token if it
123 has expired and been refreshed.
124 """
125 self.store = store
126
127 def __getstate__(self):
128 """Trim the state down to something that can be pickled.
129 """
130 d = copy.copy(self.__dict__)
131 del d['store']
132 return d
133
134 def __setstate__(self, state):
135 """Reconstitute the state of the object from being pickled.
136 """
137 self.__dict__.update(state)
138 self.store = None
139
140 def _refresh(self, http_request):
141 """Refresh the access_token using the refresh_token.
142
143 Args:
144 http: An instance of httplib2.Http.request
145 or something that acts like it.
146 """
147 body = urllib.urlencode({
148 'grant_type': 'refresh_token',
149 'client_id': self.client_id,
150 'client_secret': self.client_secret,
151 'refresh_token' : self.refresh_token
152 })
153 headers = {
154 'user-agent': self.user_agent,
155 'content-type': 'application/x-www-form-urlencoded'
156 }
157 resp, content = http_request(self.token_uri, method='POST', body=body, headers=headers)
158 if resp.status == 200:
159 # TODO(jcgregorio) Raise an error if loads fails?
160 d = simplejson.loads(content)
161 self.access_token = d['access_token']
162 self.refresh_token = d.get('refresh_token', self.refresh_token)
163 if 'expires_in' in d:
164 self.token_expiry = datetime.timedelta(seconds = int(d['expires_in'])) + datetime.datetime.now()
165 else:
166 self.token_expiry = None
167 if self.store is not None:
168 self.store(self)
169 else:
170 logging.error('Failed to retrieve access token: %s' % content)
171 raise RequestError('Invalid response %s.' % resp['status'])
172
173 def authorize(self, http):
174 """
175 Args:
176 http: An instance of httplib2.Http
177 or something that acts like it.
178
179 Returns:
180 A modified instance of http that was passed in.
181
182 Example:
183
184 h = httplib2.Http()
185 h = credentials.authorize(h)
186
187 You can't create a new OAuth
188 subclass of httplib2.Authenication because
189 it never gets passed the absolute URI, which is
190 needed for signing. So instead we have to overload
191 'request' with a closure that adds in the
192 Authorization header and then calls the original version
193 of 'request()'.
194 """
195 request_orig = http.request
196
197 # The closure that will replace 'httplib2.Http.request'.
198 def new_request(uri, method='GET', body=None, headers=None,
199 redirections=httplib2.DEFAULT_MAX_REDIRECTS,
200 connection_type=None):
201 """Modify the request headers to add the appropriate
202 Authorization header."""
203 if headers == None:
204 headers = {}
Joe Gregorio49e94d82011-01-28 16:36:13 -0500205 headers['authorization'] = 'OAuth ' + self.access_token
Joe Gregorio695fdc12011-01-16 16:46:55 -0500206 if 'user-agent' in headers:
207 headers['user-agent'] = self.user_agent + ' ' + headers['user-agent']
208 else:
209 headers['user-agent'] = self.user_agent
210 resp, content = request_orig(uri, method, body, headers,
211 redirections, connection_type)
Joe Gregoriofd19cd32011-01-20 11:37:29 -0500212 if resp.status == 401:
Joe Gregorio695fdc12011-01-16 16:46:55 -0500213 logging.info("Refreshing because we got a 401")
214 self._refresh(request_orig)
215 return request_orig(uri, method, body, headers,
216 redirections, connection_type)
217 else:
218 return (resp, content)
219
220 http.request = new_request
221 return http
222
223
224class OAuth2WebServerFlow(Flow):
225 """Does the Web Server Flow for OAuth 2.0.
226
227 OAuth2Credentials objects may be safely pickled and unpickled.
228 """
229
230 def __init__(self, client_id, client_secret, scope, user_agent,
231 authorization_uri='https://www.google.com/accounts/o8/oauth2/authorization',
232 token_uri='https://www.google.com/accounts/o8/oauth2/token',
233 **kwargs):
234 """Constructor for OAuth2WebServerFlow
235
236 Args:
237 client_id: string, client identifier
238 client_secret: string client secret
239 scope: string, scope of the credentials being requested
240 user_agent: string, HTTP User-Agent to provide for this application.
241 authorization_uri: string, URI for authorization endpoint
242 token_uri: string, URI for token endpoint
243 **kwargs: dict, The keyword arguments are all optional and required
244 parameters for the OAuth calls.
245 """
246 self.client_id = client_id
247 self.client_secret = client_secret
248 self.scope = scope
249 self.user_agent = user_agent
250 self.authorization_uri = authorization_uri
251 self.token_uri = token_uri
252 self.params = kwargs
253 self.redirect_uri = None
254
255 def step1_get_authorize_url(self, redirect_uri='oob'):
256 """Returns a URI to redirect to the provider.
257
258 Args:
259 redirect_uri: string, Either the string 'oob' for a non-web-based
260 application, or a URI that handles the callback from
261 the authorization server.
262
263 If redirect_uri is 'oob' then pass in the
264 generated verification code to step2_exchange,
265 otherwise pass in the query parameters received
266 at the callback uri to step2_exchange.
267 """
268
269 self.redirect_uri = redirect_uri
270 query = {
271 'response_type': 'code',
272 'client_id': self.client_id,
273 'redirect_uri': redirect_uri,
274 'scope': self.scope,
275 }
276 query.update(self.params)
277 parts = list(urlparse.urlparse(self.authorization_uri))
278 query.update(dict(parse_qsl(parts[4]))) # 4 is the index of the query part
279 parts[4] = urllib.urlencode(query)
280 return urlparse.urlunparse(parts)
281
282 def step2_exchange(self, code):
283 """Exhanges a code for OAuth2Credentials.
284
285 Args:
286 code: string or dict, either the code as a string, or a dictionary
287 of the query parameters to the redirect_uri, which contains
288 the code.
289 """
290
291 if not (isinstance(code, str) or isinstance(code, unicode)):
292 code = code['code']
293
294 body = urllib.urlencode({
295 'grant_type': 'authorization_code',
296 'client_id': self.client_id,
297 'client_secret': self.client_secret,
298 'code': code,
299 'redirect_uri': self.redirect_uri,
300 'scope': self.scope
301 })
302 headers = {
303 'user-agent': self.user_agent,
304 'content-type': 'application/x-www-form-urlencoded'
305 }
306 h = httplib2.Http()
307 resp, content = h.request(self.token_uri, method='POST', body=body, headers=headers)
308 if resp.status == 200:
309 # TODO(jcgregorio) Raise an error if simplejson.loads fails?
310 d = simplejson.loads(content)
311 access_token = d['access_token']
312 refresh_token = d.get('refresh_token', None)
313 token_expiry = None
314 if 'expires_in' in d:
315 token_expiry = datetime.datetime.now() + datetime.timedelta(seconds = int(d['expires_in']))
316
317 logging.info('Successfully retrieved access token: %s' % content)
318 return OAuth2Credentials(access_token, self.client_id, self.client_secret,
319 refresh_token, token_expiry, self.token_uri,
320 self.user_agent)
321 else:
322 logging.error('Failed to retrieve access token: %s' % content)
323 raise RequestError('Invalid response %s.' % resp['status'])