-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
235 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) 2024 iiPython | ||
|
||
# Modules | ||
from nightwatch.bot import Client | ||
|
||
# Create client | ||
class NextgenerationBot(Client): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
# Handle events | ||
# async def on_message(self, ctx) -> None: | ||
# print(f"Connected to '{ctx.rics.name}' as {self.user.name}!") | ||
|
||
NextgenerationBot().run( | ||
username = "Next-gen Bot", | ||
hex = "ff0000", | ||
address = "localhost:8000" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
__version__ = "0.10.4" | ||
__version__ = "0.10.5" | ||
|
||
import re | ||
HEX_COLOR_REGEX = re.compile(r"^[A-Fa-f0-9]{6}$") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,97 +1,217 @@ | ||
# Copyright (c) 2024 iiPython | ||
|
||
# Modules | ||
import typing | ||
import asyncio | ||
from typing import Callable | ||
from dataclasses import dataclass, fields, is_dataclass | ||
|
||
import orjson | ||
import requests | ||
from websockets import connect | ||
from websockets.asyncio.client import ClientConnection | ||
|
||
from .types import User, Message, Server | ||
# Typing | ||
T = typing.TypeVar("T") | ||
|
||
def from_dict(cls: typing.Type[T], data: dict) -> T: | ||
if not is_dataclass(cls): | ||
raise ValueError(f"{cls} is not a dataclass") | ||
|
||
field_types = {f.name: f.type for f in fields(cls)} | ||
instance_data = {} | ||
|
||
for key, value in data.items(): | ||
if key in field_types: | ||
field_type = field_types[key] | ||
if is_dataclass(field_type) and isinstance(value, dict): | ||
instance_data[key] = from_dict(field_type, value) # type: ignore | ||
|
||
else: | ||
instance_data[key] = value | ||
|
||
return cls(**instance_data) | ||
|
||
@dataclass | ||
class User: | ||
name: str | ||
hex: str | ||
admin: bool | ||
|
||
def __repr__(self) -> str: | ||
return f"<User name='{self.name}' hex='{self.hex}' admin={self.admin}>" | ||
|
||
@dataclass | ||
class Message: | ||
user: User | ||
message: str | ||
time: int | ||
|
||
def __repr__(self) -> str: | ||
return f"<Message user='{self.user}' message='{self.message}' time={self.time}>" | ||
|
||
@dataclass | ||
class RicsInfo: | ||
name: str | ||
users: list[User] | ||
chat_logs: list[Message] | ||
|
||
def __repr__(self) -> str: | ||
return f"<RicsInfo name='{self.name}' users=[...] chat_logs=[...]>" | ||
|
||
# Exceptions | ||
class AuthorizationFailed(Exception): | ||
pass | ||
|
||
# Handle state | ||
class ClientState: | ||
def __init__(self) -> None: | ||
self.__state = {} | ||
|
||
# Typing | ||
self.user_list: list[User] | ||
self.chat_logs: list[Message] | ||
self.rics_info: dict[str, str] | ||
self.socket : ClientConnection | ||
|
||
def __getitem__(self, key: str) -> typing.Any: | ||
return self.__state.get(key) | ||
|
||
def __setitem__(self, key: str, value: typing.Any) -> None: | ||
self.__state[key] = value | ||
|
||
# Context | ||
class Context: | ||
def __init__(self, socket, message: Message | None, server: Server) -> None: | ||
self.socket, self.message, self.server = socket, message, server | ||
def __init__( | ||
self, | ||
state: ClientState, | ||
message: typing.Optional[Message] = None, | ||
user: typing.Optional[User] = None | ||
) -> None: | ||
self.state = state | ||
self.rics = RicsInfo( | ||
name = state.rics_info["name"], | ||
users = state.user_list, | ||
chat_logs = state.chat_logs | ||
) | ||
|
||
if message is not None: | ||
self.message = message | ||
|
||
if user is not None: | ||
self.user = user | ||
|
||
async def send(self, message: str) -> None: | ||
await self.socket.send(orjson.dumps({"type": "message", "data": {"message": message}}), text = True) | ||
await self.state.socket.send("PENIS!!") | ||
await self.state.socket.send(orjson.dumps({"type": "message", "data": {"message": message}})) | ||
|
||
async def reply(self, message: str) -> None: | ||
if self.message is None: | ||
raise RuntimeError("Cannot reply to a context with no message information!") | ||
|
||
await self.send(f"[↑ {self.message.user.name}] {message}") | ||
await self.send(f"[↑ {self.user.name}] {message}") | ||
|
||
async def run_command(self, command: str, data: dict) -> dict: | ||
await self.socket.send(orjson.dumps({"type": command, "data": data})) | ||
return orjson.loads(await self.socket.recv()) | ||
def __repr__(self) -> str: | ||
return f"<Context rics={self.rics} message={getattr(self, 'message', None)} user={getattr(self, 'user', None)}>" | ||
|
||
# Main client class | ||
class Client: | ||
def __init__(self) -> None: | ||
self.callbacks: dict[str, Callable] = {} | ||
|
||
# Handle user data | ||
def setup_profile(self, username: str, hex: str) -> None: | ||
"""Initialize the user profile with the given username and hex.""" | ||
self.user = {"username": f"[BOT] {username}", "hex": hex} | ||
|
||
async def send(self, type: str, **data) -> None: | ||
"""Send the given type and payload to the server.""" | ||
await self.socket.send(orjson.dumps({"type": type, "data": data}), text = True) | ||
|
||
async def connect(self, address: str) -> None: | ||
"""Connect to the given Nightwatch server and begin handling messages.""" | ||
|
||
# Check if we're missing user information | ||
if not hasattr(self, "user"): | ||
raise ValueError("No user information has been provided yet!") | ||
|
||
# Parse the address | ||
address_parts = address.split(":") | ||
host, port = address_parts[0], 443 if len(address_parts) == 1 else int(address_parts[1]) | ||
|
||
# Send authorization request | ||
protocol, url = "s" if port == 443 else "", f"{host}:{port}" | ||
authorization = requests.post(f"http{protocol}://{url}/api/join", json = self.user).json()["authorization"] | ||
|
||
# Connect to websocket gateway | ||
async with connect(f"ws{protocol}://{url}/api/ws?authorization={authorization}") as socket: | ||
self.socket, self.server = socket, None | ||
while socket: | ||
match orjson.loads(await socket.recv()): | ||
case {"type": "rics-info", "data": {"name": name, "message-log": message_log, "user-list": user_list}}: | ||
self.server = Server(url, name, [User(**user) for user in user_list]) | ||
if "connected" in self.callbacks: | ||
await self.callbacks["connected"](Context(socket, None, self.server)) | ||
|
||
if "message-log" in self.callbacks: | ||
await self.callbacks["message-log"]( | ||
Context(socket, None, self.server), | ||
[Message.from_payload(message) for message in message_log] | ||
) | ||
|
||
case {"type": "message", "data": payload} if self.server and "message" in self.callbacks: | ||
message = Message(**payload) | ||
await self.callbacks["message"](Context(socket, message, self.server), message) | ||
|
||
case {"type": "join", "data": {"user": user, "time": _}} if self.server and "join" in self.callbacks: | ||
await self.callbacks["join"](Context(socket, None, self.server), User.from_payload(user)) | ||
|
||
case {"type": "leave", "data": {"user": user, "time": _}} if self.server and "leave" in self.callbacks: | ||
await self.callbacks["leave"](Context(socket, None, self.server), User.from_payload(user)) | ||
|
||
def run(self, address: str) -> None: | ||
"""Passthrough method to run :client.connect: asynchronously and start the event loop. | ||
This is the recommended method to use when launching a client.""" | ||
asyncio.run(self.connect(address)) | ||
|
||
# Handle event connections | ||
def event(self, event_name: str) -> Callable: | ||
"""Attach a listener to a specific Nightwatch event.""" | ||
def internal_callback(func: Callable) -> None: | ||
self.callbacks[event_name] = func | ||
|
||
return internal_callback | ||
self.__state = ClientState() | ||
self.__session = requests.Session() | ||
|
||
# Events (for overwriting) | ||
async def on_connect(self, ctx) -> None: | ||
print(ctx) | ||
|
||
async def on_message(self, ctx: Context) -> None: | ||
if ctx.message.message == "fuck you": | ||
await ctx.reply("FUCK YOU1!!!!!!!!!!!") | ||
|
||
if ctx.message.message == "what": | ||
await ctx.send("test message") | ||
|
||
async def on_join(self, ctx) -> None: | ||
pass | ||
|
||
async def on_leave(self, ctx) -> None: | ||
pass | ||
|
||
# Handle running | ||
async def __authorize(self, username: str, hex: str, address: str) -> tuple[str, int, str, str]: | ||
"""Given an authorization payload, attempt an authorization request. | ||
Return: | ||
:host: (str) -- hostname of the backend | ||
:port: (int) -- port of the backend | ||
:protocol: (str) -- ws(s):// depending on the port | ||
:auth: (str) -- authorization code""" | ||
host, port = (address if ":" in address else f"{address}:443").split(":") | ||
protocol = "s" if port == "443" else "" | ||
|
||
# Establish authorization | ||
try: | ||
response = self.__session.post( | ||
f"http{protocol}://{host}:{port}/api/join", | ||
json = { | ||
"username": username, | ||
"hex": hex | ||
}, | ||
timeout = 5 | ||
) | ||
response.raise_for_status() | ||
|
||
# Handle payload | ||
payload = response.json() | ||
if payload["code"] != 200: | ||
raise AuthorizationFailed(response) | ||
|
||
return host, int(port), f"ws{protocol}://", payload["authorization"] | ||
|
||
except requests.RequestException: | ||
raise AuthorizationFailed("Connection failed!") | ||
|
||
async def __match_event(self, event: dict[str, typing.Any]) -> None: | ||
match event: | ||
case {"type": "rics-info", "data": payload}: | ||
self.__state.chat_logs = [from_dict(Message, message) for message in payload["message-log"]] | ||
self.__state.user_list = [from_dict(User, user) for user in payload["user-list"]] | ||
self.__state.rics_info = {"name": payload["name"]} | ||
await self.on_connect(Context(self.__state)) | ||
|
||
case {"type": "message", "data": payload}: | ||
message = from_dict(Message, payload) | ||
|
||
# Propagate | ||
await self.on_message(Context(self.__state, message)) | ||
self.__state.chat_logs.append(message) | ||
|
||
case {"type": "join", "data": payload}: | ||
user = from_dict(User, payload["user"]) | ||
self.__state.user_list.append(user) | ||
await self.on_join(Context(self.__state, user = user)) | ||
|
||
case {"type": "leave", "data": payload}: | ||
user = from_dict(User, payload["user"]) | ||
self.__state.user_list.remove(user) | ||
await self.on_leave(Context(self.__state, user = user)) | ||
|
||
async def __event_loop(self, username: str, hex: str, address: str) -> None: | ||
"""Establish a connection and listen to websocket messages. | ||
This method shouldn't be called directly, use :Client.run: instead.""" | ||
|
||
host, port, protocol, auth = await self.__authorize(username, hex, address) | ||
async with connect(f"{protocol}{host}:{port}/api/ws?authorization={auth}") as socket: | ||
self.__state.socket = socket | ||
while socket.state == 1: | ||
await self.__match_event(orjson.loads(await socket.recv())) | ||
|
||
def run( | ||
self, | ||
username: str, | ||
hex: str, | ||
address: str | ||
): | ||
"""Start the client and run the event loop. | ||
Arguments: | ||
:username: (str) -- the username to connect with | ||
:hex: (str) -- the hex color code to connect with | ||
:address: (str) -- the FQDN to connect to | ||
""" | ||
asyncio.run(self.__event_loop(username, hex, address)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters