diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index 336736e2..248c7f07 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -29,7 +29,7 @@ from .config import FUZZY_TABLE_LENGTH from .controls import SourceControls from .embeddings import Embeddings -from .llm import Llm +from .llm import Llm, Message from .memory import memory from .models import ( DataRequired, FuzzyTable, JoinRequired, Sql, TableJoins, Topic, @@ -68,6 +68,8 @@ class Agent(Viewer): provides = param.List(default=[], readonly=True) + _steps_layout = param.ClassSelector(default=None, class_=pn.Column) + _extensions = () _max_width = 1200 @@ -88,7 +90,7 @@ def _exception_handler(exception): ) if "interface" not in params: - params["interface"] = ChatInterface(callback=self._chat_invoke) + params["interface"] = ChatInterface(callback=self._interface_callback) super().__init__(**params) if not self.debug: pn.config.exception_handler = _exception_handler @@ -105,15 +107,15 @@ async def applies(cls) -> bool: """ return True - async def _chat_invoke(self, contents: list | str, user: str, instance: ChatInterface): - await self.invoke(contents) + async def _interface_callback(self, contents: list | str, user: str, instance: ChatInterface): + await self.respond(contents) self._retries_left = 1 def __panel__(self): return self.interface async def _system_prompt_with_context( - self, messages: list | str, context: str = "" + self, messages: list, context: str = "" ) -> str: system_prompt = self.system_prompt if self.embeddings: @@ -122,7 +124,7 @@ async def _system_prompt_with_context( system_prompt += f"\n### CONTEXT: {context}" return system_prompt - async def _get_closest_tables(self, messages: list | str, tables: list[str], n: int = 3) -> list[str]: + async def _get_closest_tables(self, messages: list[Message], tables: list[str], n: int = 3) -> list[str]: system = ( f"You are great at extracting keywords based on the user query to find the correct table. " f"The current table selected: `{memory.get('current_table', 'N/A')}`. " @@ -178,23 +180,31 @@ async def _select_table(self, tables): self.interface.pop(-1) return tables - async def requirements(self, messages: list | str): + async def requirements(self, messages: list[Message]) -> list[str]: return self.requires - async def answer(self, messages: list | str): + async def respond( + self, messages: list[Message], + title: str = "", + render_output: bool = True, + steps_layout: pn.Column | None = None + ) -> None: + self._steps_layout = steps_layout + system_prompt = await self._system_prompt_with_context(messages) + response = self.llm.stream( + messages, system=system_prompt, response_model=self.response_model, field="output" + ) + + if not render_output: + return message = None - async for output in self.llm.stream( - messages, system=system_prompt, response_model=self.response_model, field="output" - ): + async for output_chunk in response: message = self.interface.stream( - output, replace=True, message=message, user=self.user, max_width=self._max_width + output_chunk, replace=True, message=message, user=self.user, max_width=self._max_width ) - async def invoke(self, messages: list | str): - await self.answer(messages) - class SourceAgent(Agent): """ @@ -212,15 +222,20 @@ class SourceAgent(Agent): _extensions = ('filedropper',) - async def answer(self, messages: list | str): + async def respond( + self, messages: list, + title: str = "", + render_output: bool = True, + steps_layout: pn.Column | None = None + ) -> None: + self._steps_layout = steps_layout + if not render_output: + return + source_controls = SourceControls(multiple=True, replace_controls=True, select_existing=False) - self.interface.send(source_controls, respond=False, user="SourceAgent") while not source_controls._add_button.clicks > 0: await asyncio.sleep(0.05) - async def invoke(self, messages: list[str] | str): - await self.answer(messages) - class ChatAgent(Agent): """ @@ -245,13 +260,13 @@ class ChatAgent(Agent): requires = param.List(default=["current_source"], readonly=True) @retry_llm_output() - async def requirements(self, messages: list | str, errors=None): + async def requirements(self, messages: list[Message], errors=None): if 'current_data' in memory: return self.requires available_sources = memory["available_sources"] _, tables_schema_str = await gather_table_sources(available_sources) - with self.interface.add_step(title="Checking if data is required") as step: + with self.interface.add_step(title="Checking if data is required", steps_layout=self._steps_layout) as step: response = self.llm.stream( messages, system=( @@ -270,7 +285,7 @@ async def requirements(self, messages: list | str, errors=None): return self.requires async def _system_prompt_with_context( - self, messages: list | str, context: str = "" + self, messages: list[Message], context: str = "" ) -> str: source = memory.get("current_source") if not source: @@ -321,7 +336,7 @@ class ChatDetailsAgent(ChatAgent): ) async def _system_prompt_with_context( - self, messages: list | str, context: str = "" + self, messages: list[Message], context: str = "" ) -> str: system_prompt = self.system_prompt topic = (await self.llm.invoke( @@ -352,7 +367,16 @@ class LumenBaseAgent(Agent): _max_width = None - def _render_lumen(self, component: Component, message: pn.chat.ChatMessage = None, **kwargs): + def _render_lumen( + self, + component: Component, + message: pn.chat.ChatMessage = None, + render_output: bool = True, + **kwargs + ): + if not render_output: + return + out = self._output_type(component=component, **kwargs) message_kwargs = dict(value=out, user=self.user) self.interface.stream(message=message, **message_kwargs, replace=True, max_width=self._max_width) @@ -384,13 +408,23 @@ def _use_table(self, event): table = self._df.iloc[event.row, 0] self.interface.send(f"Show the table: {table!r}") - async def answer(self, messages: list | str): + async def respond( + self, + messages: list[Message], + title: str = "", + render_output: bool = True, + steps_layout: pn.Column | None = None + ) -> None: + self._steps_layout = steps_layout tables = [] for source in memory['available_sources']: tables += source.get_tables() if not tables: return + if not render_output: + return + self._df = pd.DataFrame({"Table": tables}) table_list = pn.widgets.Tabulator( self._df, @@ -405,10 +439,6 @@ async def answer(self, messages: list | str): ) table_list.on_click(self._use_table) self.interface.stream(table_list, user="Lumen") - return tables - - async def invoke(self, messages: list | str): - await self.answer(messages) class SQLAgent(LumenBaseAgent): @@ -434,7 +464,9 @@ class SQLAgent(LumenBaseAgent): _extensions = ('codeeditor', 'tabulator',) - async def _select_relevant_table(self, messages: list | str) -> tuple[str, BaseSQLSource]: + _output_type = SQLOutput + + async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, BaseSQLSource]: """Select the most relevant table based on the user query.""" available_sources = memory["available_sources"] @@ -446,7 +478,7 @@ async def _select_relevant_table(self, messages: list | str) -> tuple[str, BaseS elif len(tables) == 1: table = tables[0] else: - with self.interface.add_step(title="Choosing the most relevant table...") as step: + with self.interface.add_step(title="Choosing the most relevant table...", steps_layout=self._steps_layout) as step: closest_tables = memory.pop("closest_tables", []) if closest_tables: tables = closest_tables @@ -476,14 +508,15 @@ async def _select_relevant_table(self, messages: list | str) -> tuple[str, BaseS return table, source - def _render_sql(self, query): - pipeline = memory['current_pipeline'] - out = SQLOutput(component=pipeline, spec=query.rstrip(';')) - self.interface.stream(out, user="SQL", replace=True, max_width=self._max_width) - return out - @retry_llm_output() - async def _create_valid_sql(self, messages, system, tables_to_source, errors=None): + async def _create_valid_sql( + self, + messages: list[Message], + system: str, + tables_to_source, + title: str, + errors=None + ): if errors: last_query = self.interface.serialize()[-1]["content"].replace("```sql", "").rstrip("```").strip() errors = '\n'.join(errors) @@ -499,7 +532,7 @@ async def _create_valid_sql(self, messages, system, tables_to_source, errors=Non } ] - with self.interface.add_step(title="Creating SQL query...", success_title="SQL Query") as step: + with self.interface.add_step(title=title or "SQL query", steps_layout=self._steps_layout) as step: response = self.llm.stream(messages, system=system, response_model=Sql) sql_query = None async for output in response: @@ -574,8 +607,13 @@ async def _create_valid_sql(self, messages, system, tables_to_source, errors=Non memory["current_sql"] = sql_query return sql_query - async def check_join_required(self, messages, schema, table): - with self.interface.add_step(title="Checking if join is required", user="Assistant") as step: + async def _check_join_required( + self, + messages: list[Message], + schema, + table: str + ): + with self.interface.add_step(title="Checking if join is required", steps_layout=self._steps_layout) as step: join_prompt = render_template( "join_required.jinja2", schema=yaml.dump(schema), @@ -592,7 +630,7 @@ async def check_join_required(self, messages, schema, table): step.success_title = 'Query requires join' if join_required else 'No join required' return join_required - async def find_join_tables(self, messages: list | str): + async def find_join_tables(self, messages: list): multi_source = len(memory['available_sources']) > 1 if multi_source: available_tables = [ @@ -606,7 +644,7 @@ async def find_join_tables(self, messages: list | str): "find_joins.jinja2", available_tables=available_tables ) - with self.interface.add_step(title="Determining tables required for join") as step: + with self.interface.add_step(title="Determining tables required for join", steps_layout=self._steps_layout) as step: output = await self.llm.invoke( messages, system=find_joins_prompt, @@ -640,26 +678,33 @@ async def find_join_tables(self, messages: list | str): tables_to_source[a_table] = a_source return tables_to_source - async def answer(self, messages: list | str): + async def respond( + self, + messages: list[Message], + title: str = "", + render_output: bool = True, + steps_layout: pn.Column | None = None + ) -> None: """ Steps: 1. Retrieve the current source and table from memory. 2. If the source lacks a `get_sql_expr` method, return `None`. 3. Fetch the schema for the current table using `get_schema` without min/max values. - 4. Determine if a join is required by calling `check_join_required`. + 4. Determine if a join is required by calling `_check_join_required`. 5. If required, find additional tables via `find_join_tables`; otherwise, use the current source and table. 6. For each source and table, get the schema and SQL expression, storing them in `table_schemas`. 7. Render the SQL prompt using the table schemas, dialect, join status, and table. 8. If a join is required, remove source/table prefixes from the last message. 9. Construct the SQL query with `_create_valid_sql`. """ + self._steps_layout = steps_layout table, source = await self._select_relevant_table(messages) if not hasattr(source, "get_sql_expr"): return None schema = await get_schema(source, table, include_min_max=False) - join_required = await self.check_join_required(messages, schema, table) + join_required = await self._check_join_required(messages, schema, table) if join_required: tables_to_source = await self.find_join_tables(messages) else: @@ -698,13 +743,9 @@ async def answer(self, messages: list | str): if join_required: # Remove source prefixes message, e.g. //// messages[-1]["content"] = re.sub(r"//[^/]+//", "", messages[-1]["content"]) - sql_query = await self._create_valid_sql(messages, system, tables_to_source) - print(sql_query) - return sql_query - - async def invoke(self, messages: list | str): - sql_query = await self.answer(messages) - self._render_sql(sql_query) + sql_query = await self._create_valid_sql(messages, system, tables_to_source, title) + pipeline = memory['current_pipeline'] + self._render_lumen(pipeline, spec=sql_query, render_output=render_output) class BaseViewAgent(LumenBaseAgent): @@ -716,7 +757,14 @@ class BaseViewAgent(LumenBaseAgent): async def _extract_spec(self, model: BaseModel): return dict(model) - async def answer(self, messages: list | str) -> hvPlotUIView: + async def respond( + self, + messages: list[Message], + title: str = "", + render_output: bool = True, + steps_layout: pn.Column | None = None + ) -> None: + self._steps_layout = steps_layout pipeline = memory["current_pipeline"] # Write prompts @@ -740,15 +788,12 @@ async def answer(self, messages: list | str) -> hvPlotUIView: spec = await self._extract_spec(output) chain_of_thought = spec.pop("chain_of_thought", None) if chain_of_thought: - with self.interface.add_step(title="Generating view...") as step: + with self.interface.add_step(title=title or "Generating view...", steps_layout=self._steps_layout) as step: step.stream(chain_of_thought) print(f"{self.name} settled on {spec=!r}.") memory["current_view"] = dict(spec, type=self.view_type) - return self.view_type(pipeline=pipeline, **spec) - - async def invoke(self, messages: list | str): - view = await self.answer(messages) - self._render_lumen(view) + view = self.view_type(pipeline=pipeline, **spec) + self._render_lumen(view, render_output=render_output) class hvPlotAgent(BaseViewAgent): @@ -829,7 +874,7 @@ class VegaLiteAgent(BaseViewAgent): def _get_model(cls, schema): return VegaLiteSpec - async def _extract_spec(self, model): + async def _extract_spec(self, model: VegaLiteSpec): vega_spec = json.loads(model.json_spec) if "$schema" not in vega_spec: vega_spec["$schema"] = "https://vega.github.io/schema/vega-lite/v5.json" @@ -861,7 +906,10 @@ class AnalysisAgent(LumenBaseAgent): _output_type = AnalysisOutput async def _system_prompt_with_context( - self, messages: list | str, context: str = "", analyses: list[Analysis] = [] + self, + messages: list[Message], + context: str = "", + analyses: list[Analysis] = [] ) -> str: system_prompt = self.system_prompt for name, analysis in analyses.items(): @@ -880,7 +928,15 @@ async def _system_prompt_with_context( system_prompt += f"\n### CONTEXT: {context}".strip() return system_prompt - async def answer(self, messages: list | str, agents: list[Agent] | None = None): + async def respond( + self, + messages: list[Message], + title: str = "", + render_output: bool = True, + steps_layout: pn.Column | None = None, + agents: list[Agent] | None = None + ) -> None: + self._steps_layout = steps_layout pipeline = memory['current_pipeline'] analyses = {a.name: a for a in self.analyses if await a.applies(pipeline)} if not analyses: @@ -888,13 +944,13 @@ async def answer(self, messages: list | str, agents: list[Agent] | None = None): return None # Short cut analysis selection if there's an exact match - if isinstance(messages, list) and messages: + if len(messages): analysis = messages[0].get('content').replace('Apply ', '') if analysis in analyses: analyses = {analysis: analyses[analysis]} if len(analyses) > 1: - with self.interface.add_step(title="Choosing the most relevant analysis...", user="Assistant") as step: + with self.interface.add_step(title="Choosing the most relevant analysis...", steps_layout=self._steps_layout) as step: type_ = Literal[tuple(analyses)] analysis_model = create_model( "Analysis", @@ -912,7 +968,7 @@ async def answer(self, messages: list | str, agents: list[Agent] | None = None): else: analysis_name = next(iter(analyses)) - with self.interface.add_step(title="Creating view...", user="Assistant") as step: + with self.interface.add_step(title=title or "Creating view...", steps_layout=self._steps_layout) as step: await asyncio.sleep(0.1) # necessary to give it time to render before calling sync function... analysis_callable = analyses[analysis_name].instance(agents=agents) @@ -943,12 +999,9 @@ async def answer(self, messages: list | str, agents: list[Agent] | None = None): else: step.success_title = "Configure the analysis" view = None - return view - async def invoke(self, messages: list | str, agents=None): - view = await self.answer(messages, agents=agents) analysis = memory["current_analysis"] if view is None and analysis.autorun: self.interface.stream('Failed to find an analysis that applies to this data') else: - self._render_lumen(view, analysis=analysis, pipeline=memory['current_pipeline']) + self._render_lumen(view, analysis=analysis, pipeline=memory['current_pipeline'], render_output=render_output) diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index 338c16c7..13ca339f 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -9,19 +9,20 @@ import param import yaml -from panel import bind +from panel import Card, bind from panel.chat import ChatInterface, ChatStep from panel.layout import Column, FlexBox, Tabs from panel.pane import HTML, Markdown from panel.viewable import Viewer from panel.widgets import Button, FileDownload +from pydantic import BaseModel from .agents import ( Agent, AnalysisAgent, ChatAgent, SQLAgent, ) from .config import DEMO_MESSAGES, GETTING_STARTED_SUGGESTIONS from .export import export_notebook -from .llm import Llama, Llm +from .llm import Llama, Llm, Message from .logs import ChatLogs from .memory import memory from .models import Validity, make_agent_model, make_plan_models @@ -29,11 +30,27 @@ if TYPE_CHECKING: from panel.chat.step import ChatStep - from pydantic import BaseModel from ..sources import Source +class AgentChainLink(param.Parameterized): + """ + A link in the chain of agents to + be executed. + """ + + agent = param.ClassSelector(class_=Agent) + + provides = param.ClassSelector(class_=(set, list), default=set()) + + instruction = param.String(default="") + + title = param.String(default="") + + render_output = param.Boolean(default=False) + + class Assistant(Viewer): """ An Assistant handles multiple agents. @@ -196,7 +213,7 @@ async def use_suggestion(event): else: print("No analysis agent found.") return - await agent.invoke([{'role': 'user', 'content': contents}], agents=self.agents) + await agent.respond([{'role': 'user', 'content': contents}], agents=self.agents) await self._add_analysis_suggestions() else: self.interface.send(contents) @@ -336,7 +353,13 @@ async def _fill_model(self, messages, system, agent_model, errors=None): ) return out - async def _choose_agent(self, messages: list | str, agents: list[Agent] | None = None, primary: bool = False, unmet_dependencies: tuple[str] | None = None): + async def _choose_agent( + self, + messages: list[Message], + agents: list[Agent] | None = None, + primary: bool = False, + unmet_dependencies: tuple[str] | None = None + ): if agents is None: agents = self.agents agents = [agent for agent in agents if await agent.applies()] @@ -353,7 +376,7 @@ async def _choose_agent(self, messages: list | str, agents: list[Agent] | None = ) return await self._fill_model(messages, system, agent_model) - async def _resolve_dependencies(self, messages, agents: dict[str, Agent]) -> list[tuple(Agent, any)]: + async def _resolve_dependencies(self, messages, agents: dict[str, Agent]) -> list[AgentChainLink]: if len(agents) == 1: agent = next(iter(agents.values())) else: @@ -383,11 +406,17 @@ async def _resolve_dependencies(self, messages, agents: dict[str, Agent]) -> lis if output.agent is None: continue subagent = agents[output.agent] - agent_chain.append((subagent, unmet_dependencies, output.chain_of_thought)) + agent_chain.append( + AgentChainLink( + agent=subagent, + provides=unmet_dependencies, + instruction=output.chain_of_thought, + ) + ) step.success_title = f"Solved a dependency with {output.agent}" - return agent_chain[::-1]+[(agent, (), None)] + return agent_chain[::-1] + [AgentChainLink(agent=agent)] - async def _get_agent(self, messages: list | str): + async def _get_agent_chain_link(self, messages: list[Message]) -> AgentChainLink | None: if len(self.agents) == 1: return self.agents[0] @@ -397,16 +426,19 @@ async def _get_agent(self, messages: list | str): if not agent_chain: return - selected = agent = agent_chain[-1][0] - print(f"Assistant decided on \033[95m{agent!r}\033[0m") - for subagent, deps, instruction in agent_chain[:-1]: + for agent_chain_link in agent_chain[:-1]: + subagent = agent_chain_link.agent + instruction = agent_chain_link.instruction + title = agent_chain_link.title.capitalize() + render_output = agent_chain_link.render_output + agent_name = type(subagent).name.replace('Agent', '') with self.interface.add_step(title=f"Querying {agent_name} agent...") as step: step.stream(f"`{agent_name}` agent is working on the following task:\n\n{instruction}") self._current_agent.object = f"## **Current Agent**: {agent_name}" custom_messages = messages.copy() if isinstance(subagent, SQLAgent): - custom_agent = next((agent for agent in self.agents if isinstance(agent, AnalysisAgent)), None) + custom_agent = next((a for a in self.agents if isinstance(a, AnalysisAgent)), None) if custom_agent: custom_analysis_doc = custom_agent.__doc__.replace("Available analyses include:\n", "") custom_message = ( @@ -416,10 +448,20 @@ async def _get_agent(self, messages: list | str): custom_messages.append({"role": "user", "content": custom_message}) if instruction: custom_messages.append({"role": "user", "content": instruction}) - await subagent.answer(custom_messages) + + respond_kwargs = {} + # attach the new steps to the existing steps--used when there is intermediate Lumen output + last_steps_message = self.interface.objects[-2] + if last_steps_message.user == "Assistant" and isinstance(last_steps_message.object, Card): + respond_kwargs["steps_layout"] = last_steps_message.object + + await subagent.respond(custom_messages, title=title, render_output=render_output, **respond_kwargs) step.stream(f"`{agent_name}` agent successfully completed the following task:\n\n- {instruction}", replace=True) step.success_title = f"{agent_name} agent successfully responded" - return selected + + selected_chain_link = agent_chain[-1] + print(f"Assistant decided on \033[95m{selected_chain_link!r}\033[0m") + return selected_chain_link def _serialize(self, obj, exclude_passwords=True): if isinstance(obj, Tabs | Column): @@ -445,15 +487,16 @@ def _serialize(self, obj, exclude_passwords=True): obj = obj.value return str(obj) - async def invoke(self, messages: list | str) -> str: + async def invoke(self, messages: list[Message]) -> str: messages = self.interface.serialize(custom_serializer=self._serialize)[-4:] invalidation_assessment = await self._invalidate_memory(messages[-2:]) context_length = 3 if invalidation_assessment: messages.append({"role": "assistant", "content": invalidation_assessment + " so do not choose that table."}) context_length += 1 - agent = await self._get_agent(messages[-context_length:]) - if agent is None: + + agent_chain_link = await self._get_agent_chain_link(messages[-context_length:]) + if agent_chain_link is None: msg = ( "Assistant could not settle on an agent to perform the requested query. " "Please restate your request." @@ -461,6 +504,9 @@ async def invoke(self, messages: list | str) -> str: self.interface.stream(msg, user='Lumen') return msg + agent = agent_chain_link.agent + title = agent_chain_link.title.capitalize() + self._current_agent.object = f"## **Current Agent**: {agent.name[:-5]}" print("\n\033[95mMESSAGES:\033[0m") for message in messages: @@ -469,10 +515,14 @@ async def invoke(self, messages: list | str) -> str: print("\n\033[95mAGENT:\033[0m", agent, messages[-context_length:]) - kwargs = {} + last_steps_message = self.interface.objects[-2] + respond_kwargs = {"title": title, "render_output": True} + # attach the new steps to the existing steps--used when there is intermediate Lumen output + if last_steps_message.user == "Assistant" and isinstance(last_steps_message.object, Card): + respond_kwargs["steps_layout"] = last_steps_message.object if isinstance(agent, AnalysisAgent): - kwargs["agents"] = self.agents - await agent.invoke(messages[-context_length:], **kwargs) + respond_kwargs["agents"] = self.agents + await agent.respond(messages[-context_length:], **respond_kwargs) self._current_agent.object = "## No agent active" if "current_pipeline" in agent.provides: await self._add_analysis_suggestions() @@ -518,7 +568,7 @@ async def _lookup_schemas( async def _make_plan( self, - messages: list, + messages: list[Message], agents: dict[str, Agent], tables: dict[str, Source], unmet_dependencies: set[str], @@ -547,7 +597,8 @@ async def _make_plan( system=system, response_model=reason_model, ): - step.stream(reasoning.chain_of_thought, replace=True) + if reasoning.chain_of_thought: # do not replace with empty string + step.stream(reasoning.chain_of_thought, replace=True) requested = [ t for t in getattr(reasoning, 'tables', []) if t and t not in provided @@ -557,13 +608,21 @@ async def _make_plan( plan = await self._fill_model(messages, system, plan_model) return plan - async def _resolve_plan(self, plan, agents, messages): + async def _resolve_plan(self, plan, agents, messages) -> tuple[list[AgentChainLink], set[str]]: step = plan.steps[-1] subagent = agents[step.expert] unmet_dependencies = { r for r in await subagent.requirements(messages) if r not in memory } - agent_chain = [(subagent, unmet_dependencies, step.instruction)] + agent_chain = [ + AgentChainLink( + agent=subagent, + provides=unmet_dependencies, + instruction=step.instruction, + title=step.title, + render_output=step.render_output + ) + ] for step in plan.steps[:-1][::-1]: subagent = agents[step.expert] requires = set(await subagent.requirements(messages)) @@ -571,10 +630,18 @@ async def _resolve_plan(self, plan, agents, messages): dep for dep in (unmet_dependencies | requires) if dep not in subagent.provides and dep not in memory } - agent_chain.append((subagent, subagent.provides, step.instruction)) + agent_chain.append( + AgentChainLink( + agent=subagent, + provides=subagent.provides, + instruction=step.instruction, + title=step.title, + render_output=step.render_output + ) + ) return agent_chain, unmet_dependencies - async def _resolve_dependencies(self, messages: list, agents: dict[str, Agent]) -> list[tuple(Agent, any)]: + async def _resolve_dependencies(self, messages: list[Message], agents: dict[str, Agent]) -> list[AgentChainLink]: agent_names = tuple(sagent.name[:-5] for sagent in agents.values()) tables = {} for src in memory['available_sources']: @@ -598,5 +665,5 @@ async def _resolve_dependencies(self, messages: list, agents: dict[str, Agent]) istep.stream('\n\nHere are the steps:\n\n') for i, step in enumerate(plan.steps): istep.stream(f"{i+1}. {step.expert}: {step.instruction}\n") - istep.success_title = "Successfully came up with a plan." + istep.success_title = "Successfully came up with a plan" return agent_chain[::-1] diff --git a/lumen/ai/config.py b/lumen/ai/config.py index f04ddef0..c54c17d1 100644 --- a/lumen/ai/config.py +++ b/lumen/ai/config.py @@ -37,4 +37,4 @@ class LlmSetupError(Exception): RecursionError, ) -pn.chat.ChatStep.min_width = 350 +pn.chat.ChatStep.min_width = 450 diff --git a/lumen/ai/llm.py b/lumen/ai/llm.py index b3b48068..7be1187a 100644 --- a/lumen/ai/llm.py +++ b/lumen/ai/llm.py @@ -4,6 +4,7 @@ from functools import partial from types import SimpleNamespace +from typing import Literal, TypedDict import instructor import panel as pn @@ -16,6 +17,12 @@ from .interceptor import Interceptor +class Message(TypedDict): + role: Literal["system", "user", "assistant"] + content: str + name: str | None + + class Llm(param.Parameterized): mode = param.Selector( @@ -58,7 +65,7 @@ def _add_system_message(self, messages, system, input_kwargs): async def invoke( self, - messages: list | str, + messages: list[Message], system: str = "", response_model: BaseModel | None = None, allow_partial: bool = False, @@ -91,7 +98,7 @@ def _get_delta(cls, chunk): async def stream( self, - messages: list | str, + messages: list[Message], system: str = "", response_model: BaseModel | None = None, field: str | None = None, @@ -357,7 +364,7 @@ def _get_delta(cls, chunk): async def invoke( self, - messages: list | str, + messages: list[Message], system: str = "", response_model: BaseModel | None = None, allow_partial: bool = False, @@ -470,7 +477,7 @@ def _add_system_message(self, messages, system, input_kwargs): async def invoke( self, - messages: list | str, + messages: list[Message], system: str = "", response_model: BaseModel | None = None, allow_partial: bool = False, diff --git a/lumen/ai/models.py b/lumen/ai/models.py index f6072897..195e0b6b 100644 --- a/lumen/ai/models.py +++ b/lumen/ai/models.py @@ -102,11 +102,14 @@ class VegaLiteSpec(BaseModel): json_spec: str = Field(description="A vega-lite JSON specification WITHOUT the data field, which will be added automatically.") + def make_plan_models(agent_names: list[str], tables: list[str]): step = create_model( "Step", expert=(Literal[agent_names], FieldInfo(description="The name of the expert to assign a task to.")), - instruction=(str, FieldInfo(description="Instructions to the expert to assist in the task.")) + instruction=(str, FieldInfo(description="Instructions to the expert to assist in the task, and whether rendering is required.")), + title=(str, FieldInfo(description="Short title of the task to be performed; up to six words.")), + render_output=(bool, FieldInfo(description="Whether the output of the expert should be rendered. If the user wants to see the table, and the expert is SQL, then this should be `True`.")), ) extras = {} if tables: