Add support for batch requests.

Reviewed in http://codereview.appspot.com/5434059/
diff --git a/apiclient/discovery.py b/apiclient/discovery.py
index 9e5b230..a3d4beb 100644
--- a/apiclient/discovery.py
+++ b/apiclient/discovery.py
@@ -160,10 +160,10 @@
 
   requested_url = uritemplate.expand(discoveryServiceUrl, params)
 
-  # REMOTE_ADDR is defined by the CGI spec [RFC3875] as the environment variable
-  # that contains the network address of the client sending the request. If it
-  # exists then add that to the request for the discovery document to avoid
-  # exceeding the quota on discovery requests.
+  # REMOTE_ADDR is defined by the CGI spec [RFC3875] as the environment
+  # variable that contains the network address of the client sending the
+  # request. If it exists then add that to the request for the discovery
+  # document to avoid exceeding the quota on discovery requests.
   if 'REMOTE_ADDR' in os.environ:
     requested_url = _add_query_parameter(requested_url, 'userIp',
                                          os.environ['REMOTE_ADDR'])
@@ -459,8 +459,7 @@
         elif isinstance(media_filename, MediaUpload):
           media_upload = media_filename
         else:
-          raise TypeError(
-              'media_filename must be str or MediaUpload. Got %s' % type(media_upload))
+          raise TypeError('media_filename must be str or MediaUpload.')
 
         if media_upload.resumable():
           resumable = media_upload
diff --git a/apiclient/errors.py b/apiclient/errors.py
index 30a48e8..0d420df 100644
--- a/apiclient/errors.py
+++ b/apiclient/errors.py
@@ -70,6 +70,7 @@
   """Link type unknown or unexpected."""
   pass
 
+
 class UnknownApiNameOrVersion(Error):
   """No API with that name and version exists."""
   pass
@@ -90,6 +91,11 @@
   pass
 
 
+class BatchError(Error):
+  """Error occured during batch operations."""
+  pass
+
+
 class UnexpectedMethodError(Error):
   """Exception raised by RequestMockBuilder on unexpected calls."""
 
diff --git a/apiclient/http.py b/apiclient/http.py
index 0b45a44..333461e 100644
--- a/apiclient/http.py
+++ b/apiclient/http.py
@@ -25,18 +25,27 @@
     'set_user_agent', 'tunnel_patch'
     ]
 
+import StringIO
 import copy
+import gzip
 import httplib2
-import os
 import mimeparse
 import mimetypes
+import os
+import urllib
+import urlparse
+import uuid
 
-from model import JsonModel
+from anyjson import simplejson
+from email.mime.multipart import MIMEMultipart
+from email.mime.nonmultipart import MIMENonMultipart
+from email.parser import FeedParser
+from errors import BatchError
 from errors import HttpError
 from errors import ResumableUploadError
 from errors import UnexpectedBodyError
 from errors import UnexpectedMethodError
-from anyjson import simplejson
+from model import JsonModel
 
 
 class MediaUploadProgress(object):
@@ -54,7 +63,7 @@
 
   def progress(self):
     """Percent of upload completed, as a float."""
-    return float(self.resumable_progress)/float(self.total_size)
+    return float(self.resumable_progress) / float(self.total_size)
 
 
 class MediaUpload(object):
@@ -126,6 +135,7 @@
     from_json = getattr(kls, 'from_json')
     return from_json(s)
 
+
 class MediaFileUpload(MediaUpload):
   """A MediaUpload for a file.
 
@@ -150,8 +160,8 @@
         guessed from the file extension.
       chunksize: int, File will be uploaded in chunks of this many bytes. Only
         used if resumable=True.
-      resumable: bool, True if this is a resumable upload. False means upload in
-        a single request.
+      resumable: bool, True if this is a resumable upload. False means upload
+        in a single request.
     """
     self._filename = filename
     self._size = os.path.getsize(filename)
@@ -207,8 +217,7 @@
 
 
 class HttpRequest(object):
-  """Encapsulates a single HTTP request.
-  """
+  """Encapsulates a single HTTP request."""
 
   def __init__(self, http, postproc, uri,
                method='GET',
@@ -239,6 +248,7 @@
     self.postproc = postproc
     self.resumable = resumable
 
+    # Pull the multipart boundary out of the content-type header.
     major, minor, params = mimeparse.parse_mime_type(
         headers.get('content-type', 'application/json'))
     self.multipart_boundary = params.get('boundary', '').strip('"')
@@ -252,12 +262,17 @@
     # The bytes that have been uploaded.
     self.resumable_progress = 0
 
+    self.total_size = 0
+
     if resumable is not None:
       if self.body is not None:
         self.multipart_size = len(self.body)
       else:
         self.multipart_size = 0
-      self.total_size = self.resumable.size() + self.multipart_size + len(self.multipart_boundary)
+      self.total_size = (
+          self.resumable.size() +
+          self.multipart_size +
+          len(self.multipart_boundary))
 
   def execute(self, http=None):
     """Execute the request.
@@ -293,13 +308,13 @@
   def next_chunk(self, http=None):
     """Execute the next step of a resumable upload.
 
-    Can only be used if the method being executed supports media uploads and the
-    MediaUpload object passed in was flagged as using resumable upload.
+    Can only be used if the method being executed supports media uploads and
+    the MediaUpload object passed in was flagged as using resumable upload.
 
     Example:
 
-      media = MediaFileUpload('smiley.png', mimetype='image/png', chunksize=1000,
-                              resumable=True)
+      media = MediaFileUpload('smiley.png', mimetype='image/png',
+                              chunksize=1000, resumable=True)
       request = service.objects().insert(
           bucket=buckets['items'][0]['id'],
           name='smiley.png',
@@ -351,8 +366,8 @@
                                  headers=headers)
     if resp.status in [200, 201]:
       return None, self.postproc(resp, content)
-    # A "308 Resume Incomplete" indicates we are not done.
     elif resp.status == 308:
+      # A "308 Resume Incomplete" indicates we are not done.
       self.resumable_progress = int(resp['range'].split('-')[1]) + 1
       if self.resumable_progress >= self.multipart_size:
         self.body = None
@@ -381,14 +396,288 @@
     return HttpRequest(
         http,
         postproc,
-        uri = d['uri'],
-        method= d['method'],
+        uri=d['uri'],
+        method=d['method'],
         body=d['body'],
         headers=d['headers'],
         methodId=d['methodId'],
         resumable=d['resumable'])
 
 
+class BatchHttpRequest(object):
+  """Batches multiple HttpRequest objects into a single HTTP request."""
+
+  def __init__(self, callback=None, batch_uri=None):
+    """Constructor for a BatchHttpRequest.
+
+    Args:
+      callback: callable, A callback to be called for each response, of the
+        form callback(id, response). The first parameter is the request id, and
+        the second is the deserialized response object.
+      batch_uri: string, URI to send batch requests to.
+    """
+    if batch_uri is None:
+      batch_uri = 'https://www.googleapis.com/batch'
+    self._batch_uri = batch_uri
+
+    # Global callback to be called for each individual response in the batch.
+    self._callback = callback
+
+    # A map from id to (request, callback) pairs.
+    self._requests = {}
+
+    # List of request ids, in the order in which they were added.
+    self._order = []
+
+    # The last auto generated id.
+    self._last_auto_id = 0
+
+    # Unique ID on which to base the Content-ID headers.
+    self._base_id = None
+
+  def _id_to_header(self, id_):
+    """Convert an id to a Content-ID header value.
+
+    Args:
+      id_: string, identifier of individual request.
+
+    Returns:
+      A Content-ID header with the id_ encoded into it. A UUID is prepended to
+      the value because Content-ID headers are supposed to be universally
+      unique.
+    """
+    if self._base_id is None:
+      self._base_id = uuid.uuid4()
+
+    return '<%s+%s>' % (self._base_id, urllib.quote(id_))
+
+  def _header_to_id(self, header):
+    """Convert a Content-ID header value to an id.
+
+    Presumes the Content-ID header conforms to the format that _id_to_header()
+    returns.
+
+    Args:
+      header: string, Content-ID header value.
+
+    Returns:
+      The extracted id value.
+
+    Raises:
+      BatchError if the header is not in the expected format.
+    """
+    if header[0] != '<' or header[-1] != '>':
+      raise BatchError("Invalid value for Content-ID: %s" % header)
+    if '+' not in header:
+      raise BatchError("Invalid value for Content-ID: %s" % header)
+    base, id_ = header[1:-1].rsplit('+', 1)
+
+    return urllib.unquote(id_)
+
+  def _serialize_request(self, request):
+    """Convert an HttpRequest object into a string.
+
+    Args:
+      request: HttpRequest, the request to serialize.
+
+    Returns:
+      The request as a string in application/http format.
+    """
+    # Construct status line
+    parsed = urlparse.urlparse(request.uri)
+    request_line = urlparse.urlunparse(
+        (None, None, parsed.path, parsed.params, parsed.query, None)
+        )
+    status_line = request.method + ' ' + request_line + ' HTTP/1.1\n'
+    major, minor = request.headers.get('content-type', 'text/plain').split('/')
+    msg = MIMENonMultipart(major, minor)
+    headers = request.headers.copy()
+
+    # MIMENonMultipart adds its own Content-Type header.
+    if 'content-type' in headers:
+      del headers['content-type']
+
+    for key, value in headers.iteritems():
+      msg[key] = value
+    msg['Host'] = parsed.netloc
+    msg.set_unixfrom(None)
+
+    if request.body is not None:
+      msg.set_payload(request.body)
+
+    body = msg.as_string(False)
+    # Strip off the \n\n that the MIME lib tacks onto the end of the payload.
+    if request.body is None:
+      body = body[:-2]
+
+    return status_line + body
+
+  def _deserialize_response(self, payload):
+    """Convert string into httplib2 response and content.
+
+    Args:
+      payload: string, headers and body as a string.
+
+    Returns:
+      A pair (resp, content) like would be returned from httplib2.request.
+    """
+    # Strip off the status line
+    status_line, payload = payload.split('\n', 1)
+    protocol, status, reason = status_line.split(' ')
+
+    # Parse the rest of the response
+    parser = FeedParser()
+    parser.feed(payload)
+    msg = parser.close()
+    msg['status'] = status
+
+    # Create httplib2.Response from the parsed headers.
+    resp = httplib2.Response(msg)
+    resp.reason = reason
+    resp.version = int(protocol.split('/', 1)[1].replace('.', ''))
+
+    content = payload.split('\r\n\r\n', 1)[1]
+
+    return resp, content
+
+  def _new_id(self):
+    """Create a new id.
+
+    Auto incrementing number that avoids conflicts with ids already used.
+
+    Returns:
+       string, a new unique id.
+    """
+    self._last_auto_id += 1
+    while str(self._last_auto_id) in self._requests:
+      self._last_auto_id += 1
+    return str(self._last_auto_id)
+
+  def add(self, request, callback=None, request_id=None):
+    """Add a new request.
+
+    Every callback added will be paired with a unique id, the request_id. That
+    unique id will be passed back to the callback when the response comes back
+    from the server. The default behavior is to have the library generate it's
+    own unique id. If the caller passes in a request_id then they must ensure
+    uniqueness for each request_id, and if they are not an exception is
+    raised. Callers should either supply all request_ids or nevery supply a
+    request id, to avoid such an error.
+
+    Args:
+      request: HttpRequest, Request to add to the batch.
+      callback: callable, A callback to be called for this response, of the
+        form callback(id, response). The first parameter is the request id, and
+        the second is the deserialized response object.
+      request_id: string, A unique id for the request. The id will be passed to
+        the callback with the response.
+
+    Returns:
+      None
+
+    Raises:
+      BatchError if a resumable request is added to a batch.
+      KeyError is the request_id is not unique.
+    """
+    if request_id is None:
+      request_id = self._new_id()
+    if request.resumable is not None:
+      raise BatchError("Resumable requests cannot be used in a batch request.")
+    if request_id in self._requests:
+      raise KeyError("A request with this ID already exists: %s" % request_id)
+    self._requests[request_id] = (request, callback)
+    self._order.append(request_id)
+
+  def execute(self, http=None):
+    """Execute all the requests as a single batched HTTP request.
+
+    Args:
+      http: httplib2.Http, an http object to be used in place of the one the
+        HttpRequest request object was constructed with.  If one isn't supplied
+        then use a http object from the requests in this batch.
+
+    Returns:
+      None
+
+    Raises:
+      apiclient.errors.HttpError if the response was not a 2xx.
+      httplib2.Error if a transport error has occured.
+    """
+    if http is None:
+      for request_id in self._order:
+        request, callback = self._requests[request_id]
+        if request is not None:
+          http = request.http
+          break
+    if http is None:
+      raise ValueError("Missing a valid http object.")
+
+
+    msgRoot = MIMEMultipart('mixed')
+    # msgRoot should not write out it's own headers
+    setattr(msgRoot, '_write_headers', lambda self: None)
+
+    # Add all the individual requests.
+    for request_id in self._order:
+      request, callback = self._requests[request_id]
+
+      msg = MIMENonMultipart('application', 'http')
+      msg['Content-Transfer-Encoding'] = 'binary'
+      msg['Content-ID'] = self._id_to_header(request_id)
+
+      body = self._serialize_request(request)
+      msg.set_payload(body)
+      msgRoot.attach(msg)
+
+    body = msgRoot.as_string()
+
+    headers = {}
+    headers['content-type'] = ('multipart/mixed; '
+                               'boundary="%s"') % msgRoot.get_boundary()
+
+    resp, content = http.request(self._batch_uri, 'POST', body=body,
+                                 headers=headers)
+
+    if resp.status >= 300:
+      raise HttpError(resp, content, self._batch_uri)
+
+    # Now break up the response and process each one with the correct postproc
+    # and trigger the right callbacks.
+    boundary, _ = content.split(None, 1)
+
+    # Prepend with a content-type header so FeedParser can handle it.
+    header = 'Content-Type: %s\r\n\r\n' % resp['content-type']
+    content = header + content
+
+    parser = FeedParser()
+    parser.feed(content)
+    respRoot = parser.close()
+
+    if not respRoot.is_multipart():
+      raise BatchError("Response not in multipart/mixed format.")
+
+    parts = respRoot.get_payload()
+    for part in parts:
+      request_id = self._header_to_id(part['Content-ID'])
+
+      headers, content = self._deserialize_response(part.get_payload())
+
+      # TODO(jcgregorio) Remove this temporary hack once the server stops
+      # gzipping individual response bodies.
+      if content[0] != '{':
+        gzipped_content = content
+        content = gzip.GzipFile(
+            fileobj=StringIO.StringIO(gzipped_content)).read()
+
+      request, cb = self._requests[request_id]
+      postproc = request.postproc
+      response = postproc(resp, content)
+      if cb is not None:
+        cb(request_id, response)
+      if self._callback is not None:
+        self._callback(request_id, response)
+
+
 class HttpRequestMock(object):
   """Mock of HttpRequest.
 
@@ -441,8 +730,8 @@
       apiclient.discovery.build("plus", "v1", requestBuilder=requestBuilder)
 
     Methods that you do not supply a response for will return a
-    200 OK with an empty string as the response content or raise an excpetion if
-    check_unexpected is set to True. The methodId is taken from the rpcName
+    200 OK with an empty string as the response content or raise an excpetion
+    if check_unexpected is set to True. The methodId is taken from the rpcName
     in the discovery document.
 
     For more details see the project wiki.
diff --git a/tests/test_http.py b/tests/test_http.py
index 09d6eb8..f502bac 100644
--- a/tests/test_http.py
+++ b/tests/test_http.py
@@ -26,11 +26,14 @@
 import os
 import unittest
 
-from apiclient.http import set_user_agent
+from apiclient.errors import BatchError
+from apiclient.http import BatchHttpRequest
 from apiclient.http import HttpMockSequence
 from apiclient.http import HttpRequest
-from apiclient.http import MediaUpload
 from apiclient.http import MediaFileUpload
+from apiclient.http import MediaUpload
+from apiclient.http import set_user_agent
+from apiclient.model import JsonModel
 
 
 DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
@@ -99,7 +102,7 @@
     json = req.to_json()
     new_req = HttpRequest.from_json(json, http, _postproc)
 
-    self.assertEquals(new_req.headers, 
+    self.assertEquals(new_req.headers,
                       {'content-type':
                        'multipart/related; boundary="---flubber"'})
     self.assertEquals(new_req.uri, 'http://example.com')
@@ -108,6 +111,163 @@
     self.assertEquals(new_req.resumable.to_json(), media_upload.to_json())
     self.assertEquals(new_req.multipart_boundary, '---flubber')
 
+EXPECTED = """POST /someapi/v1/collection/?foo=bar HTTP/1.1
+Content-Type: application/json
+MIME-Version: 1.0
+Host: www.googleapis.com\r\n\r\n{}"""
+
+
+RESPONSE = """HTTP/1.1 200 OK
+Content-Type application/json
+Content-Length: 14
+ETag: "etag/pony"\r\n\r\n{"answer": 42}"""
+
+
+BATCH_RESPONSE = """--batch_foobarbaz
+Content-Type: application/http
+Content-Transfer-Encoding: binary
+Content-ID: <randomness+1>
+
+HTTP/1.1 200 OK
+Content-Type application/json
+Content-Length: 14
+ETag: "etag/pony"\r\n\r\n{"foo": 42}
+
+--batch_foobarbaz
+Content-Type: application/http
+Content-Transfer-Encoding: binary
+Content-ID: <randomness+2>
+
+HTTP/1.1 200 OK
+Content-Type application/json
+Content-Length: 14
+ETag: "etag/sheep"\r\n\r\n{"baz": "qux"}
+--batch_foobarbaz--"""
+
+class TestBatch(unittest.TestCase):
+
+  def setUp(self):
+    model = JsonModel()
+    self.request1 = HttpRequest(
+        None,
+        model.response,
+        'https://www.googleapis.com/someapi/v1/collection/?foo=bar',
+        method='POST',
+        body='{}',
+        headers={'content-type': 'application/json'})
+
+    self.request2 = HttpRequest(
+        None,
+        model.response,
+        'https://www.googleapis.com/someapi/v1/collection/?foo=bar',
+        method='POST',
+        body='{}',
+        headers={'content-type': 'application/json'})
+
+
+  def test_id_to_from_content_id_header(self):
+    batch = BatchHttpRequest()
+    self.assertEquals('12', batch._header_to_id(batch._id_to_header('12')))
+
+  def test_invalid_content_id_header(self):
+    batch = BatchHttpRequest()
+    self.assertRaises(BatchError, batch._header_to_id, '[foo+x]')
+    self.assertRaises(BatchError, batch._header_to_id, 'foo+1')
+    self.assertRaises(BatchError, batch._header_to_id, '<foo>')
+
+  def test_serialize_request(self):
+    batch = BatchHttpRequest()
+    request = HttpRequest(
+        None,
+        None,
+        'https://www.googleapis.com/someapi/v1/collection/?foo=bar',
+        method='POST',
+        body='{}',
+        headers={'content-type': 'application/json'},
+        methodId=None,
+        resumable=None)
+    s = batch._serialize_request(request).splitlines()
+    self.assertEquals(s, EXPECTED.splitlines())
+
+  def test_deserialize_response(self):
+    batch = BatchHttpRequest()
+    resp, content = batch._deserialize_response(RESPONSE)
+
+    self.assertEquals(resp.status, 200)
+    self.assertEquals(resp.reason, 'OK')
+    self.assertEquals(resp.version, 11)
+    self.assertEquals(content, '{"answer": 42}')
+
+  def test_new_id(self):
+    batch = BatchHttpRequest()
+
+    id_ = batch._new_id()
+    self.assertEquals(id_, '1')
+
+    id_ = batch._new_id()
+    self.assertEquals(id_, '2')
+
+    batch.add(self.request1, request_id='3')
+
+    id_ = batch._new_id()
+    self.assertEquals(id_, '4')
+
+  def test_add(self):
+    batch = BatchHttpRequest()
+    batch.add(self.request1, request_id='1')
+    self.assertRaises(KeyError, batch.add, self.request1, request_id='1')
+
+  def test_add_fail_for_resumable(self):
+    batch = BatchHttpRequest()
+
+    upload = MediaFileUpload(
+        datafile('small.png'), chunksize=500, resumable=True)
+    self.request1.resumable = upload
+    self.assertRaises(BatchError, batch.add, self.request1, request_id='1')
+
+  def test_execute(self):
+    class Callbacks(object):
+      def __init__(self):
+        self.responses = {}
+
+      def f(self, request_id, response):
+        self.responses[request_id] = response
+
+    batch = BatchHttpRequest()
+    callbacks = Callbacks()
+
+    batch.add(self.request1, callback=callbacks.f)
+    batch.add(self.request2, callback=callbacks.f)
+    http = HttpMockSequence([
+      ({'status': '200',
+        'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
+       BATCH_RESPONSE),
+      ])
+    batch.execute(http)
+    self.assertEqual(callbacks.responses['1'], {'foo': 42})
+    self.assertEqual(callbacks.responses['2'], {'baz': 'qux'})
+
+  def test_execute_global_callback(self):
+    class Callbacks(object):
+      def __init__(self):
+        self.responses = {}
+
+      def f(self, request_id, response):
+        self.responses[request_id] = response
+
+    callbacks = Callbacks()
+    batch = BatchHttpRequest(callback=callbacks.f)
+
+    batch.add(self.request1)
+    batch.add(self.request2)
+    http = HttpMockSequence([
+      ({'status': '200',
+        'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
+       BATCH_RESPONSE),
+      ])
+    batch.execute(http)
+    self.assertEqual(callbacks.responses['1'], {'foo': 42})
+    self.assertEqual(callbacks.responses['2'], {'baz': 'qux'})
 
 if __name__ == '__main__':
   unittest.main()