fix(google.auth.compute_engine.metadata): add retry to google.auth.compute_engine._metadata.get() (#398)
Initial fix of issue #211 was done in CL #323, but only for .ping()
This one is adding same behaviour & tests for .get() method, as the problem still occurres
See the issue for details
Refs: #323
Resolves: #211
diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py
index f4fae72..30cd3d4 100644
--- a/google/auth/compute_engine/_metadata.py
+++ b/google/auth/compute_engine/_metadata.py
@@ -99,7 +99,7 @@
return False
-def get(request, path, root=_METADATA_ROOT, recursive=False):
+def get(request, path, root=_METADATA_ROOT, recursive=False, retry_count=5):
"""Fetch a resource from the metadata server.
Args:
@@ -111,6 +111,8 @@
recursive (bool): Whether to do a recursive query of metadata. See
https://cloud.google.com/compute/docs/metadata#aggcontents for more
details.
+ retry_count (int): How many times to attempt connecting to metadata
+ server using above timeout.
Returns:
Union[Mapping, str]: If the metadata server returns JSON, a mapping of
@@ -129,7 +131,24 @@
url = _helpers.update_query(base_url, query_params)
- response = request(url=url, method="GET", headers=_METADATA_HEADERS)
+ retries = 0
+ while retries < retry_count:
+ try:
+ response = request(url=url, method="GET", headers=_METADATA_HEADERS)
+ break
+
+ except exceptions.TransportError:
+ _LOGGER.info(
+ "Compute Engine Metadata server unavailable on" "attempt %s of %s",
+ retries + 1,
+ retry_count,
+ )
+ retries += 1
+ else:
+ raise exceptions.TransportError(
+ "Failed to retrieve {} from the Google Compute Engine"
+ "metadata service. Compute Engine Metadata server unavailable".format(url)
+ )
if response.status == http_client.OK:
content = _helpers.from_bytes(response.data)
diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py
index bd06b74..0898e1f 100644
--- a/tests/compute_engine/test__metadata.py
+++ b/tests/compute_engine/test__metadata.py
@@ -30,14 +30,17 @@
PATH = "instance/service-accounts/default"
-def make_request(data, status=http_client.OK, headers=None):
+def make_request(data, status=http_client.OK, headers=None, retry=False):
response = mock.create_autospec(transport.Response, instance=True)
response.status = status
response.data = _helpers.to_bytes(data)
response.headers = headers or {}
request = mock.create_autospec(transport.Request)
- request.return_value = response
+ if retry:
+ request.side_effect = [exceptions.TransportError(), response]
+ else:
+ request.return_value = response
return request
@@ -55,6 +58,20 @@
)
+def test_ping_success_retry():
+ request = make_request("", headers=_metadata._METADATA_HEADERS, retry=True)
+
+ assert _metadata.ping(request)
+
+ request.assert_called_with(
+ method="GET",
+ url=_metadata._METADATA_IP_ROOT,
+ headers=_metadata._METADATA_HEADERS,
+ timeout=_metadata._METADATA_DEFAULT_TIMEOUT,
+ )
+ assert request.call_count == 2
+
+
def test_ping_failure_bad_flavor():
request = make_request("", headers={_metadata._METADATA_FLAVOR_HEADER: "meep"})
@@ -105,6 +122,25 @@
assert result[key] == value
+def test_get_success_retry():
+ key, value = "foo", "bar"
+
+ data = json.dumps({key: value})
+ request = make_request(
+ data, headers={"content-type": "application/json"}, retry=True
+ )
+
+ result = _metadata.get(request, PATH)
+
+ request.assert_called_with(
+ method="GET",
+ url=_metadata._METADATA_ROOT + PATH,
+ headers=_metadata._METADATA_HEADERS,
+ )
+ assert request.call_count == 2
+ assert result[key] == value
+
+
def test_get_success_text():
data = "foobar"
request = make_request(data, headers={"content-type": "text/plain"})
@@ -154,6 +190,23 @@
)
+def test_get_failure_connection_failed():
+ request = make_request("")
+ request.side_effect = exceptions.TransportError()
+
+ with pytest.raises(exceptions.TransportError) as excinfo:
+ _metadata.get(request, PATH)
+
+ assert excinfo.match(r"Compute Engine Metadata server unavailable")
+
+ request.assert_called_with(
+ method="GET",
+ url=_metadata._METADATA_ROOT + PATH,
+ headers=_metadata._METADATA_HEADERS,
+ )
+ assert request.call_count == 5
+
+
def test_get_failure_bad_json():
request = make_request("{", headers={"content-type": "application/json"})