]>
Commit | Line | Data |
---|---|---|
e0df8241 JR |
1 | from __future__ import annotations |
2 | ||
3 | import fnmatch | |
4 | import os | |
5 | import subprocess | |
6 | import sys | |
7 | import threading | |
8 | import time | |
9 | import typing as t | |
10 | from itertools import chain | |
11 | from pathlib import PurePath | |
12 | ||
13 | from ._internal import _log | |
14 | ||
15 | # The various system prefixes where imports are found. Base values are | |
16 | # different when running in a virtualenv. All reloaders will ignore the | |
17 | # base paths (usually the system installation). The stat reloader won't | |
18 | # scan the virtualenv paths, it will only include modules that are | |
19 | # already imported. | |
20 | _ignore_always = tuple({sys.base_prefix, sys.base_exec_prefix}) | |
21 | prefix = {*_ignore_always, sys.prefix, sys.exec_prefix} | |
22 | ||
23 | if hasattr(sys, "real_prefix"): | |
24 | # virtualenv < 20 | |
25 | prefix.add(sys.real_prefix) | |
26 | ||
27 | _stat_ignore_scan = tuple(prefix) | |
28 | del prefix | |
29 | _ignore_common_dirs = { | |
30 | "__pycache__", | |
31 | ".git", | |
32 | ".hg", | |
33 | ".tox", | |
34 | ".nox", | |
35 | ".pytest_cache", | |
36 | ".mypy_cache", | |
37 | } | |
38 | ||
39 | ||
40 | def _iter_module_paths() -> t.Iterator[str]: | |
41 | """Find the filesystem paths associated with imported modules.""" | |
42 | # List is in case the value is modified by the app while updating. | |
43 | for module in list(sys.modules.values()): | |
44 | name = getattr(module, "__file__", None) | |
45 | ||
46 | if name is None or name.startswith(_ignore_always): | |
47 | continue | |
48 | ||
49 | while not os.path.isfile(name): | |
50 | # Zip file, find the base file without the module path. | |
51 | old = name | |
52 | name = os.path.dirname(name) | |
53 | ||
54 | if name == old: # skip if it was all directories somehow | |
55 | break | |
56 | else: | |
57 | yield name | |
58 | ||
59 | ||
60 | def _remove_by_pattern(paths: set[str], exclude_patterns: set[str]) -> None: | |
61 | for pattern in exclude_patterns: | |
62 | paths.difference_update(fnmatch.filter(paths, pattern)) | |
63 | ||
64 | ||
65 | def _find_stat_paths( | |
66 | extra_files: set[str], exclude_patterns: set[str] | |
67 | ) -> t.Iterable[str]: | |
68 | """Find paths for the stat reloader to watch. Returns imported | |
69 | module files, Python files under non-system paths. Extra files and | |
70 | Python files under extra directories can also be scanned. | |
71 | ||
72 | System paths have to be excluded for efficiency. Non-system paths, | |
73 | such as a project root or ``sys.path.insert``, should be the paths | |
74 | of interest to the user anyway. | |
75 | """ | |
76 | paths = set() | |
77 | ||
78 | for path in chain(list(sys.path), extra_files): | |
79 | path = os.path.abspath(path) | |
80 | ||
81 | if os.path.isfile(path): | |
82 | # zip file on sys.path, or extra file | |
83 | paths.add(path) | |
84 | continue | |
85 | ||
86 | parent_has_py = {os.path.dirname(path): True} | |
87 | ||
88 | for root, dirs, files in os.walk(path): | |
89 | # Optimizations: ignore system prefixes, __pycache__ will | |
90 | # have a py or pyc module at the import path, ignore some | |
91 | # common known dirs such as version control and tool caches. | |
92 | if ( | |
93 | root.startswith(_stat_ignore_scan) | |
94 | or os.path.basename(root) in _ignore_common_dirs | |
95 | ): | |
96 | dirs.clear() | |
97 | continue | |
98 | ||
99 | has_py = False | |
100 | ||
101 | for name in files: | |
102 | if name.endswith((".py", ".pyc")): | |
103 | has_py = True | |
104 | paths.add(os.path.join(root, name)) | |
105 | ||
106 | # Optimization: stop scanning a directory if neither it nor | |
107 | # its parent contained Python files. | |
108 | if not (has_py or parent_has_py[os.path.dirname(root)]): | |
109 | dirs.clear() | |
110 | continue | |
111 | ||
112 | parent_has_py[root] = has_py | |
113 | ||
114 | paths.update(_iter_module_paths()) | |
115 | _remove_by_pattern(paths, exclude_patterns) | |
116 | return paths | |
117 | ||
118 | ||
119 | def _find_watchdog_paths( | |
120 | extra_files: set[str], exclude_patterns: set[str] | |
121 | ) -> t.Iterable[str]: | |
122 | """Find paths for the stat reloader to watch. Looks at the same | |
123 | sources as the stat reloader, but watches everything under | |
124 | directories instead of individual files. | |
125 | """ | |
126 | dirs = set() | |
127 | ||
128 | for name in chain(list(sys.path), extra_files): | |
129 | name = os.path.abspath(name) | |
130 | ||
131 | if os.path.isfile(name): | |
132 | name = os.path.dirname(name) | |
133 | ||
134 | dirs.add(name) | |
135 | ||
136 | for name in _iter_module_paths(): | |
137 | dirs.add(os.path.dirname(name)) | |
138 | ||
139 | _remove_by_pattern(dirs, exclude_patterns) | |
140 | return _find_common_roots(dirs) | |
141 | ||
142 | ||
143 | def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]: | |
144 | root: dict[str, dict] = {} | |
145 | ||
146 | for chunks in sorted((PurePath(x).parts for x in paths), key=len, reverse=True): | |
147 | node = root | |
148 | ||
149 | for chunk in chunks: | |
150 | node = node.setdefault(chunk, {}) | |
151 | ||
152 | node.clear() | |
153 | ||
154 | rv = set() | |
155 | ||
156 | def _walk(node: t.Mapping[str, dict], path: tuple[str, ...]) -> None: | |
157 | for prefix, child in node.items(): | |
158 | _walk(child, path + (prefix,)) | |
159 | ||
160 | if not node: | |
161 | rv.add(os.path.join(*path)) | |
162 | ||
163 | _walk(root, ()) | |
164 | return rv | |
165 | ||
166 | ||
167 | def _get_args_for_reloading() -> list[str]: | |
168 | """Determine how the script was executed, and return the args needed | |
169 | to execute it again in a new process. | |
170 | """ | |
171 | if sys.version_info >= (3, 10): | |
172 | # sys.orig_argv, added in Python 3.10, contains the exact args used to invoke | |
173 | # Python. Still replace argv[0] with sys.executable for accuracy. | |
174 | return [sys.executable, *sys.orig_argv[1:]] | |
175 | ||
176 | rv = [sys.executable] | |
177 | py_script = sys.argv[0] | |
178 | args = sys.argv[1:] | |
179 | # Need to look at main module to determine how it was executed. | |
180 | __main__ = sys.modules["__main__"] | |
181 | ||
182 | # The value of __package__ indicates how Python was called. It may | |
183 | # not exist if a setuptools script is installed as an egg. It may be | |
184 | # set incorrectly for entry points created with pip on Windows. | |
185 | if getattr(__main__, "__package__", None) is None or ( | |
186 | os.name == "nt" | |
187 | and __main__.__package__ == "" | |
188 | and not os.path.exists(py_script) | |
189 | and os.path.exists(f"{py_script}.exe") | |
190 | ): | |
191 | # Executed a file, like "python app.py". | |
192 | py_script = os.path.abspath(py_script) | |
193 | ||
194 | if os.name == "nt": | |
195 | # Windows entry points have ".exe" extension and should be | |
196 | # called directly. | |
197 | if not os.path.exists(py_script) and os.path.exists(f"{py_script}.exe"): | |
198 | py_script += ".exe" | |
199 | ||
200 | if ( | |
201 | os.path.splitext(sys.executable)[1] == ".exe" | |
202 | and os.path.splitext(py_script)[1] == ".exe" | |
203 | ): | |
204 | rv.pop(0) | |
205 | ||
206 | rv.append(py_script) | |
207 | else: | |
208 | # Executed a module, like "python -m werkzeug.serving". | |
209 | if os.path.isfile(py_script): | |
210 | # Rewritten by Python from "-m script" to "/path/to/script.py". | |
211 | py_module = t.cast(str, __main__.__package__) | |
212 | name = os.path.splitext(os.path.basename(py_script))[0] | |
213 | ||
214 | if name != "__main__": | |
215 | py_module += f".{name}" | |
216 | else: | |
217 | # Incorrectly rewritten by pydevd debugger from "-m script" to "script". | |
218 | py_module = py_script | |
219 | ||
220 | rv.extend(("-m", py_module.lstrip("."))) | |
221 | ||
222 | rv.extend(args) | |
223 | return rv | |
224 | ||
225 | ||
226 | class ReloaderLoop: | |
227 | name = "" | |
228 | ||
229 | def __init__( | |
230 | self, | |
231 | extra_files: t.Iterable[str] | None = None, | |
232 | exclude_patterns: t.Iterable[str] | None = None, | |
233 | interval: int | float = 1, | |
234 | ) -> None: | |
235 | self.extra_files: set[str] = {os.path.abspath(x) for x in extra_files or ()} | |
236 | self.exclude_patterns: set[str] = set(exclude_patterns or ()) | |
237 | self.interval = interval | |
238 | ||
239 | def __enter__(self) -> ReloaderLoop: | |
240 | """Do any setup, then run one step of the watch to populate the | |
241 | initial filesystem state. | |
242 | """ | |
243 | self.run_step() | |
244 | return self | |
245 | ||
246 | def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore | |
247 | """Clean up any resources associated with the reloader.""" | |
248 | pass | |
249 | ||
250 | def run(self) -> None: | |
251 | """Continually run the watch step, sleeping for the configured | |
252 | interval after each step. | |
253 | """ | |
254 | while True: | |
255 | self.run_step() | |
256 | time.sleep(self.interval) | |
257 | ||
258 | def run_step(self) -> None: | |
259 | """Run one step for watching the filesystem. Called once to set | |
260 | up initial state, then repeatedly to update it. | |
261 | """ | |
262 | pass | |
263 | ||
264 | def restart_with_reloader(self) -> int: | |
265 | """Spawn a new Python interpreter with the same arguments as the | |
266 | current one, but running the reloader thread. | |
267 | """ | |
268 | while True: | |
269 | _log("info", f" * Restarting with {self.name}") | |
270 | args = _get_args_for_reloading() | |
271 | new_environ = os.environ.copy() | |
272 | new_environ["WERKZEUG_RUN_MAIN"] = "true" | |
273 | exit_code = subprocess.call(args, env=new_environ, close_fds=False) | |
274 | ||
275 | if exit_code != 3: | |
276 | return exit_code | |
277 | ||
278 | def trigger_reload(self, filename: str) -> None: | |
279 | self.log_reload(filename) | |
280 | sys.exit(3) | |
281 | ||
282 | def log_reload(self, filename: str) -> None: | |
283 | filename = os.path.abspath(filename) | |
284 | _log("info", f" * Detected change in {filename!r}, reloading") | |
285 | ||
286 | ||
287 | class StatReloaderLoop(ReloaderLoop): | |
288 | name = "stat" | |
289 | ||
290 | def __enter__(self) -> ReloaderLoop: | |
291 | self.mtimes: dict[str, float] = {} | |
292 | return super().__enter__() | |
293 | ||
294 | def run_step(self) -> None: | |
295 | for name in _find_stat_paths(self.extra_files, self.exclude_patterns): | |
296 | try: | |
297 | mtime = os.stat(name).st_mtime | |
298 | except OSError: | |
299 | continue | |
300 | ||
301 | old_time = self.mtimes.get(name) | |
302 | ||
303 | if old_time is None: | |
304 | self.mtimes[name] = mtime | |
305 | continue | |
306 | ||
307 | if mtime > old_time: | |
308 | self.trigger_reload(name) | |
309 | ||
310 | ||
311 | class WatchdogReloaderLoop(ReloaderLoop): | |
312 | def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: | |
313 | from watchdog.observers import Observer | |
314 | from watchdog.events import PatternMatchingEventHandler | |
315 | from watchdog.events import EVENT_TYPE_OPENED | |
316 | from watchdog.events import FileModifiedEvent | |
317 | ||
318 | super().__init__(*args, **kwargs) | |
319 | trigger_reload = self.trigger_reload | |
320 | ||
321 | class EventHandler(PatternMatchingEventHandler): | |
322 | def on_any_event(self, event: FileModifiedEvent): # type: ignore | |
323 | if event.event_type == EVENT_TYPE_OPENED: | |
324 | return | |
325 | ||
326 | trigger_reload(event.src_path) | |
327 | ||
328 | reloader_name = Observer.__name__.lower() # type: ignore[attr-defined] | |
329 | ||
330 | if reloader_name.endswith("observer"): | |
331 | reloader_name = reloader_name[:-8] | |
332 | ||
333 | self.name = f"watchdog ({reloader_name})" | |
334 | self.observer = Observer() | |
335 | # Extra patterns can be non-Python files, match them in addition | |
336 | # to all Python files in default and extra directories. Ignore | |
337 | # __pycache__ since a change there will always have a change to | |
338 | # the source file (or initial pyc file) as well. Ignore Git and | |
339 | # Mercurial internal changes. | |
340 | extra_patterns = [p for p in self.extra_files if not os.path.isdir(p)] | |
341 | self.event_handler = EventHandler( | |
342 | patterns=["*.py", "*.pyc", "*.zip", *extra_patterns], | |
343 | ignore_patterns=[ | |
344 | *[f"*/{d}/*" for d in _ignore_common_dirs], | |
345 | *self.exclude_patterns, | |
346 | ], | |
347 | ) | |
348 | self.should_reload = False | |
349 | ||
350 | def trigger_reload(self, filename: str) -> None: | |
351 | # This is called inside an event handler, which means throwing | |
352 | # SystemExit has no effect. | |
353 | # https://github.com/gorakhargosh/watchdog/issues/294 | |
354 | self.should_reload = True | |
355 | self.log_reload(filename) | |
356 | ||
357 | def __enter__(self) -> ReloaderLoop: | |
358 | self.watches: dict[str, t.Any] = {} | |
359 | self.observer.start() | |
360 | return super().__enter__() | |
361 | ||
362 | def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore | |
363 | self.observer.stop() | |
364 | self.observer.join() | |
365 | ||
366 | def run(self) -> None: | |
367 | while not self.should_reload: | |
368 | self.run_step() | |
369 | time.sleep(self.interval) | |
370 | ||
371 | sys.exit(3) | |
372 | ||
373 | def run_step(self) -> None: | |
374 | to_delete = set(self.watches) | |
375 | ||
376 | for path in _find_watchdog_paths(self.extra_files, self.exclude_patterns): | |
377 | if path not in self.watches: | |
378 | try: | |
379 | self.watches[path] = self.observer.schedule( | |
380 | self.event_handler, path, recursive=True | |
381 | ) | |
382 | except OSError: | |
383 | # Clear this path from list of watches We don't want | |
384 | # the same error message showing again in the next | |
385 | # iteration. | |
386 | self.watches[path] = None | |
387 | ||
388 | to_delete.discard(path) | |
389 | ||
390 | for path in to_delete: | |
391 | watch = self.watches.pop(path, None) | |
392 | ||
393 | if watch is not None: | |
394 | self.observer.unschedule(watch) | |
395 | ||
396 | ||
397 | reloader_loops: dict[str, type[ReloaderLoop]] = { | |
398 | "stat": StatReloaderLoop, | |
399 | "watchdog": WatchdogReloaderLoop, | |
400 | } | |
401 | ||
402 | try: | |
403 | __import__("watchdog.observers") | |
404 | except ImportError: | |
405 | reloader_loops["auto"] = reloader_loops["stat"] | |
406 | else: | |
407 | reloader_loops["auto"] = reloader_loops["watchdog"] | |
408 | ||
409 | ||
410 | def ensure_echo_on() -> None: | |
411 | """Ensure that echo mode is enabled. Some tools such as PDB disable | |
412 | it which causes usability issues after a reload.""" | |
413 | # tcgetattr will fail if stdin isn't a tty | |
414 | if sys.stdin is None or not sys.stdin.isatty(): | |
415 | return | |
416 | ||
417 | try: | |
418 | import termios | |
419 | except ImportError: | |
420 | return | |
421 | ||
422 | attributes = termios.tcgetattr(sys.stdin) | |
423 | ||
424 | if not attributes[3] & termios.ECHO: | |
425 | attributes[3] |= termios.ECHO | |
426 | termios.tcsetattr(sys.stdin, termios.TCSANOW, attributes) | |
427 | ||
428 | ||
429 | def run_with_reloader( | |
430 | main_func: t.Callable[[], None], | |
431 | extra_files: t.Iterable[str] | None = None, | |
432 | exclude_patterns: t.Iterable[str] | None = None, | |
433 | interval: int | float = 1, | |
434 | reloader_type: str = "auto", | |
435 | ) -> None: | |
436 | """Run the given function in an independent Python interpreter.""" | |
437 | import signal | |
438 | ||
439 | signal.signal(signal.SIGTERM, lambda *args: sys.exit(0)) | |
440 | reloader = reloader_loops[reloader_type]( | |
441 | extra_files=extra_files, exclude_patterns=exclude_patterns, interval=interval | |
442 | ) | |
443 | ||
444 | try: | |
445 | if os.environ.get("WERKZEUG_RUN_MAIN") == "true": | |
446 | ensure_echo_on() | |
447 | t = threading.Thread(target=main_func, args=()) | |
448 | t.daemon = True | |
449 | ||
450 | # Enter the reloader to set up initial state, then start | |
451 | # the app thread and reloader update loop. | |
452 | with reloader: | |
453 | t.start() | |
454 | reloader.run() | |
455 | else: | |
456 | sys.exit(reloader.restart_with_reloader()) | |
457 | except KeyboardInterrupt: | |
458 | pass |