Implement retry related functions

TEST: manual test
Change-Id: I165ab62bf7be96f629982676822543ac02940816
diff --git a/internal/lib/utils_test.py b/internal/lib/utils_test.py
index 9899d75..bb1200e 100644
--- a/internal/lib/utils_test.py
+++ b/internal/lib/utils_test.py
@@ -19,6 +19,7 @@
 import getpass
 import os
 import subprocess
+import time
 
 import mock
 
@@ -51,6 +52,52 @@
         utils.SSH_KEYGEN_CMD + ["-C", getpass.getuser(), "-f", private_key],
         stdout=mock.ANY, stderr=mock.ANY)
 
+  def testRetryOnException(self):
+    def _IsValueError(exc):
+      return isinstance(exc, ValueError)
+    num_retry = 5
+
+    @utils.RetryOnException(_IsValueError, num_retry)
+    def _RaiseAndRetry(sentinel):
+      sentinel.alert()
+      raise ValueError("Fake error.")
+
+    sentinel = mock.MagicMock()
+    self.assertRaises(ValueError, _RaiseAndRetry, sentinel)
+    self.assertEqual(1 + num_retry, sentinel.alert.call_count)
+
+  def testRetryExceptionType(self):
+    """Test RetryExceptionType function."""
+    def _RaiseAndRetry(sentinel):
+      sentinel.alert()
+      raise ValueError("Fake error.")
+
+    num_retry = 5
+    sentinel = mock.MagicMock()
+    self.assertRaises(ValueError, utils.RetryExceptionType,
+                      (KeyError, ValueError), num_retry, _RaiseAndRetry,
+                      sentinel=sentinel)
+    self.assertEqual(1 + num_retry, sentinel.alert.call_count)
+
+  def testRetry(self):
+    """Test Retry."""
+    self.Patch(time, "sleep")
+    def _RaiseAndRetry(sentinel):
+      sentinel.alert()
+      raise ValueError("Fake error.")
+
+    num_retry = 5
+    sentinel = mock.MagicMock()
+    self.assertRaises(ValueError, utils.RetryExceptionType,
+                      (ValueError, KeyError), num_retry, _RaiseAndRetry,
+                      sleep_multiplier=1,
+                      retry_backoff_factor=2,
+                      sentinel=sentinel)
+
+    self.assertEqual(1 + num_retry, sentinel.alert.call_count)
+    time.sleep.assert_has_calls(
+        [mock.call(1), mock.call(2), mock.call(4), mock.call(8), mock.call(16)])
+
 
 if __name__ == "__main__":
     unittest.main()