Fix bugs with auth, batch and retries.
Reviewed in http://codereview.appspot.com/5633052/.
diff --git a/apiclient/http.py b/apiclient/http.py
index 94eb266..ff61cb1 100644
--- a/apiclient/http.py
+++ b/apiclient/http.py
@@ -37,6 +37,7 @@
import urlparse
import uuid
+from email.generator import Generator
from email.mime.multipart import MIMEMultipart
from email.mime.nonmultipart import MIMENonMultipart
from email.parser import FeedParser
@@ -498,9 +499,12 @@
# Global callback to be called for each individual response in the batch.
self._callback = callback
- # A map from id to (request, callback) pairs.
+ # A map from id to request.
self._requests = {}
+ # A map from id to callback.
+ self._callbacks = {}
+
# List of request ids, in the order in which they were added.
self._order = []
@@ -510,6 +514,39 @@
# Unique ID on which to base the Content-ID headers.
self._base_id = None
+ # A map from request id to (headers, content) response pairs
+ self._responses = {}
+
+ # A map of id(Credentials) that have been refreshed.
+ self._refreshed_credentials = {}
+
+ def _refresh_and_apply_credentials(self, request, http):
+ """Refresh the credentials and apply to the request.
+
+ Args:
+ request: HttpRequest, the request.
+ http: httplib2.Http, the global http object for the batch.
+ """
+ # For the credentials to refresh, but only once per refresh_token
+ # If there is no http per the request then refresh the http passed in
+ # via execute()
+ creds = None
+ if request.http is not None and hasattr(request.http.request,
+ 'credentials'):
+ creds = request.http.request.credentials
+ elif http is not None and hasattr(http.request, 'credentials'):
+ creds = http.request.credentials
+ if creds is not None:
+ if id(creds) not in self._refreshed_credentials:
+ creds.refresh(http)
+ self._refreshed_credentials[id(creds)] = 1
+
+ # Only apply the credentials if we are using the http object passed in,
+ # otherwise apply() will get called during _serialize_request().
+ if request.http is None or not hasattr(request.http.request,
+ 'credentials'):
+ creds.apply(request.headers)
+
def _id_to_header(self, id_):
"""Convert an id to a Content-ID header value.
@@ -568,6 +605,10 @@
msg = MIMENonMultipart(major, minor)
headers = request.headers.copy()
+ if request.http is not None and hasattr(request.http.request,
+ 'credentials'):
+ request.http.request.credentials.apply(headers)
+
# MIMENonMultipart adds its own Content-Type header.
if 'content-type' in headers:
del headers['content-type']
@@ -581,7 +622,13 @@
msg.set_payload(request.body)
msg['content-length'] = str(len(request.body))
- body = msg.as_string(False)
+ # Serialize the mime message.
+ fp = StringIO.StringIO()
+ # maxheaderlen=0 means don't line wrap headers.
+ g = Generator(fp, maxheaderlen=0)
+ g.flatten(msg, unixfrom=False)
+ body = fp.getvalue()
+
# Strip off the \n\n that the MIME lib tacks onto the end of the payload.
if request.body is None:
body = body[:-2]
@@ -661,9 +708,71 @@
raise BatchError("Resumable requests cannot be used in a batch request.")
if request_id in self._requests:
raise KeyError("A request with this ID already exists: %s" % request_id)
- self._requests[request_id] = (request, callback)
+ self._requests[request_id] = request
+ self._callbacks[request_id] = callback
self._order.append(request_id)
+ def _execute(self, http, order, requests):
+ """Serialize batch request, send to server, process response.
+
+ Args:
+ http: httplib2.Http, an http object to be used to make the request with.
+ order: list, list of request ids in the order they were added to the
+ batch.
+ request: list, list of request objects to send.
+
+ Raises:
+ httplib2.Error if a transport error has occured.
+ apiclient.errors.BatchError if the response is the wrong format.
+ """
+ message = MIMEMultipart('mixed')
+ # Message should not write out it's own headers.
+ setattr(message, '_write_headers', lambda self: None)
+
+ # Add all the individual requests.
+ for request_id in order:
+ request = requests[request_id]
+
+ msg = MIMENonMultipart('application', 'http')
+ msg['Content-Transfer-Encoding'] = 'binary'
+ msg['Content-ID'] = self._id_to_header(request_id)
+
+ body = self._serialize_request(request)
+ msg.set_payload(body)
+ message.attach(msg)
+
+ body = message.as_string()
+
+ headers = {}
+ headers['content-type'] = ('multipart/mixed; '
+ 'boundary="%s"') % message.get_boundary()
+
+ resp, content = http.request(self._batch_uri, 'POST', body=body,
+ headers=headers)
+
+ if resp.status >= 300:
+ raise HttpError(resp, content, self._batch_uri)
+
+ # Now break out the individual responses and store each one.
+ boundary, _ = content.split(None, 1)
+
+ # Prepend with a content-type header so FeedParser can handle it.
+ header = 'content-type: %s\r\n\r\n' % resp['content-type']
+ for_parser = header + content
+
+ parser = FeedParser()
+ parser.feed(for_parser)
+ mime_response = parser.close()
+
+ if not mime_response.is_multipart():
+ raise BatchError("Response not in multipart/mixed format.", resp,
+ content)
+
+ for part in mime_response.get_payload():
+ request_id = self._header_to_id(part['Content-ID'])
+ headers, content = self._deserialize_response(part.get_payload())
+ self._responses[request_id] = (headers, content)
+
def execute(self, http=None):
"""Execute all the requests as a single batched HTTP request.
@@ -676,84 +785,61 @@
None
Raises:
- apiclient.errors.HttpError if the response was not a 2xx.
httplib2.Error if a transport error has occured.
apiclient.errors.BatchError if the response is the wrong format.
"""
+
+ # If http is not supplied use the first valid one given in the requests.
if http is None:
for request_id in self._order:
- request, callback = self._requests[request_id]
+ request = self._requests[request_id]
if request is not None:
http = request.http
break
+
if http is None:
raise ValueError("Missing a valid http object.")
+ self._execute(http, self._order, self._requests)
- msgRoot = MIMEMultipart('mixed')
- # msgRoot should not write out it's own headers
- setattr(msgRoot, '_write_headers', lambda self: None)
+ # Loop over all the requests and check for 401s. For each 401 request the
+ # credentials should be refreshed and then sent again in a separate batch.
+ redo_requests = {}
+ redo_order = []
- # Add all the individual requests.
for request_id in self._order:
- request, callback = self._requests[request_id]
+ headers, content = self._responses[request_id]
+ if headers['status'] == '401':
+ redo_order.append(request_id)
+ request = self._requests[request_id]
+ self._refresh_and_apply_credentials(request, http)
+ redo_requests[request_id] = request
- msg = MIMENonMultipart('application', 'http')
- msg['Content-Transfer-Encoding'] = 'binary'
- msg['Content-ID'] = self._id_to_header(request_id)
+ if redo_requests:
+ self._execute(http, redo_order, redo_requests)
- body = self._serialize_request(request)
- msg.set_payload(body)
- msgRoot.attach(msg)
+ # Now process all callbacks that are erroring, and raise an exception for
+ # ones that return a non-2xx response? Or add extra parameter to callback
+ # that contains an HttpError?
- body = msgRoot.as_string()
+ for request_id in self._order:
+ headers, content = self._responses[request_id]
- headers = {}
- headers['content-type'] = ('multipart/mixed; '
- 'boundary="%s"') % msgRoot.get_boundary()
+ request = self._requests[request_id]
+ callback = self._callbacks[request_id]
- resp, content = http.request(self._batch_uri, 'POST', body=body,
- headers=headers)
+ response = None
+ exception = None
+ try:
+ r = httplib2.Response(headers)
+ response = request.postproc(r, content)
+ except HttpError, e:
+ exception = e
- if resp.status >= 300:
- raise HttpError(resp, content, self._batch_uri)
-
- # Now break up the response and process each one with the correct postproc
- # and trigger the right callbacks.
- boundary, _ = content.split(None, 1)
-
- # Prepend with a content-type header so FeedParser can handle it.
- header = 'content-type: %s\r\n\r\n' % resp['content-type']
- for_parser = header + content
-
- parser = FeedParser()
- parser.feed(for_parser)
- respRoot = parser.close()
-
- if not respRoot.is_multipart():
- raise BatchError("Response not in multipart/mixed format.", resp,
- content)
-
- parts = respRoot.get_payload()
- for part in parts:
- request_id = self._header_to_id(part['Content-ID'])
-
- headers, content = self._deserialize_response(part.get_payload())
-
- # TODO(jcgregorio) Remove this temporary hack once the server stops
- # gzipping individual response bodies.
- if content[0] != '{':
- gzipped_content = content
- content = gzip.GzipFile(
- fileobj=StringIO.StringIO(gzipped_content)).read()
-
- request, cb = self._requests[request_id]
- postproc = request.postproc
- response = postproc(resp, content)
- if cb is not None:
- cb(request_id, response)
+ if callback is not None:
+ callback(request_id, response, exception)
if self._callback is not None:
- self._callback(request_id, response)
+ self._callback(request_id, response, exception)
class HttpRequestMock(object):