+class IPv6ThreadingTCPServer(ThreadingTCPServer):
+ address_family = socket.AF_INET6
+
+
+class SocksHTTPTestRequestHandler(http.server.BaseHTTPRequestHandler, SocksTestRequestHandler):
+ def do_GET(self):
+ if self.path == '/socks_info':
+ payload = json.dumps(self.socks_info.copy())
+ self.send_response(200)
+ self.send_header('Content-Type', 'application/json; charset=utf-8')
+ self.send_header('Content-Length', str(len(payload)))
+ self.end_headers()
+ self.wfile.write(payload.encode())
+
+
+class SocksWebSocketTestRequestHandler(SocksTestRequestHandler):
+ def handle(self):
+ import websockets.sync.server
+ protocol = websockets.ServerProtocol()
+ connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
+ connection.handshake()
+ connection.send(json.dumps(self.socks_info))
+ connection.close()
+
+
+@contextlib.contextmanager
+def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs):
+ server = server_thread = None
+ try:
+ bind_address = bind_ip or '127.0.0.1'
+ server_type = ThreadingTCPServer if '.' in bind_address else IPv6ThreadingTCPServer
+ server = server_type(
+ (bind_address, 0), functools.partial(socks_server_class, request_handler, socks_server_kwargs))
+ server_port = http_server_port(server)
+ server_thread = threading.Thread(target=server.serve_forever)
+ server_thread.daemon = True
+ server_thread.start()
+ if '.' not in bind_address:
+ yield f'[{bind_address}]:{server_port}'
+ else:
+ yield f'{bind_address}:{server_port}'
+ finally:
+ server.shutdown()
+ server.server_close()
+ server_thread.join(2.0)
+
+
+class SocksProxyTestContext(abc.ABC):
+ REQUEST_HANDLER_CLASS = None
+
+ def socks_server(self, server_class, *args, **kwargs):
+ return socks_server(server_class, self.REQUEST_HANDLER_CLASS, *args, **kwargs)
+
+ @abc.abstractmethod
+ def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs) -> dict:
+ """return a dict of socks_info"""
+
+
+class HTTPSocksTestProxyContext(SocksProxyTestContext):
+ REQUEST_HANDLER_CLASS = SocksHTTPTestRequestHandler
+
+ def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
+ request = Request(f'http://{target_domain or "127.0.0.1"}:{target_port or "40000"}/socks_info', **req_kwargs)
+ handler.validate(request)
+ return json.loads(handler.send(request).read().decode())
+
+
+class WebSocketSocksTestProxyContext(SocksProxyTestContext):
+ REQUEST_HANDLER_CLASS = SocksWebSocketTestRequestHandler
+
+ def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
+ request = Request(f'ws://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs)
+ handler.validate(request)
+ ws = handler.send(request)
+ ws.send('socks_info')
+ socks_info = ws.recv()
+ ws.close()
+ return json.loads(socks_info)
+
+
+CTX_MAP = {
+ 'http': HTTPSocksTestProxyContext,
+ 'ws': WebSocketSocksTestProxyContext,
+}
+
+
+@pytest.fixture(scope='module')
+def ctx(request):
+ return CTX_MAP[request.param]()
+
+
+@pytest.mark.parametrize(
+ 'handler,ctx', [
+ ('Urllib', 'http'),
+ ('Requests', 'http'),
+ ('Websockets', 'ws'),
+ ('CurlCFFI', 'http')
+ ], indirect=True)
+class TestSocks4Proxy:
+ def test_socks4_no_auth(self, handler, ctx):
+ with handler() as rh:
+ with ctx.socks_server(Socks4ProxyHandler) as server_address:
+ response = ctx.socks_info_request(
+ rh, proxies={'all': f'socks4://{server_address}'})
+ assert response['version'] == 4
+
+ def test_socks4_auth(self, handler, ctx):
+ with handler() as rh:
+ with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address:
+ with pytest.raises(ProxyError):
+ ctx.socks_info_request(rh, proxies={'all': f'socks4://{server_address}'})
+ response = ctx.socks_info_request(
+ rh, proxies={'all': f'socks4://user:@{server_address}'})
+ assert response['version'] == 4
+
+ def test_socks4a_ipv4_target(self, handler, ctx):
+ with ctx.socks_server(Socks4ProxyHandler) as server_address:
+ with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
+ response = ctx.socks_info_request(rh, target_domain='127.0.0.1')
+ assert response['version'] == 4
+ assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1')
+
+ def test_socks4a_domain_target(self, handler, ctx):
+ with ctx.socks_server(Socks4ProxyHandler) as server_address:
+ with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
+ response = ctx.socks_info_request(rh, target_domain='localhost')
+ assert response['version'] == 4
+ assert response['ipv4_address'] is None
+ assert response['domain_address'] == 'localhost'
+
+ def test_ipv4_client_source_address(self, handler, ctx):
+ with ctx.socks_server(Socks4ProxyHandler) as server_address:
+ source_address = f'127.0.0.{random.randint(5, 255)}'
+ verify_address_availability(source_address)
+ with handler(proxies={'all': f'socks4://{server_address}'},
+ source_address=source_address) as rh:
+ response = ctx.socks_info_request(rh)
+ assert response['client_address'][0] == source_address
+ assert response['version'] == 4
+
+ @pytest.mark.parametrize('reply_code', [
+ Socks4CD.REQUEST_REJECTED_OR_FAILED,
+ Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD,
+ Socks4CD.REQUEST_REJECTED_DIFFERENT_USERID,
+ ])
+ def test_socks4_errors(self, handler, ctx, reply_code):
+ with ctx.socks_server(Socks4ProxyHandler, cd_reply=reply_code) as server_address:
+ with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
+ with pytest.raises(ProxyError):
+ ctx.socks_info_request(rh)
+
+ def test_ipv6_socks4_proxy(self, handler, ctx):
+ with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address:
+ with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
+ response = ctx.socks_info_request(rh, target_domain='127.0.0.1')
+ assert response['client_address'][0] == '::1'
+ assert response['ipv4_address'] == '127.0.0.1'
+ assert response['version'] == 4
+
+ def test_timeout(self, handler, ctx):
+ with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address:
+ with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh:
+ with pytest.raises(TransportError):
+ ctx.socks_info_request(rh)
+
+
+@pytest.mark.parametrize(
+ 'handler,ctx', [
+ ('Urllib', 'http'),
+ ('Requests', 'http'),
+ ('Websockets', 'ws'),
+ ('CurlCFFI', 'http')
+ ], indirect=True)
+class TestSocks5Proxy:
+
+ def test_socks5_no_auth(self, handler, ctx):
+ with ctx.socks_server(Socks5ProxyHandler) as server_address:
+ with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
+ response = ctx.socks_info_request(rh)
+ assert response['auth_methods'] == [0x0]
+ assert response['version'] == 5
+
+ def test_socks5_user_pass(self, handler, ctx):
+ with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address:
+ with handler() as rh:
+ with pytest.raises(ProxyError):
+ ctx.socks_info_request(rh, proxies={'all': f'socks5://{server_address}'})
+
+ response = ctx.socks_info_request(
+ rh, proxies={'all': f'socks5://test:testpass@{server_address}'})
+
+ assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS]
+ assert response['version'] == 5
+
+ def test_socks5_ipv4_target(self, handler, ctx):
+ with ctx.socks_server(Socks5ProxyHandler) as server_address:
+ with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
+ response = ctx.socks_info_request(rh, target_domain='127.0.0.1')
+ assert response['ipv4_address'] == '127.0.0.1'
+ assert response['version'] == 5
+
+ def test_socks5_domain_target(self, handler, ctx):
+ with ctx.socks_server(Socks5ProxyHandler) as server_address:
+ with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
+ response = ctx.socks_info_request(rh, target_domain='localhost')
+ assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1')
+ assert response['version'] == 5
+
+ def test_socks5h_domain_target(self, handler, ctx):
+ with ctx.socks_server(Socks5ProxyHandler) as server_address:
+ with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
+ response = ctx.socks_info_request(rh, target_domain='localhost')
+ assert response['ipv4_address'] is None
+ assert response['domain_address'] == 'localhost'
+ assert response['version'] == 5