mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2024-11-07 20:30:41 -05:00
[networking] Add request handler preference framework (#7603)
Preference functions that take a request and a request handler instance can be registered to prioritize different request handlers per request. Authored by: coletdjnz Co-authored-by: pukkandan <pukkandan.ytdlp@gmail.com>
This commit is contained in:
parent
db97438940
commit
db7b054a61
3 changed files with 65 additions and 11 deletions
|
@ -1035,17 +1035,17 @@ def test_send(self):
|
||||||
assert isinstance(director.send(Request('http://')), FakeResponse)
|
assert isinstance(director.send(Request('http://')), FakeResponse)
|
||||||
|
|
||||||
def test_unsupported_handlers(self):
|
def test_unsupported_handlers(self):
|
||||||
director = RequestDirector(logger=FakeLogger())
|
|
||||||
director.add_handler(FakeRH(logger=FakeLogger()))
|
|
||||||
|
|
||||||
class SupportedRH(RequestHandler):
|
class SupportedRH(RequestHandler):
|
||||||
_SUPPORTED_URL_SCHEMES = ['http']
|
_SUPPORTED_URL_SCHEMES = ['http']
|
||||||
|
|
||||||
def _send(self, request: Request):
|
def _send(self, request: Request):
|
||||||
return Response(fp=io.BytesIO(b'supported'), headers={}, url=request.url)
|
return Response(fp=io.BytesIO(b'supported'), headers={}, url=request.url)
|
||||||
|
|
||||||
# This handler should by default take preference over FakeRH
|
director = RequestDirector(logger=FakeLogger())
|
||||||
director.add_handler(SupportedRH(logger=FakeLogger()))
|
director.add_handler(SupportedRH(logger=FakeLogger()))
|
||||||
|
director.add_handler(FakeRH(logger=FakeLogger()))
|
||||||
|
|
||||||
|
# First should take preference
|
||||||
assert director.send(Request('http://')).read() == b'supported'
|
assert director.send(Request('http://')).read() == b'supported'
|
||||||
assert director.send(Request('any://')).read() == b''
|
assert director.send(Request('any://')).read() == b''
|
||||||
|
|
||||||
|
@ -1072,6 +1072,27 @@ def _send(self, request: Request):
|
||||||
director.add_handler(UnexpectedRH(logger=FakeLogger))
|
director.add_handler(UnexpectedRH(logger=FakeLogger))
|
||||||
assert director.send(Request('any://'))
|
assert director.send(Request('any://'))
|
||||||
|
|
||||||
|
def test_preference(self):
|
||||||
|
director = RequestDirector(logger=FakeLogger())
|
||||||
|
director.add_handler(FakeRH(logger=FakeLogger()))
|
||||||
|
|
||||||
|
class SomeRH(RequestHandler):
|
||||||
|
_SUPPORTED_URL_SCHEMES = ['http']
|
||||||
|
|
||||||
|
def _send(self, request: Request):
|
||||||
|
return Response(fp=io.BytesIO(b'supported'), headers={}, url=request.url)
|
||||||
|
|
||||||
|
def some_preference(rh, request):
|
||||||
|
return (0 if not isinstance(rh, SomeRH)
|
||||||
|
else 100 if 'prefer' in request.headers
|
||||||
|
else -1)
|
||||||
|
|
||||||
|
director.add_handler(SomeRH(logger=FakeLogger()))
|
||||||
|
director.preferences.add(some_preference)
|
||||||
|
|
||||||
|
assert director.send(Request('http://')).read() == b''
|
||||||
|
assert director.send(Request('http://', headers={'prefer': '1'})).read() == b'supported'
|
||||||
|
|
||||||
|
|
||||||
# XXX: do we want to move this to test_YoutubeDL.py?
|
# XXX: do we want to move this to test_YoutubeDL.py?
|
||||||
class TestYoutubeDLNetworking:
|
class TestYoutubeDLNetworking:
|
||||||
|
|
|
@ -34,7 +34,7 @@
|
||||||
from .extractor.openload import PhantomJSwrapper
|
from .extractor.openload import PhantomJSwrapper
|
||||||
from .minicurses import format_text
|
from .minicurses import format_text
|
||||||
from .networking import HEADRequest, Request, RequestDirector
|
from .networking import HEADRequest, Request, RequestDirector
|
||||||
from .networking.common import _REQUEST_HANDLERS
|
from .networking.common import _REQUEST_HANDLERS, _RH_PREFERENCES
|
||||||
from .networking.exceptions import (
|
from .networking.exceptions import (
|
||||||
HTTPError,
|
HTTPError,
|
||||||
NoSupportingHandlers,
|
NoSupportingHandlers,
|
||||||
|
@ -683,7 +683,7 @@ def process_color_policy(stream):
|
||||||
self.params['http_headers'] = HTTPHeaderDict(std_headers, self.params.get('http_headers'))
|
self.params['http_headers'] = HTTPHeaderDict(std_headers, self.params.get('http_headers'))
|
||||||
self._load_cookies(self.params['http_headers'].get('Cookie')) # compat
|
self._load_cookies(self.params['http_headers'].get('Cookie')) # compat
|
||||||
self.params['http_headers'].pop('Cookie', None)
|
self.params['http_headers'].pop('Cookie', None)
|
||||||
self._request_director = self.build_request_director(_REQUEST_HANDLERS.values())
|
self._request_director = self.build_request_director(_REQUEST_HANDLERS.values(), _RH_PREFERENCES)
|
||||||
|
|
||||||
if auto_init and auto_init != 'no_verbose_header':
|
if auto_init and auto_init != 'no_verbose_header':
|
||||||
self.print_debug_header()
|
self.print_debug_header()
|
||||||
|
@ -4077,7 +4077,7 @@ def urlopen(self, req):
|
||||||
except HTTPError as e: # TODO: Remove in a future release
|
except HTTPError as e: # TODO: Remove in a future release
|
||||||
raise _CompatHTTPError(e) from e
|
raise _CompatHTTPError(e) from e
|
||||||
|
|
||||||
def build_request_director(self, handlers):
|
def build_request_director(self, handlers, preferences=None):
|
||||||
logger = _YDLLogger(self)
|
logger = _YDLLogger(self)
|
||||||
headers = self.params['http_headers'].copy()
|
headers = self.params['http_headers'].copy()
|
||||||
proxies = self.proxies.copy()
|
proxies = self.proxies.copy()
|
||||||
|
@ -4106,6 +4106,7 @@ def build_request_director(self, handlers):
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
))
|
))
|
||||||
|
director.preferences.update(preferences or [])
|
||||||
return director
|
return director
|
||||||
|
|
||||||
def encode(self, s):
|
def encode(self, s):
|
||||||
|
|
|
@ -31,8 +31,19 @@
|
||||||
)
|
)
|
||||||
from ..utils.networking import HTTPHeaderDict, normalize_url
|
from ..utils.networking import HTTPHeaderDict, normalize_url
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
RequestData = bytes | Iterable[bytes] | typing.IO | None
|
def register_preference(*handlers: type[RequestHandler]):
|
||||||
|
assert all(issubclass(handler, RequestHandler) for handler in handlers)
|
||||||
|
|
||||||
|
def outer(preference: Preference):
|
||||||
|
@functools.wraps(preference)
|
||||||
|
def inner(handler, *args, **kwargs):
|
||||||
|
if not handlers or isinstance(handler, handlers):
|
||||||
|
return preference(handler, *args, **kwargs)
|
||||||
|
return 0
|
||||||
|
_RH_PREFERENCES.add(inner)
|
||||||
|
return inner
|
||||||
|
return outer
|
||||||
|
|
||||||
|
|
||||||
class RequestDirector:
|
class RequestDirector:
|
||||||
|
@ -40,12 +51,17 @@ class RequestDirector:
|
||||||
|
|
||||||
Helper class that, when given a request, forward it to a RequestHandler that supports it.
|
Helper class that, when given a request, forward it to a RequestHandler that supports it.
|
||||||
|
|
||||||
|
Preference functions in the form of func(handler, request) -> int
|
||||||
|
can be registered into the `preferences` set. These are used to sort handlers
|
||||||
|
in order of preference.
|
||||||
|
|
||||||
@param logger: Logger instance.
|
@param logger: Logger instance.
|
||||||
@param verbose: Print debug request information to stdout.
|
@param verbose: Print debug request information to stdout.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, logger, verbose=False):
|
def __init__(self, logger, verbose=False):
|
||||||
self.handlers: dict[str, RequestHandler] = {}
|
self.handlers: dict[str, RequestHandler] = {}
|
||||||
|
self.preferences: set[Preference] = set()
|
||||||
self.logger = logger # TODO(Grub4k): default logger
|
self.logger = logger # TODO(Grub4k): default logger
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
|
@ -58,6 +74,16 @@ def add_handler(self, handler: RequestHandler):
|
||||||
assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'
|
assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'
|
||||||
self.handlers[handler.RH_KEY] = handler
|
self.handlers[handler.RH_KEY] = handler
|
||||||
|
|
||||||
|
def _get_handlers(self, request: Request) -> list[RequestHandler]:
|
||||||
|
"""Sorts handlers by preference, given a request"""
|
||||||
|
preferences = {
|
||||||
|
rh: sum(pref(rh, request) for pref in self.preferences)
|
||||||
|
for rh in self.handlers.values()
|
||||||
|
}
|
||||||
|
self._print_verbose('Handler preferences for this request: %s' % ', '.join(
|
||||||
|
f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items()))
|
||||||
|
return sorted(self.handlers.values(), key=preferences.get, reverse=True)
|
||||||
|
|
||||||
def _print_verbose(self, msg):
|
def _print_verbose(self, msg):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self.logger.stdout(f'director: {msg}')
|
self.logger.stdout(f'director: {msg}')
|
||||||
|
@ -73,8 +99,7 @@ def send(self, request: Request) -> Response:
|
||||||
|
|
||||||
unexpected_errors = []
|
unexpected_errors = []
|
||||||
unsupported_errors = []
|
unsupported_errors = []
|
||||||
# TODO (future): add a per-request preference system
|
for handler in self._get_handlers(request):
|
||||||
for handler in reversed(list(self.handlers.values())):
|
|
||||||
self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')
|
self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')
|
||||||
try:
|
try:
|
||||||
handler.validate(request)
|
handler.validate(request)
|
||||||
|
@ -530,3 +555,10 @@ def info(self):
|
||||||
def getheader(self, name, default=None):
|
def getheader(self, name, default=None):
|
||||||
deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2)
|
deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2)
|
||||||
return self.get_header(name, default)
|
return self.get_header(name, default)
|
||||||
|
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
RequestData = bytes | Iterable[bytes] | typing.IO | None
|
||||||
|
Preference = typing.Callable[[RequestHandler, Request], int]
|
||||||
|
|
||||||
|
_RH_PREFERENCES: set[Preference] = set()
|
||||||
|
|
Loading…
Reference in a new issue