diff --git a/nightwatch/__init__.py b/nightwatch/__init__.py index deded32..732155f 100644 --- a/nightwatch/__init__.py +++ b/nightwatch/__init__.py @@ -1 +1 @@ -__version__ = "0.8.2" +__version__ = "0.8.3" diff --git a/nightwatch/client/extra/commands/__init__.py b/nightwatch/client/extra/commands/__init__.py index 3f467ce..f5925cb 100644 --- a/nightwatch/client/extra/commands/__init__.py +++ b/nightwatch/client/extra/commands/__init__.py @@ -1,15 +1,14 @@ # Copyright (c) 2024 iiPython # Modules -from typing import List -from types import FunctionType +from typing import Callable from nightwatch import __version__ from nightwatch.config import config # Main class class BaseCommand(): - def __init__(self, name: str, ui, add_message: FunctionType) -> None: + def __init__(self, name: str, ui, add_message: Callable) -> None: self.name, self.ui = name, ui self.add_message = add_message @@ -21,14 +20,14 @@ class ShrugCommand(BaseCommand): def __init__(self, *args) -> None: super().__init__("shrug", *args) - def on_execute(self, args: List[str]) -> str: - return "¯\_(ツ)_/¯" + def on_execute(self, args: list[str]) -> str: + return r"¯\_(ツ)_/¯" class ConfigCommand(BaseCommand): def __init__(self, *args) -> None: super().__init__("config", *args) - def on_execute(self, args: List[str]) -> None: + def on_execute(self, args: list[str]) -> None: if not args: for line in [ "Nightwatch client configuration", @@ -54,7 +53,7 @@ class HelpCommand(BaseCommand): def __init__(self, *args) -> None: super().__init__("help", *args) - def on_execute(self, args: List[str]) -> None: + def on_execute(self, args: list[str]) -> None: self.print(f"✨ Nightwatch v{__version__}") self.print("Available commands:") for command in self.ui.commands: @@ -64,15 +63,94 @@ class MembersCommand(BaseCommand): def __init__(self, *args) -> None: super().__init__("members", *args) - def on_execute(self, args: List[str]) -> None: + def on_execute(self, args: list[str]) -> None: def members_callback(response: dict): self.print(", ".join(response["data"]["list"])) self.ui.websocket.callback({"type": "members"}, members_callback) +class AdminCommand(BaseCommand): + def __init__(self, *args) -> None: + self.admin = False + super().__init__("admin", *args) + + def on_execute(self, args: list[str]) -> None: + match args: + case [] if not self.admin: + self.ui.websocket.send({"type": "admin"}) + self.print("Run /admin with the admin code in your server console.") + + case [] | ["help"]: + self.print("Available commands:") + if not self.admin: + self.print(" /admin ") + + self.print(" /admin ban ") + self.print(" /admin unban ") + self.print(" /admin ip ") + self.print(" /admin banlist") + self.print(" /admin say ") + + case ["ban", username]: + def on_ban_response(response: dict): + if not response["data"]["success"]: + return self.print(f"(fail) {response['data']['error']}") + + self.print(f"(success) {username} has been banned.") + + self.ui.websocket.callback({"type": "admin", "data": {"command": args}}, on_ban_response) + + case ["unban", username]: + def on_unban_response(response: dict): + if not response["data"]["success"]: + return self.print(f"(fail) {response['data']['error']}") + + self.print(f"(success) {username} has been unbanned.") + + self.ui.websocket.callback({"type": "admin", "data": {"command": args}}, on_unban_response) + + case ["ip", username]: + def on_ip_response(response: dict): + if not response["data"]["success"]: + return self.print(f"(fail) {response['data']['error']}") + + self.print(f"(success) {username}'s IP address is {response['data']['ip']}.") + + self.ui.websocket.callback({"type": "admin", "data": {"command": args}}, on_ip_response) + + case ["banlist"]: + def on_banlist_response(response: dict): + if not response["data"]["banlist"]: + return self.print("(fail) Nobody is banned on this server.") + + self.print("Current banlist:") + self.print(f"{', '.join(f'{v} ({k})' for k, v in response['data']['banlist'].items())}") + + self.ui.websocket.callback({"type": "admin", "data": {"command": args}}, on_banlist_response) + + case ["say", _]: + self.ui.websocket.send({"type": "admin", "data": {"command": args}}) + + case [code]: + if self.admin: + return self.print("(fail) Privileges already escalated.") + + def on_admin_response(response: dict): + if response["data"]["success"] is False: + return self.print("(fail) Invalid admin code specified.") + + self.print("(success) Privileges escalated.") + self.admin = True + + self.ui.websocket.callback({"type": "admin", "data": {"code": code}}, on_admin_response) + + case _: + self.print("Admin command not recognized, try /admin help.") + commands = [ ShrugCommand, ConfigCommand, HelpCommand, - MembersCommand + MembersCommand, + AdminCommand ] diff --git a/nightwatch/client/extra/ui.py b/nightwatch/client/extra/ui.py index 603efdd..df690c9 100644 --- a/nightwatch/client/extra/ui.py +++ b/nightwatch/client/extra/ui.py @@ -107,6 +107,9 @@ def on_message(self, data: dict) -> None: # Push message to screen self.add_message(user["name"], data["text"], color_code) + case "error": + exit(f"Nightwatch Exception\n{'=' * 50}\n\n{data['text']}") + def on_ready(self, loop: urwid.MainLoop, payload: dict) -> None: self.loop = loop self.construct_message("Nightwatch", f"Welcome to {payload['name']}. There are {payload['online']} user(s) online.") diff --git a/nightwatch/server/__init__.py b/nightwatch/server/__init__.py index 03a072c..ca6a494 100644 --- a/nightwatch/server/__init__.py +++ b/nightwatch/server/__init__.py @@ -4,10 +4,11 @@ import orjson from pydantic import ValidationError from websockets import WebSocketCommonProtocol -from websockets.exceptions import ConnectionClosedError +from websockets.exceptions import ConnectionClosed from .utils.commands import registry from .utils.websocket import NightwatchClient +from .utils.modules.admin import admin_module from nightwatch.logging import log @@ -18,6 +19,7 @@ def __init__(self) -> None: def add_client(self, client: WebSocketCommonProtocol) -> None: self.clients[client] = None + setattr(client, "ip", client.request_headers.get("CF-Connecting-IP", client.remote_address[0])) def remove_client(self, client: WebSocketCommonProtocol) -> None: if client in self.clients: @@ -28,11 +30,18 @@ def remove_client(self, client: WebSocketCommonProtocol) -> None: # Socket entrypoint async def connection(websocket: WebSocketCommonProtocol) -> None: client = NightwatchClient(state, websocket) + if websocket.ip in admin_module.banned_users: # type: ignore + return await client.send("error", text = "You have been banned from this server.") + try: log.info(client.id, "Client connected!") async for message in websocket: message = orjson.loads(message) + if not isinstance(message, dict): + await client.send("error", text = "Expected payload is an object.") + continue + if message.get("type") not in registry.commands: await client.send("error", text = "Specified command type does not exist or is missing.") continue @@ -55,7 +64,7 @@ async def connection(websocket: WebSocketCommonProtocol) -> None: except orjson.JSONDecodeError: log.warn(client.id, "Failed to decode JSON from client.") - except ConnectionClosedError: + except ConnectionClosed: log.info(client.id, "Client disconnected!") state.remove_client(websocket) diff --git a/nightwatch/server/utils/commands.py b/nightwatch/server/utils/commands.py index c0fecce..98c4343 100644 --- a/nightwatch/server/utils/commands.py +++ b/nightwatch/server/utils/commands.py @@ -1,6 +1,7 @@ # Copyright (c) 2024 iiPython # Modules +import random from typing import Callable import orjson @@ -8,6 +9,7 @@ from . import models from .websocket import NightwatchClient +from .modules.admin import admin_module from nightwatch.logging import log from nightwatch.config import config @@ -16,6 +18,7 @@ class Constant: SERVER_USER: dict[str, str] = {"name": "Nightwatch", "color": "gray"} SERVER_NAME: str = config["server.name"] or "Untitled Server" + ADMIN_CODE: str = str(random.randint(100000, 999999)) # Handle command registration class CommandRegistry(): @@ -73,3 +76,65 @@ async def command_members(state, client: NightwatchClient) -> None: @registry.command("ping") async def command_ping(state, client: NightwatchClient) -> None: return await client.send("pong") + +# New commands (coming back to this branch) +@registry.command("admin") +async def command_admin(state, client: NightwatchClient, data: models.AdminModel) -> None: + if not client.identified: + return await client.send("error", text = "You cannot enter admin mode while anonymous.") + + # Handle admin commands + if client.admin: + match data.command: + case ["ban", username]: + for client_object, client_username in state.clients.items(): + if client_username == username: + await client_object.send(orjson.dumps({ + "type": "message", + "data": {"text": "You have been banned from this server.", "user": Constant.SERVER_USER} + }).decode()) + await client_object.close() + admin_module.add_ban(client_object.ip, username) + return await client.send("admin", success = True) + + await client.send("admin", success = False, error = "Specified username couldn't be found.") + + case ["unban", username]: + for ip, client_username in admin_module.banned_users.items(): + if client_username == username: + admin_module.unban(ip) + return await client.send("admin", success = True) + + await client.send("admin", success = False, error = "Specified banned user couldn't be found.") + + case ["ip", username]: + for client_object, client_username in state.clients.items(): + if client_username == username: + return await client.send("admin", success = True, ip = client_object.ip) + + await client.send("admin", success = False, error = "Specified username couldn't be found.") + + case ["banlist"]: + await client.send("admin", banlist = admin_module.banned_users) + + case ["say", message]: + websockets.broadcast(state.clients, orjson.dumps({ + "type": "message", + "data": {"text": message, "user": Constant.SERVER_USER} + }).decode()) + + case _: + await client.send("error", text = "Invalid admin command sent, your client might be outdated.") + + return + + # Handle becoming admin + if data.code is None: + return log.info("admin", f"Admin code is {Constant.ADMIN_CODE}") + + if data.code != Constant.ADMIN_CODE: + return await client.send("admin", success = False) + + client.admin = True + log.info("admin", f"{client.user_data['name']} ({client.id}) is now an administrator.") + return await client.send("admin", success = True) diff --git a/nightwatch/server/utils/models.py b/nightwatch/server/utils/models.py index 9675f85..d74993e 100644 --- a/nightwatch/server/utils/models.py +++ b/nightwatch/server/utils/models.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 iiPython # Modules -from typing import Annotated +from typing import Annotated, Optional from pydantic import BaseModel, PlainSerializer, StringConstraints from pydantic_extra_types.color import Color @@ -12,3 +12,7 @@ class IdentifyModel(BaseModel): class MessageModel(BaseModel): text: Annotated[str, StringConstraints(min_length = 1, max_length = 300)] + +class AdminModel(BaseModel): + code: Optional[str] = None + command: Optional[list[str]] = None diff --git a/nightwatch/server/utils/modules/admin.py b/nightwatch/server/utils/modules/admin.py new file mode 100644 index 0000000..254ff90 --- /dev/null +++ b/nightwatch/server/utils/modules/admin.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024 iiPython + +# Modules +import json +from nightwatch.config import config_path + +# Main module +class AdminModule: + def __init__(self) -> None: + self.banfile = config_path.parent / "bans.json" + self.banned_users = json.loads(self.banfile.read_text()) if self.banfile.is_file() else {} + + def save(self) -> None: + self.banfile.write_text(json.dumps(self.banned_users, indent = 4)) + + def add_ban(self, ip: str, username: str) -> None: + self.banned_users[ip] = username + self.save() + + def unban(self, ip: str) -> None: + del self.banned_users[ip] + self.save() + +admin_module = AdminModule() diff --git a/nightwatch/server/utils/websocket.py b/nightwatch/server/utils/websocket.py index e67b067..70b24f3 100644 --- a/nightwatch/server/utils/websocket.py +++ b/nightwatch/server/utils/websocket.py @@ -14,7 +14,7 @@ class NightwatchClient(): data serialization through orjson.""" def __init__(self, state, client: WebSocketCommonProtocol) -> None: self.client = client - self.identified, self.callback = False, None + self.admin, self.identified, self.callback = False, False, None self.state = state self.state.add_client(client)