]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/utils.py
Standardize retry mechanism (#1649)
[yt-dlp.git] / yt_dlp / utils.py
index 545c027635da2213809c01b746155ca0e3e436d4..a5c2d10ef510cfcb41785c7b335f5510818c80c5 100644 (file)
@@ -599,6 +599,7 @@ def sanitize_open(filename, open_mode):
     if filename == '-':
         if sys.platform == 'win32':
             import msvcrt
+
             # stdout may be any IO stream. Eg, when using contextlib.redirect_stdout
             with contextlib.suppress(io.UnsupportedOperation):
                 msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
@@ -5650,6 +5651,62 @@ def items_(self):
 KNOWN_EXTENSIONS = (*MEDIA_EXTENSIONS.video, *MEDIA_EXTENSIONS.audio, *MEDIA_EXTENSIONS.manifests)
 
 
+class RetryManager:
+    """Usage:
+        for retry in RetryManager(...):
+            try:
+                ...
+            except SomeException as err:
+                retry.error = err
+                continue
+    """
+    attempt, _error = 0, None
+
+    def __init__(self, _retries, _error_callback, **kwargs):
+        self.retries = _retries or 0
+        self.error_callback = functools.partial(_error_callback, **kwargs)
+
+    def _should_retry(self):
+        return self._error is not NO_DEFAULT and self.attempt <= self.retries
+
+    @property
+    def error(self):
+        if self._error is NO_DEFAULT:
+            return None
+        return self._error
+
+    @error.setter
+    def error(self, value):
+        self._error = value
+
+    def __iter__(self):
+        while self._should_retry():
+            self.error = NO_DEFAULT
+            self.attempt += 1
+            yield self
+            if self.error:
+                self.error_callback(self.error, self.attempt, self.retries)
+
+    @staticmethod
+    def report_retry(e, count, retries, *, sleep_func, info, warn, error=None, suffix=None):
+        """Utility function for reporting retries"""
+        if count > retries:
+            if error:
+                return error(f'{e}. Giving up after {count - 1} retries') if count > 1 else error(str(e))
+            raise e
+
+        if not count:
+            return warn(e)
+        elif isinstance(e, ExtractorError):
+            e = remove_end(e.cause or e.orig_msg, '.')
+        warn(f'{e}. Retrying{format_field(suffix, None, " %s")} ({count}/{retries})...')
+
+        delay = float_or_none(sleep_func(n=count - 1)) if callable(sleep_func) else sleep_func
+        if delay:
+            info(f'Sleeping {delay:.2f} seconds ...')
+            time.sleep(delay)
+
+
 # Deprecated
 has_certifi = bool(certifi)
 has_websockets = bool(websockets)