]> jfr.im git - yt-dlp.git/commitdiff
[utils] `traverse_obj`: Always return list when branching (#5170)
authorSimon Sawicki <redacted>
Sun, 9 Oct 2022 01:27:32 +0000 (03:27 +0200)
committerGitHub <redacted>
Sun, 9 Oct 2022 01:27:32 +0000 (06:57 +0530)
Fixes #5162
Authored by: Grub4K

test/test_utils.py
yt_dlp/utils.py

index 69313564a1366cb9bf1f9c7ea89a2bb118cba84a..6f3f6cb914be724a469f8154b21b2740da4e43ba 100644 (file)
@@ -1890,6 +1890,7 @@ def test_traverse_obj(self):
                 {'index': 2},
                 {'index': 3},
             ),
+            'dict': {},
         }
 
         # Test base functionality
@@ -1926,11 +1927,15 @@ def test_traverse_obj(self):
 
         # Test alternative paths
         self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
-                         msg='multiple `path_list` should be treated as alternative paths')
+                         msg='multiple `paths` should be treated as alternative paths')
         self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str',
                          msg='alternatives should exit early')
         self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None,
                          msg='alternatives should return `default` if exhausted')
+        self.assertEqual(traverse_obj(_TEST_DATA, (..., 'fail'), 100), 100,
+                         msg='alternatives should track their own branching return')
+        self.assertEqual(traverse_obj(_TEST_DATA, ('dict', ...), ('data', ...)), list(_TEST_DATA['data']),
+                         msg='alternatives on empty objects should search further')
 
         # Test branch and path nesting
         self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'],
@@ -1963,8 +1968,16 @@ def test_traverse_obj(self):
         self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}),
                          {0: ['https://www.example.com/1', 'https://www.example.com/0']},
                          msg='tripple nesting in dict path should be treated as branches')
-        self.assertEqual(traverse_obj({}, {0: 1}, default=...), {0: ...},
-                         msg='do not remove `None` values when dict key')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
+                         msg='remove `None` values when dict key')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=...), {0: ...},
+                         msg='do not remove `None` values if `default`')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}},
+                         msg='do not remove empty values when dict key')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: {}},
+                         msg='do not remove empty values when dict key and a default')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {0: []},
+                         msg='if branch in dict key not successful, return `[]`')
 
         # Testing default parameter behavior
         _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
@@ -1981,7 +1994,13 @@ def test_traverse_obj(self):
         self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None,
                          msg='`IndexError` should result in `default`')
         self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=1), 1,
-                         msg='if branched but not successfull return `default`, not `[]`')
+                         msg='if branched but not successful return `default` if defined, not `[]`')
+        self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=None), None,
+                         msg='if branched but not successful return `default` even if `default` is `None`')
+        self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail')), [],
+                         msg='if branched but not successful return `[]`, not `default`')
+        self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', ...)), [],
+                         msg='if branched but object is empty return `[]`, not `default`')
 
         # Testing expected_type behavior
         _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
index d0be7f19ef7189ef4e93e353c268374d912f52d8..7d8e9716265a4e236b8e037ce3789b0d849bf2ef 100644 (file)
@@ -5294,7 +5294,7 @@ def load_plugins(name, suffix, namespace):
 
 
 def traverse_obj(
-        obj, *paths, default=None, expected_type=None, get_all=True,
+        obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
         casesense=True, is_user_input=False, traverse_string=False):
     """
     Safely traverse nested `dict`s and `Sequence`s
@@ -5304,6 +5304,7 @@ def traverse_obj(
     "value"
 
     Each of the provided `paths` is tested and the first producing a valid result will be returned.
+    The next path will also be tested if the path branched but no results could be found.
     A value of None is treated as the absence of a value.
 
     The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
@@ -5342,6 +5343,7 @@ def traverse_obj(
     @returns                The result of the object traversal.
                             If successful, `get_all=True`, and the path branches at least once,
                             then a list of results is returned instead.
+                            A list is always returned if the last path branches and no `default` is given.
     """
     is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes))
     casefold = lambda k: k.casefold() if isinstance(k, str) else k
@@ -5385,7 +5387,7 @@ def apply_key(key, obj):
         elif isinstance(key, dict):
             iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items())
             yield {k: v if v is not None else default for k, v in iter_obj
-                   if v is not None or default is not None}
+                   if v is not None or default is not NO_DEFAULT}
 
         elif isinstance(obj, dict):
             yield (obj.get(key) if casesense or (key in obj)
@@ -5426,18 +5428,22 @@ def apply_path(start_obj, path):
 
         return has_branched, objs
 
-    def _traverse_obj(obj, path):
+    def _traverse_obj(obj, path, use_list=True):
         has_branched, results = apply_path(obj, path)
         results = LazyList(x for x in map(type_test, results) if x is not None)
-        if results:
-            return results.exhaust() if get_all and has_branched else results[0]
 
-    for path in paths:
-        result = _traverse_obj(obj, path)
+        if get_all and has_branched:
+            return results.exhaust() if results or use_list else None
+
+        return results[0] if results else None
+
+    for index, path in enumerate(paths, 1):
+        use_list = default is NO_DEFAULT and index == len(paths)
+        result = _traverse_obj(obj, path, use_list)
         if result is not None:
             return result
 
-    return default
+    return None if default is NO_DEFAULT else default
 
 
 def traverse_dict(dictn, keys, casesense=True):