Skip to content

Commit

Permalink
Improved
Browse files Browse the repository at this point in the history
  • Loading branch information
joelsiby02 committed Apr 11, 2024
1 parent 90700f9 commit b09595e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
8 changes: 8 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import streamlit as st

def main():
st.title("A streait app for demo")
st

if __name__ == "__main__":
main()
10 changes: 7 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq
import streamlit as st
from langchain_google_genai import ChatGoogleGenerativeAI

# Function to initialize the database connection
def init_database(user: str, password: str, host: str, port: str, database: str) -> SQLDatabase:
Expand All @@ -28,16 +29,17 @@ def get_sql_chain(db):
SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
Question: Name 10 artists
SQL Query: SELECT Name FROM Artist LIMIT 10;
Your turn:
make sure to return proper syntax with no error
Question: {question}
SQL Query:
"""

prompt = ChatPromptTemplate.from_template(template)

llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0.3)
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
# llm = ChatGoogleGenerativeAI(model="gemini-pro")

def get_schema(_):
return db.get_table_info()
Expand Down Expand Up @@ -72,7 +74,9 @@ def get_response(user_query: str, db: SQLDatabase, chat_history: list):

prompt = ChatPromptTemplate.from_template(template)

# llm = ChatGoogleGenerativeAI(model="gemini-pro")
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)


chain = (
RunnablePassthrough.assign(query=sql_chain).assign(
Expand Down Expand Up @@ -149,4 +153,4 @@ def get_response(user_query: str, db: SQLDatabase, chat_history: list):
response = get_response(user_query, st.session_state.db, st.session_state.chat_history)
st.markdown(response)

st.session_state.chat_history.append(AIMessage(content=response))
st.session_state.chat_history.append(AIMessage(content=response))

0 comments on commit b09595e

Please sign in to comment.