From 6bb8dcb7ad32640f5718db87c3678909050468e8 Mon Sep 17 00:00:00 2001 From: odrevet Date: Tue, 21 May 2024 20:48:13 +0200 Subject: [PATCH] add option FORCE_USER_ROLE: force api calls to use role:user instead of system --- sgpt/app.py | 5 +++++ sgpt/config.py | 1 + sgpt/handlers/chat_handler.py | 5 +++-- sgpt/handlers/default_handler.py | 5 +++-- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/sgpt/app.py b/sgpt/app.py index bd092165..7b5ad842 100644 --- a/sgpt/app.py +++ b/sgpt/app.py @@ -155,6 +155,11 @@ def main( callback=inst_funcs, hidden=True, # Hiding since should be used only once. ), + force_user_role: bool = typer.Option( + cfg.get("FORCE_USER_ROLE") == "true", + help="Force role: user in API calls", + rich_help_panel="Chat Options", + ), ) -> None: stdin_passed = not sys.stdin.isatty() diff --git a/sgpt/config.py b/sgpt/config.py index 981cba74..4c6e95b3 100644 --- a/sgpt/config.py +++ b/sgpt/config.py @@ -34,6 +34,7 @@ "API_BASE_URL": os.getenv("API_BASE_URL", "default"), "PRETTIFY_MARKDOWN": os.getenv("PRETTIFY_MARKDOWN", "true"), "USE_LITELLM": os.getenv("USE_LITELLM", "false"), + "FORCE_USER_ROLE": os.getenv("FORCE_USER_ROLE", "false"), # New features might add their own config variables here. } diff --git a/sgpt/handlers/chat_handler.py b/sgpt/handlers/chat_handler.py index 6ba0a18d..60bfc456 100644 --- a/sgpt/handlers/chat_handler.py +++ b/sgpt/handlers/chat_handler.py @@ -14,7 +14,7 @@ CHAT_CACHE_LENGTH = int(cfg.get("CHAT_CACHE_LENGTH")) CHAT_CACHE_PATH = Path(cfg.get("CHAT_CACHE_PATH")) - +FORCE_USER_ROLE = cfg.get("FORCE_USER_ROLE") == "true" class ChatSession: """ @@ -168,7 +168,8 @@ def validate(self) -> None: def make_messages(self, prompt: str) -> List[Dict[str, str]]: messages = [] if not self.initiated: - messages.append({"role": "system", "content": self.role.role}) + role = "system" if FORCE_USER_ROLE == False else "user" + messages.append({"role": role, "content": self.role.role}) messages.append({"role": "user", "content": prompt}) return messages diff --git a/sgpt/handlers/default_handler.py b/sgpt/handlers/default_handler.py index e0fdad13..6640c69d 100644 --- a/sgpt/handlers/default_handler.py +++ b/sgpt/handlers/default_handler.py @@ -7,7 +7,7 @@ CHAT_CACHE_LENGTH = int(cfg.get("CHAT_CACHE_LENGTH")) CHAT_CACHE_PATH = Path(cfg.get("CHAT_CACHE_PATH")) - +FORCE_USER_ROLE = cfg.get("FORCE_USER_ROLE") == "true" class DefaultHandler(Handler): def __init__(self, role: SystemRole, markdown: bool) -> None: @@ -15,8 +15,9 @@ def __init__(self, role: SystemRole, markdown: bool) -> None: self.role = role def make_messages(self, prompt: str) -> List[Dict[str, str]]: + role = "system" if FORCE_USER_ROLE == False else "user" messages = [ - {"role": "system", "content": self.role.role}, + {"role": role, "content": self.role.role}, {"role": "user", "content": prompt}, ] return messages