]>
Commit | Line | Data |
---|---|---|
ccfd70f4 | 1 | from __future__ import annotations |
2 | ||
3 | import io | |
4 | import logging | |
5 | import ssl | |
6 | import sys | |
7 | ||
8 | from ._helper import create_connection, select_proxy, make_socks_proxy_opts, create_socks_proxy_socket | |
9 | from .common import Response, register_rh, Features | |
10 | from .exceptions import ( | |
11 | CertificateVerifyError, | |
12 | HTTPError, | |
13 | RequestError, | |
14 | SSLError, | |
15 | TransportError, ProxyError, | |
16 | ) | |
17 | from .websocket import WebSocketRequestHandler, WebSocketResponse | |
18 | from ..compat import functools | |
19 | from ..dependencies import websockets | |
20 | from ..utils import int_or_none | |
21 | from ..socks import ProxyError as SocksProxyError | |
22 | ||
23 | if not websockets: | |
24 | raise ImportError('websockets is not installed') | |
25 | ||
26 | import websockets.version | |
27 | ||
28 | websockets_version = tuple(map(int_or_none, websockets.version.version.split('.'))) | |
29 | if websockets_version < (12, 0): | |
30 | raise ImportError('Only websockets>=12.0 is supported') | |
31 | ||
32 | import websockets.sync.client | |
33 | from websockets.uri import parse_uri | |
34 | ||
35 | ||
36 | class WebsocketsResponseAdapter(WebSocketResponse): | |
37 | ||
38 | def __init__(self, wsw: websockets.sync.client.ClientConnection, url): | |
39 | super().__init__( | |
40 | fp=io.BytesIO(wsw.response.body or b''), | |
41 | url=url, | |
42 | headers=wsw.response.headers, | |
43 | status=wsw.response.status_code, | |
44 | reason=wsw.response.reason_phrase, | |
45 | ) | |
46 | self.wsw = wsw | |
47 | ||
48 | def close(self): | |
49 | self.wsw.close() | |
50 | super().close() | |
51 | ||
52 | def send(self, message): | |
53 | # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send | |
54 | try: | |
55 | return self.wsw.send(message) | |
56 | except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: | |
57 | raise TransportError(cause=e) from e | |
58 | except SocksProxyError as e: | |
59 | raise ProxyError(cause=e) from e | |
60 | except TypeError as e: | |
61 | raise RequestError(cause=e) from e | |
62 | ||
63 | def recv(self): | |
64 | # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv | |
65 | try: | |
66 | return self.wsw.recv() | |
67 | except SocksProxyError as e: | |
68 | raise ProxyError(cause=e) from e | |
69 | except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: | |
70 | raise TransportError(cause=e) from e | |
71 | ||
72 | ||
73 | @register_rh | |
74 | class WebsocketsRH(WebSocketRequestHandler): | |
75 | """ | |
76 | Websockets request handler | |
77 | https://websockets.readthedocs.io | |
78 | https://github.com/python-websockets/websockets | |
79 | """ | |
80 | _SUPPORTED_URL_SCHEMES = ('wss', 'ws') | |
81 | _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h') | |
82 | _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY) | |
83 | RH_NAME = 'websockets' | |
84 | ||
85 | def __init__(self, *args, **kwargs): | |
86 | super().__init__(*args, **kwargs) | |
87 | for name in ('websockets.client', 'websockets.server'): | |
88 | logger = logging.getLogger(name) | |
89 | handler = logging.StreamHandler(stream=sys.stdout) | |
90 | handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s')) | |
91 | logger.addHandler(handler) | |
92 | if self.verbose: | |
93 | logger.setLevel(logging.DEBUG) | |
94 | ||
95 | def _check_extensions(self, extensions): | |
96 | super()._check_extensions(extensions) | |
97 | extensions.pop('timeout', None) | |
98 | extensions.pop('cookiejar', None) | |
99 | ||
100 | def _send(self, request): | |
101 | timeout = float(request.extensions.get('timeout') or self.timeout) | |
102 | headers = self._merge_headers(request.headers) | |
103 | if 'cookie' not in headers: | |
104 | cookiejar = request.extensions.get('cookiejar') or self.cookiejar | |
105 | cookie_header = cookiejar.get_cookie_header(request.url) | |
106 | if cookie_header: | |
107 | headers['cookie'] = cookie_header | |
108 | ||
109 | wsuri = parse_uri(request.url) | |
110 | create_conn_kwargs = { | |
111 | 'source_address': (self.source_address, 0) if self.source_address else None, | |
112 | 'timeout': timeout | |
113 | } | |
114 | proxy = select_proxy(request.url, request.proxies or self.proxies or {}) | |
115 | try: | |
116 | if proxy: | |
117 | socks_proxy_options = make_socks_proxy_opts(proxy) | |
118 | sock = create_connection( | |
119 | address=(socks_proxy_options['addr'], socks_proxy_options['port']), | |
120 | _create_socket_func=functools.partial( | |
121 | create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options), | |
122 | **create_conn_kwargs | |
123 | ) | |
124 | else: | |
125 | sock = create_connection( | |
126 | address=(wsuri.host, wsuri.port), | |
127 | **create_conn_kwargs | |
128 | ) | |
129 | conn = websockets.sync.client.connect( | |
130 | sock=sock, | |
131 | uri=request.url, | |
132 | additional_headers=headers, | |
133 | open_timeout=timeout, | |
134 | user_agent_header=None, | |
135 | ssl_context=self._make_sslcontext() if wsuri.secure else None, | |
136 | close_timeout=0, # not ideal, but prevents yt-dlp hanging | |
137 | ) | |
138 | return WebsocketsResponseAdapter(conn, url=request.url) | |
139 | ||
140 | # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html | |
141 | except SocksProxyError as e: | |
142 | raise ProxyError(cause=e) from e | |
143 | except websockets.exceptions.InvalidURI as e: | |
144 | raise RequestError(cause=e) from e | |
145 | except ssl.SSLCertVerificationError as e: | |
146 | raise CertificateVerifyError(cause=e) from e | |
147 | except ssl.SSLError as e: | |
148 | raise SSLError(cause=e) from e | |
149 | except websockets.exceptions.InvalidStatus as e: | |
150 | raise HTTPError( | |
151 | Response( | |
152 | fp=io.BytesIO(e.response.body), | |
153 | url=request.url, | |
154 | headers=e.response.headers, | |
155 | status=e.response.status_code, | |
156 | reason=e.response.reason_phrase), | |
157 | ) from e | |
158 | except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e: | |
159 | raise TransportError(cause=e) from e |