]> jfr.im git - yt-dlp.git/commitdiff
[downloader/fragment] Improve `--live-from-start` for YouTube livestreams (#2870)
authorLesmiscore (Naoya Ozaki) <redacted>
Thu, 24 Feb 2022 17:00:46 +0000 (02:00 +0900)
committerGitHub <redacted>
Thu, 24 Feb 2022 17:00:46 +0000 (02:00 +0900)
yt_dlp/downloader/fragment.py
yt_dlp/extractor/youtube.py

index 19c0990d39d0a35d4797c3bf876ca942ec7c73da..082581b54d59c8dec2ba7cd303080dfa78ae5c18 100644 (file)
@@ -25,6 +25,7 @@
     error_to_compat_str,
     encodeFilename,
     sanitized_Request,
+    traverse_obj,
 )
 
 
@@ -382,6 +383,7 @@ def download_and_append_fragments_multiple(self, *args, pack_func=None, finish_f
         max_workers = self.params.get('concurrent_fragment_downloads', 1)
         if max_progress > 1:
             self._prepare_multiline_status(max_progress)
+        is_live = any(traverse_obj(args, (..., 2, 'is_live'), default=[]))
 
         def thread_func(idx, ctx, fragments, info_dict, tpe):
             ctx['max_progress'] = max_progress
@@ -395,25 +397,44 @@ class FTPE(concurrent.futures.ThreadPoolExecutor):
             def __exit__(self, exc_type, exc_val, exc_tb):
                 pass
 
-        spins = []
         if compat_os_name == 'nt':
-            self.report_warning('Ctrl+C does not work on Windows when used with parallel threads. '
-                                'This is a known issue and patches are welcome')
+            def bindoj_result(future):
+                while True:
+                    try:
+                        return future.result(0.1)
+                    except KeyboardInterrupt:
+                        raise
+                    except concurrent.futures.TimeoutError:
+                        continue
+        else:
+            def bindoj_result(future):
+                return future.result()
+
+        spins = []
         for idx, (ctx, fragments, info_dict) in enumerate(args):
             tpe = FTPE(math.ceil(max_workers / max_progress))
-            job = tpe.submit(thread_func, idx, ctx, fragments, info_dict, tpe)
+
+            def interrupt_trigger_iter():
+                for f in fragments:
+                    if not interrupt_trigger[0]:
+                        break
+                    yield f
+
+            job = tpe.submit(thread_func, idx, ctx, interrupt_trigger_iter(), info_dict, tpe)
             spins.append((tpe, job))
 
         result = True
         for tpe, job in spins:
             try:
-                result = result and job.result()
+                result = result and bindoj_result(job)
             except KeyboardInterrupt:
                 interrupt_trigger[0] = False
             finally:
                 tpe.shutdown(wait=True)
-        if not interrupt_trigger[0]:
+        if not interrupt_trigger[0] and not is_live:
             raise KeyboardInterrupt()
+        # we expect the user wants to stop and DO WANT the preceding postprocessors to run;
+        # so returning a intermediate result here instead of KeyboardInterrupt on live
         return result
 
     def download_and_append_fragments(
@@ -431,10 +452,11 @@ def download_and_append_fragments(
             pack_func = lambda frag_content, _: frag_content
 
         def download_fragment(fragment, ctx):
+            if not interrupt_trigger[0]:
+                return False, fragment['frag_index']
+
             frag_index = ctx['fragment_index'] = fragment['frag_index']
             ctx['last_error'] = None
-            if not interrupt_trigger[0]:
-                return False, frag_index
             headers = info_dict.get('http_headers', {}).copy()
             byte_range = fragment.get('byte_range')
             if byte_range:
@@ -500,8 +522,6 @@ def _download_fragment(fragment):
             self.report_warning('The download speed shown is only of one thread. This is a known issue and patches are welcome')
             with tpe or concurrent.futures.ThreadPoolExecutor(max_workers) as pool:
                 for fragment, frag_content, frag_index, frag_filename in pool.map(_download_fragment, fragments):
-                    if not interrupt_trigger[0]:
-                        break
                     ctx['fragment_filename_sanitized'] = frag_filename
                     ctx['fragment_index'] = frag_index
                     result = append_fragment(decrypt_fragment(fragment, frag_content), frag_index, ctx)
index 636bf42b602e4ef9a82d3ab0492392fb72a20560..47b3c5a85218a32715ce10e5a1d9411703f833d3 100644 (file)
@@ -2135,6 +2135,7 @@ def mpd_feed(format_id, delay):
             return f['manifest_url'], f['manifest_stream_number'], is_live
 
         for f in formats:
+            f['is_live'] = True
             f['protocol'] = 'http_dash_segments_generator'
             f['fragments'] = functools.partial(
                 self._live_dash_fragments, f['format_id'], live_start_time, mpd_feed)
@@ -2157,12 +2158,12 @@ def _live_dash_fragments(self, format_id, live_start_time, mpd_feed, ctx):
         known_idx, no_fragment_score, last_segment_url = begin_index, 0, None
         fragments, fragment_base_url = None, None
 
-        def _extract_sequence_from_mpd(refresh_sequence):
+        def _extract_sequence_from_mpd(refresh_sequence, immediate):
             nonlocal mpd_url, stream_number, is_live, no_fragment_score, fragments, fragment_base_url
             # Obtain from MPD's maximum seq value
             old_mpd_url = mpd_url
             last_error = ctx.pop('last_error', None)
-            expire_fast = last_error and isinstance(last_error, compat_HTTPError) and last_error.code == 403
+            expire_fast = immediate or last_error and isinstance(last_error, compat_HTTPError) and last_error.code == 403
             mpd_url, stream_number, is_live = (mpd_feed(format_id, 5 if expire_fast else 18000)
                                                or (mpd_url, stream_number, False))
             if not refresh_sequence:
@@ -2176,7 +2177,7 @@ def _extract_sequence_from_mpd(refresh_sequence):
             except ExtractorError:
                 fmts = None
             if not fmts:
-                no_fragment_score += 1
+                no_fragment_score += 2
                 return False, last_seq
             fmt_info = next(x for x in fmts if x['manifest_stream_number'] == stream_number)
             fragments = fmt_info['fragments']
@@ -2199,11 +2200,12 @@ def _extract_sequence_from_mpd(refresh_sequence):
                     urlh = None
                 last_seq = try_get(urlh, lambda x: int_or_none(x.headers['X-Head-Seqnum']))
                 if last_seq is None:
-                    no_fragment_score += 1
+                    no_fragment_score += 2
                     last_segment_url = None
                     continue
             else:
-                should_continue, last_seq = _extract_sequence_from_mpd(True)
+                should_continue, last_seq = _extract_sequence_from_mpd(True, no_fragment_score > 15)
+                no_fragment_score += 2
                 if not should_continue:
                     continue
 
@@ -2221,7 +2223,7 @@ def _extract_sequence_from_mpd(refresh_sequence):
             try:
                 for idx in range(known_idx, last_seq):
                     # do not update sequence here or you'll get skipped some part of it
-                    should_continue, _ = _extract_sequence_from_mpd(False)
+                    should_continue, _ = _extract_sequence_from_mpd(False, False)
                     if not should_continue:
                         known_idx = idx - 1
                         raise ExtractorError('breaking out of outer loop')