]>
Commit | Line | Data |
---|---|---|
1 | from __future__ import annotations | |
2 | ||
3 | import io | |
4 | import logging | |
5 | import ssl | |
6 | import sys | |
7 | ||
8 | from ._helper import ( | |
9 | create_connection, | |
10 | create_socks_proxy_socket, | |
11 | make_socks_proxy_opts, | |
12 | select_proxy, | |
13 | ) | |
14 | from .common import Features, Response, register_rh | |
15 | from .exceptions import ( | |
16 | CertificateVerifyError, | |
17 | HTTPError, | |
18 | ProxyError, | |
19 | RequestError, | |
20 | SSLError, | |
21 | TransportError, | |
22 | ) | |
23 | from .websocket import WebSocketRequestHandler, WebSocketResponse | |
24 | from ..compat import functools | |
25 | from ..dependencies import websockets | |
26 | from ..socks import ProxyError as SocksProxyError | |
27 | from ..utils import int_or_none | |
28 | ||
29 | if not websockets: | |
30 | raise ImportError('websockets is not installed') | |
31 | ||
32 | import websockets.version | |
33 | ||
34 | websockets_version = tuple(map(int_or_none, websockets.version.version.split('.'))) | |
35 | if websockets_version < (12, 0): | |
36 | raise ImportError('Only websockets>=12.0 is supported') | |
37 | ||
38 | import websockets.sync.client | |
39 | from websockets.uri import parse_uri | |
40 | ||
41 | ||
42 | class WebsocketsResponseAdapter(WebSocketResponse): | |
43 | ||
44 | def __init__(self, wsw: websockets.sync.client.ClientConnection, url): | |
45 | super().__init__( | |
46 | fp=io.BytesIO(wsw.response.body or b''), | |
47 | url=url, | |
48 | headers=wsw.response.headers, | |
49 | status=wsw.response.status_code, | |
50 | reason=wsw.response.reason_phrase, | |
51 | ) | |
52 | self.wsw = wsw | |
53 | ||
54 | def close(self): | |
55 | self.wsw.close() | |
56 | super().close() | |
57 | ||
58 | def send(self, message): | |
59 | # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send | |
60 | try: | |
61 | return self.wsw.send(message) | |
62 | except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: | |
63 | raise TransportError(cause=e) from e | |
64 | except SocksProxyError as e: | |
65 | raise ProxyError(cause=e) from e | |
66 | except TypeError as e: | |
67 | raise RequestError(cause=e) from e | |
68 | ||
69 | def recv(self): | |
70 | # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv | |
71 | try: | |
72 | return self.wsw.recv() | |
73 | except SocksProxyError as e: | |
74 | raise ProxyError(cause=e) from e | |
75 | except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: | |
76 | raise TransportError(cause=e) from e | |
77 | ||
78 | ||
79 | @register_rh | |
80 | class WebsocketsRH(WebSocketRequestHandler): | |
81 | """ | |
82 | Websockets request handler | |
83 | https://websockets.readthedocs.io | |
84 | https://github.com/python-websockets/websockets | |
85 | """ | |
86 | _SUPPORTED_URL_SCHEMES = ('wss', 'ws') | |
87 | _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h') | |
88 | _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY) | |
89 | RH_NAME = 'websockets' | |
90 | ||
91 | def __init__(self, *args, **kwargs): | |
92 | super().__init__(*args, **kwargs) | |
93 | self.__logging_handlers = {} | |
94 | for name in ('websockets.client', 'websockets.server'): | |
95 | logger = logging.getLogger(name) | |
96 | handler = logging.StreamHandler(stream=sys.stdout) | |
97 | handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s')) | |
98 | self.__logging_handlers[name] = handler | |
99 | logger.addHandler(handler) | |
100 | if self.verbose: | |
101 | logger.setLevel(logging.DEBUG) | |
102 | ||
103 | def _check_extensions(self, extensions): | |
104 | super()._check_extensions(extensions) | |
105 | extensions.pop('timeout', None) | |
106 | extensions.pop('cookiejar', None) | |
107 | ||
108 | def close(self): | |
109 | # Remove the logging handler that contains a reference to our logger | |
110 | # See: https://github.com/yt-dlp/yt-dlp/issues/8922 | |
111 | for name, handler in self.__logging_handlers.items(): | |
112 | logging.getLogger(name).removeHandler(handler) | |
113 | ||
114 | def _send(self, request): | |
115 | timeout = float(request.extensions.get('timeout') or self.timeout) | |
116 | headers = self._merge_headers(request.headers) | |
117 | if 'cookie' not in headers: | |
118 | cookiejar = request.extensions.get('cookiejar') or self.cookiejar | |
119 | cookie_header = cookiejar.get_cookie_header(request.url) | |
120 | if cookie_header: | |
121 | headers['cookie'] = cookie_header | |
122 | ||
123 | wsuri = parse_uri(request.url) | |
124 | create_conn_kwargs = { | |
125 | 'source_address': (self.source_address, 0) if self.source_address else None, | |
126 | 'timeout': timeout | |
127 | } | |
128 | proxy = select_proxy(request.url, request.proxies or self.proxies or {}) | |
129 | try: | |
130 | if proxy: | |
131 | socks_proxy_options = make_socks_proxy_opts(proxy) | |
132 | sock = create_connection( | |
133 | address=(socks_proxy_options['addr'], socks_proxy_options['port']), | |
134 | _create_socket_func=functools.partial( | |
135 | create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options), | |
136 | **create_conn_kwargs | |
137 | ) | |
138 | else: | |
139 | sock = create_connection( | |
140 | address=(wsuri.host, wsuri.port), | |
141 | **create_conn_kwargs | |
142 | ) | |
143 | conn = websockets.sync.client.connect( | |
144 | sock=sock, | |
145 | uri=request.url, | |
146 | additional_headers=headers, | |
147 | open_timeout=timeout, | |
148 | user_agent_header=None, | |
149 | ssl_context=self._make_sslcontext() if wsuri.secure else None, | |
150 | close_timeout=0, # not ideal, but prevents yt-dlp hanging | |
151 | ) | |
152 | return WebsocketsResponseAdapter(conn, url=request.url) | |
153 | ||
154 | # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html | |
155 | except SocksProxyError as e: | |
156 | raise ProxyError(cause=e) from e | |
157 | except websockets.exceptions.InvalidURI as e: | |
158 | raise RequestError(cause=e) from e | |
159 | except ssl.SSLCertVerificationError as e: | |
160 | raise CertificateVerifyError(cause=e) from e | |
161 | except ssl.SSLError as e: | |
162 | raise SSLError(cause=e) from e | |
163 | except websockets.exceptions.InvalidStatus as e: | |
164 | raise HTTPError( | |
165 | Response( | |
166 | fp=io.BytesIO(e.response.body), | |
167 | url=request.url, | |
168 | headers=e.response.headers, | |
169 | status=e.response.status_code, | |
170 | reason=e.response.reason_phrase), | |
171 | ) from e | |
172 | except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e: | |
173 | raise TransportError(cause=e) from e |