]> jfr.im git - yt-dlp.git/commitdiff
Improve `traverse_obj`
authorpukkandan <redacted>
Sat, 10 Jul 2021 22:14:39 +0000 (03:44 +0530)
committerpukkandan <redacted>
Sat, 10 Jul 2021 23:16:53 +0000 (04:46 +0530)
yt_dlp/utils.py

index 888cfbb7e8e260e6ba8644c77963e91ffa18197d..8f9cb46f64f3cbb766688d497fa3ce64d6b0cda8 100644 (file)
@@ -6224,37 +6224,49 @@ def load_plugins(name, suffix, namespace):
     return classes
 
 
-def traverse_obj(obj, keys, *, casesense=True, is_user_input=False, traverse_string=False):
+def traverse_obj(
+        obj, *key_list, default=None, expected_type=None,
+        casesense=True, is_user_input=False, traverse_string=False):
     ''' Traverse nested list/dict/tuple
+    @param default          Default value to return
+    @param expected_type    Only accept final value of this type
     @param casesense        Whether to consider dictionary keys as case sensitive
     @param is_user_input    Whether the keys are generated from user input. If True,
                             strings are converted to int/slice if necessary
     @param traverse_string  Whether to traverse inside strings. If True, any
                             non-compatible object will also be converted into a string
     '''
-    keys = list(keys)[::-1]
-    while keys:
-        key = keys.pop()
-        if isinstance(obj, dict):
-            assert isinstance(key, compat_str)
-            if not casesense:
-                obj = {k.lower(): v for k, v in obj.items()}
-                key = key.lower()
-            obj = obj.get(key)
-        else:
-            if is_user_input:
-                key = (int_or_none(key) if ':' not in key
-                       else slice(*map(int_or_none, key.split(':'))))
-                if key is None:
+    if not casesense:
+        _lower = lambda k: k.lower() if isinstance(k, str) else k
+        key_list = ((_lower(k) for k in keys) for keys in key_list)
+
+    def _traverse_obj(obj, keys):
+        for key in list(keys):
+            if isinstance(obj, dict):
+                obj = (obj.get(key) if casesense or (key in obj)
+                       else next((v for k, v in obj.items() if _lower(k) == key), None))
+            else:
+                if is_user_input:
+                    key = (int_or_none(key) if ':' not in key
+                           else slice(*map(int_or_none, key.split(':'))))
+                if not isinstance(key, (int, slice)):
                     return None
-            if not isinstance(obj, (list, tuple)):
-                if traverse_string:
-                    obj = compat_str(obj)
-                else:
+                if not isinstance(obj, (list, tuple)):
+                    if not traverse_string:
+                        return None
+                    obj = str(obj)
+                try:
+                    obj = obj[key]
+                except IndexError:
                     return None
-            assert isinstance(key, (int, slice))
-            obj = try_get(obj, lambda x: x[key])
-    return obj
+        return obj
+
+    for keys in key_list:
+        val = _traverse_obj(obj, keys)
+        if val is not None:
+            if expected_type is None or isinstance(val, expected_type):
+                return val
+    return default
 
 
 def traverse_dict(dictn, keys, casesense=True):