You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
async def find_join_tables(self, messages: list | str, join_tables: list[str] | None = None):
if join_tables:
multi_source = any('//' in jt for jt in join_tables)
else:
multi_source = len(memory['available_sources']) > 1
if multi_source:
available_tables = ", ".join(f"//{a_source}//{a_table}" for a_source in memory["available_sources"] for a_table in a_source.get_tables())
else:
available_tables = memory['current_source'].get_tables()
with self.interface.add_step(title="Determining tables required for join") as step:
output = await self.llm.invoke(
messages,
system=f"List the tables that need to be joined; be sure to include both `//`: {available_tables}",
response_model=TableJoins,
)
join_tables = output.tables
step.stream(f'\nJoin requires following tables: {join_tables}', replace=True)
step.success_title = 'Found tables required for join'
sources = {}
for source_table in join_tables:
if multi_source:
try:
_, a_source_name, a_table = source_table.split("//", maxsplit=2)
except ValueError:
a_source_name, a_table = source_table.split("//", maxsplit=1)
a_source = next((source for source in memory["available_sources"] if a_source_name == source.name), None)
else:
a_source = memory['current_source']
a_table = source_table
sources[a_source_name] = (a_source, a_table)
return sources
The text was updated successfully, but these errors were encountered: