Fix bugs with auth, batch and retries.

Reviewed in http://codereview.appspot.com/5633052/.
diff --git a/apiclient/http.py b/apiclient/http.py
index 94eb266..ff61cb1 100644
--- a/apiclient/http.py
+++ b/apiclient/http.py
@@ -37,6 +37,7 @@
 import urlparse
 import uuid
 
+from email.generator import Generator
 from email.mime.multipart import MIMEMultipart
 from email.mime.nonmultipart import MIMENonMultipart
 from email.parser import FeedParser
@@ -498,9 +499,12 @@
     # Global callback to be called for each individual response in the batch.
     self._callback = callback
 
-    # A map from id to (request, callback) pairs.
+    # A map from id to request.
     self._requests = {}
 
+    # A map from id to callback.
+    self._callbacks = {}
+
     # List of request ids, in the order in which they were added.
     self._order = []
 
@@ -510,6 +514,39 @@
     # Unique ID on which to base the Content-ID headers.
     self._base_id = None
 
+    # A map from request id to (headers, content) response pairs
+    self._responses = {}
+
+    # A map of id(Credentials) that have been refreshed.
+    self._refreshed_credentials = {}
+
+  def _refresh_and_apply_credentials(self, request, http):
+    """Refresh the credentials and apply to the request.
+
+    Args:
+      request: HttpRequest, the request.
+      http: httplib2.Http, the global http object for the batch.
+    """
+    # For the credentials to refresh, but only once per refresh_token
+    # If there is no http per the request then refresh the http passed in
+    # via execute()
+    creds = None
+    if request.http is not None and hasattr(request.http.request,
+        'credentials'):
+      creds = request.http.request.credentials
+    elif http is not None and hasattr(http.request, 'credentials'):
+      creds = http.request.credentials
+    if creds is not None:
+      if id(creds) not in self._refreshed_credentials:
+        creds.refresh(http)
+        self._refreshed_credentials[id(creds)] = 1
+
+    # Only apply the credentials if we are using the http object passed in,
+    # otherwise apply() will get called during _serialize_request().
+    if request.http is None or not hasattr(request.http.request,
+        'credentials'):
+      creds.apply(request.headers)
+
   def _id_to_header(self, id_):
     """Convert an id to a Content-ID header value.
 
@@ -568,6 +605,10 @@
     msg = MIMENonMultipart(major, minor)
     headers = request.headers.copy()
 
+    if request.http is not None and hasattr(request.http.request,
+        'credentials'):
+      request.http.request.credentials.apply(headers)
+
     # MIMENonMultipart adds its own Content-Type header.
     if 'content-type' in headers:
       del headers['content-type']
@@ -581,7 +622,13 @@
       msg.set_payload(request.body)
       msg['content-length'] = str(len(request.body))
 
-    body = msg.as_string(False)
+    # Serialize the mime message.
+    fp = StringIO.StringIO()
+    # maxheaderlen=0 means don't line wrap headers.
+    g = Generator(fp, maxheaderlen=0)
+    g.flatten(msg, unixfrom=False)
+    body = fp.getvalue()
+
     # Strip off the \n\n that the MIME lib tacks onto the end of the payload.
     if request.body is None:
       body = body[:-2]
@@ -661,9 +708,71 @@
       raise BatchError("Resumable requests cannot be used in a batch request.")
     if request_id in self._requests:
       raise KeyError("A request with this ID already exists: %s" % request_id)
-    self._requests[request_id] = (request, callback)
+    self._requests[request_id] = request
+    self._callbacks[request_id] = callback
     self._order.append(request_id)
 
+  def _execute(self, http, order, requests):
+    """Serialize batch request, send to server, process response.
+
+    Args:
+      http: httplib2.Http, an http object to be used to make the request with.
+      order: list, list of request ids in the order they were added to the
+        batch.
+      request: list, list of request objects to send.
+
+    Raises:
+      httplib2.Error if a transport error has occured.
+      apiclient.errors.BatchError if the response is the wrong format.
+    """
+    message = MIMEMultipart('mixed')
+    # Message should not write out it's own headers.
+    setattr(message, '_write_headers', lambda self: None)
+
+    # Add all the individual requests.
+    for request_id in order:
+      request = requests[request_id]
+
+      msg = MIMENonMultipart('application', 'http')
+      msg['Content-Transfer-Encoding'] = 'binary'
+      msg['Content-ID'] = self._id_to_header(request_id)
+
+      body = self._serialize_request(request)
+      msg.set_payload(body)
+      message.attach(msg)
+
+    body = message.as_string()
+
+    headers = {}
+    headers['content-type'] = ('multipart/mixed; '
+                               'boundary="%s"') % message.get_boundary()
+
+    resp, content = http.request(self._batch_uri, 'POST', body=body,
+                                 headers=headers)
+
+    if resp.status >= 300:
+      raise HttpError(resp, content, self._batch_uri)
+
+    # Now break out the individual responses and store each one.
+    boundary, _ = content.split(None, 1)
+
+    # Prepend with a content-type header so FeedParser can handle it.
+    header = 'content-type: %s\r\n\r\n' % resp['content-type']
+    for_parser = header + content
+
+    parser = FeedParser()
+    parser.feed(for_parser)
+    mime_response = parser.close()
+
+    if not mime_response.is_multipart():
+      raise BatchError("Response not in multipart/mixed format.", resp,
+          content)
+
+    for part in mime_response.get_payload():
+      request_id = self._header_to_id(part['Content-ID'])
+      headers, content = self._deserialize_response(part.get_payload())
+      self._responses[request_id] = (headers, content)
+
   def execute(self, http=None):
     """Execute all the requests as a single batched HTTP request.
 
@@ -676,84 +785,61 @@
       None
 
     Raises:
-      apiclient.errors.HttpError if the response was not a 2xx.
       httplib2.Error if a transport error has occured.
       apiclient.errors.BatchError if the response is the wrong format.
     """
+
+    # If http is not supplied use the first valid one given in the requests.
     if http is None:
       for request_id in self._order:
-        request, callback = self._requests[request_id]
+        request = self._requests[request_id]
         if request is not None:
           http = request.http
           break
+
     if http is None:
       raise ValueError("Missing a valid http object.")
 
+    self._execute(http, self._order, self._requests)
 
-    msgRoot = MIMEMultipart('mixed')
-    # msgRoot should not write out it's own headers
-    setattr(msgRoot, '_write_headers', lambda self: None)
+    # Loop over all the requests and check for 401s. For each 401 request the
+    # credentials should be refreshed and then sent again in a separate batch.
+    redo_requests = {}
+    redo_order = []
 
-    # Add all the individual requests.
     for request_id in self._order:
-      request, callback = self._requests[request_id]
+      headers, content = self._responses[request_id]
+      if headers['status'] == '401':
+        redo_order.append(request_id)
+        request = self._requests[request_id]
+        self._refresh_and_apply_credentials(request, http)
+        redo_requests[request_id] = request
 
-      msg = MIMENonMultipart('application', 'http')
-      msg['Content-Transfer-Encoding'] = 'binary'
-      msg['Content-ID'] = self._id_to_header(request_id)
+    if redo_requests:
+      self._execute(http, redo_order, redo_requests)
 
-      body = self._serialize_request(request)
-      msg.set_payload(body)
-      msgRoot.attach(msg)
+    # Now process all callbacks that are erroring, and raise an exception for
+    # ones that return a non-2xx response? Or add extra parameter to callback
+    # that contains an HttpError?
 
-    body = msgRoot.as_string()
+    for request_id in self._order:
+      headers, content = self._responses[request_id]
 
-    headers = {}
-    headers['content-type'] = ('multipart/mixed; '
-                               'boundary="%s"') % msgRoot.get_boundary()
+      request = self._requests[request_id]
+      callback = self._callbacks[request_id]
 
-    resp, content = http.request(self._batch_uri, 'POST', body=body,
-                                 headers=headers)
+      response = None
+      exception = None
+      try:
+        r = httplib2.Response(headers)
+        response = request.postproc(r, content)
+      except HttpError, e:
+        exception = e
 
-    if resp.status >= 300:
-      raise HttpError(resp, content, self._batch_uri)
-
-    # Now break up the response and process each one with the correct postproc
-    # and trigger the right callbacks.
-    boundary, _ = content.split(None, 1)
-
-    # Prepend with a content-type header so FeedParser can handle it.
-    header = 'content-type: %s\r\n\r\n' % resp['content-type']
-    for_parser = header + content
-
-    parser = FeedParser()
-    parser.feed(for_parser)
-    respRoot = parser.close()
-
-    if not respRoot.is_multipart():
-      raise BatchError("Response not in multipart/mixed format.", resp,
-          content)
-
-    parts = respRoot.get_payload()
-    for part in parts:
-      request_id = self._header_to_id(part['Content-ID'])
-
-      headers, content = self._deserialize_response(part.get_payload())
-
-      # TODO(jcgregorio) Remove this temporary hack once the server stops
-      # gzipping individual response bodies.
-      if content[0] != '{':
-        gzipped_content = content
-        content = gzip.GzipFile(
-            fileobj=StringIO.StringIO(gzipped_content)).read()
-
-      request, cb = self._requests[request_id]
-      postproc = request.postproc
-      response = postproc(resp, content)
-      if cb is not None:
-        cb(request_id, response)
+      if callback is not None:
+        callback(request_id, response, exception)
       if self._callback is not None:
-        self._callback(request_id, response)
+        self._callback(request_id, response, exception)
 
 
 class HttpRequestMock(object):