Skip to content

Commit

Permalink
improve bot library
Browse files Browse the repository at this point in the history
  • Loading branch information
iiPythonx committed Nov 26, 2024
1 parent 5521491 commit 320c1e2
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions nightwatch/bot/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

# Exceptions
class AuthorizationFailed(Exception):
pass
def __init__(self, message: str, json: dict | None = None) -> None:
super().__init__(message)
self.json = json

# Handle state
class ClientState:
Expand Down Expand Up @@ -64,6 +66,12 @@ def __init__(self) -> None:
self.__state = ClientState()
self.__session = requests.Session()

# Public attributes (provided just for the hell of it)
self.user: User | None = None
"""The current user this client is connected as."""
self.address: str | None = None
"""The address this client is connected to."""

# Events (for overwriting)
async def on_connect(self, ctx: Context) -> None:
"""Listen to the :connect: event."""
Expand Down Expand Up @@ -105,12 +113,12 @@ async def __authorize(self, username: str, hex: str, address: str) -> tuple[str,
# Handle payload
payload = response.json()
if payload["code"] != 200:
raise AuthorizationFailed(response)
raise AuthorizationFailed("Connection failed!", payload)

return host, int(port), f"ws{protocol}://", payload["authorization"]

except requests.RequestException:
raise AuthorizationFailed("Connection failed!")
except requests.RequestException as e:
raise AuthorizationFailed("Connection failed!", e.response.json() if e.response is not None else None)

async def __match_event(self, event: dict[str, typing.Any]) -> None:
match event:
Expand All @@ -129,6 +137,9 @@ async def __match_event(self, event: dict[str, typing.Any]) -> None:

case {"type": "join", "data": payload}:
user = from_dict(User, payload["user"])
if user == self.user:
return

self.__state.user_list.append(user)
await self.on_join(Context(self.__state, user = user))

Expand All @@ -137,16 +148,22 @@ async def __match_event(self, event: dict[str, typing.Any]) -> None:
self.__state.user_list.remove(user)
await self.on_leave(Context(self.__state, user = user))

async def __event_loop(self, username: str, hex: str, address: str) -> None:
async def event_loop(self, username: str, hex: str, address: str) -> None:
"""Establish a connection and listen to websocket messages.
This method shouldn't be called directly, use :Client.run: instead."""

host, port, protocol, auth = await self.__authorize(username, hex, address)
self.user, self.address = User(username, hex, False, True), address

async with connect(f"{protocol}{host}:{port}/api/ws?authorization={auth}") as socket:
self.__state.socket = socket
while socket.state == 1:
await self.__match_event(orjson.loads(await socket.recv()))

async def close(self) -> None:
"""Closes the websocket connection."""
await self.__state.socket.close()

def run(
self,
username: str,
Expand All @@ -160,4 +177,4 @@ def run(
:hex: (str) -- the hex color code to connect with
:address: (str) -- the FQDN to connect to
"""
asyncio.run(self.__event_loop(username, hex, address))
asyncio.run(self.event_loop(username, hex, address))

0 comments on commit 320c1e2

Please sign in to comment.