[utils] Allow partial application for even more functions (#11437)

Fixes b6dc2c49e8

Authored by: Grub4K
This commit is contained in:
Simon Sawicki 2024-11-02 21:42:00 +01:00 committed by GitHub
parent a6783a3b99
commit 422195ec70
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 20 additions and 0 deletions

View file

@ -9,6 +9,7 @@
determine_ext, determine_ext,
dict_get, dict_get,
int_or_none, int_or_none,
join_nonempty,
str_or_none, str_or_none,
) )
from yt_dlp.utils.traversal import ( from yt_dlp.utils.traversal import (
@ -16,6 +17,7 @@
subs_list_to_dict, subs_list_to_dict,
traverse_obj, traverse_obj,
trim_str, trim_str,
unpack,
) )
_TEST_DATA = { _TEST_DATA = {
@ -510,6 +512,15 @@ def test_trim_str(self):
assert trim_str(start='abc', end='abc')('abc') == '' assert trim_str(start='abc', end='abc')('abc') == ''
assert trim_str(start='', end='')('abc') == 'abc' assert trim_str(start='', end='')('abc') == 'abc'
def test_unpack(self):
assert unpack(lambda *x: ''.join(map(str, x)))([1, 2, 3]) == '123'
assert unpack(join_nonempty)([1, 2, 3]) == '1-2-3'
assert unpack(join_nonempty(delim=' '))([1, 2, 3]) == '1 2 3'
with pytest.raises(TypeError):
unpack(join_nonempty)()
with pytest.raises(TypeError):
unpack()
class TestDictGet: class TestDictGet:
def test_dict_get(self): def test_dict_get(self):

View file

@ -5294,6 +5294,7 @@ def make_archive_id(ie, video_id):
return f'{ie_key.lower()} {video_id}' return f'{ie_key.lower()} {video_id}'
@partial_application
def truncate_string(s, left, right=0): def truncate_string(s, left, right=0):
assert left > 3 and right >= 0 assert left > 3 and right >= 0
if s is None or len(s) <= left + right: if s is None or len(s) <= left + right:

View file

@ -449,6 +449,14 @@ def trim(s):
return trim return trim
def unpack(func):
@functools.wraps(func)
def inner(items, **kwargs):
return func(*items, **kwargs)
return inner
def get_first(obj, *paths, **kwargs): def get_first(obj, *paths, **kwargs):
return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False) return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)