Respect custom headers set on the request supplied to MediaIoBaseDownload within each call to next_chunk (#546)
Closes #207
diff --git a/googleapiclient/http.py b/googleapiclient/http.py
index 6a47bf7..9e549de 100644
--- a/googleapiclient/http.py
+++ b/googleapiclient/http.py
@@ -647,6 +647,14 @@
self._sleep = time.sleep
self._rand = random.random
+ self._headers = {}
+ for k, v in six.iteritems(request.headers):
+ # allow users to supply custom headers by setting them on the request
+ # but strip out the ones that are set by default on requests generated by
+ # API methods like Drive's files().get(fileId=...)
+ if not k.lower() in ('accept', 'accept-encoding', 'user-agent'):
+ self._headers[k] = v
+
@util.positional(1)
def next_chunk(self, num_retries=0):
"""Get the next chunk of the download.
@@ -666,10 +674,9 @@
googleapiclient.errors.HttpError if the response was not a 2xx.
httplib2.HttpLib2Error if a transport error has occured.
"""
- headers = {
- 'range': 'bytes=%d-%d' % (
+ headers = self._headers.copy()
+ headers['range'] = 'bytes=%d-%d' % (
self._progress, self._progress + self._chunksize)
- }
http = self._request.http
resp, content = _retry_request(
diff --git a/tests/test_http.py b/tests/test_http.py
index 8a976ee..e381294 100644
--- a/tests/test_http.py
+++ b/tests/test_http.py
@@ -29,6 +29,7 @@
from six.moves.urllib.parse import urlencode
# Do not remove the httplib2 import
+import json
import httplib2
import logging
import mock
@@ -456,6 +457,41 @@
self.assertEqual(5, download._progress)
self.assertEqual(5, download._total_size)
+ def test_media_io_base_download_custom_request_headers(self):
+ self.request.http = HttpMockSequence([
+ ({'status': '200',
+ 'content-range': '0-2/5'}, 'echo_request_headers_as_json'),
+ ({'status': '200',
+ 'content-range': '3-4/5'}, 'echo_request_headers_as_json'),
+ ])
+ self.assertEqual(True, self.request.http.follow_redirects)
+
+ self.request.headers['Cache-Control'] = 'no-store'
+
+ download = MediaIoBaseDownload(
+ fd=self.fd, request=self.request, chunksize=3)
+
+ self.assertEqual(download._headers, {'Cache-Control':'no-store'})
+
+ status, done = download.next_chunk()
+
+ result = self.fd.getvalue().decode('utf-8')
+
+ # we abuse the internals of the object we're testing, pay no attention
+ # to the actual bytes= values here; we are just asserting that the
+ # header we added to the original request is sent up to the server
+ # on each call to next_chunk
+
+ self.assertEqual(json.loads(result),
+ {"Cache-Control": "no-store", "range": "bytes=0-3"})
+
+ download._fd = self.fd = BytesIO()
+ status, done = download.next_chunk()
+
+ result = self.fd.getvalue().decode('utf-8')
+ self.assertEqual(json.loads(result),
+ {"Cache-Control": "no-store", "range": "bytes=51-54"})
+
def test_media_io_base_download_handle_redirects(self):
self.request.http = HttpMockSequence([
({'status': '200',