update OpenAI usage from Vanna
Browse files- app.py +10 -1
- climateqa/engine/talk_to_data/main.py +5 -18
- climateqa/engine/talk_to_data/utils.py +10 -16
app.py
CHANGED
|
@@ -13,6 +13,7 @@ from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
|
|
| 13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
| 14 |
from climateqa.chat import start_chat, chat_stream, finish_chat
|
| 15 |
from climateqa.engine.talk_to_data.main import ask_vanna
|
|
|
|
| 16 |
|
| 17 |
from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
|
| 18 |
from front.utils import process_figures
|
|
@@ -77,6 +78,14 @@ else :
|
|
| 77 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
|
| 78 |
agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
|
| 82 |
print("chat cqa - message received")
|
|
@@ -126,7 +135,7 @@ def create_drias_tab():
|
|
| 126 |
show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
|
| 127 |
|
| 128 |
vanna_display = gr.Plot()
|
| 129 |
-
vanna_direct_question.submit(
|
| 130 |
|
| 131 |
# # UI Layout Components
|
| 132 |
def cqa_tab(tab_name):
|
|
|
|
| 13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
| 14 |
from climateqa.chat import start_chat, chat_stream, finish_chat
|
| 15 |
from climateqa.engine.talk_to_data.main import ask_vanna
|
| 16 |
+
from climateqa.engine.talk_to_data.myVanna import MyVanna
|
| 17 |
|
| 18 |
from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
|
| 19 |
from front.utils import process_figures
|
|
|
|
| 78 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
|
| 79 |
agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
|
| 80 |
|
| 81 |
+
#Vanna object
|
| 82 |
+
|
| 83 |
+
vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4})
|
| 84 |
+
db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db")
|
| 85 |
+
vn.connect_to_sqlite(db_vanna_path)
|
| 86 |
+
|
| 87 |
+
def ask_vanna_query(query):
|
| 88 |
+
return ask_vanna(vn, db_vanna_path, query)
|
| 89 |
|
| 90 |
async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
|
| 91 |
print("chat cqa - message received")
|
|
|
|
| 135 |
show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
|
| 136 |
|
| 137 |
vanna_display = gr.Plot()
|
| 138 |
+
vanna_direct_question.submit(ask_vanna_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
|
| 139 |
|
| 140 |
# # UI Layout Components
|
| 141 |
def cqa_tab(tab_name):
|
climateqa/engine/talk_to_data/main.py
CHANGED
|
@@ -4,24 +4,10 @@ import sqlite3
|
|
| 4 |
import os
|
| 5 |
import pandas as pd
|
| 6 |
from climateqa.engine.llm import get_llm
|
| 7 |
-
|
| 8 |
-
from dotenv import load_dotenv
|
| 9 |
import ast
|
| 10 |
|
| 11 |
-
load_dotenv()
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
OPENAI_API_KEY = os.getenv('THEO_API_KEY')
|
| 15 |
-
PC_API_KEY = os.getenv('VANNA_PINECONE_API_KEY')
|
| 16 |
-
INDEX_NAME = os.getenv('VANNA_INDEX_NAME')
|
| 17 |
-
VANNA_MODEL = os.getenv('VANNA_MODEL')
|
| 18 |
|
| 19 |
|
| 20 |
-
#Vanna object
|
| 21 |
-
vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, "top_k" : 4})
|
| 22 |
-
db_vanna_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "data/drias/drias.db")
|
| 23 |
-
vn.connect_to_sqlite(db_vanna_path)
|
| 24 |
-
|
| 25 |
llm = get_llm(provider="openai")
|
| 26 |
|
| 27 |
def ask_llm_to_add_table_names(sql_query, llm):
|
|
@@ -33,9 +19,10 @@ def ask_llm_column_names(sql_query, llm):
|
|
| 33 |
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
|
| 34 |
return columns_list
|
| 35 |
|
| 36 |
-
def ask_vanna(query):
|
|
|
|
| 37 |
try :
|
| 38 |
-
location = detect_location_with_openai(
|
| 39 |
if location:
|
| 40 |
|
| 41 |
coords = loc2coords(location)
|
|
@@ -51,10 +38,10 @@ def ask_vanna(query):
|
|
| 51 |
|
| 52 |
else :
|
| 53 |
empty_df = pd.DataFrame()
|
| 54 |
-
empty_fig =
|
| 55 |
return "", empty_df, empty_fig
|
| 56 |
except Exception as e:
|
| 57 |
print(f"Error: {e}")
|
| 58 |
empty_df = pd.DataFrame()
|
| 59 |
-
empty_fig =
|
| 60 |
return "", empty_df, empty_fig
|
|
|
|
| 4 |
import os
|
| 5 |
import pandas as pd
|
| 6 |
from climateqa.engine.llm import get_llm
|
|
|
|
|
|
|
| 7 |
import ast
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
llm = get_llm(provider="openai")
|
| 12 |
|
| 13 |
def ask_llm_to_add_table_names(sql_query, llm):
|
|
|
|
| 19 |
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
|
| 20 |
return columns_list
|
| 21 |
|
| 22 |
+
def ask_vanna(vn,db_vanna_path, query):
|
| 23 |
+
|
| 24 |
try :
|
| 25 |
+
location = detect_location_with_openai(query)
|
| 26 |
if location:
|
| 27 |
|
| 28 |
coords = loc2coords(location)
|
|
|
|
| 38 |
|
| 39 |
else :
|
| 40 |
empty_df = pd.DataFrame()
|
| 41 |
+
empty_fig = None
|
| 42 |
return "", empty_df, empty_fig
|
| 43 |
except Exception as e:
|
| 44 |
print(f"Error: {e}")
|
| 45 |
empty_df = pd.DataFrame()
|
| 46 |
+
empty_fig = None
|
| 47 |
return "", empty_df, empty_fig
|
climateqa/engine/talk_to_data/utils.py
CHANGED
|
@@ -4,13 +4,13 @@ import pandas as pd
|
|
| 4 |
from geopy.geocoders import Nominatim
|
| 5 |
import sqlite3
|
| 6 |
import ast
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
def detect_location_with_openai(api_key, sentence):
|
| 10 |
"""
|
| 11 |
-
Detects locations in a sentence using OpenAI's API.
|
| 12 |
"""
|
| 13 |
-
|
| 14 |
|
| 15 |
prompt = f"""
|
| 16 |
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
|
|
@@ -19,18 +19,12 @@ def detect_location_with_openai(api_key, sentence):
|
|
| 19 |
Sentence: "{sentence}"
|
| 20 |
"""
|
| 21 |
|
| 22 |
-
response =
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
max_tokens=100,
|
| 29 |
-
temperature=0
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
return response.choices[0].message.content.split("\n")[1][2:-2]
|
| 33 |
-
|
| 34 |
|
| 35 |
def detectTable(sql_query):
|
| 36 |
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
|
|
|
|
| 4 |
from geopy.geocoders import Nominatim
|
| 5 |
import sqlite3
|
| 6 |
import ast
|
| 7 |
+
from climateqa.engine.llm import get_llm
|
| 8 |
|
| 9 |
+
def detect_location_with_openai(sentence):
|
|
|
|
| 10 |
"""
|
| 11 |
+
Detects locations in a sentence using OpenAI's API via LangChain.
|
| 12 |
"""
|
| 13 |
+
llm = get_llm()
|
| 14 |
|
| 15 |
prompt = f"""
|
| 16 |
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
|
|
|
|
| 19 |
Sentence: "{sentence}"
|
| 20 |
"""
|
| 21 |
|
| 22 |
+
response = llm.invoke(prompt)
|
| 23 |
+
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
|
| 24 |
+
if location_list:
|
| 25 |
+
return location_list[0]
|
| 26 |
+
else:
|
| 27 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def detectTable(sql_query):
|
| 30 |
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
|