[utils] traverse_obj: More fixes (#6959)

- Fix result when branching with `traverse_string`
- Fix `slice` path on `dict`s
- Fix tests and docstrings from 21b5ec86c2
- Add `is_iterable_like` helper function

Authored by: Grub4K
This commit is contained in:
Simon Sawicki 2023-04-30 19:50:22 +02:00 committed by GitHub
parent 4d9280c9c8
commit b079c26f0a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 12 deletions

View file

@ -2016,7 +2016,7 @@ def test_traverse_obj(self):
msg='nested `...` queries should work') msg='nested `...` queries should work')
self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4), self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4),
msg='`...` query result should be flattened') msg='`...` query result should be flattened')
self.assertEqual(traverse_obj(range(4), ...), list(range(4)), self.assertEqual(traverse_obj(iter(range(4)), ...), list(range(4)),
msg='`...` should accept iterables') msg='`...` should accept iterables')
# Test function as key # Test function as key
@ -2025,7 +2025,7 @@ def test_traverse_obj(self):
msg='function as query key should perform a filter based on (key, value)') msg='function as query key should perform a filter based on (key, value)')
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'}, self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
msg='exceptions in the query function should be catched') msg='exceptions in the query function should be catched')
self.assertEqual(traverse_obj(range(4), lambda _, x: x % 2 == 0), [0, 2], self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2],
msg='function key should accept iterables') msg='function key should accept iterables')
if __debug__: if __debug__:
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'): with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
@ -2051,6 +2051,17 @@ def test_traverse_obj(self):
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'): with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
traverse_obj(_TEST_DATA, {str.upper, str}) traverse_obj(_TEST_DATA, {str.upper, str})
# Test `slice` as a key
_SLICE_DATA = [0, 1, 2, 3, 4]
self.assertEqual(traverse_obj(_TEST_DATA, ('dict', slice(1))), None,
msg='slice on a dictionary should not throw')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1)), _SLICE_DATA[:1],
msg='slice key should apply slice to sequence')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 2)), _SLICE_DATA[1:2],
msg='slice key should apply slice to sequence')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 4, 2)), _SLICE_DATA[1:4:2],
msg='slice key should apply slice to sequence')
# Test alternative paths # Test alternative paths
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
msg='multiple `paths` should be treated as alternative paths') msg='multiple `paths` should be treated as alternative paths')
@ -2234,6 +2245,12 @@ def test_traverse_obj(self):
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
traverse_string=True), ['s', 'r'], traverse_string=True), ['s', 'r'],
msg='branching should result in list if `traverse_string`') msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj({}, (0, ...), traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj({}, (0, lambda x, y: True), traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj({}, (0, slice(1)), traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
# Test is_user_input behavior # Test is_user_input behavior
_IS_USER_INPUT_DATA = {'range8': list(range(8))} _IS_USER_INPUT_DATA = {'range8': list(range(8))}

View file

@ -3273,8 +3273,14 @@ def multipart_encode(data, boundary=None):
return out, content_type return out, content_type
def variadic(x, allowed_types=(str, bytes, dict)): def is_iterable_like(x, allowed_types=collections.abc.Iterable, blocked_types=NO_DEFAULT):
return x if isinstance(x, collections.abc.Iterable) and not isinstance(x, allowed_types) else (x,) if blocked_types is NO_DEFAULT:
blocked_types = (str, bytes, collections.abc.Mapping)
return isinstance(x, allowed_types) and not isinstance(x, blocked_types)
def variadic(x, allowed_types=NO_DEFAULT):
return x if is_iterable_like(x, blocked_types=allowed_types) else (x,)
def dict_get(d, key_or_keys, default=None, skip_false_values=True): def dict_get(d, key_or_keys, default=None, skip_false_values=True):
@ -5467,7 +5473,7 @@ def traverse_obj(
obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True, obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
casesense=True, is_user_input=False, traverse_string=False): casesense=True, is_user_input=False, traverse_string=False):
""" """
Safely traverse nested `dict`s and `Sequence`s Safely traverse nested `dict`s and `Iterable`s
>>> obj = [{}, {"key": "value"}] >>> obj = [{}, {"key": "value"}]
>>> traverse_obj(obj, (1, "key")) >>> traverse_obj(obj, (1, "key"))
@ -5475,7 +5481,7 @@ def traverse_obj(
Each of the provided `paths` is tested and the first producing a valid result will be returned. Each of the provided `paths` is tested and the first producing a valid result will be returned.
The next path will also be tested if the path branched but no results could be found. The next path will also be tested if the path branched but no results could be found.
Supported values for traversal are `Mapping`, `Sequence` and `re.Match`. Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded. Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
@ -5492,7 +5498,7 @@ def traverse_obj(
Read as: `[traverse_obj(obj, branch) for branch in branches]`. Read as: `[traverse_obj(obj, branch) for branch in branches]`.
- `function`: Branch out and return values filtered by the function. - `function`: Branch out and return values filtered by the function.
Read as: `[value for key, value in obj if function(key, value)]`. Read as: `[value for key, value in obj if function(key, value)]`.
For `Sequence`s, `key` is the index of the value. For `Iterable`s, `key` is the index of the value.
For `re.Match`es, `key` is the group number (0 = full match) For `re.Match`es, `key` is the group number (0 = full match)
as well as additionally any group names, if given. as well as additionally any group names, if given.
- `dict` Transform the current object and return a matching dict. - `dict` Transform the current object and return a matching dict.
@ -5540,7 +5546,9 @@ def apply_key(key, obj, is_last):
result = None result = None
if obj is None and traverse_string: if obj is None and traverse_string:
pass if key is ... or callable(key) or isinstance(key, slice):
branching = True
result = ()
elif key is None: elif key is None:
result = obj result = obj
@ -5563,7 +5571,7 @@ def apply_key(key, obj, is_last):
branching = True branching = True
if isinstance(obj, collections.abc.Mapping): if isinstance(obj, collections.abc.Mapping):
result = obj.values() result = obj.values()
elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)): elif is_iterable_like(obj):
result = obj result = obj
elif isinstance(obj, re.Match): elif isinstance(obj, re.Match):
result = obj.groups() result = obj.groups()
@ -5577,7 +5585,7 @@ def apply_key(key, obj, is_last):
branching = True branching = True
if isinstance(obj, collections.abc.Mapping): if isinstance(obj, collections.abc.Mapping):
iter_obj = obj.items() iter_obj = obj.items()
elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)): elif is_iterable_like(obj):
iter_obj = enumerate(obj) iter_obj = enumerate(obj)
elif isinstance(obj, re.Match): elif isinstance(obj, re.Match):
iter_obj = itertools.chain( iter_obj = itertools.chain(
@ -5601,7 +5609,7 @@ def apply_key(key, obj, is_last):
} or None } or None
elif isinstance(obj, collections.abc.Mapping): elif isinstance(obj, collections.abc.Mapping):
result = (obj.get(key) if casesense or (key in obj) else result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
next((v for k, v in obj.items() if casefold(k) == key), None)) next((v for k, v in obj.items() if casefold(k) == key), None))
elif isinstance(obj, re.Match): elif isinstance(obj, re.Match):
@ -5613,7 +5621,7 @@ def apply_key(key, obj, is_last):
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
elif isinstance(key, (int, slice)): elif isinstance(key, (int, slice)):
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, (str, bytes)): if is_iterable_like(obj, collections.abc.Sequence):
branching = isinstance(key, slice) branching = isinstance(key, slice)
with contextlib.suppress(IndexError): with contextlib.suppress(IndexError):
result = obj[key] result = obj[key]