]> jfr.im git - yt-dlp.git/commitdiff
[utils] `traverse_obj`: Various improvements
authorSimon Sawicki <redacted>
Thu, 2 Feb 2023 05:40:19 +0000 (06:40 +0100)
committerGitHub <redacted>
Thu, 2 Feb 2023 05:40:19 +0000 (06:40 +0100)
- Add `set` key for transformations/filters
- Add `re.Match` group names
- Fix behavior for `expected_type` with `dict` key
- Raise for filter function signature mismatch in debug

Authored by: Grub4K

test/test_utils.py
yt_dlp/utils.py

index 3d5a6ea6baaed3278d64a63a136bf1166297a7a0..ffe1b729fe024f19d9100a157dabd44edb478b73 100644 (file)
     sanitized_Request,
     shell_quote,
     smuggle_url,
+    str_or_none,
     str_to_int,
     strip_jsonp,
     strip_or_none,
@@ -2015,6 +2016,29 @@ def test_traverse_obj(self):
                          msg='function as query key should perform a filter based on (key, value)')
         self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
                               msg='exceptions in the query function should be catched')
+        if __debug__:
+            with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
+                traverse_obj(_TEST_DATA, lambda a: ...)
+            with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
+                traverse_obj(_TEST_DATA, lambda a, b, c: ...)
+
+        # Test set as key (transformation/type, like `expected_type`)
+        self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper}, )), ['STR'],
+                         msg='Function in set should be a transformation')
+        self.assertEqual(traverse_obj(_TEST_DATA, (..., {str})), ['str'],
+                         msg='Type in set should be a type filter')
+        self.assertEqual(traverse_obj(_TEST_DATA, {dict}), _TEST_DATA,
+                         msg='A single set should be wrapped into a path')
+        self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper})), ['STR'],
+                         msg='Transformation function should not raise')
+        self.assertEqual(traverse_obj(_TEST_DATA, (..., {str_or_none})),
+                         [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None],
+                         msg='Function in set should be a transformation')
+        if __debug__:
+            with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
+                traverse_obj(_TEST_DATA, set())
+            with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
+                traverse_obj(_TEST_DATA, {str.upper, str})
 
         # Test alternative paths
         self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
@@ -2106,6 +2130,20 @@ def test_traverse_obj(self):
                          msg='wrap expected_type fuction in try_call')
         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'],
                          msg='eliminate items that expected_type fails on')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int), {0: 100},
+                         msg='type as expected_type should filter dict values')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none), {0: '100', 1: '1.2'},
+                         msg='function as expected_type should transform dict values')
+        self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int), 1,
+                         msg='expected_type should not filter non final dict values')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int), {0: {0: 100}},
+                         msg='expected_type should transform deep dict values')
+        self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(...)), [{0: ...}, {0: ...}],
+                         msg='expected_type should transform branched dict values')
+        self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int), [4],
+                         msg='expected_type regression for type matching in tuple branching')
+        self.assertEqual(traverse_obj(_TEST_DATA, ['data', ...], expected_type=int), [],
+                         msg='expected_type regression for type matching in dict result')
 
         # Test get_all behavior
         _GET_ALL_DATA = {'key': [0, 1, 2]}
@@ -2189,6 +2227,8 @@ def test_traverse_obj(self):
                          msg='failing str key on a `re.Match` should return `default`')
         self.assertEqual(traverse_obj(mobj, 8), None,
                          msg='failing int key on a `re.Match` should return `default`')
+        self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
+                         msg='function on a `re.Match` should give group name as well')
 
 
 if __name__ == '__main__':
index 7d51fe472e10f1bcaf86e00be066d08c174aa42c..55e1c4415097452fce79c0411cd593f30bbe8258 100644 (file)
@@ -5424,6 +5424,9 @@ def traverse_obj(
 
     The keys in the path can be one of:
         - `None`:           Return the current object.
+        - `set`:            Requires the only item in the set to be a type or function,
+                            like `{type}`/`{func}`. If a `type`, returns only values
+                            of this type. If a function, returns `func(obj)`.
         - `str`/`int`:      Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
         - `slice`:          Branch out and return all values in `obj[key]`.
         - `Ellipsis`:       Branch out and return a list of all values.
@@ -5432,6 +5435,8 @@ def traverse_obj(
         - `function`:       Branch out and return values filtered by the function.
                             Read as: `[value for key, value in obj if function(key, value)]`.
                             For `Sequence`s, `key` is the index of the value.
+                            For `re.Match`es, `key` is the group number (0 = full match)
+                            as well as additionally any group names, if given.
         - `dict`            Transform the current object and return a matching dict.
                             Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
 
@@ -5441,6 +5446,8 @@ def traverse_obj(
     @param default          Value to return if the paths do not match.
     @param expected_type    If a `type`, only accept final values of this type.
                             If any other callable, try to call the function on each result.
+                            If the last key in the path is a `dict`, it will apply to each value inside
+                            the dict instead, recursively. This does respect branching paths.
     @param get_all          If `False`, return the first matching result, otherwise all matching ones.
     @param casesense        If `False`, consider string dictionary keys as case insensitive.
 
@@ -5466,16 +5473,25 @@ def traverse_obj(
     else:
         type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
 
-    def apply_key(key, obj):
+    def apply_key(key, test_type, obj):
         if obj is None:
             return
 
         elif key is None:
             yield obj
 
+        elif isinstance(key, set):
+            assert len(key) == 1, 'Set should only be used to wrap a single item'
+            item = next(iter(key))
+            if isinstance(item, type):
+                if isinstance(obj, item):
+                    yield obj
+            else:
+                yield try_call(item, args=(obj,))
+
         elif isinstance(key, (list, tuple)):
             for branch in key:
-                _, result = apply_path(obj, branch)
+                _, result = apply_path(obj, branch, test_type)
                 yield from result
 
         elif key is ...:
@@ -5494,7 +5510,9 @@ def apply_key(key, obj):
             elif isinstance(obj, collections.abc.Mapping):
                 iter_obj = obj.items()
             elif isinstance(obj, re.Match):
-                iter_obj = enumerate((obj.group(), *obj.groups()))
+                iter_obj = itertools.chain(
+                    enumerate((obj.group(), *obj.groups())),
+                    obj.groupdict().items())
             elif traverse_string:
                 iter_obj = enumerate(str(obj))
             else:
@@ -5502,7 +5520,7 @@ def apply_key(key, obj):
             yield from (v for k, v in iter_obj if try_call(key, args=(k, v)))
 
         elif isinstance(key, dict):
-            iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items())
+            iter_obj = ((k, _traverse_obj(obj, v, test_type=test_type)) for k, v in key.items())
             yield {k: v if v is not None else default for k, v in iter_obj
                    if v is not None or default is not NO_DEFAULT}
 
@@ -5537,11 +5555,24 @@ def apply_key(key, obj):
             with contextlib.suppress(IndexError):
                 yield obj[key]
 
-    def apply_path(start_obj, path):
+    def lazy_last(iterable):
+        iterator = iter(iterable)
+        prev = next(iterator, NO_DEFAULT)
+        if prev is NO_DEFAULT:
+            return
+
+        for item in iterator:
+            yield False, prev
+            prev = item
+
+        yield True, prev
+
+    def apply_path(start_obj, path, test_type=False):
         objs = (start_obj,)
         has_branched = False
 
-        for key in variadic(path):
+        key = None
+        for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
             if is_user_input and key == ':':
                 key = ...
 
@@ -5551,14 +5582,21 @@ def apply_path(start_obj, path):
             if key is ... or isinstance(key, (list, tuple)) or callable(key):
                 has_branched = True
 
-            key_func = functools.partial(apply_key, key)
+            if __debug__ and callable(key):
+                # Verify function signature
+                inspect.signature(key).bind(None, None)
+
+            key_func = functools.partial(apply_key, key, last)
             objs = itertools.chain.from_iterable(map(key_func, objs))
 
+        if test_type and not isinstance(key, (dict, list, tuple)):
+            objs = map(type_test, objs)
+
         return has_branched, objs
 
-    def _traverse_obj(obj, path, use_list=True):
-        has_branched, results = apply_path(obj, path)
-        results = LazyList(x for x in map(type_test, results) if x is not None)
+    def _traverse_obj(obj, path, use_list=True, test_type=True):
+        has_branched, results = apply_path(obj, path, test_type)
+        results = LazyList(x for x in results if x is not None)
 
         if get_all and has_branched:
             return results.exhaust() if results or use_list else None