Add option to automatically retry requests.

Reviewed in: https://codereview.appspot.com/9920043/
diff --git a/apiclient/http.py b/apiclient/http.py
index b73014a..31a1c44 100644
--- a/apiclient/http.py
+++ b/apiclient/http.py
@@ -26,10 +26,13 @@
 import copy
 import gzip
 import httplib2
+import logging
 import mimeparse
 import mimetypes
 import os
+import random
 import sys
+import time
 import urllib
 import urlparse
 import uuid
@@ -508,9 +511,20 @@
     self._original_follow_redirects = request.http.follow_redirects
     request.http.follow_redirects = False
 
-  def next_chunk(self):
+    # Stubs for testing.
+    self._sleep = time.sleep
+    self._rand = random.random
+
+  @util.positional(1)
+  def next_chunk(self, num_retries=0):
     """Get the next chunk of the download.
 
+    Args:
+      num_retries: Integer, number of times to retry 500's with randomized
+            exponential backoff. If all retries fail, the raised HttpError
+            represents the last request. If zero (default), we attempt the
+            request only once.
+
     Returns:
       (status, done): (MediaDownloadStatus, boolean)
          The value of 'done' will be True when the media has been fully
@@ -526,7 +540,17 @@
         }
     http = self._request.http
 
-    resp, content = http.request(self._uri, headers=headers)
+    for retry_num in xrange(num_retries + 1):
+      if retry_num > 0:
+        self._sleep(self._rand() * 2**retry_num)
+        logging.warning(
+            'Retry #%d for media download: GET %s, following status: %d'
+            % (retry_num, self._uri, resp.status))
+
+      resp, content = http.request(self._uri, headers=headers)
+      if resp.status < 500:
+        break
+
     if resp.status in [301, 302, 303, 307, 308] and 'location' in resp:
         self._uri = resp['location']
         resp, content = http.request(self._uri, headers=headers)
@@ -635,13 +659,21 @@
     # The bytes that have been uploaded.
     self.resumable_progress = 0
 
+    # Stubs for testing.
+    self._rand = random.random
+    self._sleep = time.sleep
+
   @util.positional(1)
-  def execute(self, http=None):
+  def execute(self, http=None, num_retries=0):
     """Execute the request.
 
     Args:
       http: httplib2.Http, an http object to be used in place of the
             one the HttpRequest request object was constructed with.
+      num_retries: Integer, number of times to retry 500's with randomized
+            exponential backoff. If all retries fail, the raised HttpError
+            represents the last request. If zero (default), we attempt the
+            request only once.
 
     Returns:
       A deserialized object model of the response body as determined
@@ -653,33 +685,46 @@
     """
     if http is None:
       http = self.http
+
     if self.resumable:
       body = None
       while body is None:
-        _, body = self.next_chunk(http=http)
+        _, body = self.next_chunk(http=http, num_retries=num_retries)
       return body
-    else:
-      if 'content-length' not in self.headers:
-        self.headers['content-length'] = str(self.body_size)
-      # If the request URI is too long then turn it into a POST request.
-      if len(self.uri) > MAX_URI_LENGTH and self.method == 'GET':
-        self.method = 'POST'
-        self.headers['x-http-method-override'] = 'GET'
-        self.headers['content-type'] = 'application/x-www-form-urlencoded'
-        parsed = urlparse.urlparse(self.uri)
-        self.uri = urlparse.urlunparse(
-            (parsed.scheme, parsed.netloc, parsed.path, parsed.params, None,
-             None)
-            )
-        self.body = parsed.query
-        self.headers['content-length'] = str(len(self.body))
+
+    # Non-resumable case.
+
+    if 'content-length' not in self.headers:
+      self.headers['content-length'] = str(self.body_size)
+    # If the request URI is too long then turn it into a POST request.
+    if len(self.uri) > MAX_URI_LENGTH and self.method == 'GET':
+      self.method = 'POST'
+      self.headers['x-http-method-override'] = 'GET'
+      self.headers['content-type'] = 'application/x-www-form-urlencoded'
+      parsed = urlparse.urlparse(self.uri)
+      self.uri = urlparse.urlunparse(
+          (parsed.scheme, parsed.netloc, parsed.path, parsed.params, None,
+           None)
+          )
+      self.body = parsed.query
+      self.headers['content-length'] = str(len(self.body))
+
+    # Handle retries for server-side errors.
+    for retry_num in xrange(num_retries + 1):
+      if retry_num > 0:
+        self._sleep(self._rand() * 2**retry_num)
+        logging.warning('Retry #%d for request: %s %s, following status: %d'
+                        % (retry_num, self.method, self.uri, resp.status))
 
       resp, content = http.request(str(self.uri), method=str(self.method),
                                    body=self.body, headers=self.headers)
-      for callback in self.response_callbacks:
-        callback(resp)
-      if resp.status >= 300:
-        raise HttpError(resp, content, uri=self.uri)
+      if resp.status < 500:
+        break
+
+    for callback in self.response_callbacks:
+      callback(resp)
+    if resp.status >= 300:
+      raise HttpError(resp, content, uri=self.uri)
     return self.postproc(resp, content)
 
   @util.positional(2)
@@ -695,7 +740,7 @@
     self.response_callbacks.append(cb)
 
   @util.positional(1)
-  def next_chunk(self, http=None):
+  def next_chunk(self, http=None, num_retries=0):
     """Execute the next step of a resumable upload.
 
     Can only be used if the method being executed supports media uploads and
@@ -717,6 +762,14 @@
           print "Upload %d%% complete." % int(status.progress() * 100)
 
 
+    Args:
+      http: httplib2.Http, an http object to be used in place of the
+            one the HttpRequest request object was constructed with.
+      num_retries: Integer, number of times to retry 500's with randomized
+            exponential backoff. If all retries fail, the raised HttpError
+            represents the last request. If zero (default), we attempt the
+            request only once.
+
     Returns:
       (status, body): (ResumableMediaStatus, object)
          The body will be None until the resumable media is fully uploaded.
@@ -740,9 +793,19 @@
         start_headers['X-Upload-Content-Length'] = size
       start_headers['content-length'] = str(self.body_size)
 
-      resp, content = http.request(self.uri, method=self.method,
-                                   body=self.body,
-                                   headers=start_headers)
+      for retry_num in xrange(num_retries + 1):
+        if retry_num > 0:
+          self._sleep(self._rand() * 2**retry_num)
+          logging.warning(
+              'Retry #%d for resumable URI request: %s %s, following status: %d'
+              % (retry_num, self.method, self.uri, resp.status))
+
+        resp, content = http.request(self.uri, method=self.method,
+                                     body=self.body,
+                                     headers=start_headers)
+        if resp.status < 500:
+          break
+
       if resp.status == 200 and 'location' in resp:
         self.resumable_uri = resp['location']
       else:
@@ -794,13 +857,23 @@
         # calculate the size when working with _StreamSlice.
         'Content-Length': str(chunk_end - self.resumable_progress + 1)
         }
-    try:
-      resp, content = http.request(self.resumable_uri, method='PUT',
-                                   body=data,
-                                   headers=headers)
-    except:
-      self._in_error_state = True
-      raise
+
+    for retry_num in xrange(num_retries + 1):
+      if retry_num > 0:
+        self._sleep(self._rand() * 2**retry_num)
+        logging.warning(
+            'Retry #%d for media upload: %s %s, following status: %d'
+            % (retry_num, self.method, self.uri, resp.status))
+
+      try:
+        resp, content = http.request(self.resumable_uri, method='PUT',
+                                     body=data,
+                                     headers=headers)
+      except:
+        self._in_error_state = True
+        raise
+      if resp.status < 500:
+        break
 
     return self._process_response(resp, content)
 
@@ -841,6 +914,8 @@
       d['resumable'] = self.resumable.to_json()
     del d['http']
     del d['postproc']
+    del d['_sleep']
+    del d['_rand']
 
     return simplejson.dumps(d)
 
diff --git a/tests/test_http.py b/tests/test_http.py
index 1651140..7d6fff7 100644
--- a/tests/test_http.py
+++ b/tests/test_http.py
@@ -23,10 +23,13 @@
 
 # Do not remove the httplib2 import
 import httplib2
+import logging
 import os
 import unittest
 import urllib
+import random
 import StringIO
+import time
 
 from apiclient.discovery import build
 from apiclient.errors import BatchError
@@ -184,6 +187,9 @@
     self.assertEqual(http, new_req.http)
     self.assertEqual(media_upload.to_json(), new_req.resumable.to_json())
 
+    self.assertEqual(random.random, new_req._rand)
+    self.assertEqual(time.sleep, new_req._sleep)
+
 
 class TestMediaIoBaseUpload(unittest.TestCase):
 
@@ -276,6 +282,48 @@
     except ImportError:
       pass
 
+  def test_media_io_base_next_chunk_retries(self):
+    try:
+      import io
+    except ImportError:
+      return
+
+    f = open(datafile('small.png'), 'r')
+    fd = io.BytesIO(f.read())
+    upload = MediaIoBaseUpload(
+        fd=fd, mimetype='image/png', chunksize=500, resumable=True)
+
+    # Simulate 5XXs for both the request that creates the resumable upload and
+    # the upload itself.
+    http = HttpMockSequence([
+      ({'status': '500'}, ''),
+      ({'status': '500'}, ''),
+      ({'status': '503'}, ''),
+      ({'status': '200', 'location': 'location'}, ''),
+      ({'status': '500'}, ''),
+      ({'status': '500'}, ''),
+      ({'status': '503'}, ''),
+      ({'status': '200'}, '{}'),
+    ])
+
+    model = JsonModel()
+    uri = u'https://www.googleapis.com/someapi/v1/upload/?foo=bar'
+    method = u'POST'
+    request = HttpRequest(
+        http,
+        model.response,
+        uri,
+        method=method,
+        headers={},
+        resumable=upload)
+
+    sleeptimes = []
+    request._sleep = lambda x: sleeptimes.append(x)
+    request._rand = lambda: 10
+
+    request.execute(num_retries=3)
+    self.assertEqual([20, 40, 80, 20, 40, 80], sleeptimes)
+
 
 class TestMediaIoBaseDownload(unittest.TestCase):
 
@@ -367,6 +415,59 @@
 
     self.assertEqual(self.fd.getvalue(), '123')
 
+  def test_media_io_base_download_retries_5xx(self):
+    self.request.http = HttpMockSequence([
+      ({'status': '500'}, ''),
+      ({'status': '500'}, ''),
+      ({'status': '500'}, ''),
+      ({'status': '200',
+        'content-range': '0-2/5'}, '123'),
+      ({'status': '503'}, ''),
+      ({'status': '503'}, ''),
+      ({'status': '503'}, ''),
+      ({'status': '200',
+        'content-range': '3-4/5'}, '45'),
+    ])
+
+    download = MediaIoBaseDownload(
+        fd=self.fd, request=self.request, chunksize=3)
+
+    self.assertEqual(self.fd, download._fd)
+    self.assertEqual(3, download._chunksize)
+    self.assertEqual(0, download._progress)
+    self.assertEqual(None, download._total_size)
+    self.assertEqual(False, download._done)
+    self.assertEqual(self.request.uri, download._uri)
+
+    # Set time.sleep and random.random stubs.
+    sleeptimes = []
+    download._sleep = lambda x: sleeptimes.append(x)
+    download._rand = lambda: 10
+
+    status, done = download.next_chunk(num_retries=3)
+
+    # Check for exponential backoff using the rand function above.
+    self.assertEqual([20, 40, 80], sleeptimes)
+
+    self.assertEqual(self.fd.getvalue(), '123')
+    self.assertEqual(False, done)
+    self.assertEqual(3, download._progress)
+    self.assertEqual(5, download._total_size)
+    self.assertEqual(3, status.resumable_progress)
+
+    # Reset time.sleep stub.
+    del sleeptimes[0:len(sleeptimes)]
+
+    status, done = download.next_chunk(num_retries=3)
+
+    # Check for exponential backoff using the rand function above.
+    self.assertEqual([20, 40, 80], sleeptimes)
+
+    self.assertEqual(self.fd.getvalue(), '12345')
+    self.assertEqual(True, done)
+    self.assertEqual(5, download._progress)
+    self.assertEqual(5, download._total_size)
+
 EXPECTED = """POST /someapi/v1/collection/?foo=bar HTTP/1.1
 Content-Type: application/json
 MIME-Version: 1.0
@@ -508,6 +609,58 @@
     self.assertEqual(method, http.method)
     self.assertEqual(str, type(http.method))
 
+  def test_retry(self):
+    num_retries = 5
+    resp_seq = [({'status': '500'}, '')] * num_retries
+    resp_seq.append(({'status': '200'}, '{}'))
+
+    http = HttpMockSequence(resp_seq)
+    model = JsonModel()
+    uri = u'https://www.googleapis.com/someapi/v1/collection/?foo=bar'
+    method = u'POST'
+    request = HttpRequest(
+        http,
+        model.response,
+        uri,
+        method=method,
+        body=u'{}',
+        headers={'content-type': 'application/json'})
+
+    sleeptimes = []
+    request._sleep = lambda x: sleeptimes.append(x)
+    request._rand = lambda: 10
+
+    request.execute(num_retries=num_retries)
+
+    self.assertEqual(num_retries, len(sleeptimes))
+    for retry_num in xrange(num_retries):
+      self.assertEqual(10 * 2**(retry_num + 1), sleeptimes[retry_num])
+
+  def test_no_retry_fails_fast(self):
+    http = HttpMockSequence([
+        ({'status': '500'}, ''),
+        ({'status': '200'}, '{}')
+        ])
+    model = JsonModel()
+    uri = u'https://www.googleapis.com/someapi/v1/collection/?foo=bar'
+    method = u'POST'
+    request = HttpRequest(
+        http,
+        model.response,
+        uri,
+        method=method,
+        body=u'{}',
+        headers={'content-type': 'application/json'})
+
+    request._rand = lambda: 1.0
+    request._sleep = lambda _: self.fail('sleep should not have been called.')
+
+    try:
+      request.execute()
+      self.fail('Should have raised an exception.')
+    except HttpError:
+      pass
+
 
 class TestBatch(unittest.TestCase):
 
@@ -856,4 +1009,5 @@
 
 
 if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.ERROR)
   unittest.main()