]>
Commit | Line | Data |
---|---|---|
1 | from __future__ import annotations | |
2 | ||
3 | import re | |
4 | from abc import ABC | |
5 | from dataclasses import dataclass | |
6 | from typing import Any | |
7 | ||
8 | from .common import RequestHandler, register_preference | |
9 | from .exceptions import UnsupportedRequest | |
10 | from ..compat.types import NoneType | |
11 | from ..utils import classproperty, join_nonempty | |
12 | from ..utils.networking import std_headers | |
13 | ||
14 | ||
15 | @dataclass(order=True, frozen=True) | |
16 | class ImpersonateTarget: | |
17 | """ | |
18 | A target for browser impersonation. | |
19 | ||
20 | Parameters: | |
21 | @param client: the client to impersonate | |
22 | @param version: the client version to impersonate | |
23 | @param os: the client OS to impersonate | |
24 | @param os_version: the client OS version to impersonate | |
25 | ||
26 | Note: None is used to indicate to match any. | |
27 | ||
28 | """ | |
29 | client: str | None = None | |
30 | version: str | None = None | |
31 | os: str | None = None | |
32 | os_version: str | None = None | |
33 | ||
34 | def __post_init__(self): | |
35 | if self.version and not self.client: | |
36 | raise ValueError('client is required if version is set') | |
37 | if self.os_version and not self.os: | |
38 | raise ValueError('os is required if os_version is set') | |
39 | ||
40 | def __contains__(self, target: ImpersonateTarget): | |
41 | if not isinstance(target, ImpersonateTarget): | |
42 | return False | |
43 | return ( | |
44 | (self.client is None or target.client is None or self.client == target.client) | |
45 | and (self.version is None or target.version is None or self.version == target.version) | |
46 | and (self.os is None or target.os is None or self.os == target.os) | |
47 | and (self.os_version is None or target.os_version is None or self.os_version == target.os_version) | |
48 | ) | |
49 | ||
50 | def __str__(self): | |
51 | return f'{join_nonempty(self.client, self.version)}:{join_nonempty(self.os, self.os_version)}'.rstrip(':') | |
52 | ||
53 | @classmethod | |
54 | def from_str(cls, target: str): | |
55 | mobj = re.fullmatch(r'(?:(?P<client>[^:-]+)(?:-(?P<version>[^:-]+))?)?(?::(?:(?P<os>[^:-]+)(?:-(?P<os_version>[^:-]+))?)?)?', target) | |
56 | if not mobj: | |
57 | raise ValueError(f'Invalid impersonate target "{target}"') | |
58 | return cls(**mobj.groupdict()) | |
59 | ||
60 | ||
61 | class ImpersonateRequestHandler(RequestHandler, ABC): | |
62 | """ | |
63 | Base class for request handlers that support browser impersonation. | |
64 | ||
65 | This provides a method for checking the validity of the impersonate extension, | |
66 | which can be used in _check_extensions. | |
67 | ||
68 | Impersonate targets consist of a client, version, os and os_ver. | |
69 | See the ImpersonateTarget class for more details. | |
70 | ||
71 | The following may be defined: | |
72 | - `_SUPPORTED_IMPERSONATE_TARGET_MAP`: a dict mapping supported targets to custom object. | |
73 | Any Request with an impersonate target not in this list will raise an UnsupportedRequest. | |
74 | Set to None to disable this check. | |
75 | Note: Entries are in order of preference | |
76 | ||
77 | Parameters: | |
78 | @param impersonate: the default impersonate target to use for requests. | |
79 | Set to None to disable impersonation. | |
80 | """ | |
81 | _SUPPORTED_IMPERSONATE_TARGET_MAP: dict[ImpersonateTarget, Any] = {} | |
82 | ||
83 | def __init__(self, *, impersonate: ImpersonateTarget = None, **kwargs): | |
84 | super().__init__(**kwargs) | |
85 | self.impersonate = impersonate | |
86 | ||
87 | def _check_impersonate_target(self, target: ImpersonateTarget): | |
88 | assert isinstance(target, (ImpersonateTarget, NoneType)) | |
89 | if target is None or not self.supported_targets: | |
90 | return | |
91 | if not self.is_supported_target(target): | |
92 | raise UnsupportedRequest(f'Unsupported impersonate target: {target}') | |
93 | ||
94 | def _check_extensions(self, extensions): | |
95 | super()._check_extensions(extensions) | |
96 | if 'impersonate' in extensions: | |
97 | self._check_impersonate_target(extensions.get('impersonate')) | |
98 | ||
99 | def _validate(self, request): | |
100 | super()._validate(request) | |
101 | self._check_impersonate_target(self.impersonate) | |
102 | ||
103 | def _resolve_target(self, target: ImpersonateTarget | None): | |
104 | """Resolve a target to a supported target.""" | |
105 | if target is None: | |
106 | return | |
107 | for supported_target in self.supported_targets: | |
108 | if target in supported_target: | |
109 | if self.verbose: | |
110 | self._logger.stdout( | |
111 | f'{self.RH_NAME}: resolved impersonate target {target} to {supported_target}') | |
112 | return supported_target | |
113 | ||
114 | @classproperty | |
115 | def supported_targets(cls) -> tuple[ImpersonateTarget, ...]: | |
116 | return tuple(cls._SUPPORTED_IMPERSONATE_TARGET_MAP.keys()) | |
117 | ||
118 | def is_supported_target(self, target: ImpersonateTarget): | |
119 | assert isinstance(target, ImpersonateTarget) | |
120 | return self._resolve_target(target) is not None | |
121 | ||
122 | def _get_request_target(self, request): | |
123 | """Get the requested target for the request""" | |
124 | return self._resolve_target(request.extensions.get('impersonate') or self.impersonate) | |
125 | ||
126 | def _get_impersonate_headers(self, request): | |
127 | headers = self._merge_headers(request.headers) | |
128 | if self._get_request_target(request) is not None: | |
129 | # remove all headers present in std_headers | |
130 | # TODO: change this to not depend on std_headers | |
131 | for k, v in std_headers.items(): | |
132 | if headers.get(k) == v: | |
133 | headers.pop(k) | |
134 | return headers | |
135 | ||
136 | ||
137 | @register_preference(ImpersonateRequestHandler) | |
138 | def impersonate_preference(rh, request): | |
139 | if request.extensions.get('impersonate') or rh.impersonate: | |
140 | return 1000 | |
141 | return 0 |