Core: Mitigate busy reopen loop in ResumableBidiRpc consuming 100% CPU (#8193)
* Add bidi._Throttle helper class
* Add optional reopen throttling to ResumableBidiRpc
* Enable Bidi reopen throttling in SPM
* Change bidi._Throttle signature
The commit renames the entry_cap parameter to access_limit, and
changes the type of the time_window argument from float to timedelta.
diff --git a/google/api_core/bidi.py b/google/api_core/bidi.py
index 053363a..3b69e91 100644
--- a/google/api_core/bidi.py
+++ b/google/api_core/bidi.py
@@ -14,8 +14,11 @@
"""Bi-directional streaming RPC helpers."""
+import collections
+import datetime
import logging
import threading
+import time
from six.moves import queue
@@ -134,6 +137,73 @@
yield item
+class _Throttle(object):
+ """A context manager limiting the total entries in a sliding time window.
+
+ If more than ``access_limit`` attempts are made to enter the context manager
+ instance in the last ``time window`` interval, the exceeding requests block
+ until enough time elapses.
+
+ The context manager instances are thread-safe and can be shared between
+ multiple threads. If multiple requests are blocked and waiting to enter,
+ the exact order in which they are allowed to proceed is not determined.
+
+ Example::
+
+ max_three_per_second = _Throttle(
+ access_limit=3, time_window=datetime.timedelta(seconds=1)
+ )
+
+ for i in range(5):
+ with max_three_per_second as time_waited:
+ print("{}: Waited {} seconds to enter".format(i, time_waited))
+
+ Args:
+ access_limit (int): the maximum number of entries allowed in the time window
+ time_window (datetime.timedelta): the width of the sliding time window
+ """
+
+ def __init__(self, access_limit, time_window):
+ if access_limit < 1:
+ raise ValueError("access_limit argument must be positive")
+
+ if time_window <= datetime.timedelta(0):
+ raise ValueError("time_window argument must be a positive timedelta")
+
+ self._time_window = time_window
+ self._access_limit = access_limit
+ self._past_entries = collections.deque(maxlen=access_limit) # least recent first
+ self._entry_lock = threading.Lock()
+
+ def __enter__(self):
+ with self._entry_lock:
+ cutoff_time = datetime.datetime.now() - self._time_window
+
+ # drop the entries that are too old, as they are no longer relevant
+ while self._past_entries and self._past_entries[0] < cutoff_time:
+ self._past_entries.popleft()
+
+ if len(self._past_entries) < self._access_limit:
+ self._past_entries.append(datetime.datetime.now())
+ return 0.0 # no waiting was needed
+
+ to_wait = (self._past_entries[0] - cutoff_time).total_seconds()
+ time.sleep(to_wait)
+
+ self._past_entries.append(datetime.datetime.now())
+ return to_wait
+
+ def __exit__(self, *_):
+ pass
+
+ def __repr__(self):
+ return "{}(access_limit={}, time_window={})".format(
+ self.__class__.__name__,
+ self._access_limit,
+ repr(self._time_window),
+ )
+
+
class BidiRpc(object):
"""A helper for consuming a bi-directional streaming RPC.
@@ -323,15 +393,31 @@
whenever an error is encountered on the stream.
metadata Sequence[Tuple(str, str)]: RPC metadata to include in
the request.
+ throttle_reopen (bool): If ``True``, throttling will be applied to
+ stream reopen calls. Defaults to ``False``.
"""
- def __init__(self, start_rpc, should_recover, initial_request=None, metadata=None):
+ def __init__(
+ self,
+ start_rpc,
+ should_recover,
+ initial_request=None,
+ metadata=None,
+ throttle_reopen=False,
+ ):
super(ResumableBidiRpc, self).__init__(start_rpc, initial_request, metadata)
self._should_recover = should_recover
self._operational_lock = threading.RLock()
self._finalized = False
self._finalize_lock = threading.Lock()
+ if throttle_reopen:
+ self._reopen_throttle = _Throttle(
+ access_limit=5, time_window=datetime.timedelta(seconds=10),
+ )
+ else:
+ self._reopen_throttle = None
+
def _finalize(self, result):
with self._finalize_lock:
if self._finalized:
@@ -374,7 +460,11 @@
# retryable error.
try:
- self.open()
+ if self._reopen_throttle:
+ with self._reopen_throttle:
+ self.open()
+ else:
+ self.open()
# If re-opening or re-calling the method fails for any reason,
# consider it a terminal error and finalize the stream.
except Exception as exc:
@@ -573,7 +663,7 @@
thread = threading.Thread(
name=_BIDIRECTIONAL_CONSUMER_NAME,
target=self._thread_main,
- args=(ready,)
+ args=(ready,),
)
thread.daemon = True
thread.start()
diff --git a/tests/unit/test_bidi.py b/tests/unit/test_bidi.py
index 08d2021..8e9f262 100644
--- a/tests/unit/test_bidi.py
+++ b/tests/unit/test_bidi.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import datetime
import logging
import threading
@@ -116,6 +117,87 @@
assert items == []
+class Test_Throttle(object):
+ def test_repr(self):
+ delta = datetime.timedelta(seconds=4.5)
+ instance = bidi._Throttle(access_limit=42, time_window=delta)
+ assert repr(instance) == \
+ "_Throttle(access_limit=42, time_window={})".format(repr(delta))
+
+ def test_raises_error_on_invalid_init_arguments(self):
+ with pytest.raises(ValueError) as exc_info:
+ bidi._Throttle(
+ access_limit=10, time_window=datetime.timedelta(seconds=0.0)
+ )
+ assert "time_window" in str(exc_info.value)
+ assert "must be a positive timedelta" in str(exc_info.value)
+
+ with pytest.raises(ValueError) as exc_info:
+ bidi._Throttle(
+ access_limit=0, time_window=datetime.timedelta(seconds=10)
+ )
+ assert "access_limit" in str(exc_info.value)
+ assert "must be positive" in str(exc_info.value)
+
+ def test_does_not_delay_entry_attempts_under_threshold(self):
+ throttle = bidi._Throttle(
+ access_limit=3, time_window=datetime.timedelta(seconds=1)
+ )
+ entries = []
+
+ for _ in range(3):
+ with throttle as time_waited:
+ entry_info = {
+ "entered_at": datetime.datetime.now(),
+ "reported_wait": time_waited,
+ }
+ entries.append(entry_info)
+
+ # check the reported wait times ...
+ assert all(entry["reported_wait"] == 0.0 for entry in entries)
+
+ # .. and the actual wait times
+ delta = entries[1]["entered_at"] - entries[0]["entered_at"]
+ assert delta.total_seconds() < 0.1
+ delta = entries[2]["entered_at"] - entries[1]["entered_at"]
+ assert delta.total_seconds() < 0.1
+
+ def test_delays_entry_attempts_above_threshold(self):
+ throttle = bidi._Throttle(
+ access_limit=3, time_window=datetime.timedelta(seconds=1)
+ )
+ entries = []
+
+ for _ in range(6):
+ with throttle as time_waited:
+ entry_info = {
+ "entered_at": datetime.datetime.now(),
+ "reported_wait": time_waited,
+ }
+ entries.append(entry_info)
+
+ # For each group of 4 consecutive entries the time difference between
+ # the first and the last entry must have been greater than time_window,
+ # because a maximum of 3 are allowed in each time_window.
+ for i, entry in enumerate(entries[3:], start=3):
+ first_entry = entries[i - 3]
+ delta = entry["entered_at"] - first_entry["entered_at"]
+ assert delta.total_seconds() > 1.0
+
+ # check the reported wait times
+ # (NOTE: not using assert all(...), b/c the coverage check would complain)
+ for i, entry in enumerate(entries):
+ if i != 3:
+ assert entry["reported_wait"] == 0.0
+
+ # The delayed entry is expected to have been delayed for a significant
+ # chunk of the full second, and the actual and reported delay times
+ # should reflect that.
+ assert entries[3]["reported_wait"] > 0.7
+ delta = entries[3]["entered_at"] - entries[2]["entered_at"]
+ assert delta.total_seconds() > 0.7
+
+
class _CallAndFuture(grpc.Call, grpc.Future):
pass
@@ -442,6 +524,22 @@
assert bidi_rpc.is_active is False
callback.assert_called_once_with(error2)
+ def test_using_throttle_on_reopen_requests(self):
+ call = CallStub([])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True, return_value=call
+ )
+ should_recover = mock.Mock(spec=["__call__"], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(
+ start_rpc, should_recover, throttle_reopen=True
+ )
+
+ patcher = mock.patch.object(bidi_rpc._reopen_throttle.__class__, "__enter__")
+ with patcher as mock_enter:
+ bidi_rpc._reopen()
+
+ mock_enter.assert_called_once()
+
def test_send_not_open(self):
bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)