]> jfr.im git - yt-dlp.git/commitdiff
[utils] `traverse_obj`: Move `is_user_input` into output template (#8673)
authorSimon Sawicki <redacted>
Wed, 6 Dec 2023 20:46:45 +0000 (21:46 +0100)
committerGitHub <redacted>
Wed, 6 Dec 2023 20:46:45 +0000 (21:46 +0100)
Authored by: Grub4K

test/test_utils.py
yt_dlp/YoutubeDL.py
yt_dlp/utils/traversal.py

index 77040f29c60404faf22d459c8c29c38586a16b89..100f117889935a4366e4c5e64ca44a6252651dae 100644 (file)
@@ -2317,23 +2317,6 @@ def test_traverse_obj(self):
         self.assertEqual(traverse_obj({}, (0, slice(1)), traverse_string=True), [],
                          msg='branching should result in list if `traverse_string`')
 
-        # Test is_user_input behavior
-        _IS_USER_INPUT_DATA = {'range8': list(range(8))}
-        self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'),
-                                      is_user_input=True), 3,
-                         msg='allow for string indexing if `is_user_input`')
-        self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'),
-                                           is_user_input=True), tuple(range(8))[3:],
-                              msg='allow for string slice if `is_user_input`')
-        self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'),
-                                           is_user_input=True), tuple(range(8))[:4:2],
-                              msg='allow step in string slice if `is_user_input`')
-        self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'),
-                                           is_user_input=True), range(8),
-                              msg='`:` should be treated as `...` if `is_user_input`')
-        with self.assertRaises(TypeError, msg='too many params should result in error'):
-            traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True)
-
         # Test re.Match as input obj
         mobj = re.fullmatch(r'0(12)(?P<group>3)(4)?', '0123')
         self.assertEqual(traverse_obj(mobj, ...), [x for x in mobj.groups() if x is not None],
index 29dd761862faa362af0e9b0b7c3816d2894f80fa..0c07866e4992312b6e0e53b1bb579c0696302b86 100644 (file)
@@ -1201,6 +1201,15 @@ def prepare_outtmpl(self, outtmpl, info_dict, sanitize=False):
                 (?:\|(?P<default>.*?))?
             )$''')
 
+        def _from_user_input(field):
+            if field == ':':
+                return ...
+            elif ':' in field:
+                return slice(*map(int_or_none, field.split(':')))
+            elif int_or_none(field) is not None:
+                return int(field)
+            return field
+
         def _traverse_infodict(fields):
             fields = [f for x in re.split(r'\.({.+?})\.?', fields)
                       for f in ([x] if x.startswith('{') else x.split('.'))]
@@ -1210,11 +1219,12 @@ def _traverse_infodict(fields):
 
             for i, f in enumerate(fields):
                 if not f.startswith('{'):
+                    fields[i] = _from_user_input(f)
                     continue
                 assert f.endswith('}'), f'No closing brace for {f} in {fields}'
-                fields[i] = {k: k.split('.') for k in f[1:-1].split(',')}
+                fields[i] = {k: list(map(_from_user_input, k.split('.'))) for k in f[1:-1].split(',')}
 
-            return traverse_obj(info_dict, fields, is_user_input=True, traverse_string=True)
+            return traverse_obj(info_dict, fields, traverse_string=True)
 
         def get_value(mdict):
             # Object traversal
index 462c3ba5dff120172eaf12ebc57b1cf15161808c..ff5703198aea65e844f7a5831e492015e89f4e28 100644 (file)
@@ -8,7 +8,7 @@
     IDENTITY,
     NO_DEFAULT,
     LazyList,
-    int_or_none,
+    deprecation_warning,
     is_iterable_like,
     try_call,
     variadic,
@@ -17,7 +17,7 @@
 
 def traverse_obj(
         obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
-        casesense=True, is_user_input=False, traverse_string=False):
+        casesense=True, is_user_input=NO_DEFAULT, traverse_string=False):
     """
     Safely traverse nested `dict`s and `Iterable`s
 
@@ -63,10 +63,8 @@ def traverse_obj(
     @param get_all          If `False`, return the first matching result, otherwise all matching ones.
     @param casesense        If `False`, consider string dictionary keys as case insensitive.
 
-    The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API
+    `traverse_string` is only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API
 
-    @param is_user_input    Whether the keys are generated from user input.
-                            If `True` strings get converted to `int`/`slice` if needed.
     @param traverse_string  Whether to traverse into objects as strings.
                             If `True`, any non-compatible object will first be
                             converted into a string and then traversed into.
@@ -80,6 +78,9 @@ def traverse_obj(
                             If no `default` is given and the last path branches, a `list` of results
                             is always returned. If a path ends on a `dict` that result will always be a `dict`.
     """
+    if is_user_input is not NO_DEFAULT:
+        deprecation_warning('The is_user_input parameter is deprecated and no longer works')
+
     casefold = lambda k: k.casefold() if isinstance(k, str) else k
 
     if isinstance(expected_type, type):
@@ -195,14 +196,6 @@ def apply_path(start_obj, path, test_type):
 
         key = None
         for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
-            if is_user_input and isinstance(key, str):
-                if key == ':':
-                    key = ...
-                elif ':' in key:
-                    key = slice(*map(int_or_none, key.split(':')))
-                elif int_or_none(key) is not None:
-                    key = int(key)
-
             if not casesense and isinstance(key, str):
                 key = key.casefold()