Skip to content

Commit

Permalink
Refactoring implementation for ODA RAG POC for NL2SQL use case
Browse files Browse the repository at this point in the history
  • Loading branch information
shekharchavan1990 committed Sep 4, 2024
1 parent 78f359d commit 8210040
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ The key prerequisites that you would need to set up before you can proceed to ru

- Notebook session used to initiate the distributed training and to access the fine-tuned model. For more information, see [Notebook session](https://docs.oracle.com/en-us/iaas/data-science/using/manage-notebook-sessions.htm).

- Install the latest version of [Oracle Accelerated Data Science (ADS)](https://accelerated-data-science.readthedocs.io/en/latest/index.html)
- Install "PyTorch 2.0 for GPU on Python 3.9" conda and install required dependencies mentioned in requirements.txt

```
pip install oracle-ads[opctl] -U
pip install -r requirements.txt
```

## Task 1: Deploy Required Models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@
}
],
"source": [
"#Copy required python scripts to artifact dir \n",
"#Copy required python scripts to artifact dir\n",
"#mkdir langchain_nl2sql_model\n",
"#cp config.py config_private.py oci_utils.py oracle_vector_db.py langchain_nl2sql_model/\n",
"langchain_model.prepare(\n",
" inference_conda_env=CONDA_PACK_PATH,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ The key prerequisites that you would need to set up before you can proceed to ru

- Notebook session used to initiate the distributed training and to access the fine-tuned model. For more information, see [Notebook session](https://docs.oracle.com/en-us/iaas/data-science/using/manage-notebook-sessions.htm).

- Install the latest version of [Oracle Accelerated Data Science (ADS)](https://accelerated-data-science.readthedocs.io/en/latest/index.html)
- Install "PyTorch 2.0 for GPU on Python 3.9" conda and install required dependencies mentioned in requirements.txt

```
pip install oracle-ads[opctl] -U
pip install -r requirements.txt
```

## Task 1: Deploy Required Models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@

from langchain_community.llms import OCIModelDeploymentVLLM
import ads
from oracle_vector_db import oracle_query, test_oracle_query
from oracle_vector_db import oracle_query, oracle_query
from ads.model.generic_model import GenericModel
import ads
from oracle_vector_db import oracle_query, test_oracle_query
from oracle_vector_db import oracle_query, oracle_query
from langchain.chains import RetrievalQA
from typing import List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from oracle_vector_db import oracle_query, test_oracle_query
from oracle_vector_db import oracle_query, oracle_query
from config import COMMAND_MD_ENDPOINT, EMBEDDING_MD_ENDPOINT, RERANKER_MD_ENDPOINT, TOP_K
from pprint import pprint

Expand Down Expand Up @@ -57,50 +57,48 @@
- You should NEVER generate SQL queries with JOIN, since the Schema only includes 1 table.
- Use Context section to get additional details while bulding the query
- Do not return multiple queries in response. Just respond with single SQL query and nothing else
- Return only single query and append response between ###
- Return only single query and append response between backtick character
- For date column use TO_DATE function in where clause. e.g. Nomination_date > TO_DATE('2001-01-01', 'YYYY-MM-DD');
Question: {question}
Oracle SQL: """

relational_data_summary_prompt_template = """We have implmented the natural language question to SQL model which retrieves records from relational database.
The data retrieved is provided below in the Context section and the corresponding question is provided in the Question section.
By considering the question and context data can you please summarize the answer. The first row in the context data represents the column headers, and the subsequent rows represent the data.
Just convert context data to user understandable answer
relational_data_summary_prompt_template = """In the context section, we provide a question and its corresponding answer. The answer contains a list of rows fetched from a database, where the first row is the column headers, and the subsequent rows are the data.
Your task is to summarize the answer in a single, human-understandable sentence. If there is data in the context section, assume it is correct and summarize it directly without evaluating its content.
Do not assume there is only one result unless specified.
### Context:
Question: {question}
Ans:
{context_data}
### In-Context Examples:
Question: list or view partners who renewal last year
Answer: Partner2 has renewal last year
Can you please summarize the answer based on the data provided in the context? Your summary should be a single, concise sentence.
Can you please summarize the ans mentioned in context section according to question in context? Just return the summarized text. Dont mention 'Based on the context data provided'
Ans:"""
Summary:
"""

import json
import oci
import requests

from langchain_community.llms import OCIModelDeploymentVLLM
import ads
from oracle_vector_db import oracle_query, test_oracle_query
from oracle_vector_db import oracle_query, oracle_query
from ads.model.generic_model import GenericModel
import ads
from oracle_vector_db import oracle_query, test_oracle_query
from oracle_vector_db import oracle_query, oracle_query
from langchain.chains import RetrievalQA
from typing import List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from oracle_vector_db import oracle_query, test_oracle_query
from oracle_vector_db import oracle_query, oracle_query
from config import COMMAND_MD_ENDPOINT, EMBEDDING_MD_ENDPOINT, RERANKER_MD_ENDPOINT, TOP_K
from pprint import pprint
from langchain import PromptTemplate
from config_private import DB_USER, DB_PWD, DB_HOST_IP, DB_SERVICE
import oracledb

import re

ads.set_auth("resource_principal")
command_md = OCIModelDeploymentVLLM(
Expand All @@ -122,7 +120,7 @@ def _get_relevant_documents(
prediction = requests.post(EMBEDDING_MD_ENDPOINT, data=f'["{query}"]', auth=rps)

#Search in DB
q_result = test_oracle_query(prediction.json()['embeddings'][0], TOP_K, True, False)
q_result = oracle_query(prediction.json()['embeddings'][0], TOP_K, True, False)
text_list = []
for n, id, sim in zip(q_result.nodes, q_result.ids, q_result.similarities):
text_list.append(n.text)
Expand Down Expand Up @@ -156,7 +154,7 @@ def load_model(model_file_name=model_name):

def get_data_from_DB(query):
DSN = f"{DB_HOST_IP}/{DB_SERVICE}"
connection = oracledb.connect(user=DB_USER, password=DB_PWD, dsn=DSN)
connection = oracledb.connect(user=DB_USER, password=DB_PWD, dsn=DSN, config_dir="/opt/oracle/config")
# Creating a cursor object
cursor = connection.cursor()

Expand All @@ -170,7 +168,7 @@ def get_data_from_DB(query):
for row in cursor:
row_values = [str(value).replace(',', ' ') for value in row]
result += ','.join(row_values) + '\n'
#print(f"Query Result: {result}")
print(f"Query Result: {result}")
# Closing the cursor and connection
cursor.close()
connection.close()
Expand All @@ -197,6 +195,13 @@ def predict(data, model=load_model()):
query_str = op.split("Question:")[0]
else:
query_str = op

match = re.search(r'`([^`]*)`', op)
if match:
query_str = match.group(1)
else:
query_str = match

query = query_str.replace(';', '')
result = get_data_from_DB(query)

Expand Down

0 comments on commit 8210040

Please sign in to comment.