Fix bugs with auth, batch and retries.

Reviewed in http://codereview.appspot.com/5633052/.
diff --git a/apiclient/http.py b/apiclient/http.py
index 94eb266..ff61cb1 100644
--- a/apiclient/http.py
+++ b/apiclient/http.py
@@ -37,6 +37,7 @@
 import urlparse
 import uuid
 
+from email.generator import Generator
 from email.mime.multipart import MIMEMultipart
 from email.mime.nonmultipart import MIMENonMultipart
 from email.parser import FeedParser
@@ -498,9 +499,12 @@
     # Global callback to be called for each individual response in the batch.
     self._callback = callback
 
-    # A map from id to (request, callback) pairs.
+    # A map from id to request.
     self._requests = {}
 
+    # A map from id to callback.
+    self._callbacks = {}
+
     # List of request ids, in the order in which they were added.
     self._order = []
 
@@ -510,6 +514,39 @@
     # Unique ID on which to base the Content-ID headers.
     self._base_id = None
 
+    # A map from request id to (headers, content) response pairs
+    self._responses = {}
+
+    # A map of id(Credentials) that have been refreshed.
+    self._refreshed_credentials = {}
+
+  def _refresh_and_apply_credentials(self, request, http):
+    """Refresh the credentials and apply to the request.
+
+    Args:
+      request: HttpRequest, the request.
+      http: httplib2.Http, the global http object for the batch.
+    """
+    # For the credentials to refresh, but only once per refresh_token
+    # If there is no http per the request then refresh the http passed in
+    # via execute()
+    creds = None
+    if request.http is not None and hasattr(request.http.request,
+        'credentials'):
+      creds = request.http.request.credentials
+    elif http is not None and hasattr(http.request, 'credentials'):
+      creds = http.request.credentials
+    if creds is not None:
+      if id(creds) not in self._refreshed_credentials:
+        creds.refresh(http)
+        self._refreshed_credentials[id(creds)] = 1
+
+    # Only apply the credentials if we are using the http object passed in,
+    # otherwise apply() will get called during _serialize_request().
+    if request.http is None or not hasattr(request.http.request,
+        'credentials'):
+      creds.apply(request.headers)
+
   def _id_to_header(self, id_):
     """Convert an id to a Content-ID header value.
 
@@ -568,6 +605,10 @@
     msg = MIMENonMultipart(major, minor)
     headers = request.headers.copy()
 
+    if request.http is not None and hasattr(request.http.request,
+        'credentials'):
+      request.http.request.credentials.apply(headers)
+
     # MIMENonMultipart adds its own Content-Type header.
     if 'content-type' in headers:
       del headers['content-type']
@@ -581,7 +622,13 @@
       msg.set_payload(request.body)
       msg['content-length'] = str(len(request.body))
 
-    body = msg.as_string(False)
+    # Serialize the mime message.
+    fp = StringIO.StringIO()
+    # maxheaderlen=0 means don't line wrap headers.
+    g = Generator(fp, maxheaderlen=0)
+    g.flatten(msg, unixfrom=False)
+    body = fp.getvalue()
+
     # Strip off the \n\n that the MIME lib tacks onto the end of the payload.
     if request.body is None:
       body = body[:-2]
@@ -661,9 +708,71 @@
       raise BatchError("Resumable requests cannot be used in a batch request.")
     if request_id in self._requests:
       raise KeyError("A request with this ID already exists: %s" % request_id)
-    self._requests[request_id] = (request, callback)
+    self._requests[request_id] = request
+    self._callbacks[request_id] = callback
     self._order.append(request_id)
 
+  def _execute(self, http, order, requests):
+    """Serialize batch request, send to server, process response.
+
+    Args:
+      http: httplib2.Http, an http object to be used to make the request with.
+      order: list, list of request ids in the order they were added to the
+        batch.
+      request: list, list of request objects to send.
+
+    Raises:
+      httplib2.Error if a transport error has occured.
+      apiclient.errors.BatchError if the response is the wrong format.
+    """
+    message = MIMEMultipart('mixed')
+    # Message should not write out it's own headers.
+    setattr(message, '_write_headers', lambda self: None)
+
+    # Add all the individual requests.
+    for request_id in order:
+      request = requests[request_id]
+
+      msg = MIMENonMultipart('application', 'http')
+      msg['Content-Transfer-Encoding'] = 'binary'
+      msg['Content-ID'] = self._id_to_header(request_id)
+
+      body = self._serialize_request(request)
+      msg.set_payload(body)
+      message.attach(msg)
+
+    body = message.as_string()
+
+    headers = {}
+    headers['content-type'] = ('multipart/mixed; '
+                               'boundary="%s"') % message.get_boundary()
+
+    resp, content = http.request(self._batch_uri, 'POST', body=body,
+                                 headers=headers)
+
+    if resp.status >= 300:
+      raise HttpError(resp, content, self._batch_uri)
+
+    # Now break out the individual responses and store each one.
+    boundary, _ = content.split(None, 1)
+
+    # Prepend with a content-type header so FeedParser can handle it.
+    header = 'content-type: %s\r\n\r\n' % resp['content-type']
+    for_parser = header + content
+
+    parser = FeedParser()
+    parser.feed(for_parser)
+    mime_response = parser.close()
+
+    if not mime_response.is_multipart():
+      raise BatchError("Response not in multipart/mixed format.", resp,
+          content)
+
+    for part in mime_response.get_payload():
+      request_id = self._header_to_id(part['Content-ID'])
+      headers, content = self._deserialize_response(part.get_payload())
+      self._responses[request_id] = (headers, content)
+
   def execute(self, http=None):
     """Execute all the requests as a single batched HTTP request.
 
@@ -676,84 +785,61 @@
       None
 
     Raises:
-      apiclient.errors.HttpError if the response was not a 2xx.
       httplib2.Error if a transport error has occured.
       apiclient.errors.BatchError if the response is the wrong format.
     """
+
+    # If http is not supplied use the first valid one given in the requests.
     if http is None:
       for request_id in self._order:
-        request, callback = self._requests[request_id]
+        request = self._requests[request_id]
         if request is not None:
           http = request.http
           break
+
     if http is None:
       raise ValueError("Missing a valid http object.")
 
+    self._execute(http, self._order, self._requests)
 
-    msgRoot = MIMEMultipart('mixed')
-    # msgRoot should not write out it's own headers
-    setattr(msgRoot, '_write_headers', lambda self: None)
+    # Loop over all the requests and check for 401s. For each 401 request the
+    # credentials should be refreshed and then sent again in a separate batch.
+    redo_requests = {}
+    redo_order = []
 
-    # Add all the individual requests.
     for request_id in self._order:
-      request, callback = self._requests[request_id]
+      headers, content = self._responses[request_id]
+      if headers['status'] == '401':
+        redo_order.append(request_id)
+        request = self._requests[request_id]
+        self._refresh_and_apply_credentials(request, http)
+        redo_requests[request_id] = request
 
-      msg = MIMENonMultipart('application', 'http')
-      msg['Content-Transfer-Encoding'] = 'binary'
-      msg['Content-ID'] = self._id_to_header(request_id)
+    if redo_requests:
+      self._execute(http, redo_order, redo_requests)
 
-      body = self._serialize_request(request)
-      msg.set_payload(body)
-      msgRoot.attach(msg)
+    # Now process all callbacks that are erroring, and raise an exception for
+    # ones that return a non-2xx response? Or add extra parameter to callback
+    # that contains an HttpError?
 
-    body = msgRoot.as_string()
+    for request_id in self._order:
+      headers, content = self._responses[request_id]
 
-    headers = {}
-    headers['content-type'] = ('multipart/mixed; '
-                               'boundary="%s"') % msgRoot.get_boundary()
+      request = self._requests[request_id]
+      callback = self._callbacks[request_id]
 
-    resp, content = http.request(self._batch_uri, 'POST', body=body,
-                                 headers=headers)
+      response = None
+      exception = None
+      try:
+        r = httplib2.Response(headers)
+        response = request.postproc(r, content)
+      except HttpError, e:
+        exception = e
 
-    if resp.status >= 300:
-      raise HttpError(resp, content, self._batch_uri)
-
-    # Now break up the response and process each one with the correct postproc
-    # and trigger the right callbacks.
-    boundary, _ = content.split(None, 1)
-
-    # Prepend with a content-type header so FeedParser can handle it.
-    header = 'content-type: %s\r\n\r\n' % resp['content-type']
-    for_parser = header + content
-
-    parser = FeedParser()
-    parser.feed(for_parser)
-    respRoot = parser.close()
-
-    if not respRoot.is_multipart():
-      raise BatchError("Response not in multipart/mixed format.", resp,
-          content)
-
-    parts = respRoot.get_payload()
-    for part in parts:
-      request_id = self._header_to_id(part['Content-ID'])
-
-      headers, content = self._deserialize_response(part.get_payload())
-
-      # TODO(jcgregorio) Remove this temporary hack once the server stops
-      # gzipping individual response bodies.
-      if content[0] != '{':
-        gzipped_content = content
-        content = gzip.GzipFile(
-            fileobj=StringIO.StringIO(gzipped_content)).read()
-
-      request, cb = self._requests[request_id]
-      postproc = request.postproc
-      response = postproc(resp, content)
-      if cb is not None:
-        cb(request_id, response)
+      if callback is not None:
+        callback(request_id, response, exception)
       if self._callback is not None:
-        self._callback(request_id, response)
+        self._callback(request_id, response, exception)
 
 
 class HttpRequestMock(object):
diff --git a/oauth2client/client.py b/oauth2client/client.py
index c88b358..ce033ca 100644
--- a/oauth2client/client.py
+++ b/oauth2client/client.py
@@ -129,6 +129,23 @@
     """
     _abstract()
 
+  def refresh(self, http):
+    """Forces a refresh of the access_token.
+
+    Args:
+      http: httplib2.Http, an http object to be used to make the refresh
+        request.
+    """
+    _abstract()
+
+  def apply(self, headers):
+    """Add the authorization to the headers.
+
+    Args:
+      headers: dict, the headers to add the Authorization header to.
+    """
+    _abstract()
+
   def _to_json(self, strip):
     """Utility function for creating a JSON representation of an instance of Credentials.
 
@@ -324,6 +341,92 @@
     # refreshed.
     self.invalid = False
 
+  def authorize(self, http):
+    """Authorize an httplib2.Http instance with these credentials.
+
+    The modified http.request method will add authentication headers to each
+    request and will refresh access_tokens when a 401 is received on a
+    request. In addition the http.request method has a credentials property,
+    http.request.credentials, which is the Credentials object that authorized
+    it.
+
+    Args:
+       http: An instance of httplib2.Http
+           or something that acts like it.
+
+    Returns:
+       A modified instance of http that was passed in.
+
+    Example:
+
+      h = httplib2.Http()
+      h = credentials.authorize(h)
+
+    You can't create a new OAuth subclass of httplib2.Authenication
+    because it never gets passed the absolute URI, which is needed for
+    signing. So instead we have to overload 'request' with a closure
+    that adds in the Authorization header and then calls the original
+    version of 'request()'.
+    """
+    request_orig = http.request
+
+    # The closure that will replace 'httplib2.Http.request'.
+    def new_request(uri, method='GET', body=None, headers=None,
+                    redirections=httplib2.DEFAULT_MAX_REDIRECTS,
+                    connection_type=None):
+      if not self.access_token:
+        logger.info('Attempting refresh to obtain initial access_token')
+        self._refresh(request_orig)
+
+      # Modify the request headers to add the appropriate
+      # Authorization header.
+      if headers is None:
+        headers = {}
+      self.apply(headers)
+
+      if self.user_agent is not None:
+        if 'user-agent' in headers:
+          headers['user-agent'] = self.user_agent + ' ' + headers['user-agent']
+        else:
+          headers['user-agent'] = self.user_agent
+
+      resp, content = request_orig(uri, method, body, headers,
+                                   redirections, connection_type)
+
+      if resp.status == 401:
+        logger.info('Refreshing due to a 401')
+        self._refresh(request_orig)
+        self.apply(headers)
+        return request_orig(uri, method, body, headers,
+                            redirections, connection_type)
+      else:
+        return (resp, content)
+
+    # Replace the request method with our own closure.
+    http.request = new_request
+
+    # Set credentials as a property of the request method.
+    setattr(http.request, 'credentials', self)
+
+    return http
+
+  def refresh(self, http):
+    """Forces a refresh of the access_token.
+
+    Args:
+      http: httplib2.Http, an http object to be used to make the refresh
+        request.
+    """
+    self._refresh(http.request)
+
+  def apply(self, headers):
+    """Add the authorization to the headers.
+
+    Args:
+      headers: dict, the headers to add the Authorization header to.
+    """
+    headers['Authorization'] = 'Bearer ' + self.access_token
+
   def to_json(self):
     return self._to_json(Credentials.NON_SERIALIZED_MEMBERS)
 
@@ -431,6 +534,13 @@
     This method first checks by reading the Storage object if available.
     If a refresh is still needed, it holds the Storage lock until the
     refresh is completed.
+
+    Args:
+      http_request: callable, a callable that matches the method signature of
+        httplib2.Http.request, used to make the refresh request.
+
+    Raises:
+      AccessTokenRefreshError: When the refresh fails.
     """
     if not self.store:
       self._do_refresh_request(http_request)
@@ -451,8 +561,8 @@
     """Refresh the access_token using the refresh_token.
 
     Args:
-       http: An instance of httplib2.Http.request
-           or something that acts like it.
+      http_request: callable, a callable that matches the method signature of
+        httplib2.Http.request, used to make the refresh request.
 
     Raises:
       AccessTokenRefreshError: When the refresh fails.
@@ -491,64 +601,6 @@
         pass
       raise AccessTokenRefreshError(error_msg)
 
-  def authorize(self, http):
-    """Authorize an httplib2.Http instance with these credentials.
-
-    Args:
-       http: An instance of httplib2.Http
-           or something that acts like it.
-
-    Returns:
-       A modified instance of http that was passed in.
-
-    Example:
-
-      h = httplib2.Http()
-      h = credentials.authorize(h)
-
-    You can't create a new OAuth subclass of httplib2.Authenication
-    because it never gets passed the absolute URI, which is needed for
-    signing. So instead we have to overload 'request' with a closure
-    that adds in the Authorization header and then calls the original
-    version of 'request()'.
-    """
-    request_orig = http.request
-
-    # The closure that will replace 'httplib2.Http.request'.
-    def new_request(uri, method='GET', body=None, headers=None,
-                    redirections=httplib2.DEFAULT_MAX_REDIRECTS,
-                    connection_type=None):
-      if not self.access_token:
-        logger.info('Attempting refresh to obtain initial access_token')
-        self._refresh(request_orig)
-
-      # Modify the request headers to add the appropriate
-      # Authorization header.
-      if headers is None:
-        headers = {}
-      headers['authorization'] = 'OAuth ' + self.access_token
-
-      if self.user_agent is not None:
-        if 'user-agent' in headers:
-          headers['user-agent'] = self.user_agent + ' ' + headers['user-agent']
-        else:
-          headers['user-agent'] = self.user_agent
-
-      resp, content = request_orig(uri, method, body, headers,
-                                   redirections, connection_type)
-
-      if resp.status == 401:
-        logger.info('Refreshing due to a 401')
-        self._refresh(request_orig)
-        headers['authorization'] = 'OAuth ' + self.access_token
-        return request_orig(uri, method, body, headers,
-                            redirections, connection_type)
-      else:
-        return (resp, content)
-
-    http.request = new_request
-    return http
-
 
 class AccessTokenCredentials(OAuth2Credentials):
   """Credentials object for OAuth 2.0.
diff --git a/tests/test_http.py b/tests/test_http.py
index 67e8aa6..cb7a832 100644
--- a/tests/test_http.py
+++ b/tests/test_http.py
@@ -35,6 +35,52 @@
 from apiclient.http import MediaInMemoryUpload
 from apiclient.http import set_user_agent
 from apiclient.model import JsonModel
+from oauth2client.client import Credentials
+
+
+class MockCredentials(Credentials):
+  """Mock class for all Credentials objects."""
+  def __init__(self, bearer_token):
+    super(MockCredentials, self).__init__()
+    self._authorized = 0
+    self._refreshed = 0
+    self._applied = 0
+    self._bearer_token = bearer_token
+
+  def authorize(self, http):
+    self._authorized += 1
+
+    request_orig = http.request
+
+    # The closure that will replace 'httplib2.Http.request'.
+    def new_request(uri, method='GET', body=None, headers=None,
+                    redirections=httplib2.DEFAULT_MAX_REDIRECTS,
+                    connection_type=None):
+      # Modify the request headers to add the appropriate
+      # Authorization header.
+      if headers is None:
+        headers = {}
+      self.apply(headers)
+
+      resp, content = request_orig(uri, method, body, headers,
+                                   redirections, connection_type)
+
+      return resp, content
+
+    # Replace the request method with our own closure.
+    http.request = new_request
+
+    # Set credentials as a property of the request method.
+    setattr(http.request, 'credentials', self)
+
+    return http
+
+  def refresh(self, http):
+    self._refreshed += 1
+
+  def apply(self, headers):
+    self._applied += 1
+    headers['authorization'] = self._bearer_token + ' ' + str(self._refreshed)
 
 
 DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
@@ -52,7 +98,7 @@
 
     http = set_user_agent(http, "my_app/5.5")
     resp, content = http.request("http://example.com")
-    self.assertEqual(content['user-agent'], 'my_app/5.5')
+    self.assertEqual('my_app/5.5', content['user-agent'])
 
   def test_set_user_agent_nested(self):
     http = HttpMockSequence([
@@ -62,25 +108,25 @@
     http = set_user_agent(http, "my_app/5.5")
     http = set_user_agent(http, "my_library/0.1")
     resp, content = http.request("http://example.com")
-    self.assertEqual(content['user-agent'], 'my_app/5.5 my_library/0.1')
+    self.assertEqual('my_app/5.5 my_library/0.1', content['user-agent'])
 
   def test_media_file_upload_to_from_json(self):
     upload = MediaFileUpload(
         datafile('small.png'), chunksize=500, resumable=True)
-    self.assertEquals('image/png', upload.mimetype())
-    self.assertEquals(190, upload.size())
-    self.assertEquals(True, upload.resumable())
-    self.assertEquals(500, upload.chunksize())
-    self.assertEquals('PNG', upload.getbytes(1, 3))
+    self.assertEqual('image/png', upload.mimetype())
+    self.assertEqual(190, upload.size())
+    self.assertEqual(True, upload.resumable())
+    self.assertEqual(500, upload.chunksize())
+    self.assertEqual('PNG', upload.getbytes(1, 3))
 
     json = upload.to_json()
     new_upload = MediaUpload.new_from_json(json)
 
-    self.assertEquals('image/png', new_upload.mimetype())
-    self.assertEquals(190, new_upload.size())
-    self.assertEquals(True, new_upload.resumable())
-    self.assertEquals(500, new_upload.chunksize())
-    self.assertEquals('PNG', new_upload.getbytes(1, 3))
+    self.assertEqual('image/png', new_upload.mimetype())
+    self.assertEqual(190, new_upload.size())
+    self.assertEqual(True, new_upload.resumable())
+    self.assertEqual(500, new_upload.chunksize())
+    self.assertEqual('PNG', new_upload.getbytes(1, 3))
 
   def test_http_request_to_from_json(self):
 
@@ -103,13 +149,13 @@
     json = req.to_json()
     new_req = HttpRequest.from_json(json, http, _postproc)
 
-    self.assertEquals(new_req.headers,
-                      {'content-type':
-                       'multipart/related; boundary="---flubber"'})
-    self.assertEquals(new_req.uri, 'http://example.com')
-    self.assertEquals(new_req.body, '{}')
-    self.assertEquals(new_req.http, http)
-    self.assertEquals(new_req.resumable.to_json(), media_upload.to_json())
+    self.assertEqual({'content-type':
+                       'multipart/related; boundary="---flubber"'},
+                       new_req.headers)
+    self.assertEqual('http://example.com', new_req.uri)
+    self.assertEqual('{}', new_req.body)
+    self.assertEqual(http, new_req.http)
+    self.assertEqual(media_upload.to_json(), new_req.resumable.to_json())
 
 EXPECTED = """POST /someapi/v1/collection/?foo=bar HTTP/1.1
 Content-Type: application/json
@@ -153,6 +199,50 @@
 --batch_foobarbaz--"""
 
 
+BATCH_RESPONSE_WITH_401 = """--batch_foobarbaz
+Content-Type: application/http
+Content-Transfer-Encoding: binary
+Content-ID: <randomness+1>
+
+HTTP/1.1 401 Authoration Required
+Content-Type application/json
+Content-Length: 14
+ETag: "etag/pony"\r\n\r\n{"error": {"message":
+  "Authorizaton failed."}}
+
+--batch_foobarbaz
+Content-Type: application/http
+Content-Transfer-Encoding: binary
+Content-ID: <randomness+2>
+
+HTTP/1.1 200 OK
+Content-Type application/json
+Content-Length: 14
+ETag: "etag/sheep"\r\n\r\n{"baz": "qux"}
+--batch_foobarbaz--"""
+
+
+BATCH_SINGLE_RESPONSE = """--batch_foobarbaz
+Content-Type: application/http
+Content-Transfer-Encoding: binary
+Content-ID: <randomness+1>
+
+HTTP/1.1 200 OK
+Content-Type application/json
+Content-Length: 14
+ETag: "etag/pony"\r\n\r\n{"foo": 42}
+--batch_foobarbaz--"""
+
+class Callbacks(object):
+  def __init__(self):
+    self.responses = {}
+    self.exceptions = {}
+
+  def f(self, request_id, response, exception):
+    self.responses[request_id] = response
+    self.exceptions[request_id] = exception
+
+
 class TestBatch(unittest.TestCase):
 
   def setUp(self):
@@ -196,7 +286,7 @@
         methodId=None,
         resumable=None)
     s = batch._serialize_request(request).splitlines()
-    self.assertEquals(s, EXPECTED.splitlines())
+    self.assertEqual(EXPECTED.splitlines(), s)
 
   def test_serialize_request_media_body(self):
     batch = BatchHttpRequest()
@@ -213,9 +303,9 @@
         headers={'content-type': 'application/json'},
         methodId=None,
         resumable=None)
+    # Just testing it shouldn't raise an exception.
     s = batch._serialize_request(request).splitlines()
 
-
   def test_serialize_request_no_body(self):
     batch = BatchHttpRequest()
     request = HttpRequest(
@@ -228,30 +318,30 @@
         methodId=None,
         resumable=None)
     s = batch._serialize_request(request).splitlines()
-    self.assertEquals(s, NO_BODY_EXPECTED.splitlines())
+    self.assertEqual(NO_BODY_EXPECTED.splitlines(), s)
 
   def test_deserialize_response(self):
     batch = BatchHttpRequest()
     resp, content = batch._deserialize_response(RESPONSE)
 
-    self.assertEquals(resp.status, 200)
-    self.assertEquals(resp.reason, 'OK')
-    self.assertEquals(resp.version, 11)
-    self.assertEquals(content, '{"answer": 42}')
+    self.assertEqual(200, resp.status)
+    self.assertEqual('OK', resp.reason)
+    self.assertEqual(11, resp.version)
+    self.assertEqual('{"answer": 42}', content)
 
   def test_new_id(self):
     batch = BatchHttpRequest()
 
     id_ = batch._new_id()
-    self.assertEquals(id_, '1')
+    self.assertEqual('1', id_)
 
     id_ = batch._new_id()
-    self.assertEquals(id_, '2')
+    self.assertEqual('2', id_)
 
     batch.add(self.request1, request_id='3')
 
     id_ = batch._new_id()
-    self.assertEquals(id_, '4')
+    self.assertEqual('4', id_)
 
   def test_add(self):
     batch = BatchHttpRequest()
@@ -267,13 +357,6 @@
     self.assertRaises(BatchError, batch.add, self.request1, request_id='1')
 
   def test_execute(self):
-    class Callbacks(object):
-      def __init__(self):
-        self.responses = {}
-
-      def f(self, request_id, response):
-        self.responses[request_id] = response
-
     batch = BatchHttpRequest()
     callbacks = Callbacks()
 
@@ -285,8 +368,10 @@
        BATCH_RESPONSE),
       ])
     batch.execute(http)
-    self.assertEqual(callbacks.responses['1'], {'foo': 42})
-    self.assertEqual(callbacks.responses['2'], {'baz': 'qux'})
+    self.assertEqual({'foo': 42}, callbacks.responses['1'])
+    self.assertEqual(None, callbacks.exceptions['1'])
+    self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
+    self.assertEqual(None, callbacks.exceptions['2'])
 
   def test_execute_request_body(self):
     batch = BatchHttpRequest()
@@ -311,14 +396,82 @@
       header = parts[1].splitlines()[1]
       self.assertEqual('Content-Type: application/http', header)
 
+  def test_execute_refresh_and_retry_on_401(self):
+    batch = BatchHttpRequest()
+    callbacks = Callbacks()
+    cred_1 = MockCredentials('Foo')
+    cred_2 = MockCredentials('Bar')
+
+    http = HttpMockSequence([
+      ({'status': '200',
+        'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
+       BATCH_RESPONSE_WITH_401),
+      ({'status': '200',
+        'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
+       BATCH_SINGLE_RESPONSE),
+      ])
+
+    creds_http_1 = HttpMockSequence([])
+    cred_1.authorize(creds_http_1)
+
+    creds_http_2 = HttpMockSequence([])
+    cred_2.authorize(creds_http_2)
+
+    self.request1.http = creds_http_1
+    self.request2.http = creds_http_2
+
+    batch.add(self.request1, callback=callbacks.f)
+    batch.add(self.request2, callback=callbacks.f)
+    batch.execute(http)
+
+    self.assertEqual({'foo': 42}, callbacks.responses['1'])
+    self.assertEqual(None, callbacks.exceptions['1'])
+    self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
+    self.assertEqual(None, callbacks.exceptions['2'])
+
+    self.assertEqual(1, cred_1._refreshed)
+    self.assertEqual(0, cred_2._refreshed)
+
+    self.assertEqual(1, cred_1._authorized)
+    self.assertEqual(1, cred_2._authorized)
+
+    self.assertEqual(1, cred_2._applied)
+    self.assertEqual(2, cred_1._applied)
+
+  def test_http_errors_passed_to_callback(self):
+    batch = BatchHttpRequest()
+    callbacks = Callbacks()
+    cred_1 = MockCredentials('Foo')
+    cred_2 = MockCredentials('Bar')
+
+    http = HttpMockSequence([
+      ({'status': '200',
+        'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
+       BATCH_RESPONSE_WITH_401),
+      ({'status': '200',
+        'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
+       BATCH_RESPONSE_WITH_401),
+      ])
+
+    creds_http_1 = HttpMockSequence([])
+    cred_1.authorize(creds_http_1)
+
+    creds_http_2 = HttpMockSequence([])
+    cred_2.authorize(creds_http_2)
+
+    self.request1.http = creds_http_1
+    self.request2.http = creds_http_2
+
+    batch.add(self.request1, callback=callbacks.f)
+    batch.add(self.request2, callback=callbacks.f)
+    batch.execute(http)
+
+    self.assertEqual(None, callbacks.responses['1'])
+    self.assertEqual(401, callbacks.exceptions['1'].resp.status)
+    self.assertEqual({u'baz': u'qux'}, callbacks.responses['2'])
+    self.assertEqual(None, callbacks.exceptions['2'])
+
   def test_execute_global_callback(self):
-    class Callbacks(object):
-      def __init__(self):
-        self.responses = {}
-
-      def f(self, request_id, response):
-        self.responses[request_id] = response
-
     callbacks = Callbacks()
     batch = BatchHttpRequest(callback=callbacks.f)
 
@@ -330,8 +483,8 @@
        BATCH_RESPONSE),
       ])
     batch.execute(http)
-    self.assertEqual(callbacks.responses['1'], {'foo': 42})
-    self.assertEqual(callbacks.responses['2'], {'baz': 'qux'})
+    self.assertEqual({'foo': 42}, callbacks.responses['1'])
+    self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
 
   def test_media_inmemory_upload(self):
     media = MediaInMemoryUpload('abcdef', 'text/plain', chunksize=10,
diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py
index d50d6a0..3e512a1 100644
--- a/tests/test_oauth2client.py
+++ b/tests/test_oauth2client.py
@@ -70,7 +70,7 @@
       ])
     http = self.credentials.authorize(http)
     resp, content = http.request("http://example.com")
-    self.assertEqual(content['authorization'], 'OAuth 1/3w')
+    self.assertEqual('Bearer 1/3w', content['Authorization'])
 
   def test_token_refresh_failure(self):
     http = HttpMockSequence([
@@ -95,11 +95,11 @@
   def test_to_from_json(self):
     json = self.credentials.to_json()
     instance = OAuth2Credentials.from_json(json)
-    self.assertEquals(type(instance), OAuth2Credentials)
+    self.assertEqual(OAuth2Credentials, type(instance))
     instance.token_expiry = None
     self.credentials.token_expiry = None
 
-    self.assertEquals(self.credentials.__dict__, instance.__dict__)
+    self.assertEqual(instance.__dict__, self.credentials.__dict__)
 
 
 class AccessTokenCredentialsTests(unittest.TestCase):
@@ -136,7 +136,7 @@
       ])
     http = self.credentials.authorize(http)
     resp, content = http.request('http://example.com')
-    self.assertEqual(content['authorization'], 'OAuth foo')
+    self.assertEqual('Bearer foo', content['Authorization'])
 
 
 class TestAssertionCredentials(unittest.TestCase):
@@ -155,8 +155,8 @@
 
   def test_assertion_body(self):
     body = urlparse.parse_qs(self.credentials._generate_refresh_request_body())
-    self.assertEqual(body['assertion'][0], self.assertion_text)
-    self.assertEqual(body['assertion_type'][0], self.assertion_type)
+    self.assertEqual(self.assertion_text, body['assertion'][0])
+    self.assertEqual(self.assertion_type, body['assertion_type'][0])
 
   def test_assertion_refresh(self):
     http = HttpMockSequence([
@@ -165,7 +165,7 @@
       ])
     http = self.credentials.authorize(http)
     resp, content = http.request("http://example.com")
-    self.assertEqual(content['authorization'], 'OAuth 1/3w')
+    self.assertEqual('Bearer 1/3w', content['Authorization'])
 
 
 class ExtractIdTokenText(unittest.TestCase):
@@ -177,7 +177,7 @@
     jwt = 'stuff.' + payload + '.signature'
 
     extracted = _extract_id_token(jwt)
-    self.assertEqual(body, extracted)
+    self.assertEqual(extracted, body)
 
   def test_extract_failure(self):
     body = {'foo': 'bar'}
@@ -201,11 +201,11 @@
 
     parsed = urlparse.urlparse(authorize_url)
     q = parse_qs(parsed[4])
-    self.assertEqual(q['client_id'][0], 'client_id+1')
-    self.assertEqual(q['response_type'][0], 'code')
-    self.assertEqual(q['scope'][0], 'foo')
-    self.assertEqual(q['redirect_uri'][0], 'OOB_CALLBACK_URN')
-    self.assertEqual(q['access_type'][0], 'offline')
+    self.assertEqual('client_id+1', q['client_id'][0])
+    self.assertEqual('code', q['response_type'][0])
+    self.assertEqual('foo', q['scope'][0])
+    self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0])
+    self.assertEqual('offline', q['access_type'][0])
 
   def test_override_flow_access_type(self):
     """Passing access_type overrides the default."""
@@ -220,11 +220,11 @@
 
     parsed = urlparse.urlparse(authorize_url)
     q = parse_qs(parsed[4])
-    self.assertEqual(q['client_id'][0], 'client_id+1')
-    self.assertEqual(q['response_type'][0], 'code')
-    self.assertEqual(q['scope'][0], 'foo')
-    self.assertEqual(q['redirect_uri'][0], 'OOB_CALLBACK_URN')
-    self.assertEqual(q['access_type'][0], 'online')
+    self.assertEqual('client_id+1', q['client_id'][0])
+    self.assertEqual('code', q['response_type'][0])
+    self.assertEqual('foo', q['scope'][0])
+    self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0])
+    self.assertEqual('online', q['access_type'][0])
 
   def test_exchange_failure(self):
     http = HttpMockSequence([
@@ -246,9 +246,9 @@
       ])
 
     credentials = self.flow.step2_exchange('some random code', http)
-    self.assertEqual(credentials.access_token, 'SlAV32hkKG')
-    self.assertNotEqual(credentials.token_expiry, None)
-    self.assertEqual(credentials.refresh_token, '8xLOxBtZp8')
+    self.assertEqual('SlAV32hkKG', credentials.access_token)
+    self.assertNotEqual(None, credentials.token_expiry)
+    self.assertEqual('8xLOxBtZp8', credentials.refresh_token)
 
   def test_exchange_no_expires_in(self):
     http = HttpMockSequence([
@@ -257,7 +257,7 @@
       ])
 
     credentials = self.flow.step2_exchange('some random code', http)
-    self.assertEqual(credentials.token_expiry, None)
+    self.assertEqual(None, credentials.token_expiry)
 
   def test_exchange_id_token_fail(self):
     http = HttpMockSequence([
@@ -282,7 +282,7 @@
       ])
 
     credentials = self.flow.step2_exchange('some random code', http)
-    self.assertEquals(body, credentials.id_token)
+    self.assertEqual(credentials.id_token, body)
 
 
 if __name__ == '__main__':
diff --git a/tests/test_oauth2client_jwt.py b/tests/test_oauth2client_jwt.py
index 83551c6..dcbb33c 100644
--- a/tests/test_oauth2client_jwt.py
+++ b/tests/test_oauth2client_jwt.py
@@ -100,8 +100,8 @@
     certs = {'foo': public_key }
     audience = 'some_audience_address@testing.gserviceaccount.com'
     contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
-    self.assertEquals('billy bob', contents['user'])
-    self.assertEquals('data', contents['metadata']['meta'])
+    self.assertEqual('billy bob', contents['user'])
+    self.assertEqual('data', contents['metadata']['meta'])
 
   def test_verify_id_token_with_certs_uri(self):
     jwt = self._create_signed_jwt()
@@ -112,8 +112,8 @@
 
     contents = verify_id_token(jwt,
         'some_audience_address@testing.gserviceaccount.com', http)
-    self.assertEquals('billy bob', contents['user'])
-    self.assertEquals('data', contents['metadata']['meta'])
+    self.assertEqual('billy bob', contents['user'])
+    self.assertEqual('data', contents['metadata']['meta'])
 
   def test_verify_id_token_with_certs_uri_fails(self):
     jwt = self._create_signed_jwt()
@@ -195,7 +195,7 @@
       ])
     http = credentials.authorize(http)
     resp, content = http.request('http://example.org')
-    self.assertEquals(content['authorization'], 'OAuth 1/3w')
+    self.assertEqual('Bearer 1/3w', content['Authorization'])
 
 if __name__ == '__main__':
   unittest.main()