update OpenAI usage from Vanna
Browse files- app.py +10 -1
 - climateqa/engine/talk_to_data/main.py +5 -18
 - climateqa/engine/talk_to_data/utils.py +10 -16
 
    	
        app.py
    CHANGED
    
    | 
         @@ -13,6 +13,7 @@ from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc 
     | 
|
| 13 | 
         
             
            from climateqa.engine.chains.retrieve_papers import find_papers
         
     | 
| 14 | 
         
             
            from climateqa.chat import start_chat, chat_stream, finish_chat
         
     | 
| 15 | 
         
             
            from climateqa.engine.talk_to_data.main import ask_vanna
         
     | 
| 
         | 
|
| 16 | 
         | 
| 17 | 
         
             
            from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
         
     | 
| 18 | 
         
             
            from front.utils import process_figures
         
     | 
| 
         @@ -77,6 +78,14 @@ else : 
     | 
|
| 77 | 
         
             
            agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
         
     | 
| 78 | 
         
             
            agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
         
     | 
| 79 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 80 | 
         | 
| 81 | 
         
             
            async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
         
     | 
| 82 | 
         
             
                print("chat cqa - message received")
         
     | 
| 
         @@ -126,7 +135,7 @@ def create_drias_tab(): 
     | 
|
| 126 | 
         
             
                        show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
         
     | 
| 127 | 
         | 
| 128 | 
         
             
                    vanna_display = gr.Plot()
         
     | 
| 129 | 
         
            -
                    vanna_direct_question.submit( 
     | 
| 130 | 
         | 
| 131 | 
         
             
            # # UI Layout Components
         
     | 
| 132 | 
         
             
            def cqa_tab(tab_name):
         
     | 
| 
         | 
|
| 13 | 
         
             
            from climateqa.engine.chains.retrieve_papers import find_papers
         
     | 
| 14 | 
         
             
            from climateqa.chat import start_chat, chat_stream, finish_chat
         
     | 
| 15 | 
         
             
            from climateqa.engine.talk_to_data.main import ask_vanna
         
     | 
| 16 | 
         
            +
            from climateqa.engine.talk_to_data.myVanna import MyVanna
         
     | 
| 17 | 
         | 
| 18 | 
         
             
            from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
         
     | 
| 19 | 
         
             
            from front.utils import process_figures
         
     | 
| 
         | 
|
| 78 | 
         
             
            agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
         
     | 
| 79 | 
         
             
            agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
         
     | 
| 80 | 
         | 
| 81 | 
         
            +
            #Vanna object
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4})
         
     | 
| 84 | 
         
            +
            db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db")
         
     | 
| 85 | 
         
            +
            vn.connect_to_sqlite(db_vanna_path)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            def ask_vanna_query(query):
         
     | 
| 88 | 
         
            +
                return ask_vanna(vn, db_vanna_path, query)
         
     | 
| 89 | 
         | 
| 90 | 
         
             
            async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
         
     | 
| 91 | 
         
             
                print("chat cqa - message received")
         
     | 
| 
         | 
|
| 135 | 
         
             
                        show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
         
     | 
| 136 | 
         | 
| 137 | 
         
             
                    vanna_display = gr.Plot()
         
     | 
| 138 | 
         
            +
                    vanna_direct_question.submit(ask_vanna_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
         
     | 
| 139 | 
         | 
| 140 | 
         
             
            # # UI Layout Components
         
     | 
| 141 | 
         
             
            def cqa_tab(tab_name):
         
     | 
    	
        climateqa/engine/talk_to_data/main.py
    CHANGED
    
    | 
         @@ -4,24 +4,10 @@ import sqlite3 
     | 
|
| 4 | 
         
             
            import os
         
     | 
| 5 | 
         
             
            import pandas as pd
         
     | 
| 6 | 
         
             
            from climateqa.engine.llm import get_llm
         
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
            from dotenv import load_dotenv
         
     | 
| 9 | 
         
             
            import ast
         
     | 
| 10 | 
         | 
| 11 | 
         
            -
            load_dotenv()
         
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
            OPENAI_API_KEY = os.getenv('THEO_API_KEY')
         
     | 
| 15 | 
         
            -
            PC_API_KEY = os.getenv('VANNA_PINECONE_API_KEY')
         
     | 
| 16 | 
         
            -
            INDEX_NAME = os.getenv('VANNA_INDEX_NAME')
         
     | 
| 17 | 
         
            -
            VANNA_MODEL = os.getenv('VANNA_MODEL')
         
     | 
| 18 | 
         | 
| 19 | 
         | 
| 20 | 
         
            -
            #Vanna object
         
     | 
| 21 | 
         
            -
            vn = MyVanna(config = {"temperature": 0, "api_key": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, "top_k" : 4})
         
     | 
| 22 | 
         
            -
            db_vanna_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "data/drias/drias.db")
         
     | 
| 23 | 
         
            -
            vn.connect_to_sqlite(db_vanna_path)
         
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
             
            llm = get_llm(provider="openai")
         
     | 
| 26 | 
         | 
| 27 | 
         
             
            def ask_llm_to_add_table_names(sql_query, llm):
         
     | 
| 
         @@ -33,9 +19,10 @@ def ask_llm_column_names(sql_query, llm): 
     | 
|
| 33 | 
         
             
                columns_list = ast.literal_eval(columns.strip("```python\n").strip())
         
     | 
| 34 | 
         
             
                return columns_list
         
     | 
| 35 | 
         | 
| 36 | 
         
            -
            def ask_vanna(query):
         
     | 
| 
         | 
|
| 37 | 
         
             
                try :
         
     | 
| 38 | 
         
            -
                    location = detect_location_with_openai( 
     | 
| 39 | 
         
             
                    if location:
         
     | 
| 40 | 
         | 
| 41 | 
         
             
                        coords = loc2coords(location)
         
     | 
| 
         @@ -51,10 +38,10 @@ def ask_vanna(query): 
     | 
|
| 51 | 
         | 
| 52 | 
         
             
                    else : 
         
     | 
| 53 | 
         
             
                        empty_df = pd.DataFrame()
         
     | 
| 54 | 
         
            -
                        empty_fig =  
     | 
| 55 | 
         
             
                        return "", empty_df, empty_fig
         
     | 
| 56 | 
         
             
                except Exception as e:
         
     | 
| 57 | 
         
             
                    print(f"Error: {e}")
         
     | 
| 58 | 
         
             
                    empty_df = pd.DataFrame()
         
     | 
| 59 | 
         
            -
                    empty_fig =  
     | 
| 60 | 
         
             
                    return "", empty_df, empty_fig
         
     | 
| 
         | 
|
| 4 | 
         
             
            import os
         
     | 
| 5 | 
         
             
            import pandas as pd
         
     | 
| 6 | 
         
             
            from climateqa.engine.llm import get_llm
         
     | 
| 
         | 
|
| 
         | 
|
| 7 | 
         
             
            import ast
         
     | 
| 8 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 9 | 
         | 
| 10 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 11 | 
         
             
            llm = get_llm(provider="openai")
         
     | 
| 12 | 
         | 
| 13 | 
         
             
            def ask_llm_to_add_table_names(sql_query, llm):
         
     | 
| 
         | 
|
| 19 | 
         
             
                columns_list = ast.literal_eval(columns.strip("```python\n").strip())
         
     | 
| 20 | 
         
             
                return columns_list
         
     | 
| 21 | 
         | 
| 22 | 
         
            +
            def ask_vanna(vn,db_vanna_path, query):
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
             
                try :
         
     | 
| 25 | 
         
            +
                    location = detect_location_with_openai(query)
         
     | 
| 26 | 
         
             
                    if location:
         
     | 
| 27 | 
         | 
| 28 | 
         
             
                        coords = loc2coords(location)
         
     | 
| 
         | 
|
| 38 | 
         | 
| 39 | 
         
             
                    else : 
         
     | 
| 40 | 
         
             
                        empty_df = pd.DataFrame()
         
     | 
| 41 | 
         
            +
                        empty_fig = None
         
     | 
| 42 | 
         
             
                        return "", empty_df, empty_fig
         
     | 
| 43 | 
         
             
                except Exception as e:
         
     | 
| 44 | 
         
             
                    print(f"Error: {e}")
         
     | 
| 45 | 
         
             
                    empty_df = pd.DataFrame()
         
     | 
| 46 | 
         
            +
                    empty_fig = None
         
     | 
| 47 | 
         
             
                    return "", empty_df, empty_fig
         
     | 
    	
        climateqa/engine/talk_to_data/utils.py
    CHANGED
    
    | 
         @@ -4,13 +4,13 @@ import pandas as pd 
     | 
|
| 4 | 
         
             
            from geopy.geocoders import Nominatim
         
     | 
| 5 | 
         
             
            import sqlite3
         
     | 
| 6 | 
         
             
            import ast
         
     | 
| 
         | 
|
| 7 | 
         | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
            def detect_location_with_openai(api_key, sentence):
         
     | 
| 10 | 
         
             
                """
         
     | 
| 11 | 
         
            -
                Detects locations in a sentence using OpenAI's API.
         
     | 
| 12 | 
         
             
                """
         
     | 
| 13 | 
         
            -
                 
     | 
| 14 | 
         | 
| 15 | 
         
             
                prompt = f"""
         
     | 
| 16 | 
         
             
                Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
         
     | 
| 
         @@ -19,18 +19,12 @@ def detect_location_with_openai(api_key, sentence): 
     | 
|
| 19 | 
         
             
                Sentence: "{sentence}"
         
     | 
| 20 | 
         
             
                """
         
     | 
| 21 | 
         | 
| 22 | 
         
            -
                response =  
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
             
     | 
| 27 | 
         
            -
                     
     | 
| 28 | 
         
            -
                    max_tokens=100,
         
     | 
| 29 | 
         
            -
                    temperature=0
         
     | 
| 30 | 
         
            -
                )
         
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
                return response.choices[0].message.content.split("\n")[1][2:-2]
         
     | 
| 33 | 
         
            -
             
     | 
| 34 | 
         | 
| 35 | 
         
             
            def detectTable(sql_query):
         
     | 
| 36 | 
         
             
                pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
         
     | 
| 
         | 
|
| 4 | 
         
             
            from geopy.geocoders import Nominatim
         
     | 
| 5 | 
         
             
            import sqlite3
         
     | 
| 6 | 
         
             
            import ast
         
     | 
| 7 | 
         
            +
            from climateqa.engine.llm import get_llm
         
     | 
| 8 | 
         | 
| 9 | 
         
            +
            def detect_location_with_openai(sentence):
         
     | 
| 
         | 
|
| 10 | 
         
             
                """
         
     | 
| 11 | 
         
            +
                Detects locations in a sentence using OpenAI's API via LangChain.
         
     | 
| 12 | 
         
             
                """
         
     | 
| 13 | 
         
            +
                llm = get_llm()
         
     | 
| 14 | 
         | 
| 15 | 
         
             
                prompt = f"""
         
     | 
| 16 | 
         
             
                Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
         
     | 
| 
         | 
|
| 19 | 
         
             
                Sentence: "{sentence}"
         
     | 
| 20 | 
         
             
                """
         
     | 
| 21 | 
         | 
| 22 | 
         
            +
                response = llm.invoke(prompt)
         
     | 
| 23 | 
         
            +
                location_list = ast.literal_eval(response.content.strip("```python\n").strip())
         
     | 
| 24 | 
         
            +
                if location_list:
         
     | 
| 25 | 
         
            +
                    return location_list[0]
         
     | 
| 26 | 
         
            +
                else:
         
     | 
| 27 | 
         
            +
                    return ""
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 28 | 
         | 
| 29 | 
         
             
            def detectTable(sql_query):
         
     | 
| 30 | 
         
             
                pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
         
     |