diff --git a/config.txt b/config.txt index 9ee71d4b..f69fd5ed 100644 --- a/config.txt +++ b/config.txt @@ -67,10 +67,9 @@ static_root = static ## Saved Tests -# Where to keep files when users `save` them. Note that files will be automatically -# deleted from this directory, so it needs to be only used for REDbot. +# Directory to keep test results when users `save` them. # Comment out to disable saving. -save_dir = /tmp/redbot/ +save_dir = /tmp/ # How long to store things when users save them, in days. save_days = 30 diff --git a/pyproject.toml b/pyproject.toml index 11006353..2802078f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "markdown >= 3.4.4", "MarkupSafe >= 2.1.3", "netaddr >= 0.9.0", - "thor >= 0.9.6", + "thor >= 0.9.9", "typing-extensions >= 4.8.0", ] diff --git a/redbot/daemon.py b/redbot/daemon.py index d2a636bd..ea155a80 100755 --- a/redbot/daemon.py +++ b/redbot/daemon.py @@ -13,9 +13,11 @@ import locale import os from pstats import Stats +import signal import sys import traceback -from typing import Dict, Optional +from types import FrameType +from typing import Dict, Optional, Any, Union from urllib.parse import urlsplit from importlib_resources import files as resource_files @@ -27,7 +29,7 @@ import redbot from redbot.type import RawHeaderListType from redbot.webui import RedWebUi -from redbot.webui.saved_tests import clean_saved_tests +from redbot.webui.saved_tests import SavedTests if os.environ.get("SYSTEMD_WATCHDOG"): try: @@ -64,6 +66,14 @@ def __init__(self, config: SectionProxy) -> None: if self.config.get("extra_base_dir"): self.extra_files = self.walk_files(self.config["extra_base_dir"]) + # Set up signal handlers + signal.signal(signal.SIGINT, self.shutdown_handler) + signal.signal(signal.SIGABRT, self.shutdown_handler) + signal.signal(signal.SIGTERM, self.shutdown_handler) + + # open the save db + self.saved = SavedTests(config, self.console) + # Start garbage collection if config.get("save_dir", ""): thor.schedule(10, self.gc_state) @@ -74,20 +84,22 @@ def __init__(self, config: SectionProxy) -> None: self.config.getint("port", fallback=8000), ) server.on("exchange", self.handler) - try: - thor.run() - except KeyboardInterrupt: - self.console("Stopping...") - thor.stop() + thor.run() def watchdog_ping(self) -> None: notify(Notification.WATCHDOG) thor.schedule(self.watchdog_freq, self.watchdog_ping) def gc_state(self) -> None: - clean_saved_tests(self.config) + self.saved.clean() thor.schedule(self.config.getint("gc_mins", fallback=2) * 60, self.gc_state) + def shutdown_handler(self, sig: int, frame: Union[FrameType, None]) -> Any: + self.console("Stopping...") + thor.stop() + self.saved.shutdown() + sys.exit(0) + def walk_files(self, dir_name: str, uri_base: bytes = b"") -> Dict[bytes, bytes]: out: Dict[bytes, bytes] = {} for root, _, files in os.walk(dir_name): @@ -158,6 +170,7 @@ def request_done(self, trailers: RawHeaderListType) -> None: try: RedWebUi( self.server.config, + self.server.saved, self.method.decode(self.server.config["charset"]), p_uri.query, self.req_hdrs, @@ -179,6 +192,7 @@ def request_done(self, trailers: RawHeaderListType) -> None: dump = traceback.format_exc() thor.stop() self.server.console(dump) + self.server.saved.shutdown() sys.exit(1) else: return self.serve_static(p_uri.path) diff --git a/redbot/formatter/__init__.py b/redbot/formatter/__init__.py index 4f3cb512..c6c9fc2f 100644 --- a/redbot/formatter/__init__.py +++ b/redbot/formatter/__init__.py @@ -2,7 +2,6 @@ Formatters for REDbot output. """ - from collections import defaultdict from configparser import SectionProxy import inspect diff --git a/redbot/formatter/har.py b/redbot/formatter/har.py index 9296086d..85a8e956 100644 --- a/redbot/formatter/har.py +++ b/redbot/formatter/har.py @@ -2,7 +2,6 @@ HAR Formatter for REDbot. """ - import datetime import json from typing import Optional, Any, Dict, List diff --git a/redbot/resource/__init__.py b/redbot/resource/__init__.py index e005b107..10f02ab9 100644 --- a/redbot/resource/__init__.py +++ b/redbot/resource/__init__.py @@ -48,6 +48,7 @@ def __init__(self, config: SectionProxy, descend: bool = False) -> None: self.ims_support: bool = False self.gzip_support: bool = False self.gzip_savings: int = 0 + self.save_expires: float = 0.0 self._task_map: Set[RedFetcher] = set([]) self.subreqs = {ac.check_name: ac(config, self) for ac in active_checks} self.once("fetch_done", self.run_active_checks) diff --git a/redbot/resource/fetch.py b/redbot/resource/fetch.py index 2bff7431..ee876490 100644 --- a/redbot/resource/fetch.py +++ b/redbot/resource/fetch.py @@ -81,7 +81,10 @@ def __getstate__(self) -> Dict[str, Any]: del state["exchange"] except KeyError: pass - del state["response_content_processors"] + try: + del state["response_content_processors"] + except KeyError: + pass return state def __repr__(self) -> str: diff --git a/redbot/type.py b/redbot/type.py index 3e8ea573..370742f0 100644 --- a/redbot/type.py +++ b/redbot/type.py @@ -11,11 +11,8 @@ class HttpResponseExchange(Protocol): def response_start( self, status_code: bytes, status_phrase: bytes, res_hdrs: RawHeaderListType - ) -> None: - ... + ) -> None: ... - def response_body(self, chunk: bytes) -> None: - ... + def response_body(self, chunk: bytes) -> None: ... - def response_done(self, trailers: RawHeaderListType) -> None: - ... + def response_done(self, trailers: RawHeaderListType) -> None: ... diff --git a/redbot/webui/__init__.py b/redbot/webui/__init__.py index 78686a18..a6530bf0 100644 --- a/redbot/webui/__init__.py +++ b/redbot/webui/__init__.py @@ -20,12 +20,7 @@ from redbot import __version__ from redbot.webui.captcha import CaptchaHandler from redbot.webui.ratelimit import ratelimiter -from redbot.webui.saved_tests import ( - init_save_file, - save_test, - extend_saved_test, - load_saved_test, -) +from redbot.webui.saved_tests import SavedTests from redbot.webui.slack import slack_run, slack_auth from redbot.resource import HttpResource from redbot.formatter import find_formatter, html, Formatter @@ -50,6 +45,7 @@ class RedWebUi: def __init__( self, config: SectionProxy, + saved: SavedTests, method: str, query_string: bytes, req_headers: RawHeaderListType, @@ -58,13 +54,15 @@ def __init__( client_ip: str, console: Callable[[str], Optional[int]] = sys.stderr.write, ) -> None: - self.config: SectionProxy = config + self.config = config + self.saved = saved + self.method = method self.charset = self.config["charset"] self.charset_bytes = self.charset.encode("ascii") self.query_string = parse_qs(query_string.decode(self.charset, "replace")) self.req_headers = req_headers self.req_body = req_body - self.body_args = {} + self.body_args: Dict[str, list[str]] = {} self.exchange = exchange self.client_ip = client_ip self.console = console # function to log errors to @@ -88,25 +86,38 @@ def __init__( if not self.descend: self.check_name = self.query_string.get("check_name", [None])[0] - self.save_path: str self.timeout: Optional[thor.loop.ScheduledEvent] = None self.nonce: str = standard_b64encode( getrandbits(128).to_bytes(16, "big") ).decode("ascii") self.start = time.time() + self.handle_request() - if method == "POST": + def handle_request(self) -> None: + if self.method == "POST": req_ct = get_header(self.req_headers, b"content-type") if req_ct and req_ct[-1].lower() == b"application/x-www-form-urlencoded": - self.body_args = parse_qs(req_body.decode(self.charset, "replace")) + self.body_args = parse_qs(self.req_body.decode(self.charset, "replace")) if ( "save" in self.query_string and self.config.get("save_dir", "") and self.test_id ): - extend_saved_test(self) + try: + self.saved.extend(self.test_id) + except KeyError: + return self.error_response( + None, + b"404", + b"Not Found", + f"Can't find the test ID {self.test_id}", + ) + location = b"?id=%s" % self.test_id.encode("ascii") + if self.descend: + location = b"%s&descend=True" % location + self.redirect_response(location) elif "slack" in self.query_string: slack_run(self) elif "client_error" in self.query_string: @@ -115,31 +126,78 @@ def __init__( self.run_test() else: self.show_default() - elif method in ["GET", "HEAD"]: + elif self.method in ["GET", "HEAD"]: if self.test_id: - load_saved_test(self) + self.show_saved_test() elif "code" in self.query_string: slack_auth(self) else: self.show_default() else: self.error_response( - find_formatter("html")( - self.config, - HttpResource(self.config), - self.output, - { - "nonce": self.nonce, - }, - ), + None, b"405", b"Method Not Allowed", "Method Not Allowed", ) + return None + + def show_saved_test(self) -> None: + """Show a saved test.""" + try: + top_resource = self.saved.load(self) + except ValueError: + return self.error_response( + None, + b"400", + b"Bad Request", + "Saved tests are not available.", + ) + except KeyError: + return self.error_response( + None, + b"404", + b"Not Found", + f"Can't find the test ID {self.test_id}", + ) + + if self.check_name: + display_resource = cast( + HttpResource, top_resource.subreqs.get(self.check_name, top_resource) + ) + else: + display_resource = top_resource + + formatter = find_formatter(self.format, "html", top_resource.descend)( + self.config, + display_resource, + self.output, + { + "allow_save": True, + "is_saved": True, + "test_id": self.test_id, + "nonce": self.nonce, + }, + ) + self.exchange.response_start( + b"200", + b"OK", + [ + (b"Content-Type", formatter.content_type()), + (b"Cache-Control", b"max-age=3600, must-revalidate"), + ], + ) + + @thor.events.on(formatter) + def formatter_done() -> None: + self.exchange.response_done([]) + + formatter.bind_resource(display_resource) + return None def run_test(self) -> None: """Test a URI.""" - self.test_id = init_save_file(self) + self.test_id = self.saved.get_test_id() top_resource = HttpResource(self.config, descend=self.descend) top_resource.set_request(self.test_uri, headers=self.req_hdrs) formatter = find_formatter(self.format, "html", self.descend)( @@ -224,7 +282,7 @@ def formatter_done() -> None: self.timeout.delete() self.timeout = None self.exchange.response_done([]) - save_test(self, top_resource) + self.saved.save(self, top_resource) # log excessive traffic ti = sum( @@ -313,7 +371,7 @@ def show_default(self) -> None: def error_response( self, - formatter: Formatter, + formatter: Optional[Formatter], status_code: bytes, status_phrase: bytes, message: str, @@ -323,6 +381,15 @@ def error_response( if self.timeout: self.timeout.delete() self.timeout = None + if formatter is None: + formatter = find_formatter("html")( + self.config, + HttpResource(self.config), + self.output, + { + "nonce": self.nonce, + }, + ) self.exchange.response_start( status_code, status_phrase, @@ -341,6 +408,11 @@ def error_response( if log_message: self.error_log(log_message) + def redirect_response(self, location: bytes) -> None: + self.exchange.response_start(b"303", b"See Other", [(b"Location", location)]) + self.output("Redirecting...") + self.exchange.response_done([]) + def output(self, chunk: str) -> None: self.exchange.response_body(chunk.encode(self.charset, "replace")) diff --git a/redbot/webui/saved_tests.py b/redbot/webui/saved_tests.py index cd0aab46..07d61234 100644 --- a/redbot/webui/saved_tests.py +++ b/redbot/webui/saved_tests.py @@ -1,169 +1,94 @@ from configparser import SectionProxy -import gzip import os -import pickle -import tempfile +import random +import shelve +import string import time -from typing import TYPE_CHECKING, cast, IO, Tuple, Optional -import zlib +from typing import TYPE_CHECKING, Tuple, Optional, Callable, Dict, cast -import thor - -from redbot.formatter import find_formatter from redbot.resource import HttpResource if TYPE_CHECKING: - from redbot.webui import RedWebUi # pylint: disable=cyclic-import,unused-import - - -def init_save_file(webui: "RedWebUi") -> Optional[str]: - if webui.config.get("save_dir", "") and os.path.exists(webui.config["save_dir"]): + from redbot.webui import RedWebUi + + +class SavedTests: + """ + Save test results for later display. + """ + + def __init__(self, config: SectionProxy, console: Callable[[str], None]) -> None: + self.config = config + self.console = console + self.save_db: Optional[shelve.Shelf] = None + self.expiry_cache: Dict[str, float] = {} + save_dir = config.get("save_dir", "") + if not save_dir: + return + if not os.path.exists(save_dir): + self.console(f"WARNING: Save directory '{save_dir}' does not exist.") + return try: - fd, webui.save_path = tempfile.mkstemp( - prefix="", dir=webui.config["save_dir"] + self.save_db = shelve.open(f"{save_dir}/redbot_saved_tests") + except (OSError, IOError) as why: + self.console(f"WARNING: Save DB not initialised: {why}") + + def shutdown(self) -> None: + if self.save_db is not None: + self.save_db.close() + self.expiry_cache.clear() + + def get_test_id(self) -> str: + """Get a unique test id.""" + test_id = "".join(random.choice(string.ascii_lowercase) for i in range(16)) + if self.save_db is None: + return test_id + if test_id in self.save_db.keys(): + return self.get_test_id() + return test_id + + def save(self, webui: "RedWebUi", top_resource: HttpResource) -> None: + """Save a test by test_id.""" + if webui.test_id and self.save_db is not None: + top_resource.save_expires = ( + time.time() + self.config.getint("no_save_mins", fallback=20) * 60 ) - os.close(fd) - return os.path.split(webui.save_path)[1] - except OSError: - # Don't try to store it. - pass - return None # should already be None, but make sure - - -def save_test(webui: "RedWebUi", top_resource: HttpResource) -> None: - """Save a test by test_id.""" - if webui.test_id: - try: - with cast(IO[bytes], gzip.open(webui.save_path, "w")) as tmp_file: - pickle.dump(top_resource, tmp_file) - except (OSError, zlib.error, pickle.PickleError): - pass # we don't cry if we can't store it. - - -def extend_saved_test(webui: "RedWebUi") -> None: - """Extend the expiry time of a previously run test_id.""" - assert webui.test_id, "test_id not set in extend_saved_test" - try: - # touch the save file so it isn't deleted. - now = time.time() - os.utime( - os.path.join(webui.config["save_dir"], webui.test_id), - ( - now, - now + (webui.config.getint("save_days", fallback=30) * 24 * 60 * 60), - ), - ) - location = b"?id=%s" % webui.test_id.encode("ascii") - if webui.descend: - location = b"%s&descend=True" % location - webui.exchange.response_start(b"303", b"See Other", [(b"Location", location)]) - webui.output("Redirecting to the saved test page...") - except OSError: - webui.exchange.response_start( - b"500", - b"Internal Server Error", - [(b"Content-Type", b"text/html; charset=%s" % webui.charset_bytes)], - ) - webui.output("Sorry, I couldn't save that.") - webui.exchange.response_done([]) - - -def clean_saved_tests(config: SectionProxy) -> Tuple[int, int, int]: - """Clean old files from the saved tests directory.""" - now = time.time() - state_dir = config["save_dir"] - if not os.path.exists(state_dir): - return (0, 0, 0) - save_secs = config.getint("no_save_mins", fallback=20) * 60 - files = [ - os.path.join(state_dir, f) - for f in os.listdir(state_dir) - if os.path.isfile(os.path.join(state_dir, f)) - ] - removed = 0 - errors = 0 - for path in files: - try: - mtime = os.path.getmtime(path) - except OSError: - errors += 1 - continue - if now - mtime > save_secs: - try: - os.remove(path) - removed += 1 - except OSError: - errors += 1 - continue - return (len(files), removed, errors) - - -def load_saved_test(webui: "RedWebUi") -> None: - """Load a saved test by test_id.""" - assert webui.test_id, "test_id not set in load_saved_test" - try: - with cast( - IO[bytes], - gzip.open( - os.path.join(webui.config["save_dir"], os.path.basename(webui.test_id)) - ), - ) as fd: - mtime = os.fstat(fd.fileno()).st_mtime - is_saved = mtime > time.time() - top_resource = pickle.load(fd) - except (OSError, TypeError): - webui.exchange.response_start( - b"404", - b"Not Found", - [ - (b"Content-Type", b"text/html; charset=%s" % webui.charset_bytes), - (b"Cache-Control", b"max-age=600, must-revalidate"), - ], - ) - webui.output("I'm sorry, I can't find that saved response.") - webui.exchange.response_done([]) - return - except (pickle.PickleError, zlib.error, EOFError): - webui.exchange.response_start( - b"500", - b"Internal Server Error", - [ - (b"Content-Type", b"text/html; charset=%s" % webui.charset_bytes), - (b"Cache-Control", b"max-age=600, must-revalidate"), - ], + self.save_db[webui.test_id] = top_resource + self.expiry_cache[webui.test_id] = top_resource.save_expires + + def extend(self, test_id: str) -> None: + """Extend the expiry time of a previously run test_id.""" + if self.save_db is None: + return + entry = self.save_db[test_id] + entry.save_expires = ( + time.time() + self.config.getint("save_days", fallback=30) * 24 * 60 * 60 ) - webui.output("I'm sorry, I had a problem loading that.") - webui.exchange.response_done([]) - return + self.save_db[test_id] = entry + self.expiry_cache[test_id] = entry.save_expires - if webui.check_name: - display_resource = top_resource.subreqs.get(webui.check_name, top_resource) - else: - display_resource = top_resource - - formatter = find_formatter(webui.format, "html", top_resource.descend)( - webui.config, - display_resource, - webui.output, - { - "allow_save": (not is_saved), - "is_saved": True, - "test_id": webui.test_id, - "nonce": webui.nonce, - }, - ) - - webui.exchange.response_start( - b"200", - b"OK", - [ - (b"Content-Type", formatter.content_type()), - (b"Cache-Control", b"max-age=3600, must-revalidate"), - ], - ) - - @thor.events.on(formatter) - def formatter_done() -> None: - webui.exchange.response_done([]) - - formatter.bind_resource(display_resource) + def clean(self) -> Tuple[int, int, int]: + """Clean old files from the saved tests directory.""" + if self.save_db is None: + return (0, 0, 0) + now = time.time() + count = removed = errors = 0 + for test_id in self.save_db: + count += 1 + if not test_id in self.expiry_cache: + entry = self.save_db[test_id] + self.expiry_cache[test_id] = entry.save_expires + if self.expiry_cache[test_id] < now: + try: + del self.save_db[test_id] + del self.expiry_cache[test_id] + removed += 1 + except KeyError: + errors += 1 + return (count, removed, errors) + + def load(self, webui: "RedWebUi") -> HttpResource: + """Return a saved test by test_id.""" + if not webui.test_id or self.save_db is None: + raise ValueError + return cast(HttpResource, self.save_db[webui.test_id]) diff --git a/redbot/webui/slack.py b/redbot/webui/slack.py index bf743aaf..c31f1800 100644 --- a/redbot/webui/slack.py +++ b/redbot/webui/slack.py @@ -10,7 +10,6 @@ from redbot.resource import HttpResource from redbot.resource.fetch import RedHttpClient from redbot.webui.ratelimit import ratelimiter -from redbot.webui.saved_tests import init_save_file, save_test if TYPE_CHECKING: from redbot.webui import RedWebUi # pylint: disable=cyclic-import,unused-import @@ -19,7 +18,7 @@ def slack_run(webui: "RedWebUi") -> None: """Handle a slack request.""" webui.test_uri = webui.body_args.get("text", [""])[0].strip() - webui.test_id = init_save_file(webui) + webui.test_id = webui.saved.get_test_id() slack_response_uri = webui.body_args.get("response_url", [""])[0].strip() resource = HttpResource(webui.config) formatter = slack.SlackFormatter( @@ -87,7 +86,7 @@ def formatter_done() -> None: if webui.timeout: webui.timeout.delete() webui.timeout = None - save_test(webui, top_resource) + webui.saved.save(webui, top_resource) top_resource.check()