diff --git a/pyproject.toml b/pyproject.toml index ab0935d..f14ed30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ exclude = [ "**/member_counter.py", "**/remindme.py", "**/snailrace.py", - "**/starboard.py", "**/uptime.py", "**/working_on.py", "**/utils/command_utils.py", diff --git a/uqcsbot/bot.py b/uqcsbot/bot.py index e2b045c..2d80f5f 100644 --- a/uqcsbot/bot.py +++ b/uqcsbot/bot.py @@ -29,7 +29,9 @@ def __init__(self, *args: Any, **kwargs: Any): # Important channel names & constants go here self.ADMIN_ALERTS_CNAME = "admin-alerts" self.GENERAL_CNAME = "general" + self.STARBOARD_CNAME = "starboard" self.BOT_TIMEZONE = timezone("Australia/Brisbane") + self.STARBOARD_ENAME = "starhaj" self.uqcs_server: discord.Guild diff --git a/uqcsbot/models.py b/uqcsbot/models.py index 42bae75..8c1ab6d 100644 --- a/uqcsbot/models.py +++ b/uqcsbot/models.py @@ -58,15 +58,14 @@ class Starboard(Base): # composite key on recv, sent. - # recv == null implies deleted recv message. - # recv_location == null implies deleted recv channel. recv should also be null. - # sent == null implies blacklisted recv message. + # see starboard.StarboardMsgPair recv: Mapped[Optional[int]] = mapped_column( "recv", BigInteger, primary_key=True, nullable=True ) recv_location: Mapped[Optional[int]] = mapped_column( "recv_location", BigInteger, nullable=True, unique=False ) - sent: Mapped[Optional[int]] = mapped_column( + # external Optional is needed as a type hint that the column can return None values. + sent: Optional[Mapped[Optional[int]]] = mapped_column( "sent", BigInteger, primary_key=True, nullable=True, unique=True ) diff --git a/uqcsbot/starboard.py b/uqcsbot/starboard.py index ed727d1..50a1afb 100644 --- a/uqcsbot/starboard.py +++ b/uqcsbot/starboard.py @@ -1,7 +1,6 @@ import os, time from threading import Timer -from typing import Tuple, List -from zoneinfo import ZoneInfo +from typing import Tuple, List, Union import discord from discord import app_commands @@ -9,30 +8,64 @@ from sqlalchemy.sql.expression import and_ from uqcsbot import models +from uqcsbot.bot import UQCSBot from uqcsbot.utils.err_log_utils import FatalErrorWithLog -class BlacklistedMessageError(Exception): - # always caught. used to differentiate "starboard message doesn't exist" and "starboard message is blacklisted" - pass +class StarboardMsgPair(object): + def __init__( + self, + recv: Union[discord.Message, None], + sent: Union[discord.Message, None], + recv_channel: Union[ + discord.abc.GuildChannel, discord.Thread, discord.abc.PrivateChannel, None + ], + blacklist: bool = False, + ): + self.recv = recv + self.sent = sent + self.recv_channel: Union[discord.TextChannel, None] = ( + recv_channel + if recv_channel is None or isinstance(recv_channel, discord.TextChannel) + else None + ) + self.blacklist = blacklist + def dangerous_recv(self) -> discord.Message: + if self.recv is None: + raise RuntimeError() + return self.recv -class Starboard(commands.Cog): - SB_CHANNEL_NAME = "starboard" - EMOJI_NAME = "starhaj" - MODLOG_CHANNEL_NAME = "admin-alerts" - BRISBANE_TZ = ZoneInfo("Australia/Brisbane") + def dangerous_sent(self) -> discord.Message: + if self.sent is None: + raise RuntimeError() + return self.sent + + def original_msg_deleted(self) -> bool: + return self.recv is None + + def original_channel_deleted(self) -> bool: + return self.recv_channel is None - def __init__(self, bot: commands.Bot): + def is_blacklisted(self) -> bool: + return self.blacklist and self.sent is None + + +class Starboard(commands.Cog): + def __init__(self, bot: UQCSBot): self.bot = bot - self.base_threshold = int(os.environ.get("SB_BASE_THRESHOLD")) - self.big_threshold = int(os.environ.get("SB_BIG_THRESHOLD")) - self.ratelimit = int(os.environ.get("SB_RATELIMIT")) + + if (base := os.environ.get("SB_BASE_THRESHOLD")) is not None: + self.base_threshold = int(base) + if (big := os.environ.get("SB_BIG_THRESHOLD")) is not None: + self.big_threshold = int(big) + if (limit := os.environ.get("SB_RATELIMIT")) is not None: + self.ratelimit = int(limit) # messages that are temp blocked from being resent to the starboard - self.base_blocked_messages = [] + self.base_blocked_messages: List[int] = [] # messages that are temp blocked from being repinned in the starboard - self.big_blocked_messages = [] + self.big_blocked_messages: List[int] = [] self.unblacklist_menu = app_commands.ContextMenu( name="Starboard Unblacklist", @@ -53,13 +86,22 @@ async def on_ready(self): N.B. this does assume the server only has one channel called "starboard" and one emoji called "starhaj". If this assumption stops holding, we may need to move back to IDs (cringe) """ - self.starboard_emoji = discord.utils.get(self.bot.emojis, name=self.EMOJI_NAME) - self.starboard_channel = discord.utils.get( - self.bot.get_all_channels(), name=self.SB_CHANNEL_NAME - ) - self.modlog = discord.utils.get( - self.bot.get_all_channels(), name=self.MODLOG_CHANNEL_NAME - ) + if ( + emoji := discord.utils.get(self.bot.emojis, name=self.bot.STARBOARD_ENAME) + ) is not None: + self.starboard_emoji = emoji + if ( + channel := discord.utils.get( + self.bot.get_all_channels(), name=self.bot.STARBOARD_CNAME + ) + ) is not None and isinstance(channel, discord.TextChannel): + self.starboard_channel = channel + if ( + log := discord.utils.get( + self.bot.get_all_channels(), name=self.bot.ADMIN_ALERTS_CNAME + ) + ) is not None and isinstance(log, discord.TextChannel): + self.modlog = log @app_commands.command() @app_commands.default_permissions(manage_guild=True) @@ -92,32 +134,33 @@ async def cleanup_starboard(self, interaction: discord.Interaction): .filter(models.Starboard.sent == message.id) .one_or_none() ) - if query is None and message.author.id == self.bot.user.id: + if query is None and message.author.id == self.bot.safe_user.id: # only delete messages that uqcsbot itself sent await message.delete() - elif message.author.id == self.bot.user.id: - try: - recieved_msg, starboard_msg = await self._lookup_from_id( - self.starboard_channel.id, message.id - ) - except BlacklistedMessageError: - if starboard_msg is not None: - await starboard_msg.delete() + elif message.author.id == self.bot.safe_user.id: + pair = await self._lookup_from_id(self.starboard_channel.id, message.id) + + if pair.is_blacklisted() and pair.sent is not None: + await pair.sent.delete() new_reaction_count = await self._count_num_reacts( - (recieved_msg, starboard_msg) - ) - await self._process_sb_updates( - new_reaction_count, recieved_msg, starboard_msg + (pair.recv, pair.sent) ) + await self._process_sb_updates(new_reaction_count, pair.recv, pair.sent) db_session.close() await interaction.followup.send("Finished cleaning up.") async def _blacklist_log( - self, message: discord.Message, user: discord.Member, blacklist: bool + self, + message: discord.Message, + user: Union[discord.Member, discord.User, str], + blacklist: bool, ): """Logs the use of a blacklist/unblacklist command to the modlog.""" + if not isinstance(message.author, discord.Member): + return + state = "blacklisted" if blacklist else "unblacklisted" embed = discord.Embed( @@ -127,7 +170,7 @@ async def _blacklist_log( name=message.author.display_name, icon_url=message.author.display_avatar.url ) embed.set_footer( - text=message.created_at.astimezone(tz=self.BRISBANE_TZ).strftime( + text=message.created_at.astimezone(tz=self.bot.BOT_TIMEZONE).strftime( "%b %d, %H:%M:%S" ) ) @@ -232,13 +275,37 @@ def _rm_big_ratelimit(self, id: int) -> None: """Callback to remove a message from the big-ratelimited list""" self.big_blocked_messages.remove(id) - def _starboard_db_add(self, recv: int, recv_location: int, sent: int) -> None: + def _starboard_db_add(self, recv: int, recv_location: int) -> None: """Creates a starboard DB entry. Only called from _process_updates when the messages are not None, so doesn't need to handle any None-checks - we can just pass ints straight in.""" db_session = self.bot.create_db_session() db_session.add( - models.Starboard(recv=recv, recv_location=recv_location, sent=sent) + models.Starboard(recv=recv, recv_location=recv_location, sent=None) + ) + db_session.commit() + db_session.close() + + def _starboard_db_finish(self, recv: int, recv_location: int, sent: int) -> None: + """Finalises a starboard DB entry. Finishes the job that _starboard_db_add starts - it adds the sent-message-id.""" + db_session = self.bot.create_db_session() + entry = ( + db_session.query(models.Starboard) + .filter( + and_( + models.Starboard.recv == recv, + models.Starboard.recv_location == recv_location, + ) + ) + .one_or_none() ) + + if entry is not None: + entry.sent = sent + else: + raise FatalErrorWithLog( + self.bot, f"Finishable-entry for message {recv} not found!" + ) + db_session.commit() db_session.close() @@ -251,15 +318,10 @@ def _starboard_db_remove( sent_id = sent.id if sent is not None else None db_session = self.bot.create_db_session() - entry = ( - db_session.query(models.Starboard) - .filter( - and_(models.Starboard.recv == recv_id, models.Starboard.sent == sent_id) - ) - .one_or_none() + entry = db_session.query(models.Starboard).filter( + and_(models.Starboard.recv == recv_id, models.Starboard.sent == sent_id) ) - if entry is not None: - entry.delete(synchronize_session=False) + entry.delete(synchronize_session=False) db_session.commit() db_session.close() @@ -277,7 +339,7 @@ async def _fetch_message_or_none( async def _lookup_from_id( self, channel_id: int, message_id: int - ) -> Tuple[discord.Message | None, discord.Message | None]: + ) -> StarboardMsgPair: """Takes a channel ID and a message ID. These may be for a starboard message or a recieved message. Returns (Recieved Message, Starboard Message), or Nones, as applicable. @@ -295,30 +357,33 @@ async def _lookup_from_id( ) if entry is not None: - if entry.recv_location is not None: - channel = self.bot.get_channel(entry.recv_location) - - return ( + if entry.recv_location is not None and ( + (channel := self.bot.get_channel(entry.recv_location)) is not None + and isinstance(channel, discord.TextChannel) + ): + return StarboardMsgPair( await self._fetch_message_or_none(channel, entry.recv), await self._fetch_message_or_none( self.starboard_channel, entry.sent ), + channel, ) else: - # if the recieved location is None, then the message won't exist either. + # if the recieved location is None or un-gettable, then the message won't exist either. # The only thing we can possibly return is a starboard message. - return ( + return StarboardMsgPair( None, await self._fetch_message_or_none( self.starboard_channel, entry.sent ), + None, ) else: if message_id == 1076779482637144105: # This is Isaac's initial-starboard message. I know, IDs are bad. BUT # consider that this doesn't cause any of the usual ID-related issues # like breaking lookups in other servers. - return + return StarboardMsgPair(None, None, None, True) raise FatalErrorWithLog( client=self.bot, @@ -332,6 +397,10 @@ async def _lookup_from_id( .filter(models.Starboard.recv == message_id) .one_or_none() ) + if ( + safe_channel := self.bot.get_channel(channel_id) + ) is None or not isinstance(safe_channel, discord.TextChannel): + return StarboardMsgPair(None, None, None, True) if entry is not None: if entry.recv_location != channel_id: @@ -340,21 +409,21 @@ async def _lookup_from_id( message=f"Starboard state error: Recieved message ({message_id}) from different channel ({channel_id}) to what the DB expects ({entry.recv_location})!", ) elif entry.sent is None: - raise BlacklistedMessageError() + return StarboardMsgPair(None, None, None, True) channel = self.bot.get_channel(channel_id) - return ( - await self._fetch_message_or_none(channel, message_id), + return StarboardMsgPair( + await self._fetch_message_or_none(safe_channel, message_id), await self._fetch_message_or_none( self.starboard_channel, entry.sent ), + channel, ) else: - channel = self.bot.get_channel(channel_id) - - return ( - await self._fetch_message_or_none(channel, message_id), + return StarboardMsgPair( + await self._fetch_message_or_none(safe_channel, message_id), + None, None, ) @@ -365,8 +434,8 @@ async def _count_num_reacts( Returns the number of unique reactions across both messages, not including the authors of those messages. """ - users = [] - authors = [] + users: List[int] = [] + authors: List[int] = [] for message in data: if message is None: @@ -386,15 +455,22 @@ async def _count_num_reacts( def _generate_message_text( self, reaction_count: int, recieved_msg: (discord.Message | None) ) -> str: - return ( - f"{str(self.starboard_emoji)} {reaction_count} | " - f"{('#' + recieved_msg.channel.name) if recieved_msg is not None else 'OoOoO... ghost message!'}" - ) + if recieved_msg is None or not isinstance( + recieved_msg.channel, discord.TextChannel + ): + name = "OoOoO... ghost message!" + else: + name = "#" + recieved_msg.channel.name + + return f"{str(self.starboard_emoji)} {reaction_count} | {name}" def _generate_message_embeds( self, recieved_msg: discord.Message ) -> List[discord.Embed]: """Creates the starboard embed for a message, including author, colours, replies, etc.""" + if not isinstance(recieved_msg.author, discord.Member): + raise RuntimeError() + embed = discord.Embed( color=recieved_msg.author.top_role.color, description=recieved_msg.content ) @@ -403,7 +479,7 @@ def _generate_message_embeds( icon_url=recieved_msg.author.display_avatar.url, ) embed.set_footer( - text=recieved_msg.created_at.astimezone(tz=self.BRISBANE_TZ).strftime( + text=recieved_msg.created_at.astimezone(tz=self.bot.BOT_TIMEZONE).strftime( "%b %d, %H:%M:%S" ) ) @@ -412,8 +488,13 @@ def _generate_message_embeds( embed.set_image(url=recieved_msg.attachments[0].url) # only takes the first attachment to avoid sending large numbers of images to starboard. - if recieved_msg.reference is not None and not isinstance( - recieved_msg.reference.resolved, discord.DeletedReferencedMessage + if ( + recieved_msg.reference is not None + and recieved_msg.reference.resolved is not None + and not isinstance( + recieved_msg.reference.resolved, discord.DeletedReferencedMessage + ) + and isinstance(recieved_msg.reference.resolved.author, discord.Member) ): # if the reference exists, add it. isinstance here just tightens race conditions; we check # that messages aren't deleted before calling this anyway. @@ -429,7 +510,7 @@ def _generate_message_embeds( replied.set_footer( text=recieved_msg.reference.resolved.created_at.astimezone( - tz=self.BRISBANE_TZ + tz=self.bot.BOT_TIMEZONE ).strftime("%b %d, %H:%M:%S") ) @@ -454,6 +535,9 @@ async def _process_sb_updates( ) ) ): + # Add message to the DB as a fake blacklist entry. + self._starboard_db_add(recieved_msg.id, recieved_msg.channel.id) + # Above threshold, not blocked, not replying to deleted msg, no current starboard message? post it. new_sb_message = await self.starboard_channel.send( content=self._generate_message_text(reaction_count, recieved_msg), @@ -471,7 +555,7 @@ async def _process_sb_updates( ) # recieved_msg isn't None and we just sent the sb message, so it also shouldn't be None - self._starboard_db_add( + self._starboard_db_finish( recieved_msg.id, recieved_msg.channel.id, new_sb_message.id ) @@ -522,23 +606,24 @@ async def on_raw_reaction_add( ) -> None: if ( payload.emoji != self.starboard_emoji + or payload.guild_id is None or self.bot.get_guild(payload.guild_id) is None ): return - try: - recieved_msg, starboard_msg = await self._lookup_from_id( - payload.channel_id, payload.message_id - ) - except BlacklistedMessageError: + pair = await self._lookup_from_id(payload.channel_id, payload.message_id) + if pair.is_blacklisted(): return - if recieved_msg is not None and recieved_msg.author.id == payload.user_id: - # payload.member is guaranteed to be available because we're adding and we're in a server - return await recieved_msg.remove_reaction(payload.emoji, payload.member) + if ( + pair.recv is not None + and pair.recv.author.id == payload.user_id + and payload.member is not None + ): + return await pair.recv.remove_reaction(payload.emoji, payload.member) - new_reaction_count = await self._count_num_reacts((recieved_msg, starboard_msg)) - await self._process_sb_updates(new_reaction_count, recieved_msg, starboard_msg) + new_reaction_count = await self._count_num_reacts((pair.recv, pair.sent)) + await self._process_sb_updates(new_reaction_count, pair.recv, pair.sent) @commands.Cog.listener() async def on_raw_reaction_remove( @@ -546,36 +631,31 @@ async def on_raw_reaction_remove( ) -> None: if ( payload.emoji != self.starboard_emoji + or payload.guild_id is None or self.bot.get_guild(payload.guild_id) is None ): return - try: - recieved_msg, starboard_msg = await self._lookup_from_id( - payload.channel_id, payload.message_id - ) - except BlacklistedMessageError: + pair = await self._lookup_from_id(payload.channel_id, payload.message_id) + if pair.is_blacklisted(): return - new_reaction_count = await self._count_num_reacts((recieved_msg, starboard_msg)) - await self._process_sb_updates(new_reaction_count, recieved_msg, starboard_msg) + new_reaction_count = await self._count_num_reacts((pair.recv, pair.sent)) + await self._process_sb_updates(new_reaction_count, pair.recv, pair.sent) @commands.Cog.listener() async def on_raw_reaction_clear( self, payload: discord.RawReactionClearEvent ) -> None: - if self.bot.get_guild(payload.guild_id) is None: + if payload.guild_id is None or self.bot.get_guild(payload.guild_id) is None: return - try: - recieved_msg, starboard_msg = await self._lookup_from_id( - payload.channel_id, payload.message_id - ) - except BlacklistedMessageError: + pair = await self._lookup_from_id(payload.channel_id, payload.message_id) + if pair.is_blacklisted(): return - new_reaction_count = await self._count_num_reacts((recieved_msg, starboard_msg)) - await self._process_sb_updates(new_reaction_count, recieved_msg, starboard_msg) + new_reaction_count = await self._count_num_reacts((pair.recv, pair.sent)) + await self._process_sb_updates(new_reaction_count, pair.recv, pair.sent) @commands.Cog.listener() async def on_raw_reaction_clear_emoji( @@ -583,20 +663,46 @@ async def on_raw_reaction_clear_emoji( ) -> None: if ( payload.emoji != self.starboard_emoji + or payload.guild_id is None or self.bot.get_guild(payload.guild_id) is None ): return - try: - recieved_msg, starboard_msg = await self._lookup_from_id( - payload.channel_id, payload.message_id - ) - except BlacklistedMessageError: + pair = await self._lookup_from_id(payload.channel_id, payload.message_id) + if pair.is_blacklisted(): return - new_reaction_count = await self._count_num_reacts((recieved_msg, starboard_msg)) - await self._process_sb_updates(new_reaction_count, recieved_msg, starboard_msg) + new_reaction_count = await self._count_num_reacts((pair.recv, pair.sent)) + await self._process_sb_updates(new_reaction_count, pair.recv, pair.sent) + + @commands.Cog.listener() + async def on_raw_message_delete( + self, payload: discord.RawMessageDeleteEvent + ) -> None: + """ + Fallback that blacklists messages whenever a starboard message is deleted. See documentation for context blacklist commands. + """ + if payload.channel_id != self.starboard_channel.id: + return + + # resolve the message objects before blacklisting + pair = await self._lookup_from_id(payload.channel_id, payload.message_id) + + db_session = self.bot.create_db_session() + # can't use the db query functions for this, they error out if a message hits the blacklist + entry = db_session.query(models.Starboard).filter( + models.Starboard.sent == payload.message_id + ) + query_val = entry.one_or_none() + + if query_val is not None: + query_val.sent = None + + db_session.commit() + db_session.close() + + await self._blacklist_log(pair.dangerous_recv(), "Automatically", True) -async def setup(bot: commands.Bot): +async def setup(bot: UQCSBot): await bot.add_cog(Starboard(bot))