diff --git a/test/test_traversal.py b/test/test_traversal.py index 5d9fbe1d1..9179dadda 100644 --- a/test/test_traversal.py +++ b/test/test_traversal.py @@ -4,8 +4,18 @@ import pytest -from yt_dlp.utils import dict_get, int_or_none, str_or_none -from yt_dlp.utils.traversal import traverse_obj +from yt_dlp.utils import ( + ExtractorError, + determine_ext, + dict_get, + int_or_none, + str_or_none, +) +from yt_dlp.utils.traversal import ( + traverse_obj, + require, + subs_list_to_dict, +) _TEST_DATA = { 100: 100, @@ -420,6 +430,71 @@ def test_traversal_morsel(self): assert traverse_obj(morsel, [(None,), any]) == morsel, \ 'Morsel should not be implicitly changed to dict on usage' + def test_traversal_filter(self): + data = [None, False, True, 0, 1, 0.0, 1.1, '', 'str', {}, {0: 0}, [], [1]] + + assert traverse_obj(data, [..., filter]) == [True, 1, 1.1, 'str', {0: 0}, [1]], \ + '`filter` should filter falsy values' + + +class TestTraversalHelpers: + def test_traversal_require(self): + with pytest.raises(ExtractorError): + traverse_obj(_TEST_DATA, ['None', {require('value')}]) + assert traverse_obj(_TEST_DATA, ['str', {require('value')}]) == 'str', \ + '`require` should pass through non `None` values' + + def test_subs_list_to_dict(self): + assert traverse_obj([ + {'name': 'de', 'url': 'https://example.com/subs/de.vtt'}, + {'name': 'en', 'url': 'https://example.com/subs/en1.ass'}, + {'name': 'en', 'url': 'https://example.com/subs/en2.ass'}, + ], [..., { + 'id': 'name', + 'url': 'url', + }, all, {subs_list_to_dict}]) == { + 'de': [{'url': 'https://example.com/subs/de.vtt'}], + 'en': [ + {'url': 'https://example.com/subs/en1.ass'}, + {'url': 'https://example.com/subs/en2.ass'}, + ], + }, 'function should build subtitle dict from list of subtitles' + assert traverse_obj([ + {'name': 'de', 'url': 'https://example.com/subs/de.ass'}, + {'name': 'de'}, + {'name': 'en', 'content': 'content'}, + {'url': 'https://example.com/subs/en'}, + ], [..., { + 'id': 'name', + 'data': 'content', + 'url': 'url', + }, all, {subs_list_to_dict}]) == { + 'de': [{'url': 'https://example.com/subs/de.ass'}], + 'en': [{'data': 'content'}], + }, 'subs with mandatory items missing should be filtered' + assert traverse_obj([ + {'url': 'https://example.com/subs/de.ass', 'name': 'de'}, + {'url': 'https://example.com/subs/en', 'name': 'en'}, + ], [..., { + 'id': 'name', + 'ext': ['url', {lambda x: determine_ext(x, default_ext=None)}], + 'url': 'url', + }, all, {subs_list_to_dict(ext='ext')}]) == { + 'de': [{'url': 'https://example.com/subs/de.ass', 'ext': 'ass'}], + 'en': [{'url': 'https://example.com/subs/en', 'ext': 'ext'}], + }, '`ext` should set default ext but leave existing value untouched' + assert traverse_obj([ + {'name': 'en', 'url': 'https://example.com/subs/en2', 'prio': True}, + {'name': 'en', 'url': 'https://example.com/subs/en1', 'prio': False}, + ], [..., { + 'id': 'name', + 'quality': ['prio', {int}], + 'url': 'url', + }, all, {subs_list_to_dict(ext='ext')}]) == {'en': [ + {'url': 'https://example.com/subs/en1', 'ext': 'ext'}, + {'url': 'https://example.com/subs/en2', 'ext': 'ext'}, + ]}, '`quality` key should sort subtitle list accordingly' + class TestDictGet: def test_dict_get(self): diff --git a/yt_dlp/extractor/common.py b/yt_dlp/extractor/common.py index 3430036f4..812fbfa9f 100644 --- a/yt_dlp/extractor/common.py +++ b/yt_dlp/extractor/common.py @@ -573,13 +573,13 @@ class InfoExtractor: def _login_hint(self, method=NO_DEFAULT, netrc=None): password_hint = f'--username and --password, --netrc-cmd, or --netrc ({netrc or self._NETRC_MACHINE}) to provide account credentials' + cookies_hint = 'See https://github.com/yt-dlp/yt-dlp/wiki/FAQ#how-do-i-pass-cookies-to-yt-dlp for how to manually pass cookies' return { None: '', - 'any': f'Use --cookies, --cookies-from-browser, {password_hint}', + 'any': f'Use --cookies, --cookies-from-browser, {password_hint}. {cookies_hint}', 'password': f'Use {password_hint}', - 'cookies': ( - 'Use --cookies-from-browser or --cookies for the authentication. ' - 'See https://github.com/yt-dlp/yt-dlp/wiki/FAQ#how-do-i-pass-cookies-to-yt-dlp for how to manually pass cookies'), + 'cookies': f'Use --cookies-from-browser or --cookies for the authentication. {cookies_hint}', + 'session_cookies': f'Use --cookies for the authentication (--cookies-from-browser might not work). {cookies_hint}', }[method if method is not NO_DEFAULT else 'any' if self.supports_login() else 'cookies'] def __init__(self, downloader=None): diff --git a/yt_dlp/utils/_utils.py b/yt_dlp/utils/_utils.py index 967f01fdf..dd12466b8 100644 --- a/yt_dlp/utils/_utils.py +++ b/yt_dlp/utils/_utils.py @@ -1984,11 +1984,30 @@ def urljoin(base, path): return urllib.parse.urljoin(base, path) -def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1): +def partial_application(func): + sig = inspect.signature(func) + + @functools.wraps(func) + def wrapped(*args, **kwargs): + try: + sig.bind(*args, **kwargs) + except TypeError: + return functools.partial(func, *args, **kwargs) + else: + return func(*args, **kwargs) + + return wrapped + + +@partial_application +def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1, base=None): if get_attr and v is not None: v = getattr(v, get_attr, None) + if invscale == 1 and scale < 1: + invscale = int(1 / scale) + scale = 1 try: - return int(v) * invscale // scale + return (int(v) if base is None else int(v, base=base)) * invscale // scale except (ValueError, TypeError, OverflowError): return default @@ -2006,9 +2025,13 @@ def str_to_int(int_str): return int_or_none(int_str) +@partial_application def float_or_none(v, scale=1, invscale=1, default=None): if v is None: return default + if invscale == 1 and scale < 1: + invscale = int(1 / scale) + scale = 1 try: return float(v) * invscale / scale except (ValueError, TypeError): diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py index 96eb2eddf..b918487f9 100644 --- a/yt_dlp/utils/traversal.py +++ b/yt_dlp/utils/traversal.py @@ -1,18 +1,35 @@ +from __future__ import annotations + +import collections import collections.abc import contextlib +import functools import http.cookies import inspect import itertools import re +import typing import xml.etree.ElementTree from ._utils import ( IDENTITY, NO_DEFAULT, + ExtractorError, LazyList, deprecation_warning, + get_elements_html_by_class, + get_elements_html_by_attribute, + get_elements_by_attribute, + get_element_html_by_attribute, + get_element_by_attribute, + get_element_html_by_id, + get_element_by_id, + get_element_html_by_class, + get_elements_by_class, + get_element_text_and_html_by_tag, is_iterable_like, try_call, + url_or_none, variadic, ) @@ -54,6 +71,7 @@ def traverse_obj( Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. - `any`-builtin: Take the first matching object and return it, resetting branching. - `all`-builtin: Take all matching objects and return them as a list, resetting branching. + - `filter`-builtin: Return the value if it is truthy, `None` otherwise. `tuple`, `list`, and `dict` all support nested paths and branches. @@ -247,6 +265,10 @@ def apply_path(start_obj, path, test_type): objs = (list(filtered_objs),) continue + if key is filter: + objs = filter(None, objs) + continue + if __debug__ and callable(key): # Verify function signature inspect.signature(key).bind(None, None) @@ -277,13 +299,143 @@ def _traverse_obj(obj, path, allow_empty, test_type): return results[0] if results else {} if allow_empty and is_dict else None for index, path in enumerate(paths, 1): - result = _traverse_obj(obj, path, index == len(paths), True) - if result is not None: - return result + is_last = index == len(paths) + try: + result = _traverse_obj(obj, path, is_last, True) + if result is not None: + return result + except _RequiredError as e: + if is_last: + # Reraise to get cleaner stack trace + raise ExtractorError(e.orig_msg, expected=e.expected) from None return None if default is NO_DEFAULT else default +def value(value, /): + return lambda _: value + + +def require(name, /, *, expected=False): + def func(value): + if value is None: + raise _RequiredError(f'Unable to extract {name}', expected=expected) + + return value + + return func + + +class _RequiredError(ExtractorError): + pass + + +@typing.overload +def subs_list_to_dict(*, ext: str | None = None) -> collections.abc.Callable[[list[dict]], dict[str, list[dict]]]: ... + + +@typing.overload +def subs_list_to_dict(subs: list[dict] | None, /, *, ext: str | None = None) -> dict[str, list[dict]]: ... + + +def subs_list_to_dict(subs: list[dict] | None = None, /, *, ext=None): + """ + Convert subtitles from a traversal into a subtitle dict. + The path should have an `all` immediately before this function. + + Arguments: + `ext` The default value for `ext` in the subtitle dict + + In the dict you can set the following additional items: + `id` The subtitle id to sort the dict into + `quality` The sort order for each subtitle + """ + if subs is None: + return functools.partial(subs_list_to_dict, ext=ext) + + result = collections.defaultdict(list) + + for sub in subs: + if not url_or_none(sub.get('url')) and not sub.get('data'): + continue + sub_id = sub.pop('id', None) + if sub_id is None: + continue + if ext is not None and not sub.get('ext'): + sub['ext'] = ext + result[sub_id].append(sub) + result = dict(result) + + for subs in result.values(): + subs.sort(key=lambda x: x.pop('quality', 0) or 0) + + return result + + +@typing.overload +def find_element(*, attr: str, value: str, tag: str | None = None, html=False): ... + + +@typing.overload +def find_element(*, cls: str, html=False): ... + + +@typing.overload +def find_element(*, id: str, tag: str | None = None, html=False): ... + + +@typing.overload +def find_element(*, tag: str, html=False): ... + + +def find_element(*, tag=None, id=None, cls=None, attr=None, value=None, html=False): + # deliberately using `id=` and `cls=` for ease of readability + assert tag or id or cls or (attr and value), 'One of tag, id, cls or (attr AND value) is required' + if not tag: + tag = r'[\w:.-]+' + + if attr and value: + assert not cls, 'Cannot match both attr and cls' + assert not id, 'Cannot match both attr and id' + func = get_element_html_by_attribute if html else get_element_by_attribute + return functools.partial(func, attr, value, tag=tag) + + elif cls: + assert not id, 'Cannot match both cls and id' + assert tag is None, 'Cannot match both cls and tag' + func = get_element_html_by_class if html else get_elements_by_class + return functools.partial(func, cls) + + elif id: + func = get_element_html_by_id if html else get_element_by_id + return functools.partial(func, id, tag=tag) + + index = int(bool(html)) + return lambda html: get_element_text_and_html_by_tag(tag, html)[index] + + +@typing.overload +def find_elements(*, cls: str, html=False): ... + + +@typing.overload +def find_elements(*, attr: str, value: str, tag: str | None = None, html=False): ... + + +def find_elements(*, tag=None, cls=None, attr=None, value=None, html=False): + # deliberately using `cls=` for ease of readability + assert cls or (attr and value), 'One of cls or (attr AND value) is required' + + if attr and value: + assert not cls, 'Cannot match both attr and cls' + func = get_elements_html_by_attribute if html else get_elements_by_attribute + return functools.partial(func, attr, value, tag=tag or r'[\w:.-]+') + + assert not tag, 'Cannot match both cls and tag' + func = get_elements_html_by_class if html else get_elements_by_class + return functools.partial(func, cls) + + def get_first(obj, *paths, **kwargs): return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)