]> jfr.im git - yt-dlp.git/commitdiff
[utils] `traverse_obj`: Fix several behavioral problems
authorSimon Sawicki <redacted>
Wed, 8 Feb 2023 03:11:08 +0000 (04:11 +0100)
committerGitHub <redacted>
Wed, 8 Feb 2023 03:11:08 +0000 (04:11 +0100)
See #6180 for further info

Authored by: Grub4K

test/test_utils.py
yt_dlp/utils.py

index ffe1b729fe024f19d9100a157dabd44edb478b73..190e4ef9b02ffb8eba78b3510c52e211abc5cb9a 100644 (file)
@@ -2000,8 +2000,8 @@ def test_traverse_obj(self):
 
         # Test Ellipsis behavior
         self.assertCountEqual(traverse_obj(_TEST_DATA, ...),
-                              (item for item in _TEST_DATA.values() if item is not None),
-                              msg='`...` should give all values except `None`')
+                              (item for item in _TEST_DATA.values() if item not in (None, [], {})),
+                              msg='`...` should give all non discarded values')
         self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(),
                               msg='`...` selection for dicts should select all values')
         self.assertEqual(traverse_obj(_TEST_DATA, (..., ..., 'url')),
@@ -2084,15 +2084,23 @@ def test_traverse_obj(self):
                          {0: ['https://www.example.com/1', 'https://www.example.com/0']},
                          msg='tripple nesting in dict path should be treated as branches')
         self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
-                         msg='remove `None` values when dict key')
+                         msg='remove `None` values when top level dict key fails')
         self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=...), {0: ...},
-                         msg='do not remove `None` values if `default`')
-        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}},
-                         msg='do not remove empty values when dict key')
-        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: {}},
-                         msg='do not remove empty values when dict key and a default')
-        self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {0: []},
-                         msg='if branch in dict key not successful, return `[]`')
+                         msg='use `default` if key fails and `default`')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},
+                         msg='remove empty values when dict key')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: ...},
+                         msg='use `default` when dict key and `default`')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},
+                         msg='remove empty values when nested dict key fails')
+        self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
+                         msg='default to dict if pruned')
+        self.assertEqual(traverse_obj(None, {0: 'fail'}, default=...), {},
+                         msg='default to dict if pruned and default is given')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=...), {0: {0: ...}},
+                         msg='use nested `default` when nested dict key fails and `default`')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {},
+                         msg='remove key if branch in dict key not successful')
 
         # Testing default parameter behavior
         _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
@@ -2183,14 +2191,17 @@ def test_traverse_obj(self):
                                       traverse_string=True), '.',
                          msg='traverse into converted data if `traverse_string`')
         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...),
-                                      traverse_string=True), list('str'),
-                         msg='`...` branching into string should result in list')
+                                      traverse_string=True), 'str',
+                         msg='`...` should result in string (same value) if `traverse_string`')
+        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),
+                                      traverse_string=True), 'sr',
+                         msg='`slice` should result in string if `traverse_string`')
+        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"),
+                                      traverse_string=True), 'str',
+                         msg='function should result in string if `traverse_string`')
         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
                                       traverse_string=True), ['s', 'r'],
-                         msg='branching into string should result in list')
-        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x),
-                                      traverse_string=True), list('str'),
-                         msg='function branching into string should result in list')
+                         msg='branching should result in list if `traverse_string`')
 
         # Test is_user_input behavior
         _IS_USER_INPUT_DATA = {'range8': list(range(8))}
index e1e0f7b25add8a417085efefede169a27d96afa7..878b2b6a8252d557bf4e2afd64fb50cfc2b3dea1 100644 (file)
@@ -5420,7 +5420,7 @@ def traverse_obj(
     Each of the provided `paths` is tested and the first producing a valid result will be returned.
     The next path will also be tested if the path branched but no results could be found.
     Supported values for traversal are `Mapping`, `Sequence` and `re.Match`.
-    A value of None is treated as the absence of a value.
+    Unhelpful values (`[]`, `{}`, `None`) are treated as the absence of a value and discarded.
 
     The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
 
@@ -5446,6 +5446,8 @@ def traverse_obj(
 
     @params paths           Paths which to traverse by.
     @param default          Value to return if the paths do not match.
+                            If the last key in the path is a `dict`, it will apply to each value inside
+                            the dict instead, depth first. Try to avoid if using nested `dict` keys.
     @param expected_type    If a `type`, only accept final values of this type.
                             If any other callable, try to call the function on each result.
                             If the last key in the path is a `dict`, it will apply to each value inside
@@ -5460,12 +5462,15 @@ def traverse_obj(
     @param traverse_string  Whether to traverse into objects as strings.
                             If `True`, any non-compatible object will first be
                             converted into a string and then traversed into.
+                            The return value of that path will be a string instead,
+                            not respecting any further branching.
 
 
     @returns                The result of the object traversal.
                             If successful, `get_all=True`, and the path branches at least once,
                             then a list of results is returned instead.
-                            A list is always returned if the last path branches and no `default` is given.
+                            If no `default` is given and the last path branches, a `list` of results
+                            is always returned. If a path ends on a `dict` that result will always be a `dict`.
     """
     is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes))
     casefold = lambda k: k.casefold() if isinstance(k, str) else k
@@ -5475,87 +5480,94 @@ def traverse_obj(
     else:
         type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
 
-    def apply_key(key, test_type, obj):
+    def apply_key(key, obj, is_last):
+        branching = False
+        result = None
+
         if obj is None:
-            return
+            pass
 
         elif key is None:
-            yield obj
+            result = obj
 
         elif isinstance(key, set):
             assert len(key) == 1, 'Set should only be used to wrap a single item'
             item = next(iter(key))
             if isinstance(item, type):
                 if isinstance(obj, item):
-                    yield obj
+                    result = obj
             else:
-                yield try_call(item, args=(obj,))
+                result = try_call(item, args=(obj,))
 
         elif isinstance(key, (list, tuple)):
-            for branch in key:
-                _, result = apply_path(obj, branch, test_type)
-                yield from result
+            branching = True
+            result = itertools.chain.from_iterable(
+                apply_path(obj, branch, is_last)[0] for branch in key)
 
         elif key is ...:
+            branching = True
             if isinstance(obj, collections.abc.Mapping):
-                yield from obj.values()
+                result = obj.values()
             elif is_sequence(obj):
-                yield from obj
+                result = obj
             elif isinstance(obj, re.Match):
-                yield from obj.groups()
+                result = obj.groups()
             elif traverse_string:
-                yield from str(obj)
+                branching = False
+                result = str(obj)
+            else:
+                result = ()
 
         elif callable(key):
-            if is_sequence(obj):
-                iter_obj = enumerate(obj)
-            elif isinstance(obj, collections.abc.Mapping):
+            branching = True
+            if isinstance(obj, collections.abc.Mapping):
                 iter_obj = obj.items()
+            elif is_sequence(obj):
+                iter_obj = enumerate(obj)
             elif isinstance(obj, re.Match):
                 iter_obj = itertools.chain(
                     enumerate((obj.group(), *obj.groups())),
                     obj.groupdict().items())
             elif traverse_string:
+                branching = False
                 iter_obj = enumerate(str(obj))
             else:
-                return
-            yield from (v for k, v in iter_obj if try_call(key, args=(k, v)))
+                iter_obj = ()
+
+            result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
+            if not branching:  # string traversal
+                result = ''.join(result)
 
         elif isinstance(key, dict):
-            iter_obj = ((k, _traverse_obj(obj, v, test_type=test_type)) for k, v in key.items())
-            yield {k: v if v is not None else default for k, v in iter_obj
-                   if v is not None or default is not NO_DEFAULT}
+            iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
+            result = {
+                k: v if v is not None else default for k, v in iter_obj
+                if v is not None or default is not NO_DEFAULT
+            } or None
 
         elif isinstance(obj, collections.abc.Mapping):
-            yield (obj.get(key) if casesense or (key in obj)
-                   else next((v for k, v in obj.items() if casefold(k) == key), None))
+            result = (obj.get(key) if casesense or (key in obj) else
+                      next((v for k, v in obj.items() if casefold(k) == key), None))
 
         elif isinstance(obj, re.Match):
             if isinstance(key, int) or casesense:
                 with contextlib.suppress(IndexError):
-                    yield obj.group(key)
-                    return
+                    result = obj.group(key)
 
-            if not isinstance(key, str):
-                return
-
-            yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
-
-        else:
-            if is_user_input:
-                key = (int_or_none(key) if ':' not in key
-                       else slice(*map(int_or_none, key.split(':'))))
-
-            if not isinstance(key, (int, slice)):
-                return
+            elif isinstance(key, str):
+                result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
 
+        elif isinstance(key, (int, slice)):
             if not is_sequence(obj):
-                if not traverse_string:
-                    return
-                obj = str(obj)
+                if traverse_string:
+                    with contextlib.suppress(IndexError):
+                        result = str(obj)[key]
+            else:
+                branching = isinstance(key, slice)
+                with contextlib.suppress(IndexError):
+                    result = obj[key]
 
-            with contextlib.suppress(IndexError):
-                yield obj[key]
+        return branching, result if branching else (result,)
 
     def lazy_last(iterable):
         iterator = iter(iterable)
@@ -5569,45 +5581,54 @@ def lazy_last(iterable):
 
         yield True, prev
 
-    def apply_path(start_obj, path, test_type=False):
+    def apply_path(start_obj, path, test_type):
         objs = (start_obj,)
         has_branched = False
 
         key = None
         for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
-            if is_user_input and key == ':':
-                key = ...
+            if is_user_input and isinstance(key, str):
+                if key == ':':
+                    key = ...
+                elif ':' in key:
+                    key = slice(*map(int_or_none, key.split(':')))
+                elif int_or_none(key) is not None:
+                    key = int(key)
 
             if not casesense and isinstance(key, str):
                 key = key.casefold()
 
-            if key is ... or isinstance(key, (list, tuple)) or callable(key):
-                has_branched = True
-
             if __debug__ and callable(key):
                 # Verify function signature
                 inspect.signature(key).bind(None, None)
 
-            key_func = functools.partial(apply_key, key, last)
-            objs = itertools.chain.from_iterable(map(key_func, objs))
+            new_objs = []
+            for obj in objs:
+                branching, results = apply_key(key, obj, last)
+                has_branched |= branching
+                new_objs.append(results)
+
+            objs = itertools.chain.from_iterable(new_objs)
 
         if test_type and not isinstance(key, (dict, list, tuple)):
             objs = map(type_test, objs)
 
-        return has_branched, objs
-
-    def _traverse_obj(obj, path, use_list=True, test_type=True):
-        has_branched, results = apply_path(obj, path, test_type)
-        results = LazyList(x for x in results if x is not None)
+        return objs, has_branched, isinstance(key, dict)
 
+    def _traverse_obj(obj, path, allow_empty, test_type):
+        results, has_branched, is_dict = apply_path(obj, path, test_type)
+        results = LazyList(item for item in results if item not in (None, [], {}))
         if get_all and has_branched:
-            return results.exhaust() if results or use_list else None
+            if results:
+                return results.exhaust()
+            if allow_empty:
+                return [] if default is NO_DEFAULT else default
+            return None
 
-        return results[0] if results else None
+        return results[0] if results else {} if allow_empty and is_dict else None
 
     for index, path in enumerate(paths, 1):
-        use_list = default is NO_DEFAULT and index == len(paths)
-        result = _traverse_obj(obj, path, use_list)
+        result = _traverse_obj(obj, path, index == len(paths), True)
         if result is not None:
             return result