]> jfr.im git - yt-dlp.git/blobdiff - test/test_utils.py
[test] Skip source address tests if the address cannot be bound to (#8900)
[yt-dlp.git] / test / test_utils.py
index dc2d8ce12b5de9ae0e2ce60651dc9fb04a566496..09c648cf8939c84d0cf7e24079a6200f1d1955d5 100644 (file)
@@ -1209,6 +1209,9 @@ def test_js_to_json_edgecases(self):
         on = js_to_json('\'"\\""\'')
         self.assertEqual(json.loads(on), '"""', msg='Unnecessary quote escape should be escaped')
 
+        on = js_to_json('[new Date("spam"), \'("eggs")\']')
+        self.assertEqual(json.loads(on), ['spam', '("eggs")'], msg='Date regex should match a single string')
+
     def test_js_to_json_malformed(self):
         self.assertEqual(js_to_json('42a1'), '42"a1"')
         self.assertEqual(js_to_json('42a-1'), '42"a"-1')
@@ -1220,11 +1223,13 @@ def test_js_to_json_template_literal(self):
         self.assertEqual(js_to_json('`${name}"${name}"`', {'name': '5'}), '"5\\"5\\""')
         self.assertEqual(js_to_json('`${name}`', {}), '"name"')
 
-    def test_js_to_json_map_array_constructors(self):
+    def test_js_to_json_common_constructors(self):
         self.assertEqual(json.loads(js_to_json('new Map([["a", 5]])')), {'a': 5})
         self.assertEqual(json.loads(js_to_json('Array(5, 10)')), [5, 10])
         self.assertEqual(json.loads(js_to_json('new Array(15,5)')), [15, 5])
         self.assertEqual(json.loads(js_to_json('new Map([Array(5, 10),new Array(15,5)])')), {'5': 10, '15': 5})
+        self.assertEqual(json.loads(js_to_json('new Date("123")')), "123")
+        self.assertEqual(json.loads(js_to_json('new Date(\'2023-10-19\')')), "2023-10-19")
 
     def test_extract_attributes(self):
         self.assertEqual(extract_attributes('<e x="y">'), {'x': 'y'})
@@ -2105,6 +2110,8 @@ def test_traverse_obj(self):
         self.assertEqual(traverse_obj(_TEST_DATA, (..., {str_or_none})),
                          [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None],
                          msg='Function in set should be a transformation')
+        self.assertEqual(traverse_obj(_TEST_DATA, ('fail', {lambda _: 'const'})), 'const',
+                         msg='Function in set should always be called')
         if __debug__:
             with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
                 traverse_obj(_TEST_DATA, set())
@@ -2312,23 +2319,6 @@ def test_traverse_obj(self):
         self.assertEqual(traverse_obj({}, (0, slice(1)), traverse_string=True), [],
                          msg='branching should result in list if `traverse_string`')
 
-        # Test is_user_input behavior
-        _IS_USER_INPUT_DATA = {'range8': list(range(8))}
-        self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'),
-                                      is_user_input=True), 3,
-                         msg='allow for string indexing if `is_user_input`')
-        self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'),
-                                           is_user_input=True), tuple(range(8))[3:],
-                              msg='allow for string slice if `is_user_input`')
-        self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'),
-                                           is_user_input=True), tuple(range(8))[:4:2],
-                              msg='allow step in string slice if `is_user_input`')
-        self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'),
-                                           is_user_input=True), range(8),
-                              msg='`:` should be treated as `...` if `is_user_input`')
-        with self.assertRaises(TypeError, msg='too many params should result in error'):
-            traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True)
-
         # Test re.Match as input obj
         mobj = re.fullmatch(r'0(12)(?P<group>3)(4)?', '0123')
         self.assertEqual(traverse_obj(mobj, ...), [x for x in mobj.groups() if x is not None],
@@ -2350,6 +2340,58 @@ def test_traverse_obj(self):
         self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
                          msg='function on a `re.Match` should give group name as well')
 
+        # Test xml.etree.ElementTree.Element as input obj
+        etree = xml.etree.ElementTree.fromstring('''<?xml version="1.0"?>
+        <data>
+            <country name="Liechtenstein">
+                <rank>1</rank>
+                <year>2008</year>
+                <gdppc>141100</gdppc>
+                <neighbor name="Austria" direction="E"/>
+                <neighbor name="Switzerland" direction="W"/>
+            </country>
+            <country name="Singapore">
+                <rank>4</rank>
+                <year>2011</year>
+                <gdppc>59900</gdppc>
+                <neighbor name="Malaysia" direction="N"/>
+            </country>
+            <country name="Panama">
+                <rank>68</rank>
+                <year>2011</year>
+                <gdppc>13600</gdppc>
+                <neighbor name="Costa Rica" direction="W"/>
+                <neighbor name="Colombia" direction="E"/>
+            </country>
+        </data>''')
+        self.assertEqual(traverse_obj(etree, ''), etree,
+                         msg='empty str key should return the element itself')
+        self.assertEqual(traverse_obj(etree, 'country'), list(etree),
+                         msg='str key should lead all children with that tag name')
+        self.assertEqual(traverse_obj(etree, ...), list(etree),
+                         msg='`...` as key should return all children')
+        self.assertEqual(traverse_obj(etree, lambda _, x: x[0].text == '4'), [etree[1]],
+                         msg='function as key should get element as value')
+        self.assertEqual(traverse_obj(etree, lambda i, _: i == 1), [etree[1]],
+                         msg='function as key should get index as key')
+        self.assertEqual(traverse_obj(etree, 0), etree[0],
+                         msg='int key should return the nth child')
+        self.assertEqual(traverse_obj(etree, './/neighbor/@name'),
+                         ['Austria', 'Switzerland', 'Malaysia', 'Costa Rica', 'Colombia'],
+                         msg='`@<attribute>` at end of path should give that attribute')
+        self.assertEqual(traverse_obj(etree, '//neighbor/@fail'), [None, None, None, None, None],
+                         msg='`@<nonexistant>` at end of path should give `None`')
+        self.assertEqual(traverse_obj(etree, ('//neighbor/@', 2)), {'name': 'Malaysia', 'direction': 'N'},
+                         msg='`@` should give the full attribute dict')
+        self.assertEqual(traverse_obj(etree, '//year/text()'), ['2008', '2011', '2011'],
+                         msg='`text()` at end of path should give the inner text')
+        self.assertEqual(traverse_obj(etree, '//*[@direction]/@direction'), ['E', 'W', 'N', 'W', 'E'],
+                         msg='full python xpath features should be supported')
+        self.assertEqual(traverse_obj(etree, (0, '@name')), 'Liechtenstein',
+                         msg='special transformations should act on current element')
+        self.assertEqual(traverse_obj(etree, ('country', 0, ..., 'text()', {int_or_none})), [1, 2008, 141100],
+                         msg='special transformations should act on current element')
+
     def test_http_header_dict(self):
         headers = HTTPHeaderDict()
         headers['ytdl-test'] = b'0'
@@ -2382,6 +2424,11 @@ def test_http_header_dict(self):
         headers4 = HTTPHeaderDict({'ytdl-test': 'data;'})
         self.assertEqual(set(headers4.items()), {('Ytdl-Test', 'data;')})
 
+        # common mistake: strip whitespace from values
+        # https://github.com/yt-dlp/yt-dlp/issues/8729
+        headers5 = HTTPHeaderDict({'ytdl-test': ' data; '})
+        self.assertEqual(set(headers5.items()), {('Ytdl-Test', 'data;')})
+
     def test_extract_basic_auth(self):
         assert extract_basic_auth('http://:foo.bar') == ('http://:foo.bar', None)
         assert extract_basic_auth('http://foo.bar') == ('http://foo.bar', None)
@@ -2405,5 +2452,6 @@ def run_shell(args):
         assert run_shell(['echo', '^!']) == '"^!"\n'
         assert run_shell('echo "^!"') == '"^!"\n'
 
+
 if __name__ == '__main__':
     unittest.main()