+ def test_get_elements_html_by_attribute(self):
+ html = self.GET_ELEMENTS_BY_CLASS_TEST_STRING
+
+ self.assertEqual(get_elements_html_by_attribute('class', 'foo bar', html), self.GET_ELEMENTS_BY_CLASS_RES)
+ self.assertEqual(get_elements_html_by_attribute('class', 'foo', html), [])
+ self.assertEqual(get_elements_html_by_attribute('class', 'no-such-foo', html), [])
+
+ def test_get_elements_text_and_html_by_attribute(self):
+ html = self.GET_ELEMENTS_BY_CLASS_TEST_STRING
+
+ self.assertEqual(
+ list(get_elements_text_and_html_by_attribute('class', 'foo bar', html)),
+ list(zip(['nice', 'also nice'], self.GET_ELEMENTS_BY_CLASS_RES)))
+ self.assertEqual(list(get_elements_text_and_html_by_attribute('class', 'foo', html)), [])
+ self.assertEqual(list(get_elements_text_and_html_by_attribute('class', 'no-such-foo', html)), [])
+
+ self.assertEqual(list(get_elements_text_and_html_by_attribute(
+ 'class', 'foo', '<a class="foo">nice</a><span class="foo">nice</span>', tag='a')), [('nice', '<a class="foo">nice</a>')])
+
+ GET_ELEMENT_BY_TAG_TEST_STRING = '''
+ random text lorem ipsum</p>
+ <div>
+ this should be returned
+ <span>this should also be returned</span>
+ <div>
+ this should also be returned
+ </div>
+ closing tag above should not trick, so this should also be returned
+ </div>
+ but this text should not be returned
+ '''
+ GET_ELEMENT_BY_TAG_RES_OUTERDIV_HTML = GET_ELEMENT_BY_TAG_TEST_STRING.strip()[32:276]
+ GET_ELEMENT_BY_TAG_RES_OUTERDIV_TEXT = GET_ELEMENT_BY_TAG_RES_OUTERDIV_HTML[5:-6]
+ GET_ELEMENT_BY_TAG_RES_INNERSPAN_HTML = GET_ELEMENT_BY_TAG_TEST_STRING.strip()[78:119]
+ GET_ELEMENT_BY_TAG_RES_INNERSPAN_TEXT = GET_ELEMENT_BY_TAG_RES_INNERSPAN_HTML[6:-7]
+
+ def test_get_element_text_and_html_by_tag(self):
+ html = self.GET_ELEMENT_BY_TAG_TEST_STRING
+
+ self.assertEqual(
+ get_element_text_and_html_by_tag('div', html),
+ (self.GET_ELEMENT_BY_TAG_RES_OUTERDIV_TEXT, self.GET_ELEMENT_BY_TAG_RES_OUTERDIV_HTML))
+ self.assertEqual(
+ get_element_text_and_html_by_tag('span', html),
+ (self.GET_ELEMENT_BY_TAG_RES_INNERSPAN_TEXT, self.GET_ELEMENT_BY_TAG_RES_INNERSPAN_HTML))
+ self.assertRaises(compat_HTMLParseError, get_element_text_and_html_by_tag, 'article', html)
+
+ def test_iri_to_uri(self):
+ self.assertEqual(
+ iri_to_uri('https://www.google.com/search?q=foo&ie=utf-8&oe=utf-8&client=firefox-b'),
+ 'https://www.google.com/search?q=foo&ie=utf-8&oe=utf-8&client=firefox-b') # Same
+ self.assertEqual(
+ iri_to_uri('https://www.google.com/search?q=Käsesoßenrührlöffel'), # German for cheese sauce stirring spoon
+ 'https://www.google.com/search?q=K%C3%A4seso%C3%9Fenr%C3%BChrl%C3%B6ffel')
+ self.assertEqual(
+ iri_to_uri('https://www.google.com/search?q=lt<+gt>+eq%3D+amp%26+percent%25+hash%23+colon%3A+tilde~#trash=?&garbage=#'),
+ 'https://www.google.com/search?q=lt%3C+gt%3E+eq%3D+amp%26+percent%25+hash%23+colon%3A+tilde~#trash=?&garbage=#')
+ self.assertEqual(
+ iri_to_uri('http://правозащита38.рф/category/news/'),
+ 'http://xn--38-6kcaak9aj5chl4a3g.xn--p1ai/category/news/')
+ self.assertEqual(
+ iri_to_uri('http://www.правозащита38.рф/category/news/'),
+ 'http://www.xn--38-6kcaak9aj5chl4a3g.xn--p1ai/category/news/')
+ self.assertEqual(
+ iri_to_uri('https://i❤.ws/emojidomain/👍👏🤝💪'),
+ 'https://xn--i-7iq.ws/emojidomain/%F0%9F%91%8D%F0%9F%91%8F%F0%9F%A4%9D%F0%9F%92%AA')
+ self.assertEqual(
+ iri_to_uri('http://日本語.jp/'),
+ 'http://xn--wgv71a119e.jp/')
+ self.assertEqual(
+ iri_to_uri('http://导航.中国/'),
+ 'http://xn--fet810g.xn--fiqs8s/')
+
+ def test_clean_podcast_url(self):
+ self.assertEqual(clean_podcast_url('https://www.podtrac.com/pts/redirect.mp3/chtbl.com/track/5899E/traffic.megaphone.fm/HSW7835899191.mp3'), 'https://traffic.megaphone.fm/HSW7835899191.mp3')
+ self.assertEqual(clean_podcast_url('https://play.podtrac.com/npr-344098539/edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3'), 'https://edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3')
+
+ def test_LazyList(self):
+ it = list(range(10))
+
+ self.assertEqual(list(LazyList(it)), it)
+ self.assertEqual(LazyList(it).exhaust(), it)
+ self.assertEqual(LazyList(it)[5], it[5])
+
+ self.assertEqual(LazyList(it)[5:], it[5:])
+ self.assertEqual(LazyList(it)[:5], it[:5])
+ self.assertEqual(LazyList(it)[::2], it[::2])
+ self.assertEqual(LazyList(it)[1::2], it[1::2])
+ self.assertEqual(LazyList(it)[5::-1], it[5::-1])
+ self.assertEqual(LazyList(it)[6:2:-2], it[6:2:-2])
+ self.assertEqual(LazyList(it)[::-1], it[::-1])
+
+ self.assertTrue(LazyList(it))
+ self.assertFalse(LazyList(range(0)))
+ self.assertEqual(len(LazyList(it)), len(it))
+ self.assertEqual(repr(LazyList(it)), repr(it))
+ self.assertEqual(str(LazyList(it)), str(it))
+
+ self.assertEqual(list(LazyList(it, reverse=True)), it[::-1])
+ self.assertEqual(list(reversed(LazyList(it))[::-1]), it)
+ self.assertEqual(list(reversed(LazyList(it))[1:3:7]), it[::-1][1:3:7])
+
+ def test_LazyList_laziness(self):
+
+ def test(ll, idx, val, cache):
+ self.assertEqual(ll[idx], val)
+ self.assertEqual(ll._cache, list(cache))
+
+ ll = LazyList(range(10))
+ test(ll, 0, 0, range(1))
+ test(ll, 5, 5, range(6))
+ test(ll, -3, 7, range(10))
+
+ ll = LazyList(range(10), reverse=True)
+ test(ll, -1, 0, range(1))
+ test(ll, 3, 6, range(10))
+
+ ll = LazyList(itertools.count())
+ test(ll, 10, 10, range(11))
+ ll = reversed(ll)
+ test(ll, -15, 14, range(15))
+
+ def test_format_bytes(self):
+ self.assertEqual(format_bytes(0), '0.00B')
+ self.assertEqual(format_bytes(1000), '1000.00B')
+ self.assertEqual(format_bytes(1024), '1.00KiB')
+ self.assertEqual(format_bytes(1024**2), '1.00MiB')
+ self.assertEqual(format_bytes(1024**3), '1.00GiB')
+ self.assertEqual(format_bytes(1024**4), '1.00TiB')
+ self.assertEqual(format_bytes(1024**5), '1.00PiB')
+ self.assertEqual(format_bytes(1024**6), '1.00EiB')
+ self.assertEqual(format_bytes(1024**7), '1.00ZiB')
+ self.assertEqual(format_bytes(1024**8), '1.00YiB')
+ self.assertEqual(format_bytes(1024**9), '1024.00YiB')
+
+ def test_hide_login_info(self):
+ self.assertEqual(Config.hide_login_info(['-u', 'foo', '-p', 'bar']),
+ ['-u', 'PRIVATE', '-p', 'PRIVATE'])
+ self.assertEqual(Config.hide_login_info(['-u']), ['-u'])
+ self.assertEqual(Config.hide_login_info(['-u', 'foo', '-u', 'bar']),
+ ['-u', 'PRIVATE', '-u', 'PRIVATE'])
+ self.assertEqual(Config.hide_login_info(['--username=foo']),
+ ['--username=PRIVATE'])
+
+ def test_locked_file(self):
+ TEXT = 'test_locked_file\n'
+ FILE = 'test_locked_file.ytdl'
+ MODES = 'war' # Order is important
+
+ try:
+ for lock_mode in MODES:
+ with locked_file(FILE, lock_mode, False) as f:
+ if lock_mode == 'r':
+ self.assertEqual(f.read(), TEXT * 2, 'Wrong file content')
+ else:
+ f.write(TEXT)
+ for test_mode in MODES:
+ testing_write = test_mode != 'r'
+ try:
+ with locked_file(FILE, test_mode, False):
+ pass
+ except (BlockingIOError, PermissionError):
+ if not testing_write: # FIXME
+ print(f'Known issue: Exclusive lock ({lock_mode}) blocks read access ({test_mode})')
+ continue
+ self.assertTrue(testing_write, f'{test_mode} is blocked by {lock_mode}')
+ else:
+ self.assertFalse(testing_write, f'{test_mode} is not blocked by {lock_mode}')
+ finally:
+ with contextlib.suppress(OSError):
+ os.remove(FILE)
+
+ def test_determine_file_encoding(self):
+ self.assertEqual(determine_file_encoding(b''), (None, 0))
+ self.assertEqual(determine_file_encoding(b'--verbose -x --audio-format mkv\n'), (None, 0))
+
+ self.assertEqual(determine_file_encoding(b'\xef\xbb\xbf'), ('utf-8', 3))
+ self.assertEqual(determine_file_encoding(b'\x00\x00\xfe\xff'), ('utf-32-be', 4))
+ self.assertEqual(determine_file_encoding(b'\xff\xfe'), ('utf-16-le', 2))
+
+ self.assertEqual(determine_file_encoding(b'\xff\xfe# coding: utf-8\n--verbose'), ('utf-16-le', 2))
+
+ self.assertEqual(determine_file_encoding(b'# coding: utf-8\n--verbose'), ('utf-8', 0))
+ self.assertEqual(determine_file_encoding(b'# coding: someencodinghere-12345\n--verbose'), ('someencodinghere-12345', 0))
+
+ self.assertEqual(determine_file_encoding(b'#coding:utf-8\n--verbose'), ('utf-8', 0))
+ self.assertEqual(determine_file_encoding(b'# coding: utf-8 \r\n--verbose'), ('utf-8', 0))
+
+ self.assertEqual(determine_file_encoding('# coding: utf-32-be'.encode('utf-32-be')), ('utf-32-be', 0))
+ self.assertEqual(determine_file_encoding('# coding: utf-16-le'.encode('utf-16-le')), ('utf-16-le', 0))
+
+ def test_get_compatible_ext(self):
+ self.assertEqual(get_compatible_ext(
+ vcodecs=[None], acodecs=[None, None], vexts=['mp4'], aexts=['m4a', 'm4a']), 'mkv')
+ self.assertEqual(get_compatible_ext(
+ vcodecs=[None], acodecs=[None], vexts=['flv'], aexts=['flv']), 'flv')
+
+ self.assertEqual(get_compatible_ext(
+ vcodecs=[None], acodecs=[None], vexts=['mp4'], aexts=['m4a']), 'mp4')
+ self.assertEqual(get_compatible_ext(
+ vcodecs=[None], acodecs=[None], vexts=['mp4'], aexts=['webm']), 'mkv')
+ self.assertEqual(get_compatible_ext(
+ vcodecs=[None], acodecs=[None], vexts=['webm'], aexts=['m4a']), 'mkv')
+ self.assertEqual(get_compatible_ext(
+ vcodecs=[None], acodecs=[None], vexts=['webm'], aexts=['webm']), 'webm')
+ self.assertEqual(get_compatible_ext(
+ vcodecs=[None], acodecs=[None], vexts=['webm'], aexts=['weba']), 'webm')
+
+ self.assertEqual(get_compatible_ext(
+ vcodecs=['h264'], acodecs=['mp4a'], vexts=['mov'], aexts=['m4a']), 'mp4')
+ self.assertEqual(get_compatible_ext(
+ vcodecs=['av01.0.12M.08'], acodecs=['opus'], vexts=['mp4'], aexts=['webm']), 'webm')
+
+ self.assertEqual(get_compatible_ext(
+ vcodecs=['vp9'], acodecs=['opus'], vexts=['webm'], aexts=['webm'], preferences=['flv', 'mp4']), 'mp4')
+ self.assertEqual(get_compatible_ext(
+ vcodecs=['av1'], acodecs=['mp4a'], vexts=['webm'], aexts=['m4a'], preferences=('webm', 'mkv')), 'mkv')
+
+ def test_traverse_obj(self):
+ _TEST_DATA = {
+ 100: 100,
+ 1.2: 1.2,
+ 'str': 'str',
+ 'None': None,
+ '...': ...,
+ 'urls': [
+ {'index': 0, 'url': 'https://www.example.com/0'},
+ {'index': 1, 'url': 'https://www.example.com/1'},
+ ],
+ 'data': (
+ {'index': 2},
+ {'index': 3},
+ ),
+ 'dict': {},
+ }
+
+ # Test base functionality
+ self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str',
+ msg='allow tuple path')
+ self.assertEqual(traverse_obj(_TEST_DATA, ['str']), 'str',
+ msg='allow list path')
+ self.assertEqual(traverse_obj(_TEST_DATA, (value for value in ("str",))), 'str',
+ msg='allow iterable path')
+ self.assertEqual(traverse_obj(_TEST_DATA, 'str'), 'str',
+ msg='single items should be treated as a path')
+ self.assertEqual(traverse_obj(_TEST_DATA, None), _TEST_DATA)
+ self.assertEqual(traverse_obj(_TEST_DATA, 100), 100)
+ self.assertEqual(traverse_obj(_TEST_DATA, 1.2), 1.2)
+
+ # Test Ellipsis behavior
+ self.assertCountEqual(traverse_obj(_TEST_DATA, ...),
+ (item for item in _TEST_DATA.values() if item not in (None, {})),
+ msg='`...` should give all non discarded values')
+ self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(),
+ msg='`...` selection for dicts should select all values')
+ self.assertEqual(traverse_obj(_TEST_DATA, (..., ..., 'url')),
+ ['https://www.example.com/0', 'https://www.example.com/1'],
+ msg='nested `...` queries should work')
+ self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4),
+ msg='`...` query result should be flattened')
+ self.assertEqual(traverse_obj(iter(range(4)), ...), list(range(4)),
+ msg='`...` should accept iterables')
+
+ # Test function as key
+ self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
+ [_TEST_DATA['urls']],
+ msg='function as query key should perform a filter based on (key, value)')
+ self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
+ msg='exceptions in the query function should be catched')
+ self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2],
+ msg='function key should accept iterables')
+ if __debug__:
+ with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
+ traverse_obj(_TEST_DATA, lambda a: ...)
+ with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
+ traverse_obj(_TEST_DATA, lambda a, b, c: ...)
+
+ # Test set as key (transformation/type, like `expected_type`)
+ self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper}, )), ['STR'],
+ msg='Function in set should be a transformation')
+ self.assertEqual(traverse_obj(_TEST_DATA, (..., {str})), ['str'],
+ msg='Type in set should be a type filter')
+ self.assertEqual(traverse_obj(_TEST_DATA, {dict}), _TEST_DATA,
+ msg='A single set should be wrapped into a path')
+ self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper})), ['STR'],
+ msg='Transformation function should not raise')
+ self.assertEqual(traverse_obj(_TEST_DATA, (..., {str_or_none})),
+ [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None],
+ msg='Function in set should be a transformation')
+ if __debug__:
+ with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
+ traverse_obj(_TEST_DATA, set())
+ with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
+ traverse_obj(_TEST_DATA, {str.upper, str})
+
+ # Test `slice` as a key
+ _SLICE_DATA = [0, 1, 2, 3, 4]
+ self.assertEqual(traverse_obj(_TEST_DATA, ('dict', slice(1))), None,
+ msg='slice on a dictionary should not throw')
+ self.assertEqual(traverse_obj(_SLICE_DATA, slice(1)), _SLICE_DATA[:1],
+ msg='slice key should apply slice to sequence')
+ self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 2)), _SLICE_DATA[1:2],
+ msg='slice key should apply slice to sequence')
+ self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 4, 2)), _SLICE_DATA[1:4:2],
+ msg='slice key should apply slice to sequence')
+
+ # Test alternative paths
+ self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
+ msg='multiple `paths` should be treated as alternative paths')
+ self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str',
+ msg='alternatives should exit early')
+ self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None,
+ msg='alternatives should return `default` if exhausted')
+ self.assertEqual(traverse_obj(_TEST_DATA, (..., 'fail'), 100), 100,
+ msg='alternatives should track their own branching return')
+ self.assertEqual(traverse_obj(_TEST_DATA, ('dict', ...), ('data', ...)), list(_TEST_DATA['data']),
+ msg='alternatives on empty objects should search further')
+
+ # Test branch and path nesting
+ self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'],
+ msg='tuple as key should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, ('urls', [3, 0], 'url')), ['https://www.example.com/0'],
+ msg='list as key should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ((1, 'fail'), (0, 'url')))), ['https://www.example.com/0'],
+ msg='double nesting in path should be treated as paths')
+ self.assertEqual(traverse_obj(['0', [1, 2]], [(0, 1), 0]), [1],
+ msg='do not fail early on branching')
+ self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', ((1, ('fail', 'url')), (0, 'url')))),
+ ['https://www.example.com/0', 'https://www.example.com/1'],
+ msg='tripple nesting in path should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ('fail', (..., 'url')))),
+ ['https://www.example.com/0', 'https://www.example.com/1'],
+ msg='ellipsis as branch path start gets flattened')
+
+ # Test dictionary as key
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}), {0: 100, 1: 1.2},
+ msg='dict key should result in a dict with the same keys')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', 0, 'url')}),
+ {0: 'https://www.example.com/0'},
+ msg='dict key should allow paths')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', (3, 0), 'url')}),
+ {0: ['https://www.example.com/0']},
+ msg='tuple in dict path should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, 'fail'), (0, 'url')))}),
+ {0: ['https://www.example.com/0']},
+ msg='double nesting in dict path should be treated as paths')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}),
+ {0: ['https://www.example.com/1', 'https://www.example.com/0']},
+ msg='tripple nesting in dict path should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
+ msg='remove `None` values when top level dict key fails')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=...), {0: ...},
+ msg='use `default` if key fails and `default`')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},
+ msg='remove empty values when dict key')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: ...},
+ msg='use `default` when dict key and `default`')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},
+ msg='remove empty values when nested dict key fails')
+ self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
+ msg='default to dict if pruned')
+ self.assertEqual(traverse_obj(None, {0: 'fail'}, default=...), {0: ...},
+ msg='default to dict if pruned and default is given')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=...), {0: {0: ...}},
+ msg='use nested `default` when nested dict key fails and `default`')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {},
+ msg='remove key if branch in dict key not successful')
+
+ # Testing default parameter behavior
+ _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail'), None,
+ msg='default value should be `None`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', 'fail', default=...), ...,
+ msg='chained fails should result in default')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', 'int'), 0,
+ msg='should not short cirquit on `None`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', default=1), 1,
+ msg='invalid dict key should result in `default`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', default=1), 1,
+ msg='`None` is a deliberate sentinel and should become `default`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None,
+ msg='`IndexError` should result in `default`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=1), 1,
+ msg='if branched but not successful return `default` if defined, not `[]`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=None), None,
+ msg='if branched but not successful return `default` even if `default` is `None`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail')), [],
+ msg='if branched but not successful return `[]`, not `default`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', ...)), [],
+ msg='if branched but object is empty return `[]`, not `default`')
+ self.assertEqual(traverse_obj(None, ...), [],
+ msg='if branched but object is `None` return `[]`, not `default`')
+ self.assertEqual(traverse_obj({0: None}, (0, ...)), [],
+ msg='if branched but state is `None` return `[]`, not `default`')
+
+ branching_paths = [
+ ('fail', ...),
+ (..., 'fail'),
+ 100 * ('fail',) + (...,),
+ (...,) + 100 * ('fail',),
+ ]
+ for branching_path in branching_paths:
+ self.assertEqual(traverse_obj({}, branching_path), [],
+ msg='if branched but state is `None`, return `[]` (not `default`)')
+ self.assertEqual(traverse_obj({}, 'fail', branching_path), [],
+ msg='if branching in last alternative and previous did not match, return `[]` (not `default`)')
+ self.assertEqual(traverse_obj({0: 'x'}, 0, branching_path), 'x',
+ msg='if branching in last alternative and previous did match, return single value')
+ self.assertEqual(traverse_obj({0: 'x'}, branching_path, 0), 'x',
+ msg='if branching in first alternative and non-branching path does match, return single value')
+ self.assertEqual(traverse_obj({}, branching_path, 'fail'), None,
+ msg='if branching in first alternative and non-branching path does not match, return `default`')
+
+ # Testing expected_type behavior
+ _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str),
+ 'str', msg='accept matching `expected_type` type')
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int),
+ None, msg='reject non matching `expected_type` type')
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)),
+ '0', msg='transform type using type function')
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0),
+ None, msg='wrap expected_type fuction in try_call')
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str),
+ ['str'], msg='eliminate items that expected_type fails on')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int),
+ {0: 100}, msg='type as expected_type should filter dict values')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none),
+ {0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values')
+ self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int),
+ 1, msg='expected_type should not filter non final dict values')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int),
+ {0: {0: 100}}, msg='expected_type should transform deep dict values')
+ self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(...)),
+ [{0: ...}, {0: ...}], msg='expected_type should transform branched dict values')
+ self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int),
+ [4], msg='expected_type regression for type matching in tuple branching')
+ self.assertEqual(traverse_obj(_TEST_DATA, ['data', ...], expected_type=int),
+ [], msg='expected_type regression for type matching in dict result')
+
+ # Test get_all behavior
+ _GET_ALL_DATA = {'key': [0, 1, 2]}
+ self.assertEqual(traverse_obj(_GET_ALL_DATA, ('key', ...), get_all=False), 0,
+ msg='if not `get_all`, return only first matching value')
+ self.assertEqual(traverse_obj(_GET_ALL_DATA, ..., get_all=False), [0, 1, 2],
+ msg='do not overflatten if not `get_all`')
+
+ # Test casesense behavior
+ _CASESENSE_DATA = {
+ 'KeY': 'value0',
+ 0: {
+ 'KeY': 'value1',
+ 0: {'KeY': 'value2'},
+ },
+ }
+ self.assertEqual(traverse_obj(_CASESENSE_DATA, 'key'), None,
+ msg='dict keys should be case sensitive unless `casesense`')
+ self.assertEqual(traverse_obj(_CASESENSE_DATA, 'keY',
+ casesense=False), 'value0',
+ msg='allow non matching key case if `casesense`')
+ self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ('keY',)),
+ casesense=False), ['value1'],
+ msg='allow non matching key case in branch if `casesense`')
+ self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ((0, 'keY'),)),
+ casesense=False), ['value2'],
+ msg='allow non matching key case in branch path if `casesense`')
+
+ # Test traverse_string behavior
+ _TRAVERSE_STRING_DATA = {'str': 'str', 1.2: 1.2}
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0)), None,
+ msg='do not traverse into string if not `traverse_string`')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0),
+ traverse_string=True), 's',
+ msg='traverse into string if `traverse_string`')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, (1.2, 1),
+ traverse_string=True), '.',
+ msg='traverse into converted data if `traverse_string`')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...),
+ traverse_string=True), 'str',
+ msg='`...` should result in string (same value) if `traverse_string`')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),
+ traverse_string=True), 'sr',
+ msg='`slice` should result in string if `traverse_string`')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"),
+ traverse_string=True), 'str',
+ msg='function should result in string if `traverse_string`')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
+ traverse_string=True), ['s', 'r'],
+ msg='branching should result in list if `traverse_string`')
+ self.assertEqual(traverse_obj({}, (0, ...), traverse_string=True), [],
+ msg='branching should result in list if `traverse_string`')
+ self.assertEqual(traverse_obj({}, (0, lambda x, y: True), traverse_string=True), [],
+ msg='branching should result in list if `traverse_string`')
+ self.assertEqual(traverse_obj({}, (0, slice(1)), traverse_string=True), [],
+ msg='branching should result in list if `traverse_string`')
+
+ # Test is_user_input behavior
+ _IS_USER_INPUT_DATA = {'range8': list(range(8))}
+ self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'),
+ is_user_input=True), 3,
+ msg='allow for string indexing if `is_user_input`')
+ self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'),
+ is_user_input=True), tuple(range(8))[3:],
+ msg='allow for string slice if `is_user_input`')
+ self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'),
+ is_user_input=True), tuple(range(8))[:4:2],
+ msg='allow step in string slice if `is_user_input`')
+ self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'),
+ is_user_input=True), range(8),
+ msg='`:` should be treated as `...` if `is_user_input`')
+ with self.assertRaises(TypeError, msg='too many params should result in error'):
+ traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True)
+
+ # Test re.Match as input obj
+ mobj = re.fullmatch(r'0(12)(?P<group>3)(4)?', '0123')
+ self.assertEqual(traverse_obj(mobj, ...), [x for x in mobj.groups() if x is not None],
+ msg='`...` on a `re.Match` should give its `groups()`')
+ self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 2)), ['0123', '3'],
+ msg='function on a `re.Match` should give groupno, value starting at 0')
+ self.assertEqual(traverse_obj(mobj, 'group'), '3',
+ msg='str key on a `re.Match` should give group with that name')
+ self.assertEqual(traverse_obj(mobj, 2), '3',
+ msg='int key on a `re.Match` should give group with that name')
+ self.assertEqual(traverse_obj(mobj, 'gRoUp', casesense=False), '3',
+ msg='str key on a `re.Match` should respect casesense')
+ self.assertEqual(traverse_obj(mobj, 'fail'), None,
+ msg='failing str key on a `re.Match` should return `default`')
+ self.assertEqual(traverse_obj(mobj, 'gRoUpS', casesense=False), None,
+ msg='failing str key on a `re.Match` should return `default`')
+ self.assertEqual(traverse_obj(mobj, 8), None,
+ msg='failing int key on a `re.Match` should return `default`')
+ self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
+ msg='function on a `re.Match` should give group name as well')
+