]> jfr.im git - yt-dlp.git/blobdiff - test/test_traversal.py
Release 2024.04.09
[yt-dlp.git] / test / test_traversal.py
index 3b247d0597b5fb36c3a4cd637eadaf455c2ad13b..9b2a27b0807feefe83fc3065738c076b86ab0bc4 100644 (file)
@@ -1,3 +1,4 @@
+import http.cookies
 import re
 import xml.etree.ElementTree
 
 
 
 class TestTraversal:
-    def test_dict_get(self):
-        FALSE_VALUES = {
-            'none': None,
-            'false': False,
-            'zero': 0,
-            'empty_string': '',
-            'empty_list': [],
-        }
-        d = {**FALSE_VALUES, 'a': 42}
-        assert dict_get(d, 'a') == 42
-        assert dict_get(d, 'b') is None
-        assert dict_get(d, 'b', 42) == 42
-        assert dict_get(d, ('a',)) == 42
-        assert dict_get(d, ('b', 'a')) == 42
-        assert dict_get(d, ('b', 'c', 'a', 'd')) == 42
-        assert dict_get(d, ('b', 'c')) is None
-        assert dict_get(d, ('b', 'c'), 42) == 42
-        for key, false_value in FALSE_VALUES.items():
-            assert dict_get(d, ('b', 'c', key)) is None
-            assert dict_get(d, ('b', 'c', key), skip_false_values=False) == false_value
-
     def test_traversal_base(self):
         assert traverse_obj(_TEST_DATA, ('str',)) == 'str', \
             'allow tuple path'
@@ -94,6 +74,8 @@ def test_traversal_set(self):
             'Function in set should be a transformation'
         assert traverse_obj(_TEST_DATA, (..., {str})) == ['str'], \
             'Type in set should be a type filter'
+        assert traverse_obj(_TEST_DATA, (..., {str, int})) == [100, 'str'], \
+            'Multiple types in set should be a type filter'
         assert traverse_obj(_TEST_DATA, {dict}) == _TEST_DATA, \
             'A single set should be wrapped into a path'
         assert traverse_obj(_TEST_DATA, (..., {str.upper})) == ['STR'], \
@@ -103,7 +85,7 @@ def test_traversal_set(self):
             'Function in set should be a transformation'
         assert traverse_obj(_TEST_DATA, ('fail', {lambda _: 'const'})) == 'const', \
             'Function in set should always be called'
-        # Sets with length != 1 should raise in debug
+        # Sets with length < 1 or > 1 not including only types should raise
         with pytest.raises(Exception):
             traverse_obj(_TEST_DATA, set())
         with pytest.raises(Exception):
@@ -377,3 +359,86 @@ def test_traversal_xml_etree(self):
             'special transformations should act on current element'
         assert traverse_obj(etree, ('country', 0, ..., 'text()', {int_or_none})) == [1, 2008, 141100], \
             'special transformations should act on current element'
+
+    def test_traversal_unbranching(self):
+        assert traverse_obj(_TEST_DATA, [(100, 1.2), all]) == [100, 1.2], \
+            '`all` should give all results as list'
+        assert traverse_obj(_TEST_DATA, [(100, 1.2), any]) == 100, \
+            '`any` should give the first result'
+        assert traverse_obj(_TEST_DATA, [100, all]) == [100], \
+            '`all` should give list if non branching'
+        assert traverse_obj(_TEST_DATA, [100, any]) == 100, \
+            '`any` should give single item if non branching'
+        assert traverse_obj(_TEST_DATA, [('dict', 'None', 100), all]) == [100], \
+            '`all` should filter `None` and empty dict'
+        assert traverse_obj(_TEST_DATA, [('dict', 'None', 100), any]) == 100, \
+            '`any` should filter `None` and empty dict'
+        assert traverse_obj(_TEST_DATA, [{
+            'all': [('dict', 'None', 100, 1.2), all],
+            'any': [('dict', 'None', 100, 1.2), any],
+        }]) == {'all': [100, 1.2], 'any': 100}, \
+            '`all`/`any` should apply to each dict path separately'
+        assert traverse_obj(_TEST_DATA, [{
+            'all': [('dict', 'None', 100, 1.2), all],
+            'any': [('dict', 'None', 100, 1.2), any],
+        }], get_all=False) == {'all': [100, 1.2], 'any': 100}, \
+            '`all`/`any` should apply to dict regardless of `get_all`'
+        assert traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, {float}]) is None, \
+            '`all` should reset branching status'
+        assert traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), any, {float}]) is None, \
+            '`any` should reset branching status'
+        assert traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, ..., {float}]) == [1.2], \
+            '`all` should allow further branching'
+        assert traverse_obj(_TEST_DATA, [('dict', 'None', 'urls', 'data'), any, ..., 'index']) == [0, 1], \
+            '`any` should allow further branching'
+
+    def test_traversal_morsel(self):
+        values = {
+            'expires': 'a',
+            'path': 'b',
+            'comment': 'c',
+            'domain': 'd',
+            'max-age': 'e',
+            'secure': 'f',
+            'httponly': 'g',
+            'version': 'h',
+            'samesite': 'i',
+        }
+        morsel = http.cookies.Morsel()
+        morsel.set('item_key', 'item_value', 'coded_value')
+        morsel.update(values)
+        values['key'] = 'item_key'
+        values['value'] = 'item_value'
+
+        for key, value in values.items():
+            assert traverse_obj(morsel, key) == value, \
+                'Morsel should provide access to all values'
+        assert traverse_obj(morsel, ...) == list(values.values()), \
+            '`...` should yield all values'
+        assert traverse_obj(morsel, lambda k, v: True) == list(values.values()), \
+            'function key should yield all values'
+        assert traverse_obj(morsel, [(None,), any]) == morsel, \
+            'Morsel should not be implicitly changed to dict on usage'
+
+
+class TestDictGet:
+    def test_dict_get(self):
+        FALSE_VALUES = {
+            'none': None,
+            'false': False,
+            'zero': 0,
+            'empty_string': '',
+            'empty_list': [],
+        }
+        d = {**FALSE_VALUES, 'a': 42}
+        assert dict_get(d, 'a') == 42
+        assert dict_get(d, 'b') is None
+        assert dict_get(d, 'b', 42) == 42
+        assert dict_get(d, ('a',)) == 42
+        assert dict_get(d, ('b', 'a')) == 42
+        assert dict_get(d, ('b', 'c', 'a', 'd')) == 42
+        assert dict_get(d, ('b', 'c')) is None
+        assert dict_get(d, ('b', 'c'), 42) == 42
+        for key, false_value in FALSE_VALUES.items():
+            assert dict_get(d, ('b', 'c', key)) is None
+            assert dict_get(d, ('b', 'c', key), skip_false_values=False) == false_value