mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2024-11-21 20:46:36 -05:00
[utils] traverse_obj
: Allow iterables in traversal (#6902)
Authored by: Grub4K
This commit is contained in:
parent
c16644642b
commit
21b5ec86c2
2 changed files with 7 additions and 4 deletions
|
@ -2016,6 +2016,8 @@ def test_traverse_obj(self):
|
||||||
msg='nested `...` queries should work')
|
msg='nested `...` queries should work')
|
||||||
self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4),
|
self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4),
|
||||||
msg='`...` query result should be flattened')
|
msg='`...` query result should be flattened')
|
||||||
|
self.assertEqual(traverse_obj(range(4), ...), list(range(4)),
|
||||||
|
msg='`...` should accept iterables')
|
||||||
|
|
||||||
# Test function as key
|
# Test function as key
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
|
self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
|
||||||
|
@ -2023,6 +2025,8 @@ def test_traverse_obj(self):
|
||||||
msg='function as query key should perform a filter based on (key, value)')
|
msg='function as query key should perform a filter based on (key, value)')
|
||||||
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
|
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
|
||||||
msg='exceptions in the query function should be catched')
|
msg='exceptions in the query function should be catched')
|
||||||
|
self.assertEqual(traverse_obj(range(4), lambda _, x: x % 2 == 0), [0, 2],
|
||||||
|
msg='function key should accept iterables')
|
||||||
if __debug__:
|
if __debug__:
|
||||||
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
|
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
|
||||||
traverse_obj(_TEST_DATA, lambda a: ...)
|
traverse_obj(_TEST_DATA, lambda a: ...)
|
||||||
|
|
|
@ -5528,7 +5528,6 @@ def traverse_obj(
|
||||||
If no `default` is given and the last path branches, a `list` of results
|
If no `default` is given and the last path branches, a `list` of results
|
||||||
is always returned. If a path ends on a `dict` that result will always be a `dict`.
|
is always returned. If a path ends on a `dict` that result will always be a `dict`.
|
||||||
"""
|
"""
|
||||||
is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes))
|
|
||||||
casefold = lambda k: k.casefold() if isinstance(k, str) else k
|
casefold = lambda k: k.casefold() if isinstance(k, str) else k
|
||||||
|
|
||||||
if isinstance(expected_type, type):
|
if isinstance(expected_type, type):
|
||||||
|
@ -5564,7 +5563,7 @@ def apply_key(key, obj, is_last):
|
||||||
branching = True
|
branching = True
|
||||||
if isinstance(obj, collections.abc.Mapping):
|
if isinstance(obj, collections.abc.Mapping):
|
||||||
result = obj.values()
|
result = obj.values()
|
||||||
elif is_sequence(obj):
|
elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)):
|
||||||
result = obj
|
result = obj
|
||||||
elif isinstance(obj, re.Match):
|
elif isinstance(obj, re.Match):
|
||||||
result = obj.groups()
|
result = obj.groups()
|
||||||
|
@ -5578,7 +5577,7 @@ def apply_key(key, obj, is_last):
|
||||||
branching = True
|
branching = True
|
||||||
if isinstance(obj, collections.abc.Mapping):
|
if isinstance(obj, collections.abc.Mapping):
|
||||||
iter_obj = obj.items()
|
iter_obj = obj.items()
|
||||||
elif is_sequence(obj):
|
elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)):
|
||||||
iter_obj = enumerate(obj)
|
iter_obj = enumerate(obj)
|
||||||
elif isinstance(obj, re.Match):
|
elif isinstance(obj, re.Match):
|
||||||
iter_obj = itertools.chain(
|
iter_obj = itertools.chain(
|
||||||
|
@ -5614,7 +5613,7 @@ def apply_key(key, obj, is_last):
|
||||||
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
|
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
|
||||||
|
|
||||||
elif isinstance(key, (int, slice)):
|
elif isinstance(key, (int, slice)):
|
||||||
if is_sequence(obj):
|
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, (str, bytes)):
|
||||||
branching = isinstance(key, slice)
|
branching = isinstance(key, slice)
|
||||||
with contextlib.suppress(IndexError):
|
with contextlib.suppress(IndexError):
|
||||||
result = obj[key]
|
result = obj[key]
|
||||||
|
|
Loading…
Reference in a new issue