Fix refresh token handling and add unit tests for oauth2client.client
diff --git a/apiclient/http.py b/apiclient/http.py
index cecf3f2..f713ec8 100644
--- a/apiclient/http.py
+++ b/apiclient/http.py
@@ -177,3 +177,45 @@
               redirections=1,
               connection_type=None):
     return httplib2.Response(self.headers), self.data
+
+
+class HttpMockSequence(object):
+  """Mock of httplib2.Http
+
+  Mocks a sequence of calls to request returning different responses for each
+  call. Create an instance initialized with the desired response headers
+  and content and then use as if an httplib2.Http instance.
+
+    http = HttpMockSequence([
+      ({'status': '401'}, ''),
+      ({'status': '200'}, '{"access_token":"1/3w","expires_in":3600}'),
+      ({'status': '200'}, 'echo_request_headers'),
+      ])
+    resp, content = http.request("http://examples.com")
+
+  There are special values you can pass in for content to trigger
+  behavours that are helpful in testing.
+
+  'echo_request_headers' means return the request headers in the response body
+  'echo_request_body' means return the request body in the response body
+  """
+
+  def __init__(self, iterable):
+    """
+    Args:
+      iterable: iterable, a sequence of pairs of (headers, body)
+    """
+    self._iterable = iterable
+
+  def request(self, uri,
+              method='GET',
+              body=None,
+              headers=None,
+              redirections=1,
+              connection_type=None):
+    resp, content = self._iterable.pop(0)
+    if content == 'echo_request_headers':
+      content = headers
+    elif content == 'echo_request_body':
+      content = body
+    return httplib2.Response(resp), content
diff --git a/oauth2client/client.py b/oauth2client/client.py
index 3d7ded7..c8d0941 100644
--- a/oauth2client/client.py
+++ b/oauth2client/client.py
@@ -36,7 +36,12 @@
   pass
 
 
-class RequestError(Error):
+class FlowExchangeError(Error):
+  """Error occurred during request."""
+  pass
+
+
+class AccessTokenRefreshError(Error):
   """Error occurred during request."""
   pass
 
@@ -144,7 +149,7 @@
   @property
   def invalid(self):
     """True if the credentials are invalid, such as being revoked."""
-    return self._invalid
+    return getattr(self, '_invalid', False)
 
   def set_store(self, store):
     """Set the storage for the credential.
@@ -204,16 +209,18 @@
     else:
       # An {'error':...} response body means the token is expired or revoked, so
       # we flag the credentials as such.
+      logging.error('Failed to retrieve access token: %s' % content)
+      error_msg = 'Invalid response %s.' % resp['status']
       try:
         d = simplejson.loads(content)
         if 'error' in d:
+          error_msg = d['error']
           self._invalid = True
           if self.store is not None:
             self.store(self)
       except:
         pass
-      logging.error('Failed to retrieve access token: %s' % content)
-      raise RequestError('Invalid response %s.' % resp['status'])
+      raise AccessTokenRefreshError(error_msg)
 
   def authorize(self, http):
     """Authorize an httplib2.Http instance with these credentials.
@@ -258,6 +265,7 @@
       if resp.status == 401:
         logging.info("Refreshing because we got a 401")
         self._refresh(request_orig)
+        headers['authorization'] = 'OAuth ' + self.access_token
         return request_orig(uri, method, body, headers,
                             redirections, connection_type)
       else:
@@ -378,13 +386,14 @@
     parts[4] = urllib.urlencode(query)
     return urlparse.urlunparse(parts)
 
-  def step2_exchange(self, code):
+  def step2_exchange(self, code, http=None):
     """Exhanges a code for OAuth2Credentials.
 
     Args:
       code: string or dict, either the code as a string, or a dictionary
         of the query parameters to the redirect_uri, which contains
         the code.
+      http: httplib2.Http, optional http instance to use to do the fetch
     """
 
     if not (isinstance(code, str) or isinstance(code, unicode)):
@@ -402,8 +411,9 @@
       'user-agent': self.user_agent,
       'content-type': 'application/x-www-form-urlencoded'
     }
-    h = httplib2.Http()
-    resp, content = h.request(self.token_uri, method='POST', body=body, headers=headers)
+    if http is None:
+      http = httplib2.Http()
+    resp, content = http.request(self.token_uri, method='POST', body=body, headers=headers)
     if resp.status == 200:
       # TODO(jcgregorio) Raise an error if simplejson.loads fails?
       d = simplejson.loads(content)
@@ -419,4 +429,12 @@
                                self.user_agent)
     else:
       logging.error('Failed to retrieve access token: %s' % content)
-      raise RequestError('Invalid response %s.' % resp['status'])
+      error_msg = 'Invalid response %s.' % resp['status']
+      try:
+        d = simplejson.loads(content)
+        if 'error' in d:
+          error_msg = d['error']
+      except:
+        pass
+
+      raise FlowExchangeError(error_msg)
diff --git a/samples/moderator/moderator.py b/samples/moderator/moderator.py
index c22f86e..93e3ff6 100644
--- a/samples/moderator/moderator.py
+++ b/samples/moderator/moderator.py
@@ -33,7 +33,7 @@
     flow = FlowThreeLegged(moderator_discovery,
                            consumer_key='anonymous',
                            consumer_secret='anonymous',
-                           user_agent='google-api-client-python-mdrtr-cmdline/1.0',
+                           user_agent='python-moderator-sample/1.0',
                            domain='anonymous',
                            scope='https://www.googleapis.com/auth/moderator',
                            #scope='tag:google.com,2010:auth/moderator',
diff --git a/samples/threadqueue/main.py b/samples/threadqueue/main.py
index d184f6d..63d092d 100644
--- a/samples/threadqueue/main.py
+++ b/samples/threadqueue/main.py
@@ -88,7 +88,7 @@
     flow = FlowThreeLegged(moderator_discovery,
                            consumer_key='anonymous',
                            consumer_secret='anonymous',
-                           user_agent='google-api-client-python-thread-sample/1.0',
+                           user_agent='python-threading-sample/1.0',
                            domain='anonymous',
                            scope='https://www.googleapis.com/auth/moderator',
                            xoauth_displayname='Google API Client Example App')
diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py
new file mode 100644
index 0000000..9f30eff
--- /dev/null
+++ b/tests/test_oauth2client.py
@@ -0,0 +1,172 @@
+#!/usr/bin/python2.4
+#
+# Copyright 2010 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Discovery document tests
+
+Unit tests for objects created from discovery documents.
+"""
+
+__author__ = 'jcgregorio@google.com (Joe Gregorio)'
+
+import unittest
+import urlparse
+try:
+    from urlparse import parse_qs
+except ImportError:
+    from cgi import parse_qs
+
+from apiclient.http import HttpMockSequence
+from oauth2client.client import AccessTokenCredentials
+from oauth2client.client import AccessTokenCredentialsError
+from oauth2client.client import AccessTokenRefreshError
+from oauth2client.client import FlowExchangeError
+from oauth2client.client import OAuth2Credentials
+from oauth2client.client import OAuth2WebServerFlow
+
+
+class OAuth2CredentialsTests(unittest.TestCase):
+
+  def setUp(self):
+    access_token = "foo"
+    client_id = "some_client_id"
+    client_secret = "cOuDdkfjxxnv+"
+    refresh_token = "1/0/a.df219fjls0"
+    token_expiry = "ignored"
+    token_uri = "https://www.google.com/accounts/o8/oauth2/token"
+    user_agent = "refresh_checker/1.0"
+    self.credentials = OAuth2Credentials(
+      access_token, client_id, client_secret,
+      refresh_token, token_expiry, token_uri,
+      user_agent)
+
+  def test_token_refresh_success(self):
+    http = HttpMockSequence([
+      ({'status': '401'}, ''),
+      ({'status': '200'}, '{"access_token":"1/3w","expires_in":3600}'),
+      ({'status': '200'}, 'echo_request_headers'),
+      ])
+    http = self.credentials.authorize(http)
+    resp, content = http.request("http://example.com")
+    self.assertEqual(content['authorization'], 'OAuth 1/3w')
+
+  def test_token_refresh_failure(self):
+    http = HttpMockSequence([
+      ({'status': '401'}, ''),
+      ({'status': '400'}, '{"error":"access_denied"}'),
+      ])
+    http = self.credentials.authorize(http)
+    try:
+      http.request("http://example.com")
+      self.fail("should raise AccessTokenRefreshError exception")
+    except AccessTokenRefreshError:
+      pass
+
+  def test_non_401_error_response(self):
+    http = HttpMockSequence([
+      ({'status': '400'}, ''),
+      ])
+    http = self.credentials.authorize(http)
+    resp, content = http.request("http://example.com")
+    self.assertEqual(400, resp.status)
+
+
+class AccessTokenCredentialsTests(unittest.TestCase):
+
+  def setUp(self):
+    access_token = "foo"
+    user_agent = "refresh_checker/1.0"
+    self.credentials = AccessTokenCredentials(access_token, user_agent)
+
+  def test_token_refresh_success(self):
+    http = HttpMockSequence([
+      ({'status': '401'}, ''),
+      ])
+    http = self.credentials.authorize(http)
+    try:
+      resp, content = http.request("http://example.com")
+      self.fail("should throw exception if token expires")
+    except AccessTokenCredentialsError:
+      pass
+    except Exception:
+      self.fail("should only throw AccessTokenCredentialsError")
+
+  def test_non_401_error_response(self):
+    http = HttpMockSequence([
+      ({'status': '400'}, ''),
+      ])
+    http = self.credentials.authorize(http)
+    resp, content = http.request("http://example.com")
+    self.assertEqual(400, resp.status)
+
+
+class OAuth2WebServerFlowTest(unittest.TestCase):
+
+  def setUp(self):
+    self.flow = OAuth2WebServerFlow(
+        client_id='client_id+1',
+        client_secret='secret+1',
+        scope='foo',
+        user_agent='unittest-sample/1.0',
+        )
+
+  def test_construct_authorize_url(self):
+    authorize_url = self.flow.step1_get_authorize_url('oob')
+
+    parsed = urlparse.urlparse(authorize_url)
+    q = parse_qs(parsed[4])
+    self.assertEqual(q['client_id'][0], 'client_id+1')
+    self.assertEqual(q['response_type'][0], 'code')
+    self.assertEqual(q['scope'][0], 'foo')
+    self.assertEqual(q['redirect_uri'][0], 'oob')
+
+  def test_exchange_failure(self):
+    http = HttpMockSequence([
+      ({'status': '400'}, '{"error":"invalid_request"}')
+      ])
+
+    try:
+      credentials = self.flow.step2_exchange('some random code', http)
+      self.fail("should raise exception if exchange doesn't get 200")
+    except FlowExchangeError:
+      pass
+
+  def test_exchange_success(self):
+    http = HttpMockSequence([
+      ({'status': '200'},
+      """{ "access_token":"SlAV32hkKG",
+       "expires_in":3600,
+       "refresh_token":"8xLOxBtZp8" }"""),
+      ])
+
+    credentials = self.flow.step2_exchange('some random code', http)
+    self.assertEqual(credentials.access_token, 'SlAV32hkKG')
+    self.assertNotEqual(credentials.token_expiry, None)
+    self.assertEqual(credentials.refresh_token, '8xLOxBtZp8')
+
+
+  def test_exchange_no_expires_in(self):
+    http = HttpMockSequence([
+      ({'status': '200'}, """{ "access_token":"SlAV32hkKG",
+       "refresh_token":"8xLOxBtZp8" }"""),
+      ])
+
+    credentials = self.flow.step2_exchange('some random code', http)
+    self.assertEqual(credentials.token_expiry, None)
+
+
+if __name__ == '__main__':
+  unittest.main()