]> jfr.im git - yt-dlp.git/blob - yt_dlp/networking/_websockets.py
[cleanup] Add more ruff rules (#10149)
[yt-dlp.git] / yt_dlp / networking / _websockets.py
1 from __future__ import annotations
2
3 import contextlib
4 import io
5 import logging
6 import ssl
7 import sys
8
9 from ._helper import (
10 create_connection,
11 create_socks_proxy_socket,
12 make_socks_proxy_opts,
13 select_proxy,
14 )
15 from .common import Features, Response, register_rh
16 from .exceptions import (
17 CertificateVerifyError,
18 HTTPError,
19 ProxyError,
20 RequestError,
21 SSLError,
22 TransportError,
23 )
24 from .websocket import WebSocketRequestHandler, WebSocketResponse
25 from ..compat import functools
26 from ..dependencies import websockets
27 from ..socks import ProxyError as SocksProxyError
28 from ..utils import int_or_none
29
30 if not websockets:
31 raise ImportError('websockets is not installed')
32
33 import websockets.version
34
35 websockets_version = tuple(map(int_or_none, websockets.version.version.split('.')))
36 if websockets_version < (12, 0):
37 raise ImportError('Only websockets>=12.0 is supported')
38
39 import websockets.sync.client
40 from websockets.uri import parse_uri
41
42 # In websockets Connection, recv_exc and recv_events_exc are defined
43 # after the recv events handler thread is started [1].
44 # On our CI using PyPy, in some cases a race condition may occur
45 # where the recv events handler thread tries to use these attributes before they are defined [2].
46 # 1: https://github.com/python-websockets/websockets/blame/de768cf65e7e2b1a3b67854fb9e08816a5ff7050/src/websockets/sync/connection.py#L93
47 # 2: "AttributeError: 'ClientConnection' object has no attribute 'recv_events_exc'. Did you mean: 'recv_events'?"
48 import websockets.sync.connection # isort: split
49 with contextlib.suppress(Exception):
50 # > 12.0
51 websockets.sync.connection.Connection.recv_exc = None
52 # 12.0
53 websockets.sync.connection.Connection.recv_events_exc = None
54
55
56 class WebsocketsResponseAdapter(WebSocketResponse):
57
58 def __init__(self, ws: websockets.sync.client.ClientConnection, url):
59 super().__init__(
60 fp=io.BytesIO(ws.response.body or b''),
61 url=url,
62 headers=ws.response.headers,
63 status=ws.response.status_code,
64 reason=ws.response.reason_phrase,
65 )
66 self._ws = ws
67
68 def close(self):
69 self._ws.close()
70 super().close()
71
72 def send(self, message):
73 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
74 try:
75 return self._ws.send(message)
76 except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
77 raise TransportError(cause=e) from e
78 except SocksProxyError as e:
79 raise ProxyError(cause=e) from e
80 except TypeError as e:
81 raise RequestError(cause=e) from e
82
83 def recv(self):
84 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
85 try:
86 return self._ws.recv()
87 except SocksProxyError as e:
88 raise ProxyError(cause=e) from e
89 except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
90 raise TransportError(cause=e) from e
91
92
93 @register_rh
94 class WebsocketsRH(WebSocketRequestHandler):
95 """
96 Websockets request handler
97 https://websockets.readthedocs.io
98 https://github.com/python-websockets/websockets
99 """
100 _SUPPORTED_URL_SCHEMES = ('wss', 'ws')
101 _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
102 _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
103 RH_NAME = 'websockets'
104
105 def __init__(self, *args, **kwargs):
106 super().__init__(*args, **kwargs)
107 self.__logging_handlers = {}
108 for name in ('websockets.client', 'websockets.server'):
109 logger = logging.getLogger(name)
110 handler = logging.StreamHandler(stream=sys.stdout)
111 handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
112 self.__logging_handlers[name] = handler
113 logger.addHandler(handler)
114 if self.verbose:
115 logger.setLevel(logging.DEBUG)
116
117 def _check_extensions(self, extensions):
118 super()._check_extensions(extensions)
119 extensions.pop('timeout', None)
120 extensions.pop('cookiejar', None)
121
122 def close(self):
123 # Remove the logging handler that contains a reference to our logger
124 # See: https://github.com/yt-dlp/yt-dlp/issues/8922
125 for name, handler in self.__logging_handlers.items():
126 logging.getLogger(name).removeHandler(handler)
127
128 def _send(self, request):
129 timeout = self._calculate_timeout(request)
130 headers = self._merge_headers(request.headers)
131 if 'cookie' not in headers:
132 cookiejar = self._get_cookiejar(request)
133 cookie_header = cookiejar.get_cookie_header(request.url)
134 if cookie_header:
135 headers['cookie'] = cookie_header
136
137 wsuri = parse_uri(request.url)
138 create_conn_kwargs = {
139 'source_address': (self.source_address, 0) if self.source_address else None,
140 'timeout': timeout,
141 }
142 proxy = select_proxy(request.url, self._get_proxies(request))
143 try:
144 if proxy:
145 socks_proxy_options = make_socks_proxy_opts(proxy)
146 sock = create_connection(
147 address=(socks_proxy_options['addr'], socks_proxy_options['port']),
148 _create_socket_func=functools.partial(
149 create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
150 **create_conn_kwargs,
151 )
152 else:
153 sock = create_connection(
154 address=(wsuri.host, wsuri.port),
155 **create_conn_kwargs,
156 )
157 conn = websockets.sync.client.connect(
158 sock=sock,
159 uri=request.url,
160 additional_headers=headers,
161 open_timeout=timeout,
162 user_agent_header=None,
163 ssl_context=self._make_sslcontext() if wsuri.secure else None,
164 close_timeout=0, # not ideal, but prevents yt-dlp hanging
165 )
166 return WebsocketsResponseAdapter(conn, url=request.url)
167
168 # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
169 except SocksProxyError as e:
170 raise ProxyError(cause=e) from e
171 except websockets.exceptions.InvalidURI as e:
172 raise RequestError(cause=e) from e
173 except ssl.SSLCertVerificationError as e:
174 raise CertificateVerifyError(cause=e) from e
175 except ssl.SSLError as e:
176 raise SSLError(cause=e) from e
177 except websockets.exceptions.InvalidStatus as e:
178 raise HTTPError(
179 Response(
180 fp=io.BytesIO(e.response.body),
181 url=request.url,
182 headers=e.response.headers,
183 status=e.response.status_code,
184 reason=e.response.reason_phrase),
185 ) from e
186 except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
187 raise TransportError(cause=e) from e