]> jfr.im git - yt-dlp.git/commitdiff
[utils] Improve `traverse_obj`
authorpukkandan <redacted>
Wed, 21 Jul 2021 05:47:27 +0000 (11:17 +0530)
committerpukkandan <redacted>
Wed, 21 Jul 2021 06:00:06 +0000 (11:30 +0530)
yt_dlp/extractor/youtube.py
yt_dlp/utils.py

index aa0421a72e011bbda95c7b70f9c80444d57e1fdd..afe31a12dc8cd16a2bc3ea30e5516db2450048ed 100644 (file)
@@ -1929,10 +1929,11 @@ def _extract_signature_timestamp(self, video_id, player_url, ytcfg=None, fatal=F
         return sts
 
     def _mark_watched(self, video_id, player_responses):
-        playback_url = url_or_none((traverse_obj(
-            player_responses, ('playbackTracking', 'videostatsPlaybackUrl', 'baseUrl'),
-            expected_type=str) or [None])[0])
+        playback_url = traverse_obj(
+            player_responses, (..., 'playbackTracking', 'videostatsPlaybackUrl', 'baseUrl'),
+            expected_type=url_or_none, get_all=False)
         if not playback_url:
+            self.report_warning('Unable to mark watched')
             return
         parsed_playback_url = compat_urlparse.urlparse(playback_url)
         qs = compat_urlparse.parse_qs(parsed_playback_url.query)
@@ -2606,8 +2607,7 @@ def _real_extract(self, url):
             self._get_requested_clients(url, smuggled_data),
             video_id, webpage, master_ytcfg, player_url, identity_token))
 
-        get_first = lambda obj, keys, **kwargs: (
-            traverse_obj(obj, (..., *variadic(keys)), **kwargs) or [None])[0]
+        get_first = lambda obj, keys, **kwargs: traverse_obj(obj, (..., *variadic(keys)), **kwargs, get_all=False)
 
         playability_statuses = traverse_obj(
             player_responses, (..., 'playabilityStatus'), expected_type=dict, default=[])
index 4d3cbc7b4b182b71f791b2557f7f7e9739f07054..4d12c0a8e2db5de7e892bc4b255721ea9683b6bc 100644 (file)
@@ -6225,7 +6225,7 @@ def load_plugins(name, suffix, namespace):
 
 
 def traverse_obj(
-        obj, *path_list, default=None, expected_type=None,
+        obj, *path_list, default=None, expected_type=None, get_all=True,
         casesense=True, is_user_input=False, traverse_string=False):
     ''' Traverse nested list/dict/tuple
     @param path_list        A list of paths which are checked one by one.
@@ -6234,7 +6234,8 @@ def traverse_obj(
                             all the keys given in the tuple are traversed, and
                             "..." traverses all the keys in the object
     @param default          Default value to return
-    @param expected_type    Only accept final value of this type
+    @param expected_type    Only accept final value of this type (Can also be any callable)
+    @param get_all          Return all the values obtained from a path or only the first one
     @param casesense        Whether to consider dictionary keys as case sensitive
     @param is_user_input    Whether the keys are generated from user input. If True,
                             strings are converted to int/slice if necessary
@@ -6281,6 +6282,13 @@ def _traverse_obj(obj, path, _current_depth=0):
                     return None
         return obj
 
+    if isinstance(expected_type, type):
+        type_test = lambda val: val if isinstance(val, expected_type) else None
+    elif expected_type is not None:
+        type_test = expected_type
+    else:
+        type_test = lambda val: val
+
     for path in path_list:
         depth = 0
         val = _traverse_obj(obj, path)
@@ -6288,12 +6296,13 @@ def _traverse_obj(obj, path, _current_depth=0):
             if depth:
                 for _ in range(depth - 1):
                     val = itertools.chain.from_iterable(v for v in val if v is not None)
-                val = ([v for v in val if v is not None] if expected_type is None
-                       else [v for v in val if isinstance(v, expected_type)])
+                val = [v for v in map(type_test, val) if v is not None]
                 if val:
+                    return val if get_all else val[0]
+            else:
+                val = type_test(val)
+                if val is not None:
                     return val
-            elif expected_type is None or isinstance(val, expected_type):
-                return val
     return default