]> jfr.im git - yt-dlp.git/blobdiff - yt_dlp/YoutubeDL.py
[ie/youtube:tab] Detect looping feeds (#6621)
[yt-dlp.git] / yt_dlp / YoutubeDL.py
index 79b7d47b037525e1a1bf01b3e7434d8cade0c4af..7f557166694c7fec7686b6f1ed10596c5a49c82e 100644 (file)
@@ -1,9 +1,11 @@
 import collections
 import contextlib
+import copy
 import datetime
 import errno
 import fileinput
 import functools
+import http.cookiejar
 import io
 import itertools
 import json
@@ -25,7 +27,7 @@
 from .cache import Cache
 from .compat import urllib  # isort: split
 from .compat import compat_os_name, compat_shlex_quote
-from .cookies import load_cookies
+from .cookies import LenientSimpleCookie, load_cookies
 from .downloader import FFmpegFD, get_suitable_downloader, shorten_protocol_name
 from .downloader.rtmp import rtmpdump_version
 from .extractor import gen_extractor_classes, get_info_extractor
@@ -673,6 +675,9 @@ def process_color_policy(stream):
         if auto_init and auto_init != 'no_verbose_header':
             self.print_debug_header()
 
+        self.__header_cookies = []
+        self._load_cookies(traverse_obj(self.params.get('http_headers'), 'cookie', casesense=False))  # compat
+
         def check_deprecated(param, option, suggestion):
             if self.params.get(param) is not None:
                 self.report_warning(f'{option} is deprecated. Use {suggestion} instead')
@@ -983,6 +988,7 @@ def trouble(self, message=None, tb=None, is_error=True):
         ID='green',
         DELIM='blue',
         ERROR='red',
+        BAD_FORMAT='light red',
         WARNING='yellow',
         SUPPRESS='light black',
     )
@@ -1271,21 +1277,20 @@ def create_key(outer_mobj):
                 return outer_mobj.group(0)
             key = outer_mobj.group('key')
             mobj = re.match(INTERNAL_FORMAT_RE, key)
-            initial_field = mobj.group('fields') if mobj else ''
-            value, replacement, default = None, None, na
+            value, replacement, default, last_field = None, None, na, ''
             while mobj:
                 mobj = mobj.groupdict()
                 default = mobj['default'] if mobj['default'] is not None else default
                 value = get_value(mobj)
-                replacement = mobj['replacement']
+                last_field, replacement = mobj['fields'], mobj['replacement']
                 if value is None and mobj['alternate']:
                     mobj = re.match(INTERNAL_FORMAT_RE, mobj['remaining'][1:])
                 else:
                     break
 
             fmt = outer_mobj.group('format')
-            if fmt == 's' and value is not None and key in field_size_compat_map.keys():
-                fmt = f'0{field_size_compat_map[key]:d}d'
+            if fmt == 's' and value is not None and last_field in field_size_compat_map.keys():
+                fmt = f'0{field_size_compat_map[last_field]:d}d'
 
             if None not in (value, replacement):
                 try:
@@ -1322,7 +1327,7 @@ def create_key(outer_mobj):
                 value = format_decimal_suffix(value, f'%{num_fmt}f%s' if num_fmt else '%d%s',
                                               factor=1024 if '#' in flags else 1000)
             elif fmt[-1] == 'S':  # filename sanitization
-                value, fmt = filename_sanitizer(initial_field, value, restricted='#' in flags), str_fmt
+                value, fmt = filename_sanitizer(last_field, value, restricted='#' in flags), str_fmt
             elif fmt[-1] == 'c':
                 if value:
                     value = str(value)[0]
@@ -1341,7 +1346,7 @@ def create_key(outer_mobj):
                 elif fmt[-1] == 'a':
                     value, fmt = ascii(value), str_fmt
                 if fmt[-1] in 'csra':
-                    value = sanitizer(initial_field, value)
+                    value = sanitizer(last_field, value)
 
             key = '%s\0%s' % (key.replace('%', '%\0'), outer_mobj.group('format'))
             TMPL_DICT[key] = value
@@ -1625,8 +1630,60 @@ def progress(msg):
                 self.to_screen('')
             raise
 
+    def _load_cookies(self, data, *, from_headers=True):
+        """Loads cookies from a `Cookie` header
+
+        This tries to work around the security vulnerability of passing cookies to every domain.
+        See: https://github.com/yt-dlp/yt-dlp/security/advisories/GHSA-v8mc-9377-rwjj
+        The unscoped cookies are saved for later to be stored in the jar with a limited scope.
+
+        @param data         The Cookie header as string to load the cookies from
+        @param from_headers If `False`, allows Set-Cookie syntax in the cookie string (at least a domain will be required)
+        """
+        for cookie in LenientSimpleCookie(data).values():
+            if from_headers and any(cookie.values()):
+                raise ValueError('Invalid syntax in Cookie Header')
+
+            domain = cookie.get('domain') or ''
+            expiry = cookie.get('expires')
+            if expiry == '':  # 0 is valid
+                expiry = None
+            prepared_cookie = http.cookiejar.Cookie(
+                cookie.get('version') or 0, cookie.key, cookie.value, None, False,
+                domain, True, True, cookie.get('path') or '', bool(cookie.get('path')),
+                cookie.get('secure') or False, expiry, False, None, None, {})
+
+            if domain:
+                self.cookiejar.set_cookie(prepared_cookie)
+            elif from_headers:
+                self.deprecated_feature(
+                    'Passing cookies as a header is a potential security risk; '
+                    'they will be scoped to the domain of the downloaded urls. '
+                    'Please consider loading cookies from a file or browser instead.')
+                self.__header_cookies.append(prepared_cookie)
+            else:
+                self.report_error('Unscoped cookies are not allowed; please specify some sort of scoping',
+                                  tb=False, is_error=False)
+
+    def _apply_header_cookies(self, url):
+        """Applies stray header cookies to the provided url
+
+        This loads header cookies and scopes them to the domain provided in `url`.
+        While this is not ideal, it helps reduce the risk of them being sent
+        to an unintended destination while mostly maintaining compatibility.
+        """
+        parsed = urllib.parse.urlparse(url)
+        if not parsed.hostname:
+            return
+
+        for cookie in map(copy.copy, self.__header_cookies):
+            cookie.domain = f'.{parsed.hostname}'
+            self.cookiejar.set_cookie(cookie)
+
     @_handle_extraction_exceptions
     def __extract_info(self, url, ie, download, extra_info, process):
+        self._apply_header_cookies(url)
+
         try:
             ie_result = ie.extract(url)
         except UserNotLive as e:
@@ -2086,8 +2143,6 @@ def syntax_error(note, start):
         allow_multiple_streams = {'audio': self.params.get('allow_multiple_audio_streams', False),
                                   'video': self.params.get('allow_multiple_video_streams', False)}
 
-        check_formats = self.params.get('check_formats') == 'selected'
-
         def _parse_filter(tokens):
             filter_parts = []
             for type, string_, start, _, _ in tokens:
@@ -2260,10 +2315,19 @@ def _merge(formats_pair):
             return new_dict
 
         def _check_formats(formats):
-            if not check_formats:
+            if (self.params.get('check_formats') is not None
+                    or self.params.get('allow_unplayable_formats')):
                 yield from formats
                 return
-            yield from self._check_formats(formats)
+            elif self.params.get('check_formats') == 'selected':
+                yield from self._check_formats(formats)
+                return
+
+            for f in formats:
+                if f.get('has_drm'):
+                    yield from self._check_formats([f])
+                else:
+                    yield f
 
         def _build_selector_function(selector):
             if isinstance(selector, list):  # ,
@@ -2407,9 +2471,24 @@ def _calc_headers(self, info_dict):
         if 'Youtubedl-No-Compression' in res:  # deprecated
             res.pop('Youtubedl-No-Compression', None)
             res['Accept-Encoding'] = 'identity'
-        cookies = self.cookiejar.get_cookie_header(info_dict['url'])
+        cookies = self.cookiejar.get_cookies_for_url(info_dict['url'])
         if cookies:
-            res['Cookie'] = cookies
+            encoder = LenientSimpleCookie()
+            values = []
+            for cookie in cookies:
+                _, value = encoder.value_encode(cookie.value)
+                values.append(f'{cookie.name}={value}')
+                if cookie.domain:
+                    values.append(f'Domain={cookie.domain}')
+                if cookie.path:
+                    values.append(f'Path={cookie.path}')
+                if cookie.secure:
+                    values.append('Secure')
+                if cookie.expires:
+                    values.append(f'Expires={cookie.expires}')
+                if cookie.version:
+                    values.append(f'Version={cookie.version}')
+            info_dict['cookies'] = '; '.join(values)
 
         if 'X-Forwarded-For' not in res:
             x_forwarded_for_ip = info_dict.get('__x_forwarded_for_ip')
@@ -2615,10 +2694,10 @@ def sanitize_numeric_fields(info):
         if field_preference:
             info_dict['_format_sort_fields'] = field_preference
 
-        # or None ensures --clean-infojson removes it
-        info_dict['_has_drm'] = any(f.get('has_drm') for f in formats) or None
+        info_dict['_has_drm'] = any(  # or None ensures --clean-infojson removes it
+            f.get('has_drm') and f['has_drm'] != 'maybe' for f in formats) or None
         if not self.params.get('allow_unplayable_formats'):
-            formats = [f for f in formats if not f.get('has_drm')]
+            formats = [f for f in formats if not f.get('has_drm') or f['has_drm'] == 'maybe']
 
         if formats and all(f.get('acodec') == f.get('vcodec') == 'none' for f in formats):
             self.report_warning(
@@ -2767,11 +2846,8 @@ def is_wellformed(f):
             formats_to_download = list(format_selector({
                 'formats': formats,
                 'has_merged_format': any('none' not in (f.get('acodec'), f.get('vcodec')) for f in formats),
-                'incomplete_formats': (
-                    # All formats are video-only or
-                    all(f.get('vcodec') != 'none' and f.get('acodec') == 'none' for f in formats)
-                    # all formats are audio-only
-                    or all(f.get('vcodec') == 'none' and f.get('acodec') != 'none' for f in formats)),
+                'incomplete_formats': (all(f.get('vcodec') == 'none' for f in formats)  # No formats with video
+                                       or all(f.get('acodec') == 'none' for f in formats)),  # OR, No formats with audio
             }))
             if interactive_format_selection and not formats_to_download:
                 self.report_error('Requested format is not available', tb=False, is_error=False)
@@ -2806,11 +2882,13 @@ def to_screen(*msg):
                 new_info.update(fmt)
                 offset, duration = info_dict.get('section_start') or 0, info_dict.get('duration') or float('inf')
                 end_time = offset + min(chapter.get('end_time', duration), duration)
+                # duration may not be accurate. So allow deviations <1sec
+                if end_time == float('inf') or end_time > offset + duration + 1:
+                    end_time = None
                 if chapter or offset:
                     new_info.update({
                         'section_start': offset + chapter.get('start_time', 0),
-                        # duration may not be accurate. So allow deviations <1sec
-                        'section_end': end_time if end_time <= offset + duration + 1 else None,
+                        'section_end': end_time,
                         'section_title': chapter.get('title'),
                         'section_number': chapter.get('index'),
                     })
@@ -3417,6 +3495,8 @@ def download_with_info_file(self, info_filename):
             infos = [self.sanitize_info(info, self.params.get('clean_infojson', True))
                      for info in variadic(json.loads('\n'.join(f)))]
         for info in infos:
+            self._load_cookies(info.get('cookies'), from_headers=False)
+            self._load_cookies(traverse_obj(info.get('http_headers'), 'Cookie', casesense=False))  # compat
             try:
                 self.__download_wrapper(self.process_ie_result)(info, download=True)
             except (DownloadError, EntryNotInPlaylist, ReExtractInfo) as e:
@@ -3686,7 +3766,7 @@ def render_formats_table(self, info_dict):
 
         def simplified_codec(f, field):
             assert field in ('acodec', 'vcodec')
-            codec = f.get(field, 'unknown')
+            codec = f.get(field)
             if not codec:
                 return 'unknown'
             elif codec != 'none':
@@ -3721,14 +3801,13 @@ def simplified_codec(f, field):
                 simplified_codec(f, 'acodec'),
                 format_field(f, 'abr', '\t%dk', func=round),
                 format_field(f, 'asr', '\t%s', func=format_decimal_suffix),
-                join_nonempty(
-                    self._format_out('UNSUPPORTED', 'light red') if f.get('ext') in ('f4f', 'f4m') else None,
-                    self._format_out('DRM', 'light red') if f.get('has_drm') else None,
-                    format_field(f, 'language', '[%s]'),
-                    join_nonempty(format_field(f, 'format_note'),
-                                  format_field(f, 'container', ignore=(None, f.get('ext'))),
-                                  delim=', '),
-                    delim=' '),
+                join_nonempty(format_field(f, 'language', '[%s]'), join_nonempty(
+                    self._format_out('UNSUPPORTED', self.Styles.BAD_FORMAT) if f.get('ext') in ('f4f', 'f4m') else None,
+                    (self._format_out('Maybe DRM', self.Styles.WARNING) if f.get('has_drm') == 'maybe'
+                     else self._format_out('DRM', self.Styles.BAD_FORMAT) if f.get('has_drm') else None),
+                    format_field(f, 'format_note'),
+                    format_field(f, 'container', ignore=(None, f.get('ext'))),
+                    delim=', '), delim=' '),
             ] for f in formats if f.get('preference') is None or f['preference'] >= -1000]
         header_line = self._list_format_headers(
             'ID', 'EXT', 'RESOLUTION', '\tFPS', 'HDR', 'CH', delim, '\tFILESIZE', '\tTBR', 'PROTO',