]> jfr.im git - yt-dlp.git/blob - yt_dlp/networking/_websockets.py
[rh:websockets] Migrate websockets to networking framework (#7720)
[yt-dlp.git] / yt_dlp / networking / _websockets.py
1 from __future__ import annotations
2
3 import io
4 import logging
5 import ssl
6 import sys
7
8 from ._helper import create_connection, select_proxy, make_socks_proxy_opts, create_socks_proxy_socket
9 from .common import Response, register_rh, Features
10 from .exceptions import (
11 CertificateVerifyError,
12 HTTPError,
13 RequestError,
14 SSLError,
15 TransportError, ProxyError,
16 )
17 from .websocket import WebSocketRequestHandler, WebSocketResponse
18 from ..compat import functools
19 from ..dependencies import websockets
20 from ..utils import int_or_none
21 from ..socks import ProxyError as SocksProxyError
22
23 if not websockets:
24 raise ImportError('websockets is not installed')
25
26 import websockets.version
27
28 websockets_version = tuple(map(int_or_none, websockets.version.version.split('.')))
29 if websockets_version < (12, 0):
30 raise ImportError('Only websockets>=12.0 is supported')
31
32 import websockets.sync.client
33 from websockets.uri import parse_uri
34
35
36 class WebsocketsResponseAdapter(WebSocketResponse):
37
38 def __init__(self, wsw: websockets.sync.client.ClientConnection, url):
39 super().__init__(
40 fp=io.BytesIO(wsw.response.body or b''),
41 url=url,
42 headers=wsw.response.headers,
43 status=wsw.response.status_code,
44 reason=wsw.response.reason_phrase,
45 )
46 self.wsw = wsw
47
48 def close(self):
49 self.wsw.close()
50 super().close()
51
52 def send(self, message):
53 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
54 try:
55 return self.wsw.send(message)
56 except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
57 raise TransportError(cause=e) from e
58 except SocksProxyError as e:
59 raise ProxyError(cause=e) from e
60 except TypeError as e:
61 raise RequestError(cause=e) from e
62
63 def recv(self):
64 # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
65 try:
66 return self.wsw.recv()
67 except SocksProxyError as e:
68 raise ProxyError(cause=e) from e
69 except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
70 raise TransportError(cause=e) from e
71
72
73 @register_rh
74 class WebsocketsRH(WebSocketRequestHandler):
75 """
76 Websockets request handler
77 https://websockets.readthedocs.io
78 https://github.com/python-websockets/websockets
79 """
80 _SUPPORTED_URL_SCHEMES = ('wss', 'ws')
81 _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
82 _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
83 RH_NAME = 'websockets'
84
85 def __init__(self, *args, **kwargs):
86 super().__init__(*args, **kwargs)
87 for name in ('websockets.client', 'websockets.server'):
88 logger = logging.getLogger(name)
89 handler = logging.StreamHandler(stream=sys.stdout)
90 handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
91 logger.addHandler(handler)
92 if self.verbose:
93 logger.setLevel(logging.DEBUG)
94
95 def _check_extensions(self, extensions):
96 super()._check_extensions(extensions)
97 extensions.pop('timeout', None)
98 extensions.pop('cookiejar', None)
99
100 def _send(self, request):
101 timeout = float(request.extensions.get('timeout') or self.timeout)
102 headers = self._merge_headers(request.headers)
103 if 'cookie' not in headers:
104 cookiejar = request.extensions.get('cookiejar') or self.cookiejar
105 cookie_header = cookiejar.get_cookie_header(request.url)
106 if cookie_header:
107 headers['cookie'] = cookie_header
108
109 wsuri = parse_uri(request.url)
110 create_conn_kwargs = {
111 'source_address': (self.source_address, 0) if self.source_address else None,
112 'timeout': timeout
113 }
114 proxy = select_proxy(request.url, request.proxies or self.proxies or {})
115 try:
116 if proxy:
117 socks_proxy_options = make_socks_proxy_opts(proxy)
118 sock = create_connection(
119 address=(socks_proxy_options['addr'], socks_proxy_options['port']),
120 _create_socket_func=functools.partial(
121 create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
122 **create_conn_kwargs
123 )
124 else:
125 sock = create_connection(
126 address=(wsuri.host, wsuri.port),
127 **create_conn_kwargs
128 )
129 conn = websockets.sync.client.connect(
130 sock=sock,
131 uri=request.url,
132 additional_headers=headers,
133 open_timeout=timeout,
134 user_agent_header=None,
135 ssl_context=self._make_sslcontext() if wsuri.secure else None,
136 close_timeout=0, # not ideal, but prevents yt-dlp hanging
137 )
138 return WebsocketsResponseAdapter(conn, url=request.url)
139
140 # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
141 except SocksProxyError as e:
142 raise ProxyError(cause=e) from e
143 except websockets.exceptions.InvalidURI as e:
144 raise RequestError(cause=e) from e
145 except ssl.SSLCertVerificationError as e:
146 raise CertificateVerifyError(cause=e) from e
147 except ssl.SSLError as e:
148 raise SSLError(cause=e) from e
149 except websockets.exceptions.InvalidStatus as e:
150 raise HTTPError(
151 Response(
152 fp=io.BytesIO(e.response.body),
153 url=request.url,
154 headers=e.response.headers,
155 status=e.response.status_code,
156 reason=e.response.reason_phrase),
157 ) from e
158 except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
159 raise TransportError(cause=e) from e