Commit 
							
							·
						
						aba2a7c
	
1
								Parent(s):
							
							e5ba042
								
feat: model filtering and UI upgrade for TTD
Browse files- app.py +537 -155
- climateqa/engine/talk_to_data/config.py +33 -0
- climateqa/engine/talk_to_data/main.py +10 -5
- climateqa/engine/talk_to_data/plot.py +26 -17
- climateqa/engine/talk_to_data/sql_query.py +13 -35
- climateqa/engine/talk_to_data/utils.py +44 -30
- climateqa/engine/talk_to_data/workflow.py +38 -34
- style.css +9 -3
    	
        app.py
    CHANGED
    
    | @@ -9,14 +9,14 @@ from climateqa.engine.embeddings import get_embeddings_function | |
| 9 | 
             
            from climateqa.engine.llm import get_llm
         | 
| 10 | 
             
            from climateqa.engine.vectorstore import get_pinecone_vectorstore
         | 
| 11 | 
             
            from climateqa.engine.reranker import get_reranker
         | 
| 12 | 
            -
            from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
         | 
| 13 | 
             
            from climateqa.engine.chains.retrieve_papers import find_papers
         | 
| 14 | 
             
            from climateqa.chat import start_chat, chat_stream, finish_chat
         | 
| 15 | 
             
            from climateqa.engine.talk_to_data.main import ask_drias, DRIAS_MODELS
         | 
| 16 | 
             
            from climateqa.engine.talk_to_data.myVanna import MyVanna
         | 
| 17 |  | 
| 18 | 
            -
            from front.tabs import  | 
| 19 | 
            -
            from front.tabs import  | 
| 20 | 
             
            from front.utils import process_figures
         | 
| 21 | 
             
            from gradio_modal import Modal
         | 
| 22 |  | 
| @@ -25,14 +25,14 @@ from utils import create_user_id | |
| 25 | 
             
            import logging
         | 
| 26 |  | 
| 27 | 
             
            logging.basicConfig(level=logging.WARNING)
         | 
| 28 | 
            -
            os.environ[ | 
| 29 | 
             
            logging.getLogger().setLevel(logging.WARNING)
         | 
| 30 |  | 
| 31 |  | 
| 32 | 
            -
             | 
| 33 | 
             
            # Load environment variables in local mode
         | 
| 34 | 
             
            try:
         | 
| 35 | 
             
                from dotenv import load_dotenv
         | 
|  | |
| 36 | 
             
                load_dotenv()
         | 
| 37 | 
             
            except Exception as e:
         | 
| 38 | 
             
                pass
         | 
| @@ -63,42 +63,105 @@ share_client = service.get_share_client(file_share_name) | |
| 63 | 
             
            user_id = create_user_id()
         | 
| 64 |  | 
| 65 |  | 
| 66 | 
            -
             | 
| 67 | 
             
            # Create vectorstore and retriever
         | 
| 68 | 
             
            embeddings_function = get_embeddings_function()
         | 
| 69 | 
            -
            vectorstore = get_pinecone_vectorstore( | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 72 |  | 
| 73 | 
            -
            llm = get_llm(provider="openai",max_tokens | 
| 74 | 
             
            if os.environ["GRADIO_ENV"] == "local":
         | 
| 75 | 
             
                reranker = get_reranker("nano")
         | 
| 76 | 
            -
            else | 
| 77 | 
             
                reranker = get_reranker("large")
         | 
| 78 |  | 
| 79 | 
            -
            agent = make_graph_agent( | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
| 86 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 87 |  | 
| 88 | 
             
            # def ask_vanna_query(query):
         | 
| 89 | 
             
            #     return ask_vanna(vn, db_vanna_path, query)
         | 
| 90 |  | 
| 91 | 
            -
            def ask_drias_query(query: str, index_state: int, drias_model: str):
         | 
| 92 | 
            -
                return ask_drias(db_vanna_path, query, index_state, drias_model)
         | 
| 93 |  | 
| 94 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 95 | 
             
                print("chat cqa - message received")
         | 
| 96 | 
            -
                async for event in chat_stream( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 97 | 
             
                    yield event
         | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 100 | 
             
                print("chat poc - message received")
         | 
| 101 | 
            -
                async for event in chat_stream( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 102 | 
             
                    yield event
         | 
| 103 |  | 
| 104 |  | 
| @@ -106,14 +169,17 @@ async def chat_poc(query, history, audience, sources, reports, relevant_content_ | |
| 106 | 
             
            # Gradio
         | 
| 107 | 
             
            # --------------------------------------------------------------------
         | 
| 108 |  | 
|  | |
| 109 | 
             
            # Function to update modal visibility
         | 
| 110 | 
             
            def update_config_modal_visibility(config_open):
         | 
| 111 | 
             
                print(config_open)
         | 
| 112 | 
             
                new_config_visibility_status = not config_open
         | 
| 113 | 
             
                return Modal(visible=new_config_visibility_status), new_config_visibility_status
         | 
| 114 | 
            -
                
         | 
| 115 |  | 
| 116 | 
            -
             | 
|  | |
|  | |
|  | |
| 117 | 
             
                sources_number = sources_textbox.count("<h2>")
         | 
| 118 | 
             
                figures_number = figures_cards.count("<h2>")
         | 
| 119 | 
             
                graphs_number = current_graphs.count("<iframe")
         | 
| @@ -122,9 +188,18 @@ def update_sources_number_display(sources_textbox, figures_cards, current_graphs | |
| 122 | 
             
                figures_notif_label = f"Figures ({figures_number})"
         | 
| 123 | 
             
                graphs_notif_label = f"Graphs ({graphs_number})"
         | 
| 124 | 
             
                papers_notif_label = f"Papers ({papers_number})"
         | 
| 125 | 
            -
                recommended_content_notif_label =  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 126 |  | 
| 127 | 
            -
                return gr.update(label=recommended_content_notif_label), gr.update(label=sources_notif_label), gr.update(label=figures_notif_label), gr.update(label=graphs_notif_label), gr.update(label=papers_notif_label)
         | 
| 128 |  | 
| 129 | 
             
            # def create_drias_tab():
         | 
| 130 | 
             
            #     with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab_vanna:
         | 
| @@ -141,24 +216,112 @@ def update_sources_number_display(sources_textbox, figures_cards, current_graphs | |
| 141 | 
             
            #         vanna_display = gr.Plot()
         | 
| 142 | 
             
            #         vanna_direct_question.submit(ask_drias_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
         | 
| 143 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 144 | 
             
            def create_drias_tab():
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 145 | 
             
                with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6):
         | 
|  | |
|  | |
|  | |
|  | |
| 146 | 
             
                    with gr.Row():
         | 
| 147 | 
            -
                        drias_direct_question = gr.Textbox( | 
| 148 | 
            -
             | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 152 |  | 
| 153 | 
            -
                    with gr.Accordion(label= | 
| 154 | 
            -
                         | 
|  | |
|  | |
| 155 |  | 
| 156 | 
            -
                    with gr.Accordion(label="Chart"):
         | 
|  | |
|  | |
|  | |
| 157 | 
             
                        drias_display = gr.Plot(elem_id="vanna-plot")
         | 
| 158 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 159 | 
             
                    with gr.Row():
         | 
| 160 | 
            -
                        prev_button = gr.Button("Previous")
         | 
| 161 | 
            -
                        next_button = gr.Button("Next")
         | 
| 162 |  | 
| 163 | 
             
                    sql_queries_state = gr.State([])
         | 
| 164 | 
             
                    dataframes_state = gr.State([])
         | 
| @@ -166,96 +329,104 @@ def create_drias_tab(): | |
| 166 | 
             
                    index_state = gr.State(0)
         | 
| 167 |  | 
| 168 | 
             
                    drias_direct_question.submit(
         | 
| 169 | 
            -
                        ask_drias_query, | 
| 170 | 
            -
                        inputs=[drias_direct_question, index_state | 
| 171 | 
            -
                        outputs=[ | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 172 | 
             
                    )
         | 
| 173 |  | 
| 174 | 
             
                    model_selection.change(
         | 
| 175 | 
            -
                         | 
| 176 | 
            -
                        inputs=[ | 
| 177 | 
            -
                        outputs=[ | 
| 178 | 
             
                    )
         | 
| 179 |  | 
| 180 | 
             
                    def show_previous(index, sql_queries, dataframes, plots):
         | 
| 181 | 
             
                        if index > 0:
         | 
| 182 | 
             
                            index -= 1
         | 
| 183 | 
            -
                        return  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 184 |  | 
| 185 | 
             
                    def show_next(index, sql_queries, dataframes, plots):
         | 
| 186 | 
             
                        if index < len(sql_queries) - 1:
         | 
| 187 | 
             
                            index += 1
         | 
| 188 | 
            -
                        return  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 189 |  | 
| 190 | 
             
                    prev_button.click(
         | 
| 191 | 
            -
                        show_previous, | 
| 192 | 
             
                        inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
         | 
| 193 | 
            -
                        outputs=[drias_sql_query, drias_table, drias_display, index_state]
         | 
|  | |
|  | |
|  | |
|  | |
| 194 | 
             
                    )
         | 
| 195 |  | 
| 196 | 
             
                    next_button.click(
         | 
| 197 | 
            -
                        show_next, | 
| 198 | 
             
                        inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
         | 
| 199 | 
            -
                        outputs=[drias_sql_query, drias_table, drias_display, index_state]
         | 
|  | |
|  | |
|  | |
|  | |
| 200 | 
             
                    )
         | 
| 201 | 
            -
             | 
| 202 | 
            -
             | 
| 203 | 
            -
            def  | 
| 204 | 
            -
                 | 
| 205 | 
            -
             | 
| 206 | 
            -
                with gr.Tab(tab_name):
         | 
| 207 | 
            -
                    with gr.Row(elem_id="chatbot-row"):
         | 
| 208 | 
            -
                        # Left column - Chat interface
         | 
| 209 | 
            -
                        with gr.Column(scale=2):
         | 
| 210 | 
            -
                            chatbot, textbox, config_button = create_chat_interface(tab_name)
         | 
| 211 | 
            -
             | 
| 212 | 
            -
                        # Right column - Content panels
         | 
| 213 | 
            -
                        with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
         | 
| 214 | 
            -
                            with gr.Tabs(elem_id="right_panel_tab") as tabs:
         | 
| 215 | 
            -
                                # Examples tab
         | 
| 216 | 
            -
                                with gr.TabItem("Examples", elem_id="tab-examples", id=0):
         | 
| 217 | 
            -
                                    examples_hidden = create_examples_tab(tab_name)
         | 
| 218 | 
            -
             | 
| 219 | 
            -
                                # Sources tab
         | 
| 220 | 
            -
                                with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources:
         | 
| 221 | 
            -
                                    sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
         | 
| 222 | 
            -
             | 
| 223 | 
            -
             | 
| 224 | 
            -
                                # Recommended content tab
         | 
| 225 | 
            -
                                with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=2) as tab_recommended_content:
         | 
| 226 | 
            -
                                    with gr.Tabs(elem_id="group-subtabs") as tabs_recommended_content:
         | 
| 227 | 
            -
                                        # Figures subtab
         | 
| 228 | 
            -
                                        with gr.Tab("Figures", elem_id="tab-figures", id=3) as tab_figures:
         | 
| 229 | 
            -
                                            sources_raw, new_figures, used_figures, gallery_component, figures_cards, figure_modal = create_figures_tab()
         | 
| 230 | 
            -
             | 
| 231 | 
            -
                                        # Papers subtab
         | 
| 232 | 
            -
                                        with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
         | 
| 233 | 
            -
                                            papers_direct_search, papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
         | 
| 234 | 
            -
             | 
| 235 | 
            -
                                        # Graphs subtab
         | 
| 236 | 
            -
                                        with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
         | 
| 237 | 
            -
                                            graphs_container = gr.HTML(
         | 
| 238 | 
            -
                                                "<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
         | 
| 239 | 
            -
                                                elem_id="graphs-container"
         | 
| 240 | 
            -
                                            )
         | 
| 241 | 
            -
             | 
| 242 | 
            -
                                            
         | 
| 243 | 
            -
            def config_event_handling(main_tabs_components : list[MainTabPanel], config_componenets : ConfigPanel):
         | 
| 244 | 
             
                config_open = config_componenets.config_open
         | 
| 245 | 
             
                config_modal = config_componenets.config_modal
         | 
| 246 | 
             
                close_config_modal = config_componenets.close_config_modal_button
         | 
| 247 | 
            -
             | 
| 248 | 
            -
                for button in [close_config_modal] + [ | 
|  | |
|  | |
| 249 | 
             
                    button.click(
         | 
| 250 | 
             
                        fn=update_config_modal_visibility,
         | 
| 251 | 
             
                        inputs=[config_open],
         | 
| 252 | 
            -
                        outputs=[config_modal, config_open]
         | 
| 253 | 
            -
                    ) | 
| 254 | 
            -
             | 
|  | |
| 255 | 
             
            def event_handling(
         | 
| 256 | 
            -
                main_tab_components | 
| 257 | 
            -
                config_components | 
| 258 | 
            -
                tab_name="ClimateQ&A"
         | 
| 259 | 
             
            ):
         | 
| 260 | 
             
                chatbot = main_tab_components.chatbot
         | 
| 261 | 
             
                textbox = main_tab_components.textbox
         | 
| @@ -279,7 +450,7 @@ def event_handling( | |
| 279 | 
             
                graphs_container = main_tab_components.graph_container
         | 
| 280 | 
             
                follow_up_examples = main_tab_components.follow_up_examples
         | 
| 281 | 
             
                follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden
         | 
| 282 | 
            -
             | 
| 283 | 
             
                dropdown_sources = config_components.dropdown_sources
         | 
| 284 | 
             
                dropdown_reports = config_components.dropdown_reports
         | 
| 285 | 
             
                dropdown_external_sources = config_components.dropdown_external_sources
         | 
| @@ -288,91 +459,302 @@ def event_handling( | |
| 288 | 
             
                after = config_components.after
         | 
| 289 | 
             
                output_query = config_components.output_query
         | 
| 290 | 
             
                output_language = config_components.output_language
         | 
| 291 | 
            -
             | 
| 292 | 
             
                new_sources_hmtl = gr.State([])
         | 
| 293 | 
             
                ttd_data = gr.State([])
         | 
| 294 |  | 
| 295 | 
            -
                
         | 
| 296 | 
             
                if tab_name == "ClimateQ&A":
         | 
| 297 | 
             
                    print("chat cqa - message sent")
         | 
| 298 |  | 
| 299 | 
             
                    # Event for textbox
         | 
| 300 | 
            -
                    ( | 
| 301 | 
            -
                        .submit( | 
| 302 | 
            -
             | 
| 303 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 304 | 
             
                    )
         | 
| 305 | 
             
                    # Event for examples_hidden
         | 
| 306 | 
            -
                    ( | 
| 307 | 
            -
                        .change( | 
| 308 | 
            -
             | 
| 309 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 310 | 
             
                    )
         | 
| 311 | 
            -
                    ( | 
| 312 | 
            -
                        .change( | 
| 313 | 
            -
             | 
| 314 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 315 | 
             
                    )
         | 
| 316 | 
            -
             | 
| 317 | 
             
                elif tab_name == "Beta - POC Adapt'Action":
         | 
| 318 | 
             
                    print("chat poc - message sent")
         | 
| 319 | 
             
                    # Event for textbox
         | 
| 320 | 
            -
                    ( | 
| 321 | 
            -
                        .submit( | 
| 322 | 
            -
             | 
| 323 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 324 | 
             
                    )
         | 
| 325 | 
             
                    # Event for examples_hidden
         | 
| 326 | 
            -
                    ( | 
| 327 | 
            -
                        .change( | 
| 328 | 
            -
             | 
| 329 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 330 | 
             
                    )
         | 
| 331 | 
            -
                    ( | 
| 332 | 
            -
                        .change( | 
| 333 | 
            -
             | 
| 334 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 335 | 
             
                    )
         | 
| 336 | 
            -
             | 
| 337 | 
            -
                
         | 
| 338 | 
            -
             | 
| 339 | 
            -
                 | 
| 340 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 341 |  | 
| 342 | 
             
                # Update sources numbers
         | 
| 343 | 
             
                for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
         | 
| 344 | 
            -
                    component.change( | 
| 345 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 346 | 
             
                # Search for papers
         | 
| 347 | 
             
                for component in [textbox, examples_hidden, papers_direct_search]:
         | 
| 348 | 
            -
                    component.submit( | 
| 349 | 
            -
             | 
|  | |
|  | |
|  | |
| 350 |  | 
| 351 | 
             
                # if tab_name == "Beta - POC Adapt'Action": # Not untill results are good enough
         | 
| 352 | 
             
                #     # Drias search
         | 
| 353 | 
             
                #     textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
         | 
| 354 |  | 
|  | |
| 355 | 
             
            def main_ui():
         | 
| 356 | 
             
                # config_open = gr.State(True)
         | 
| 357 | 
            -
                with gr.Blocks( | 
| 358 | 
            -
                     | 
| 359 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 360 | 
             
                    with gr.Tabs():
         | 
| 361 | 
            -
                        cqa_components = cqa_tab(tab_name | 
| 362 | 
            -
                        local_cqa_components = cqa_tab(tab_name | 
| 363 | 
             
                        create_drias_tab()
         | 
| 364 | 
            -
             | 
| 365 | 
             
                        create_about_tab()
         | 
| 366 | 
            -
             | 
| 367 | 
            -
                    event_handling(cqa_components, config_components, tab_name | 
| 368 | 
            -
                    event_handling( | 
| 369 | 
            -
             | 
| 370 | 
            -
                     | 
| 371 | 
            -
             | 
|  | |
|  | |
| 372 | 
             
                    demo.queue()
         | 
| 373 | 
            -
             | 
| 374 | 
             
                return demo
         | 
| 375 |  | 
| 376 | 
            -
             | 
| 377 | 
             
            demo = main_ui()
         | 
| 378 | 
             
            demo.launch(ssr_mode=False)
         | 
|  | |
| 9 | 
             
            from climateqa.engine.llm import get_llm
         | 
| 10 | 
             
            from climateqa.engine.vectorstore import get_pinecone_vectorstore
         | 
| 11 | 
             
            from climateqa.engine.reranker import get_reranker
         | 
| 12 | 
            +
            from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
         | 
| 13 | 
             
            from climateqa.engine.chains.retrieve_papers import find_papers
         | 
| 14 | 
             
            from climateqa.chat import start_chat, chat_stream, finish_chat
         | 
| 15 | 
             
            from climateqa.engine.talk_to_data.main import ask_drias, DRIAS_MODELS
         | 
| 16 | 
             
            from climateqa.engine.talk_to_data.myVanna import MyVanna
         | 
| 17 |  | 
| 18 | 
            +
            from front.tabs import create_config_modal, cqa_tab, create_about_tab
         | 
| 19 | 
            +
            from front.tabs import MainTabPanel, ConfigPanel
         | 
| 20 | 
             
            from front.utils import process_figures
         | 
| 21 | 
             
            from gradio_modal import Modal
         | 
| 22 |  | 
|  | |
| 25 | 
             
            import logging
         | 
| 26 |  | 
| 27 | 
             
            logging.basicConfig(level=logging.WARNING)
         | 
| 28 | 
            +
            os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # Suppresses INFO and WARNING logs
         | 
| 29 | 
             
            logging.getLogger().setLevel(logging.WARNING)
         | 
| 30 |  | 
| 31 |  | 
|  | |
| 32 | 
             
            # Load environment variables in local mode
         | 
| 33 | 
             
            try:
         | 
| 34 | 
             
                from dotenv import load_dotenv
         | 
| 35 | 
            +
             | 
| 36 | 
             
                load_dotenv()
         | 
| 37 | 
             
            except Exception as e:
         | 
| 38 | 
             
                pass
         | 
|  | |
| 63 | 
             
            user_id = create_user_id()
         | 
| 64 |  | 
| 65 |  | 
|  | |
| 66 | 
             
            # Create vectorstore and retriever
         | 
| 67 | 
             
            embeddings_function = get_embeddings_function()
         | 
| 68 | 
            +
            vectorstore = get_pinecone_vectorstore(
         | 
| 69 | 
            +
                embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")
         | 
| 70 | 
            +
            )
         | 
| 71 | 
            +
            vectorstore_graphs = get_pinecone_vectorstore(
         | 
| 72 | 
            +
                embeddings_function,
         | 
| 73 | 
            +
                index_name=os.getenv("PINECONE_API_INDEX_OWID"),
         | 
| 74 | 
            +
                text_key="description",
         | 
| 75 | 
            +
            )
         | 
| 76 | 
            +
            vectorstore_region = get_pinecone_vectorstore(
         | 
| 77 | 
            +
                embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")
         | 
| 78 | 
            +
            )
         | 
| 79 |  | 
| 80 | 
            +
            llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
         | 
| 81 | 
             
            if os.environ["GRADIO_ENV"] == "local":
         | 
| 82 | 
             
                reranker = get_reranker("nano")
         | 
| 83 | 
            +
            else:
         | 
| 84 | 
             
                reranker = get_reranker("large")
         | 
| 85 |  | 
| 86 | 
            +
            agent = make_graph_agent(
         | 
| 87 | 
            +
                llm=llm,
         | 
| 88 | 
            +
                vectorstore_ipcc=vectorstore,
         | 
| 89 | 
            +
                vectorstore_graphs=vectorstore_graphs,
         | 
| 90 | 
            +
                vectorstore_region=vectorstore_region,
         | 
| 91 | 
            +
                reranker=reranker,
         | 
| 92 | 
            +
                threshold_docs=0.2,
         | 
| 93 | 
            +
            )
         | 
| 94 | 
            +
            agent_poc = make_graph_agent_poc(
         | 
| 95 | 
            +
                llm=llm,
         | 
| 96 | 
            +
                vectorstore_ipcc=vectorstore,
         | 
| 97 | 
            +
                vectorstore_graphs=vectorstore_graphs,
         | 
| 98 | 
            +
                vectorstore_region=vectorstore_region,
         | 
| 99 | 
            +
                reranker=reranker,
         | 
| 100 | 
            +
                threshold_docs=0,
         | 
| 101 | 
            +
                version="v4",
         | 
| 102 | 
            +
            )  # TODO put back default 0.2
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            # Vanna object
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            # vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4})
         | 
| 107 | 
            +
            # db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db")
         | 
| 108 | 
            +
            # vn.connect_to_sqlite(db_vanna_path)
         | 
| 109 |  | 
| 110 | 
             
            # def ask_vanna_query(query):
         | 
| 111 | 
             
            #     return ask_vanna(vn, db_vanna_path, query)
         | 
| 112 |  | 
|  | |
|  | |
| 113 |  | 
| 114 | 
            +
            def ask_drias_query(query: str, index_state: int):
         | 
| 115 | 
            +
                return ask_drias(query, index_state)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            async def chat(
         | 
| 119 | 
            +
                query,
         | 
| 120 | 
            +
                history,
         | 
| 121 | 
            +
                audience,
         | 
| 122 | 
            +
                sources,
         | 
| 123 | 
            +
                reports,
         | 
| 124 | 
            +
                relevant_content_sources_selection,
         | 
| 125 | 
            +
                search_only,
         | 
| 126 | 
            +
            ):
         | 
| 127 | 
             
                print("chat cqa - message received")
         | 
| 128 | 
            +
                async for event in chat_stream(
         | 
| 129 | 
            +
                    agent,
         | 
| 130 | 
            +
                    query,
         | 
| 131 | 
            +
                    history,
         | 
| 132 | 
            +
                    audience,
         | 
| 133 | 
            +
                    sources,
         | 
| 134 | 
            +
                    reports,
         | 
| 135 | 
            +
                    relevant_content_sources_selection,
         | 
| 136 | 
            +
                    search_only,
         | 
| 137 | 
            +
                    share_client,
         | 
| 138 | 
            +
                    user_id,
         | 
| 139 | 
            +
                ):
         | 
| 140 | 
             
                    yield event
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            async def chat_poc(
         | 
| 144 | 
            +
                query,
         | 
| 145 | 
            +
                history,
         | 
| 146 | 
            +
                audience,
         | 
| 147 | 
            +
                sources,
         | 
| 148 | 
            +
                reports,
         | 
| 149 | 
            +
                relevant_content_sources_selection,
         | 
| 150 | 
            +
                search_only,
         | 
| 151 | 
            +
            ):
         | 
| 152 | 
             
                print("chat poc - message received")
         | 
| 153 | 
            +
                async for event in chat_stream(
         | 
| 154 | 
            +
                    agent_poc,
         | 
| 155 | 
            +
                    query,
         | 
| 156 | 
            +
                    history,
         | 
| 157 | 
            +
                    audience,
         | 
| 158 | 
            +
                    sources,
         | 
| 159 | 
            +
                    reports,
         | 
| 160 | 
            +
                    relevant_content_sources_selection,
         | 
| 161 | 
            +
                    search_only,
         | 
| 162 | 
            +
                    share_client,
         | 
| 163 | 
            +
                    user_id,
         | 
| 164 | 
            +
                ):
         | 
| 165 | 
             
                    yield event
         | 
| 166 |  | 
| 167 |  | 
|  | |
| 169 | 
             
            # Gradio
         | 
| 170 | 
             
            # --------------------------------------------------------------------
         | 
| 171 |  | 
| 172 | 
            +
             | 
| 173 | 
             
            # Function to update modal visibility
         | 
| 174 | 
             
            def update_config_modal_visibility(config_open):
         | 
| 175 | 
             
                print(config_open)
         | 
| 176 | 
             
                new_config_visibility_status = not config_open
         | 
| 177 | 
             
                return Modal(visible=new_config_visibility_status), new_config_visibility_status
         | 
|  | |
| 178 |  | 
| 179 | 
            +
             | 
| 180 | 
            +
            def update_sources_number_display(
         | 
| 181 | 
            +
                sources_textbox, figures_cards, current_graphs, papers_html
         | 
| 182 | 
            +
            ):
         | 
| 183 | 
             
                sources_number = sources_textbox.count("<h2>")
         | 
| 184 | 
             
                figures_number = figures_cards.count("<h2>")
         | 
| 185 | 
             
                graphs_number = current_graphs.count("<iframe")
         | 
|  | |
| 188 | 
             
                figures_notif_label = f"Figures ({figures_number})"
         | 
| 189 | 
             
                graphs_notif_label = f"Graphs ({graphs_number})"
         | 
| 190 | 
             
                papers_notif_label = f"Papers ({papers_number})"
         | 
| 191 | 
            +
                recommended_content_notif_label = (
         | 
| 192 | 
            +
                    f"Recommended content ({figures_number + graphs_number + papers_number})"
         | 
| 193 | 
            +
                )
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                return (
         | 
| 196 | 
            +
                    gr.update(label=recommended_content_notif_label),
         | 
| 197 | 
            +
                    gr.update(label=sources_notif_label),
         | 
| 198 | 
            +
                    gr.update(label=figures_notif_label),
         | 
| 199 | 
            +
                    gr.update(label=graphs_notif_label),
         | 
| 200 | 
            +
                    gr.update(label=papers_notif_label),
         | 
| 201 | 
            +
                )
         | 
| 202 |  | 
|  | |
| 203 |  | 
| 204 | 
             
            # def create_drias_tab():
         | 
| 205 | 
             
            #     with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab_vanna:
         | 
|  | |
| 216 | 
             
            #         vanna_display = gr.Plot()
         | 
| 217 | 
             
            #         vanna_direct_question.submit(ask_drias_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
         | 
| 218 |  | 
| 219 | 
            +
             | 
| 220 | 
            +
            def show_results(sql_queries_state, dataframes_state, plots_state):
         | 
| 221 | 
            +
                if not sql_queries_state or not dataframes_state or not plots_state:
         | 
| 222 | 
            +
                    # If all results are empty, show "No result"
         | 
| 223 | 
            +
                    return (
         | 
| 224 | 
            +
                        gr.update(visible=True),
         | 
| 225 | 
            +
                        gr.update(visible=False),
         | 
| 226 | 
            +
                        gr.update(visible=False),
         | 
| 227 | 
            +
                        gr.update(visible=False),
         | 
| 228 | 
            +
                        gr.update(visible=False),
         | 
| 229 | 
            +
                        gr.update(visible=False),
         | 
| 230 | 
            +
                        gr.update(visible=False),
         | 
| 231 | 
            +
                    )
         | 
| 232 | 
            +
                else:
         | 
| 233 | 
            +
                    # Show the appropriate components with their data
         | 
| 234 | 
            +
                    return (
         | 
| 235 | 
            +
                        gr.update(visible=False),
         | 
| 236 | 
            +
                        gr.update(visible=True),
         | 
| 237 | 
            +
                        gr.update(visible=True),
         | 
| 238 | 
            +
                        gr.update(visible=True),
         | 
| 239 | 
            +
                        gr.update(visible=True),
         | 
| 240 | 
            +
                        gr.update(visible=True),
         | 
| 241 | 
            +
                        gr.update(visible=True),
         | 
| 242 | 
            +
                    )
         | 
| 243 | 
            +
             | 
| 244 | 
            +
             | 
| 245 | 
            +
            def filter_by_model(dataframes, figures, index_state, model_selection):
         | 
| 246 | 
            +
                df = dataframes[index_state]
         | 
| 247 | 
            +
                if model_selection != "ALL":
         | 
| 248 | 
            +
                    df = df[df["model"] == model_selection]
         | 
| 249 | 
            +
                figure = figures[index_state](df)
         | 
| 250 | 
            +
                return df, figure
         | 
| 251 | 
            +
             | 
| 252 | 
            +
             | 
| 253 | 
            +
            def update_pagination(index, sql_queries):
         | 
| 254 | 
            +
                pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else ""
         | 
| 255 | 
            +
                return pagination
         | 
| 256 | 
            +
             | 
| 257 | 
            +
             | 
| 258 | 
             
            def create_drias_tab():
         | 
| 259 | 
            +
                details_text = """
         | 
| 260 | 
            +
            Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.  
         | 
| 261 | 
            +
            I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
         | 
| 262 | 
            +
             | 
| 263 | 
            +
            ❓ **How to use?**  
         | 
| 264 | 
            +
            You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.  
         | 
| 265 | 
            +
            You can specify **location** and/or **year**.  
         | 
| 266 | 
            +
            You can choose from a list of climate models. By default, we take the **average of each model**.
         | 
| 267 | 
            +
             | 
| 268 | 
            +
            For example, you can ask:  
         | 
| 269 | 
            +
            - What will the temperature be like in Paris?  
         | 
| 270 | 
            +
            - What will be the total rainfall in France in 2030?  
         | 
| 271 | 
            +
            - How frequent will extreme events be in Lyon?  
         | 
| 272 | 
            +
             | 
| 273 | 
            +
            **Example of indicators in the data**:  
         | 
| 274 | 
            +
            - Mean temperature (annual, winter, summer)  
         | 
| 275 | 
            +
            - Total precipitation (annual, winter, summer)  
         | 
| 276 | 
            +
            - Number of days with remarkable precipitations, with dry ground, with temperature above 30°C  
         | 
| 277 | 
            +
             | 
| 278 | 
            +
            ⚠️ **Limitations**:  
         | 
| 279 | 
            +
            - You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.  
         | 
| 280 | 
            +
            - You can only ask about **locations in France**.  
         | 
| 281 | 
            +
            - If you specify a year, there may be **no data for that year for some models**.  
         | 
| 282 | 
            +
            - You **cannot compare two models**.  
         | 
| 283 | 
            +
             | 
| 284 | 
            +
            🛈 **Information**  
         | 
| 285 | 
            +
            Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
         | 
| 286 | 
            +
            """
         | 
| 287 | 
             
                with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6):
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    with gr.Accordion(label="Details"):
         | 
| 290 | 
            +
                        gr.Markdown(details_text)
         | 
| 291 | 
            +
             | 
| 292 | 
             
                    with gr.Row():
         | 
| 293 | 
            +
                        drias_direct_question = gr.Textbox(
         | 
| 294 | 
            +
                            label="Direct Question",
         | 
| 295 | 
            +
                            placeholder="You can write direct question here",
         | 
| 296 | 
            +
                            elem_id="direct-question",
         | 
| 297 | 
            +
                            interactive=True,
         | 
| 298 | 
            +
                        )
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    result_text = gr.Textbox(
         | 
| 301 | 
            +
                        label="", elem_id="no-result-label", interactive=False, visible=True
         | 
| 302 | 
            +
                    )
         | 
| 303 |  | 
| 304 | 
            +
                    with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
         | 
| 305 | 
            +
                        drias_sql_query = gr.Textbox(
         | 
| 306 | 
            +
                            label="", elem_id="sql-query", interactive=False
         | 
| 307 | 
            +
                        )
         | 
| 308 |  | 
| 309 | 
            +
                    with gr.Accordion(label="Chart", visible=False) as chart_accordion:
         | 
| 310 | 
            +
                        model_selection = gr.Dropdown(
         | 
| 311 | 
            +
                            label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
         | 
| 312 | 
            +
                        )
         | 
| 313 | 
             
                        drias_display = gr.Plot(elem_id="vanna-plot")
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    with gr.Accordion(
         | 
| 316 | 
            +
                        label="Data used", open=False, visible=False
         | 
| 317 | 
            +
                    ) as table_accordion:
         | 
| 318 | 
            +
                        drias_table = gr.DataFrame([], elem_id="vanna-table")
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    pagination_display = gr.Markdown(value="", visible=False, elem_id="pagination-display")
         | 
| 321 | 
            +
             | 
| 322 | 
             
                    with gr.Row():
         | 
| 323 | 
            +
                        prev_button = gr.Button("Previous", visible=False)
         | 
| 324 | 
            +
                        next_button = gr.Button("Next", visible=False)
         | 
| 325 |  | 
| 326 | 
             
                    sql_queries_state = gr.State([])
         | 
| 327 | 
             
                    dataframes_state = gr.State([])
         | 
|  | |
| 329 | 
             
                    index_state = gr.State(0)
         | 
| 330 |  | 
| 331 | 
             
                    drias_direct_question.submit(
         | 
| 332 | 
            +
                        ask_drias_query,
         | 
| 333 | 
            +
                        inputs=[drias_direct_question, index_state],
         | 
| 334 | 
            +
                        outputs=[
         | 
| 335 | 
            +
                            drias_sql_query,
         | 
| 336 | 
            +
                            drias_table,
         | 
| 337 | 
            +
                            drias_display,
         | 
| 338 | 
            +
                            sql_queries_state,
         | 
| 339 | 
            +
                            dataframes_state,
         | 
| 340 | 
            +
                            plots_state,
         | 
| 341 | 
            +
                            index_state,
         | 
| 342 | 
            +
                            result_text,
         | 
| 343 | 
            +
                        ],
         | 
| 344 | 
            +
                    ).then(
         | 
| 345 | 
            +
                        show_results,
         | 
| 346 | 
            +
                        inputs=[sql_queries_state, dataframes_state, plots_state],
         | 
| 347 | 
            +
                        outputs=[
         | 
| 348 | 
            +
                            result_text,
         | 
| 349 | 
            +
                            query_accordion,
         | 
| 350 | 
            +
                            table_accordion,
         | 
| 351 | 
            +
                            chart_accordion,
         | 
| 352 | 
            +
                            prev_button,
         | 
| 353 | 
            +
                            next_button,
         | 
| 354 | 
            +
                            pagination_display
         | 
| 355 | 
            +
                        ],
         | 
| 356 | 
            +
                    ).then(
         | 
| 357 | 
            +
                        update_pagination,
         | 
| 358 | 
            +
                        inputs=[index_state, sql_queries_state],
         | 
| 359 | 
            +
                        outputs=[pagination_display],
         | 
| 360 | 
             
                    )
         | 
| 361 |  | 
| 362 | 
             
                    model_selection.change(
         | 
| 363 | 
            +
                        filter_by_model,
         | 
| 364 | 
            +
                        inputs=[dataframes_state, plots_state, index_state, model_selection],
         | 
| 365 | 
            +
                        outputs=[drias_table, drias_display],
         | 
| 366 | 
             
                    )
         | 
| 367 |  | 
| 368 | 
             
                    def show_previous(index, sql_queries, dataframes, plots):
         | 
| 369 | 
             
                        if index > 0:
         | 
| 370 | 
             
                            index -= 1
         | 
| 371 | 
            +
                        return (
         | 
| 372 | 
            +
                            sql_queries[index],
         | 
| 373 | 
            +
                            dataframes[index],
         | 
| 374 | 
            +
                            plots[index](dataframes[index]),
         | 
| 375 | 
            +
                            index,
         | 
| 376 | 
            +
                        )
         | 
| 377 |  | 
| 378 | 
             
                    def show_next(index, sql_queries, dataframes, plots):
         | 
| 379 | 
             
                        if index < len(sql_queries) - 1:
         | 
| 380 | 
             
                            index += 1
         | 
| 381 | 
            +
                        return (
         | 
| 382 | 
            +
                            sql_queries[index],
         | 
| 383 | 
            +
                            dataframes[index],
         | 
| 384 | 
            +
                            plots[index](dataframes[index]),
         | 
| 385 | 
            +
                            index,
         | 
| 386 | 
            +
                        )
         | 
| 387 |  | 
| 388 | 
             
                    prev_button.click(
         | 
| 389 | 
            +
                        show_previous,
         | 
| 390 | 
             
                        inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
         | 
| 391 | 
            +
                        outputs=[drias_sql_query, drias_table, drias_display, index_state],
         | 
| 392 | 
            +
                    ).then(
         | 
| 393 | 
            +
                        update_pagination,
         | 
| 394 | 
            +
                        inputs=[index_state, sql_queries_state],
         | 
| 395 | 
            +
                        outputs=[pagination_display],
         | 
| 396 | 
             
                    )
         | 
| 397 |  | 
| 398 | 
             
                    next_button.click(
         | 
| 399 | 
            +
                        show_next,
         | 
| 400 | 
             
                        inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
         | 
| 401 | 
            +
                        outputs=[drias_sql_query, drias_table, drias_display, index_state],
         | 
| 402 | 
            +
                    ).then(
         | 
| 403 | 
            +
                        update_pagination,
         | 
| 404 | 
            +
                        inputs=[index_state, sql_queries_state],
         | 
| 405 | 
            +
                        outputs=[pagination_display],
         | 
| 406 | 
             
                    )
         | 
| 407 | 
            +
             | 
| 408 | 
            +
             | 
| 409 | 
            +
            def config_event_handling(
         | 
| 410 | 
            +
                main_tabs_components: list[MainTabPanel], config_componenets: ConfigPanel
         | 
| 411 | 
            +
            ):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 412 | 
             
                config_open = config_componenets.config_open
         | 
| 413 | 
             
                config_modal = config_componenets.config_modal
         | 
| 414 | 
             
                close_config_modal = config_componenets.close_config_modal_button
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                for button in [close_config_modal] + [
         | 
| 417 | 
            +
                    main_tab_component.config_button for main_tab_component in main_tabs_components
         | 
| 418 | 
            +
                ]:
         | 
| 419 | 
             
                    button.click(
         | 
| 420 | 
             
                        fn=update_config_modal_visibility,
         | 
| 421 | 
             
                        inputs=[config_open],
         | 
| 422 | 
            +
                        outputs=[config_modal, config_open],
         | 
| 423 | 
            +
                    )
         | 
| 424 | 
            +
             | 
| 425 | 
            +
             | 
| 426 | 
             
            def event_handling(
         | 
| 427 | 
            +
                main_tab_components: MainTabPanel,
         | 
| 428 | 
            +
                config_components: ConfigPanel,
         | 
| 429 | 
            +
                tab_name="ClimateQ&A",
         | 
| 430 | 
             
            ):
         | 
| 431 | 
             
                chatbot = main_tab_components.chatbot
         | 
| 432 | 
             
                textbox = main_tab_components.textbox
         | 
|  | |
| 450 | 
             
                graphs_container = main_tab_components.graph_container
         | 
| 451 | 
             
                follow_up_examples = main_tab_components.follow_up_examples
         | 
| 452 | 
             
                follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden
         | 
| 453 | 
            +
             | 
| 454 | 
             
                dropdown_sources = config_components.dropdown_sources
         | 
| 455 | 
             
                dropdown_reports = config_components.dropdown_reports
         | 
| 456 | 
             
                dropdown_external_sources = config_components.dropdown_external_sources
         | 
|  | |
| 459 | 
             
                after = config_components.after
         | 
| 460 | 
             
                output_query = config_components.output_query
         | 
| 461 | 
             
                output_language = config_components.output_language
         | 
| 462 | 
            +
             | 
| 463 | 
             
                new_sources_hmtl = gr.State([])
         | 
| 464 | 
             
                ttd_data = gr.State([])
         | 
| 465 |  | 
|  | |
| 466 | 
             
                if tab_name == "ClimateQ&A":
         | 
| 467 | 
             
                    print("chat cqa - message sent")
         | 
| 468 |  | 
| 469 | 
             
                    # Event for textbox
         | 
| 470 | 
            +
                    (
         | 
| 471 | 
            +
                        textbox.submit(
         | 
| 472 | 
            +
                            start_chat,
         | 
| 473 | 
            +
                            [textbox, chatbot, search_only],
         | 
| 474 | 
            +
                            [textbox, tabs, chatbot, sources_raw],
         | 
| 475 | 
            +
                            queue=False,
         | 
| 476 | 
            +
                            api_name=f"start_chat_{textbox.elem_id}",
         | 
| 477 | 
            +
                        )
         | 
| 478 | 
            +
                        .then(
         | 
| 479 | 
            +
                            chat,
         | 
| 480 | 
            +
                            [
         | 
| 481 | 
            +
                                textbox,
         | 
| 482 | 
            +
                                chatbot,
         | 
| 483 | 
            +
                                dropdown_audience,
         | 
| 484 | 
            +
                                dropdown_sources,
         | 
| 485 | 
            +
                                dropdown_reports,
         | 
| 486 | 
            +
                                dropdown_external_sources,
         | 
| 487 | 
            +
                                search_only,
         | 
| 488 | 
            +
                            ],
         | 
| 489 | 
            +
                            [
         | 
| 490 | 
            +
                                chatbot,
         | 
| 491 | 
            +
                                new_sources_hmtl,
         | 
| 492 | 
            +
                                output_query,
         | 
| 493 | 
            +
                                output_language,
         | 
| 494 | 
            +
                                new_figures,
         | 
| 495 | 
            +
                                current_graphs,
         | 
| 496 | 
            +
                                follow_up_examples.dataset,
         | 
| 497 | 
            +
                            ],
         | 
| 498 | 
            +
                            concurrency_limit=8,
         | 
| 499 | 
            +
                            api_name=f"chat_{textbox.elem_id}",
         | 
| 500 | 
            +
                        )
         | 
| 501 | 
            +
                        .then(
         | 
| 502 | 
            +
                            finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
         | 
| 503 | 
            +
                        )
         | 
| 504 | 
             
                    )
         | 
| 505 | 
             
                    # Event for examples_hidden
         | 
| 506 | 
            +
                    (
         | 
| 507 | 
            +
                        examples_hidden.change(
         | 
| 508 | 
            +
                            start_chat,
         | 
| 509 | 
            +
                            [examples_hidden, chatbot, search_only],
         | 
| 510 | 
            +
                            [examples_hidden, tabs, chatbot, sources_raw],
         | 
| 511 | 
            +
                            queue=False,
         | 
| 512 | 
            +
                            api_name=f"start_chat_{examples_hidden.elem_id}",
         | 
| 513 | 
            +
                        )
         | 
| 514 | 
            +
                        .then(
         | 
| 515 | 
            +
                            chat,
         | 
| 516 | 
            +
                            [
         | 
| 517 | 
            +
                                examples_hidden,
         | 
| 518 | 
            +
                                chatbot,
         | 
| 519 | 
            +
                                dropdown_audience,
         | 
| 520 | 
            +
                                dropdown_sources,
         | 
| 521 | 
            +
                                dropdown_reports,
         | 
| 522 | 
            +
                                dropdown_external_sources,
         | 
| 523 | 
            +
                                search_only,
         | 
| 524 | 
            +
                            ],
         | 
| 525 | 
            +
                            [
         | 
| 526 | 
            +
                                chatbot,
         | 
| 527 | 
            +
                                new_sources_hmtl,
         | 
| 528 | 
            +
                                output_query,
         | 
| 529 | 
            +
                                output_language,
         | 
| 530 | 
            +
                                new_figures,
         | 
| 531 | 
            +
                                current_graphs,
         | 
| 532 | 
            +
                                follow_up_examples.dataset,
         | 
| 533 | 
            +
                            ],
         | 
| 534 | 
            +
                            concurrency_limit=8,
         | 
| 535 | 
            +
                            api_name=f"chat_{examples_hidden.elem_id}",
         | 
| 536 | 
            +
                        )
         | 
| 537 | 
            +
                        .then(
         | 
| 538 | 
            +
                            finish_chat,
         | 
| 539 | 
            +
                            None,
         | 
| 540 | 
            +
                            [textbox],
         | 
| 541 | 
            +
                            api_name=f"finish_chat_{examples_hidden.elem_id}",
         | 
| 542 | 
            +
                        )
         | 
| 543 | 
             
                    )
         | 
| 544 | 
            +
                    (
         | 
| 545 | 
            +
                        follow_up_examples_hidden.change(
         | 
| 546 | 
            +
                            start_chat,
         | 
| 547 | 
            +
                            [follow_up_examples_hidden, chatbot, search_only],
         | 
| 548 | 
            +
                            [follow_up_examples_hidden, tabs, chatbot, sources_raw],
         | 
| 549 | 
            +
                            queue=False,
         | 
| 550 | 
            +
                            api_name=f"start_chat_{examples_hidden.elem_id}",
         | 
| 551 | 
            +
                        )
         | 
| 552 | 
            +
                        .then(
         | 
| 553 | 
            +
                            chat,
         | 
| 554 | 
            +
                            [
         | 
| 555 | 
            +
                                follow_up_examples_hidden,
         | 
| 556 | 
            +
                                chatbot,
         | 
| 557 | 
            +
                                dropdown_audience,
         | 
| 558 | 
            +
                                dropdown_sources,
         | 
| 559 | 
            +
                                dropdown_reports,
         | 
| 560 | 
            +
                                dropdown_external_sources,
         | 
| 561 | 
            +
                                search_only,
         | 
| 562 | 
            +
                            ],
         | 
| 563 | 
            +
                            [
         | 
| 564 | 
            +
                                chatbot,
         | 
| 565 | 
            +
                                new_sources_hmtl,
         | 
| 566 | 
            +
                                output_query,
         | 
| 567 | 
            +
                                output_language,
         | 
| 568 | 
            +
                                new_figures,
         | 
| 569 | 
            +
                                current_graphs,
         | 
| 570 | 
            +
                                follow_up_examples.dataset,
         | 
| 571 | 
            +
                            ],
         | 
| 572 | 
            +
                            concurrency_limit=8,
         | 
| 573 | 
            +
                            api_name=f"chat_{examples_hidden.elem_id}",
         | 
| 574 | 
            +
                        )
         | 
| 575 | 
            +
                        .then(
         | 
| 576 | 
            +
                            finish_chat,
         | 
| 577 | 
            +
                            None,
         | 
| 578 | 
            +
                            [textbox],
         | 
| 579 | 
            +
                            api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
         | 
| 580 | 
            +
                        )
         | 
| 581 | 
             
                    )
         | 
| 582 | 
            +
             | 
| 583 | 
             
                elif tab_name == "Beta - POC Adapt'Action":
         | 
| 584 | 
             
                    print("chat poc - message sent")
         | 
| 585 | 
             
                    # Event for textbox
         | 
| 586 | 
            +
                    (
         | 
| 587 | 
            +
                        textbox.submit(
         | 
| 588 | 
            +
                            start_chat,
         | 
| 589 | 
            +
                            [textbox, chatbot, search_only],
         | 
| 590 | 
            +
                            [textbox, tabs, chatbot, sources_raw],
         | 
| 591 | 
            +
                            queue=False,
         | 
| 592 | 
            +
                            api_name=f"start_chat_{textbox.elem_id}",
         | 
| 593 | 
            +
                        )
         | 
| 594 | 
            +
                        .then(
         | 
| 595 | 
            +
                            chat_poc,
         | 
| 596 | 
            +
                            [
         | 
| 597 | 
            +
                                textbox,
         | 
| 598 | 
            +
                                chatbot,
         | 
| 599 | 
            +
                                dropdown_audience,
         | 
| 600 | 
            +
                                dropdown_sources,
         | 
| 601 | 
            +
                                dropdown_reports,
         | 
| 602 | 
            +
                                dropdown_external_sources,
         | 
| 603 | 
            +
                                search_only,
         | 
| 604 | 
            +
                            ],
         | 
| 605 | 
            +
                            [
         | 
| 606 | 
            +
                                chatbot,
         | 
| 607 | 
            +
                                new_sources_hmtl,
         | 
| 608 | 
            +
                                output_query,
         | 
| 609 | 
            +
                                output_language,
         | 
| 610 | 
            +
                                new_figures,
         | 
| 611 | 
            +
                                current_graphs,
         | 
| 612 | 
            +
                            ],
         | 
| 613 | 
            +
                            concurrency_limit=8,
         | 
| 614 | 
            +
                            api_name=f"chat_{textbox.elem_id}",
         | 
| 615 | 
            +
                        )
         | 
| 616 | 
            +
                        .then(
         | 
| 617 | 
            +
                            finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
         | 
| 618 | 
            +
                        )
         | 
| 619 | 
             
                    )
         | 
| 620 | 
             
                    # Event for examples_hidden
         | 
| 621 | 
            +
                    (
         | 
| 622 | 
            +
                        examples_hidden.change(
         | 
| 623 | 
            +
                            start_chat,
         | 
| 624 | 
            +
                            [examples_hidden, chatbot, search_only],
         | 
| 625 | 
            +
                            [examples_hidden, tabs, chatbot, sources_raw],
         | 
| 626 | 
            +
                            queue=False,
         | 
| 627 | 
            +
                            api_name=f"start_chat_{examples_hidden.elem_id}",
         | 
| 628 | 
            +
                        )
         | 
| 629 | 
            +
                        .then(
         | 
| 630 | 
            +
                            chat_poc,
         | 
| 631 | 
            +
                            [
         | 
| 632 | 
            +
                                examples_hidden,
         | 
| 633 | 
            +
                                chatbot,
         | 
| 634 | 
            +
                                dropdown_audience,
         | 
| 635 | 
            +
                                dropdown_sources,
         | 
| 636 | 
            +
                                dropdown_reports,
         | 
| 637 | 
            +
                                dropdown_external_sources,
         | 
| 638 | 
            +
                                search_only,
         | 
| 639 | 
            +
                            ],
         | 
| 640 | 
            +
                            [
         | 
| 641 | 
            +
                                chatbot,
         | 
| 642 | 
            +
                                new_sources_hmtl,
         | 
| 643 | 
            +
                                output_query,
         | 
| 644 | 
            +
                                output_language,
         | 
| 645 | 
            +
                                new_figures,
         | 
| 646 | 
            +
                                current_graphs,
         | 
| 647 | 
            +
                            ],
         | 
| 648 | 
            +
                            concurrency_limit=8,
         | 
| 649 | 
            +
                            api_name=f"chat_{examples_hidden.elem_id}",
         | 
| 650 | 
            +
                        )
         | 
| 651 | 
            +
                        .then(
         | 
| 652 | 
            +
                            finish_chat,
         | 
| 653 | 
            +
                            None,
         | 
| 654 | 
            +
                            [textbox],
         | 
| 655 | 
            +
                            api_name=f"finish_chat_{examples_hidden.elem_id}",
         | 
| 656 | 
            +
                        )
         | 
| 657 | 
             
                    )
         | 
| 658 | 
            +
                    (
         | 
| 659 | 
            +
                        follow_up_examples_hidden.change(
         | 
| 660 | 
            +
                            start_chat,
         | 
| 661 | 
            +
                            [follow_up_examples_hidden, chatbot, search_only],
         | 
| 662 | 
            +
                            [follow_up_examples_hidden, tabs, chatbot, sources_raw],
         | 
| 663 | 
            +
                            queue=False,
         | 
| 664 | 
            +
                            api_name=f"start_chat_{examples_hidden.elem_id}",
         | 
| 665 | 
            +
                        )
         | 
| 666 | 
            +
                        .then(
         | 
| 667 | 
            +
                            chat,
         | 
| 668 | 
            +
                            [
         | 
| 669 | 
            +
                                follow_up_examples_hidden,
         | 
| 670 | 
            +
                                chatbot,
         | 
| 671 | 
            +
                                dropdown_audience,
         | 
| 672 | 
            +
                                dropdown_sources,
         | 
| 673 | 
            +
                                dropdown_reports,
         | 
| 674 | 
            +
                                dropdown_external_sources,
         | 
| 675 | 
            +
                                search_only,
         | 
| 676 | 
            +
                            ],
         | 
| 677 | 
            +
                            [
         | 
| 678 | 
            +
                                chatbot,
         | 
| 679 | 
            +
                                new_sources_hmtl,
         | 
| 680 | 
            +
                                output_query,
         | 
| 681 | 
            +
                                output_language,
         | 
| 682 | 
            +
                                new_figures,
         | 
| 683 | 
            +
                                current_graphs,
         | 
| 684 | 
            +
                                follow_up_examples.dataset,
         | 
| 685 | 
            +
                            ],
         | 
| 686 | 
            +
                            concurrency_limit=8,
         | 
| 687 | 
            +
                            api_name=f"chat_{examples_hidden.elem_id}",
         | 
| 688 | 
            +
                        )
         | 
| 689 | 
            +
                        .then(
         | 
| 690 | 
            +
                            finish_chat,
         | 
| 691 | 
            +
                            None,
         | 
| 692 | 
            +
                            [textbox],
         | 
| 693 | 
            +
                            api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
         | 
| 694 | 
            +
                        )
         | 
| 695 | 
             
                    )
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                new_sources_hmtl.change(
         | 
| 698 | 
            +
                    lambda x: x, inputs=[new_sources_hmtl], outputs=[sources_textbox]
         | 
| 699 | 
            +
                )
         | 
| 700 | 
            +
                current_graphs.change(
         | 
| 701 | 
            +
                    lambda x: x, inputs=[current_graphs], outputs=[graphs_container]
         | 
| 702 | 
            +
                )
         | 
| 703 | 
            +
                new_figures.change(
         | 
| 704 | 
            +
                    process_figures,
         | 
| 705 | 
            +
                    inputs=[sources_raw, new_figures],
         | 
| 706 | 
            +
                    outputs=[sources_raw, figures_cards, gallery_component],
         | 
| 707 | 
            +
                )
         | 
| 708 |  | 
| 709 | 
             
                # Update sources numbers
         | 
| 710 | 
             
                for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
         | 
| 711 | 
            +
                    component.change(
         | 
| 712 | 
            +
                        update_sources_number_display,
         | 
| 713 | 
            +
                        [sources_textbox, figures_cards, current_graphs, papers_html],
         | 
| 714 | 
            +
                        [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers],
         | 
| 715 | 
            +
                    )
         | 
| 716 | 
            +
             | 
| 717 | 
             
                # Search for papers
         | 
| 718 | 
             
                for component in [textbox, examples_hidden, papers_direct_search]:
         | 
| 719 | 
            +
                    component.submit(
         | 
| 720 | 
            +
                        find_papers,
         | 
| 721 | 
            +
                        [component, after, dropdown_external_sources],
         | 
| 722 | 
            +
                        [papers_html, citations_network, papers_summary],
         | 
| 723 | 
            +
                    )
         | 
| 724 |  | 
| 725 | 
             
                # if tab_name == "Beta - POC Adapt'Action": # Not untill results are good enough
         | 
| 726 | 
             
                #     # Drias search
         | 
| 727 | 
             
                #     textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
         | 
| 728 |  | 
| 729 | 
            +
             | 
| 730 | 
             
            def main_ui():
         | 
| 731 | 
             
                # config_open = gr.State(True)
         | 
| 732 | 
            +
                with gr.Blocks(
         | 
| 733 | 
            +
                    title="Climate Q&A",
         | 
| 734 | 
            +
                    css_paths=os.getcwd() + "/style.css",
         | 
| 735 | 
            +
                    theme=theme,
         | 
| 736 | 
            +
                    elem_id="main-component",
         | 
| 737 | 
            +
                ) as demo:
         | 
| 738 | 
            +
                    config_components = create_config_modal()
         | 
| 739 | 
            +
             | 
| 740 | 
             
                    with gr.Tabs():
         | 
| 741 | 
            +
                        cqa_components = cqa_tab(tab_name="ClimateQ&A")
         | 
| 742 | 
            +
                        local_cqa_components = cqa_tab(tab_name="Beta - POC Adapt'Action")
         | 
| 743 | 
             
                        create_drias_tab()
         | 
| 744 | 
            +
             | 
| 745 | 
             
                        create_about_tab()
         | 
| 746 | 
            +
             | 
| 747 | 
            +
                    event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
         | 
| 748 | 
            +
                    event_handling(
         | 
| 749 | 
            +
                        local_cqa_components, config_components, tab_name="Beta - POC Adapt'Action"
         | 
| 750 | 
            +
                    )
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                    config_event_handling([cqa_components, local_cqa_components], config_components)
         | 
| 753 | 
            +
             | 
| 754 | 
             
                    demo.queue()
         | 
| 755 | 
            +
             | 
| 756 | 
             
                return demo
         | 
| 757 |  | 
| 758 | 
            +
             | 
| 759 | 
             
            demo = main_ui()
         | 
| 760 | 
             
            demo.launch(ssr_mode=False)
         | 
    	
        climateqa/engine/talk_to_data/config.py
    ADDED
    
    | @@ -0,0 +1,33 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            DRIAS_TABLES = [
         | 
| 2 | 
            +
                "total_winter_precipitation",
         | 
| 3 | 
            +
                "total_summer_precipiation",
         | 
| 4 | 
            +
                "total_annual_precipitation",
         | 
| 5 | 
            +
                "total_remarkable_daily_precipitation",
         | 
| 6 | 
            +
                "frequency_of_remarkable_daily_precipitation",
         | 
| 7 | 
            +
                "extreme_precipitation_intensity",
         | 
| 8 | 
            +
                "mean_winter_temperature",
         | 
| 9 | 
            +
                "mean_summer_temperature",
         | 
| 10 | 
            +
                "mean_annual_temperature",
         | 
| 11 | 
            +
                "number_of_tropical_nights",
         | 
| 12 | 
            +
                "maximum_summer_temperature",
         | 
| 13 | 
            +
                "number_of_days_with_tx_above_30",
         | 
| 14 | 
            +
                "number_of_days_with_tx_above_35",
         | 
| 15 | 
            +
                "number_of_days_with_a_dry_ground",
         | 
| 16 | 
            +
            ]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            INDICATOR_COLUMNS_PER_TABLE = {
         | 
| 19 | 
            +
                "total_winter_precipitation": "total_winter_precipitation",
         | 
| 20 | 
            +
                "total_summer_precipiation": "total_summer_precipitation",
         | 
| 21 | 
            +
                "total_annual_precipitation": "total_annual_precipitation",
         | 
| 22 | 
            +
                "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
         | 
| 23 | 
            +
                "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
         | 
| 24 | 
            +
                "extreme_precipitation_intensity": "extreme_precipitation_intensity",
         | 
| 25 | 
            +
                "mean_winter_temperature": "mean_winter_temperature",
         | 
| 26 | 
            +
                "mean_summer_temperature": "mean_summer_temperature",
         | 
| 27 | 
            +
                "mean_annual_temperature": "mean_annual_temperature",
         | 
| 28 | 
            +
                "number_of_tropical_nights": "number_tropical_nights",
         | 
| 29 | 
            +
                "maximum_summer_temperature": "maximum_summer_temperature",
         | 
| 30 | 
            +
                "number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
         | 
| 31 | 
            +
                "number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
         | 
| 32 | 
            +
                "number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
         | 
| 33 | 
            +
            }
         | 
    	
        climateqa/engine/talk_to_data/main.py
    CHANGED
    
    | @@ -13,8 +13,8 @@ def ask_llm_column_names(sql_query, llm): | |
| 13 | 
             
                columns_list = ast.literal_eval(columns.strip("```python\n").strip())
         | 
| 14 | 
             
                return columns_list
         | 
| 15 |  | 
| 16 | 
            -
            def ask_drias( | 
| 17 | 
            -
                final_state = drias_workflow( | 
| 18 | 
             
                sql_queries = []
         | 
| 19 | 
             
                result_dataframes = []
         | 
| 20 | 
             
                figures = []
         | 
| @@ -28,10 +28,15 @@ def ask_drias(db_drias_path:str, query:str, index_state: int = 0, drias_model: s | |
| 28 | 
             
                            if 'dataframe' in table_state and table_state['dataframe'] is not None:
         | 
| 29 | 
             
                                result_dataframes.append(table_state['dataframe'])
         | 
| 30 | 
             
                                if 'figure' in table_state and table_state['figure'] is not None:
         | 
| 31 | 
            -
                                    figures.append(table_state['figure'] | 
| 32 |  | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 35 |  | 
| 36 | 
             
            DRIAS_MODELS = [
         | 
| 37 | 
             
                'ALL',
         | 
|  | |
| 13 | 
             
                columns_list = ast.literal_eval(columns.strip("```python\n").strip())
         | 
| 14 | 
             
                return columns_list
         | 
| 15 |  | 
| 16 | 
            +
            def ask_drias(query:str, index_state: int = 0):
         | 
| 17 | 
            +
                final_state = drias_workflow(query)
         | 
| 18 | 
             
                sql_queries = []
         | 
| 19 | 
             
                result_dataframes = []
         | 
| 20 | 
             
                figures = []
         | 
|  | |
| 28 | 
             
                            if 'dataframe' in table_state and table_state['dataframe'] is not None:
         | 
| 29 | 
             
                                result_dataframes.append(table_state['dataframe'])
         | 
| 30 | 
             
                                if 'figure' in table_state and table_state['figure'] is not None:
         | 
| 31 | 
            +
                                    figures.append(table_state['figure'])
         | 
| 32 |  | 
| 33 | 
            +
                if "error" in final_state and final_state["error"] != "":
         | 
| 34 | 
            +
                    return None, None, None, [], [], [], 0, final_state["error"]
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                sql_query = sql_queries[index_state]
         | 
| 37 | 
            +
                dataframe = result_dataframes[index_state]
         | 
| 38 | 
            +
                figure = figures[index_state](dataframe)       
         | 
| 39 | 
            +
                return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, ""
         | 
| 40 |  | 
| 41 | 
             
            DRIAS_MODELS = [
         | 
| 42 | 
             
                'ALL',
         | 
    	
        climateqa/engine/talk_to_data/plot.py
    CHANGED
    
    | @@ -29,7 +29,6 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]: | |
| 29 | 
             
                    Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe
         | 
| 30 | 
             
                """
         | 
| 31 | 
             
                indicator = params["indicator_column"]
         | 
| 32 | 
            -
                model = params["model"]
         | 
| 33 | 
             
                location = params["location"]
         | 
| 34 | 
             
                indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
         | 
| 35 |  | 
| @@ -43,7 +42,7 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]: | |
| 43 | 
             
                        Figure: Plotly figure
         | 
| 44 | 
             
                    """
         | 
| 45 | 
             
                    fig = go.Figure()
         | 
| 46 | 
            -
                    if model  | 
| 47 | 
             
                        df_avg = df.groupby("year", as_index=False)[indicator].mean()
         | 
| 48 |  | 
| 49 | 
             
                        # Transform to list to avoid pandas encoding
         | 
| @@ -58,8 +57,10 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]: | |
| 58 | 
             
                            .astype(float)
         | 
| 59 | 
             
                            .tolist()
         | 
| 60 | 
             
                        )
         | 
|  | |
|  | |
| 61 | 
             
                    else:
         | 
| 62 | 
            -
                        df_model = df | 
| 63 |  | 
| 64 | 
             
                        # Transform to list to avoid pandas encoding
         | 
| 65 | 
             
                        indicators = df_model[indicator].astype(float).tolist()
         | 
| @@ -73,6 +74,8 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]: | |
| 73 | 
             
                            .astype(float)
         | 
| 74 | 
             
                            .tolist()
         | 
| 75 | 
             
                        )
         | 
|  | |
|  | |
| 76 |  | 
| 77 | 
             
                    # Indicator per year plot
         | 
| 78 | 
             
                    fig.add_scatter(
         | 
| @@ -93,7 +96,7 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]: | |
| 93 | 
             
                        marker=dict(color="#d62728"),
         | 
| 94 | 
             
                    )
         | 
| 95 | 
             
                    fig.update_layout(
         | 
| 96 | 
            -
                        title=f"Plot of {indicator_label} in {location} { | 
| 97 | 
             
                        xaxis_title="Year",
         | 
| 98 | 
             
                        yaxis_title=indicator_label,
         | 
| 99 | 
             
                        template="plotly_white",
         | 
| @@ -125,7 +128,6 @@ def plot_indicator_number_of_days_per_year_at_location( | |
| 125 | 
             
                """
         | 
| 126 |  | 
| 127 | 
             
                indicator = params["indicator_column"]
         | 
| 128 | 
            -
                model = params["model"]
         | 
| 129 | 
             
                location = params["location"]
         | 
| 130 |  | 
| 131 | 
             
                def plot_data(df: pd.DataFrame) -> Figure:
         | 
| @@ -138,19 +140,21 @@ def plot_indicator_number_of_days_per_year_at_location( | |
| 138 | 
             
                        Figure: Plotly figure
         | 
| 139 | 
             
                    """
         | 
| 140 | 
             
                    fig = go.Figure()
         | 
| 141 | 
            -
                    if model  | 
| 142 | 
             
                        df_avg = df.groupby("year", as_index=False)[indicator].mean()
         | 
| 143 |  | 
| 144 | 
             
                        # Transform to list to avoid pandas encoding
         | 
| 145 | 
             
                        indicators = df_avg[indicator].astype(float).tolist()
         | 
| 146 | 
             
                        years = df_avg["year"].astype(int).tolist()
         | 
|  | |
| 147 |  | 
| 148 | 
             
                    else:
         | 
| 149 | 
            -
                        df_model = df | 
| 150 | 
            -
             | 
| 151 | 
             
                        # Transform to list to avoid pandas encoding
         | 
| 152 | 
             
                        indicators = df_model[indicator].astype(float).tolist()
         | 
| 153 | 
             
                        years = df_model["year"].astype(int).tolist()
         | 
|  | |
|  | |
| 154 |  | 
| 155 | 
             
                    # Bar plot
         | 
| 156 | 
             
                    fig.add_trace(
         | 
| @@ -165,7 +169,7 @@ def plot_indicator_number_of_days_per_year_at_location( | |
| 165 | 
             
                    indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
         | 
| 166 |  | 
| 167 | 
             
                    fig.update_layout(
         | 
| 168 | 
            -
                        title=f"{indicator_label} in {location} { | 
| 169 | 
             
                        xaxis_title="Year",
         | 
| 170 | 
             
                        yaxis_title=indicator,
         | 
| 171 | 
             
                        yaxis=dict(range=[0, max(indicators)]),
         | 
| @@ -199,7 +203,6 @@ def plot_distribution_of_indicator_for_given_year( | |
| 199 | 
             
                    Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe
         | 
| 200 | 
             
                """
         | 
| 201 | 
             
                indicator = params["indicator_column"]
         | 
| 202 | 
            -
                model = params["model"]
         | 
| 203 | 
             
                year = params["year"]
         | 
| 204 | 
             
                indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
         | 
| 205 |  | 
| @@ -213,18 +216,22 @@ def plot_distribution_of_indicator_for_given_year( | |
| 213 | 
             
                        Figure: Plotly figure
         | 
| 214 | 
             
                    """
         | 
| 215 | 
             
                    fig = go.Figure()
         | 
| 216 | 
            -
                    if  | 
| 217 | 
             
                        df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
         | 
| 218 | 
             
                            indicator
         | 
| 219 | 
             
                        ].mean()
         | 
| 220 |  | 
| 221 | 
             
                        # Transform to list to avoid pandas encoding
         | 
| 222 | 
             
                        indicators = df_avg[indicator].astype(float).tolist()
         | 
|  | |
|  | |
| 223 | 
             
                    else:
         | 
| 224 | 
            -
                        df_model = df | 
| 225 |  | 
| 226 | 
             
                        # Transform to list to avoid pandas encoding
         | 
| 227 | 
             
                        indicators = df_model[indicator].astype(float).tolist()
         | 
|  | |
|  | |
| 228 |  | 
| 229 | 
             
                    fig.add_trace(
         | 
| 230 | 
             
                        go.Histogram(
         | 
| @@ -236,7 +243,7 @@ def plot_distribution_of_indicator_for_given_year( | |
| 236 | 
             
                    )
         | 
| 237 |  | 
| 238 | 
             
                    fig.update_layout(
         | 
| 239 | 
            -
                        title=f"Distribution of {indicator_label} in {year} { | 
| 240 | 
             
                        xaxis_title=indicator_label,
         | 
| 241 | 
             
                        yaxis_title="Frequency",
         | 
| 242 | 
             
                        plot_bgcolor="rgba(0, 0, 0, 0)",
         | 
| @@ -270,13 +277,12 @@ def plot_map_of_france_of_indicator_for_given_year( | |
| 270 | 
             
                """
         | 
| 271 |  | 
| 272 | 
             
                indicator = params["indicator_column"]
         | 
| 273 | 
            -
                model = params["model"]
         | 
| 274 | 
             
                year = params["year"]
         | 
| 275 | 
             
                indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
         | 
| 276 |  | 
| 277 | 
             
                def plot_data(df: pd.DataFrame) -> Figure:
         | 
| 278 | 
             
                    fig = go.Figure()
         | 
| 279 | 
            -
                    if model  | 
| 280 | 
             
                        df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
         | 
| 281 | 
             
                            indicator
         | 
| 282 | 
             
                        ].mean()
         | 
| @@ -284,14 +290,17 @@ def plot_map_of_france_of_indicator_for_given_year( | |
| 284 | 
             
                        indicators = df_avg[indicator].astype(float).tolist()
         | 
| 285 | 
             
                        latitudes = df_avg["latitude"].astype(float).tolist()
         | 
| 286 | 
             
                        longitudes = df_avg["longitude"].astype(float).tolist()
         | 
|  | |
| 287 |  | 
| 288 | 
             
                    else:
         | 
| 289 | 
            -
                        df_model = df | 
| 290 |  | 
| 291 | 
             
                        # Transform to list to avoid pandas encoding
         | 
| 292 | 
             
                        indicators = df_model[indicator].astype(float).tolist()
         | 
| 293 | 
             
                        latitudes = df_model["latitude"].astype(float).tolist()
         | 
| 294 | 
             
                        longitudes = df_model["longitude"].astype(float).tolist()
         | 
|  | |
|  | |
| 295 |  | 
| 296 | 
             
                    fig.add_trace(
         | 
| 297 | 
             
                        go.Scattermapbox(
         | 
| @@ -314,7 +323,7 @@ def plot_map_of_france_of_indicator_for_given_year( | |
| 314 | 
             
                        mapbox_zoom=3,
         | 
| 315 | 
             
                        mapbox_center={"lat": 46.6, "lon": 2.0},
         | 
| 316 | 
             
                        coloraxis_colorbar=dict(title=f"{indicator_label}"),  # Add legend
         | 
| 317 | 
            -
                        title=f"{indicator_label} in {year} in France { | 
| 318 | 
             
                    )
         | 
| 319 | 
             
                    return fig
         | 
| 320 |  | 
|  | |
| 29 | 
             
                    Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe
         | 
| 30 | 
             
                """
         | 
| 31 | 
             
                indicator = params["indicator_column"]
         | 
|  | |
| 32 | 
             
                location = params["location"]
         | 
| 33 | 
             
                indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
         | 
| 34 |  | 
|  | |
| 42 | 
             
                        Figure: Plotly figure
         | 
| 43 | 
             
                    """
         | 
| 44 | 
             
                    fig = go.Figure()
         | 
| 45 | 
            +
                    if df['model'].nunique() != 1:
         | 
| 46 | 
             
                        df_avg = df.groupby("year", as_index=False)[indicator].mean()
         | 
| 47 |  | 
| 48 | 
             
                        # Transform to list to avoid pandas encoding
         | 
|  | |
| 57 | 
             
                            .astype(float)
         | 
| 58 | 
             
                            .tolist()
         | 
| 59 | 
             
                        )
         | 
| 60 | 
            +
                        model_label = "Model Average"
         | 
| 61 | 
            +
             | 
| 62 | 
             
                    else:
         | 
| 63 | 
            +
                        df_model = df
         | 
| 64 |  | 
| 65 | 
             
                        # Transform to list to avoid pandas encoding
         | 
| 66 | 
             
                        indicators = df_model[indicator].astype(float).tolist()
         | 
|  | |
| 74 | 
             
                            .astype(float)
         | 
| 75 | 
             
                            .tolist()
         | 
| 76 | 
             
                        )
         | 
| 77 | 
            +
                        model_label = f"Model : {df['model'].unique()[0]}"
         | 
| 78 | 
            +
             | 
| 79 |  | 
| 80 | 
             
                    # Indicator per year plot
         | 
| 81 | 
             
                    fig.add_scatter(
         | 
|  | |
| 96 | 
             
                        marker=dict(color="#d62728"),
         | 
| 97 | 
             
                    )
         | 
| 98 | 
             
                    fig.update_layout(
         | 
| 99 | 
            +
                        title=f"Plot of {indicator_label} in {location} ({model_label})",
         | 
| 100 | 
             
                        xaxis_title="Year",
         | 
| 101 | 
             
                        yaxis_title=indicator_label,
         | 
| 102 | 
             
                        template="plotly_white",
         | 
|  | |
| 128 | 
             
                """
         | 
| 129 |  | 
| 130 | 
             
                indicator = params["indicator_column"]
         | 
|  | |
| 131 | 
             
                location = params["location"]
         | 
| 132 |  | 
| 133 | 
             
                def plot_data(df: pd.DataFrame) -> Figure:
         | 
|  | |
| 140 | 
             
                        Figure: Plotly figure
         | 
| 141 | 
             
                    """
         | 
| 142 | 
             
                    fig = go.Figure()
         | 
| 143 | 
            +
                    if df['model'].nunique() != 1:
         | 
| 144 | 
             
                        df_avg = df.groupby("year", as_index=False)[indicator].mean()
         | 
| 145 |  | 
| 146 | 
             
                        # Transform to list to avoid pandas encoding
         | 
| 147 | 
             
                        indicators = df_avg[indicator].astype(float).tolist()
         | 
| 148 | 
             
                        years = df_avg["year"].astype(int).tolist()
         | 
| 149 | 
            +
                        model_label = "Model Average"
         | 
| 150 |  | 
| 151 | 
             
                    else:
         | 
| 152 | 
            +
                        df_model = df
         | 
|  | |
| 153 | 
             
                        # Transform to list to avoid pandas encoding
         | 
| 154 | 
             
                        indicators = df_model[indicator].astype(float).tolist()
         | 
| 155 | 
             
                        years = df_model["year"].astype(int).tolist()
         | 
| 156 | 
            +
                        model_label = f"Model : {df['model'].unique()[0]}"
         | 
| 157 | 
            +
             | 
| 158 |  | 
| 159 | 
             
                    # Bar plot
         | 
| 160 | 
             
                    fig.add_trace(
         | 
|  | |
| 169 | 
             
                    indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
         | 
| 170 |  | 
| 171 | 
             
                    fig.update_layout(
         | 
| 172 | 
            +
                        title=f"{indicator_label} in {location} ({model_label})",
         | 
| 173 | 
             
                        xaxis_title="Year",
         | 
| 174 | 
             
                        yaxis_title=indicator,
         | 
| 175 | 
             
                        yaxis=dict(range=[0, max(indicators)]),
         | 
|  | |
| 203 | 
             
                    Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe
         | 
| 204 | 
             
                """
         | 
| 205 | 
             
                indicator = params["indicator_column"]
         | 
|  | |
| 206 | 
             
                year = params["year"]
         | 
| 207 | 
             
                indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
         | 
| 208 |  | 
|  | |
| 216 | 
             
                        Figure: Plotly figure
         | 
| 217 | 
             
                    """
         | 
| 218 | 
             
                    fig = go.Figure()
         | 
| 219 | 
            +
                    if df['model'].nunique() != 1:
         | 
| 220 | 
             
                        df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
         | 
| 221 | 
             
                            indicator
         | 
| 222 | 
             
                        ].mean()
         | 
| 223 |  | 
| 224 | 
             
                        # Transform to list to avoid pandas encoding
         | 
| 225 | 
             
                        indicators = df_avg[indicator].astype(float).tolist()
         | 
| 226 | 
            +
                        model_label = "Model Average"
         | 
| 227 | 
            +
             | 
| 228 | 
             
                    else:
         | 
| 229 | 
            +
                        df_model = df
         | 
| 230 |  | 
| 231 | 
             
                        # Transform to list to avoid pandas encoding
         | 
| 232 | 
             
                        indicators = df_model[indicator].astype(float).tolist()
         | 
| 233 | 
            +
                        model_label = f"Model : {df['model'].unique()[0]}"
         | 
| 234 | 
            +
             | 
| 235 |  | 
| 236 | 
             
                    fig.add_trace(
         | 
| 237 | 
             
                        go.Histogram(
         | 
|  | |
| 243 | 
             
                    )
         | 
| 244 |  | 
| 245 | 
             
                    fig.update_layout(
         | 
| 246 | 
            +
                        title=f"Distribution of {indicator_label} in {year} ({model_label})",
         | 
| 247 | 
             
                        xaxis_title=indicator_label,
         | 
| 248 | 
             
                        yaxis_title="Frequency",
         | 
| 249 | 
             
                        plot_bgcolor="rgba(0, 0, 0, 0)",
         | 
|  | |
| 277 | 
             
                """
         | 
| 278 |  | 
| 279 | 
             
                indicator = params["indicator_column"]
         | 
|  | |
| 280 | 
             
                year = params["year"]
         | 
| 281 | 
             
                indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
         | 
| 282 |  | 
| 283 | 
             
                def plot_data(df: pd.DataFrame) -> Figure:
         | 
| 284 | 
             
                    fig = go.Figure()
         | 
| 285 | 
            +
                    if df['model'].nunique() != 1:
         | 
| 286 | 
             
                        df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
         | 
| 287 | 
             
                            indicator
         | 
| 288 | 
             
                        ].mean()
         | 
|  | |
| 290 | 
             
                        indicators = df_avg[indicator].astype(float).tolist()
         | 
| 291 | 
             
                        latitudes = df_avg["latitude"].astype(float).tolist()
         | 
| 292 | 
             
                        longitudes = df_avg["longitude"].astype(float).tolist()
         | 
| 293 | 
            +
                        model_label = "Model Average"
         | 
| 294 |  | 
| 295 | 
             
                    else:
         | 
| 296 | 
            +
                        df_model = df
         | 
| 297 |  | 
| 298 | 
             
                        # Transform to list to avoid pandas encoding
         | 
| 299 | 
             
                        indicators = df_model[indicator].astype(float).tolist()
         | 
| 300 | 
             
                        latitudes = df_model["latitude"].astype(float).tolist()
         | 
| 301 | 
             
                        longitudes = df_model["longitude"].astype(float).tolist()
         | 
| 302 | 
            +
                        model_label = f"Model : {df['model'].unique()[0]}"
         | 
| 303 | 
            +
             | 
| 304 |  | 
| 305 | 
             
                    fig.add_trace(
         | 
| 306 | 
             
                        go.Scattermapbox(
         | 
|  | |
| 323 | 
             
                        mapbox_zoom=3,
         | 
| 324 | 
             
                        mapbox_center={"lat": 46.6, "lon": 2.0},
         | 
| 325 | 
             
                        coloraxis_colorbar=dict(title=f"{indicator_label}"),  # Add legend
         | 
| 326 | 
            +
                        title=f"{indicator_label} in {year} in France ({model_label}) " # Title
         | 
| 327 | 
             
                    )
         | 
| 328 | 
             
                    return fig
         | 
| 329 |  | 
    	
        climateqa/engine/talk_to_data/sql_query.py
    CHANGED
    
    | @@ -1,41 +1,23 @@ | |
| 1 | 
            -
            import  | 
| 2 | 
            -
             | 
|  | |
| 3 |  | 
| 4 | 
            -
             | 
| 5 | 
            -
            class SqlQueryOutput(TypedDict):
         | 
| 6 | 
            -
                labels: list[str]
         | 
| 7 | 
            -
                data: list[list[Any]]
         | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
            def execute_sql_query(db_path: str, sql_query: str) -> SqlQueryOutput:
         | 
| 11 | 
             
                """Execute the SQL Query on the sqlite database
         | 
| 12 |  | 
| 13 | 
             
                Args:
         | 
| 14 | 
            -
                    db_ (str): path to the sqlite database
         | 
| 15 | 
             
                    sql_query (str): sql query to execute
         | 
| 16 |  | 
| 17 | 
             
                Returns:
         | 
| 18 | 
             
                    SqlQueryOutput: labels of the selected column and fetched data
         | 
| 19 | 
             
                """
         | 
| 20 |  | 
| 21 | 
            -
                # Connect to sqlite3 database
         | 
| 22 | 
            -
                conn = sqlite3.connect(db_path)
         | 
| 23 | 
            -
                cursor = conn.cursor()
         | 
| 24 |  | 
| 25 | 
             
                # Execute the query
         | 
| 26 | 
            -
                 | 
| 27 | 
            -
             | 
| 28 | 
            -
                # Fetch labels of selected columns
         | 
| 29 | 
            -
                labels = [desc[0] for desc in cursor.description]
         | 
| 30 |  | 
| 31 | 
            -
                #  | 
| 32 | 
            -
                 | 
| 33 | 
            -
                conn.close()
         | 
| 34 | 
            -
             | 
| 35 | 
            -
                return {
         | 
| 36 | 
            -
                    "labels": labels,
         | 
| 37 | 
            -
                    "data": data,
         | 
| 38 | 
            -
                }
         | 
| 39 |  | 
| 40 |  | 
| 41 | 
             
            class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
         | 
| @@ -60,15 +42,13 @@ def indicator_per_year_at_location_query( | |
| 60 | 
             
                indicator_column = params.get("indicator_column")
         | 
| 61 | 
             
                latitude = params.get("latitude")
         | 
| 62 | 
             
                longitude = params.get("longitude")
         | 
| 63 | 
            -
                model = params.get('model')
         | 
| 64 |  | 
| 65 | 
             
                if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
         | 
| 66 | 
             
                    return ""
         | 
| 67 |  | 
| 68 | 
            -
                 | 
| 69 | 
            -
             | 
| 70 | 
            -
                 | 
| 71 | 
            -
                    sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nAnd model = '{model}' \nOrder by Year"
         | 
| 72 |  | 
| 73 | 
             
                return sql_query
         | 
| 74 |  | 
| @@ -91,12 +71,10 @@ def indicator_for_given_year_query( | |
| 91 | 
             
                """
         | 
| 92 | 
             
                indicator_column = params.get("indicator_column")
         | 
| 93 | 
             
                year = params.get('year')
         | 
| 94 | 
            -
                model = params.get('model')
         | 
| 95 | 
             
                if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
         | 
| 96 | 
             
                    return ""
         | 
| 97 |  | 
| 98 | 
            -
                 | 
| 99 | 
            -
             | 
| 100 | 
            -
                 | 
| 101 | 
            -
                    sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}\nAnd model = '{model}'"
         | 
| 102 | 
             
                return sql_query    
         | 
|  | |
| 1 | 
            +
            from typing import TypedDict
         | 
| 2 | 
            +
            import duckdb
         | 
| 3 | 
            +
            import pandas as pd
         | 
| 4 |  | 
| 5 | 
            +
            def execute_sql_query(sql_query: str) -> pd.DataFrame:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 6 | 
             
                """Execute the SQL Query on the sqlite database
         | 
| 7 |  | 
| 8 | 
             
                Args:
         | 
|  | |
| 9 | 
             
                    sql_query (str): sql query to execute
         | 
| 10 |  | 
| 11 | 
             
                Returns:
         | 
| 12 | 
             
                    SqlQueryOutput: labels of the selected column and fetched data
         | 
| 13 | 
             
                """
         | 
| 14 |  | 
|  | |
|  | |
|  | |
| 15 |  | 
| 16 | 
             
                # Execute the query
         | 
| 17 | 
            +
                results = duckdb.sql(sql_query)
         | 
|  | |
|  | |
|  | |
| 18 |  | 
| 19 | 
            +
                # return fetched data
         | 
| 20 | 
            +
                return results.fetchdf()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 21 |  | 
| 22 |  | 
| 23 | 
             
            class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
         | 
|  | |
| 42 | 
             
                indicator_column = params.get("indicator_column")
         | 
| 43 | 
             
                latitude = params.get("latitude")
         | 
| 44 | 
             
                longitude = params.get("longitude")
         | 
|  | |
| 45 |  | 
| 46 | 
             
                if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
         | 
| 47 | 
             
                    return ""
         | 
| 48 |  | 
| 49 | 
            +
                table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
         | 
|  | |
| 52 |  | 
| 53 | 
             
                return sql_query
         | 
| 54 |  | 
|  | |
| 71 | 
             
                """
         | 
| 72 | 
             
                indicator_column = params.get("indicator_column")
         | 
| 73 | 
             
                year = params.get('year')
         | 
|  | |
| 74 | 
             
                if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
         | 
| 75 | 
             
                    return ""
         | 
| 76 |  | 
| 77 | 
            +
                table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
         | 
|  | |
| 80 | 
             
                return sql_query    
         | 
    	
        climateqa/engine/talk_to_data/utils.py
    CHANGED
    
    | @@ -1,11 +1,10 @@ | |
| 1 | 
             
            import re
         | 
| 2 | 
             
            from typing import Annotated, TypedDict
         | 
| 3 | 
            -
             | 
| 4 | 
            -
            from sympy import use
         | 
| 5 | 
             
            from geopy.geocoders import Nominatim
         | 
| 6 | 
            -
            import sqlite3
         | 
| 7 | 
             
            import ast
         | 
| 8 | 
             
            from climateqa.engine.llm import get_llm
         | 
|  | |
| 9 | 
             
            from climateqa.engine.talk_to_data.plot import PLOTS, Plot
         | 
| 10 | 
             
            from langchain_core.prompts import ChatPromptTemplate
         | 
| 11 |  | 
| @@ -35,7 +34,7 @@ class ArrayOutput(TypedDict): | |
| 35 |  | 
| 36 | 
             
                array: Annotated[str, ..., "Syntactically valid python array."]
         | 
| 37 |  | 
| 38 | 
            -
            def detect_year_with_openai(sentence: str):
         | 
| 39 | 
             
                """
         | 
| 40 | 
             
                Detects years in a sentence using OpenAI's API via LangChain.
         | 
| 41 | 
             
                """
         | 
| @@ -56,7 +55,7 @@ def detect_year_with_openai(sentence: str): | |
| 56 | 
             
                if len(years_list) > 0:
         | 
| 57 | 
             
                    return years_list[0]
         | 
| 58 | 
             
                else:
         | 
| 59 | 
            -
                    return  | 
| 60 |  | 
| 61 |  | 
| 62 | 
             
            def detectTable(sql_query):
         | 
| @@ -81,24 +80,26 @@ def coords2loc(coords: tuple): | |
| 81 | 
             
                    return "Unknown Location"
         | 
| 82 |  | 
| 83 |  | 
| 84 | 
            -
            def nearestNeighbourSQL( | 
| 85 | 
            -
                conn = sqlite3.connect(db)
         | 
| 86 | 
             
                long = round(location[1], 3)
         | 
| 87 | 
             
                lat = round(location[0], 3)
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                 | 
|  | |
|  | |
| 90 | 
             
                    f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
         | 
| 91 | 
            -
                )
         | 
|  | |
|  | |
|  | |
| 92 | 
             
                # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
         | 
| 93 | 
            -
                results  | 
| 94 | 
            -
                return results[0]
         | 
| 95 |  | 
| 96 |  | 
| 97 | 
            -
            def detect_relevant_tables( | 
| 98 | 
             
                """Detect relevant tables regarding the plot and the user input
         | 
| 99 |  | 
| 100 | 
             
                Args:
         | 
| 101 | 
            -
                    db (str): database path
         | 
| 102 | 
             
                    user_question (str): initial user input
         | 
| 103 | 
             
                    plot (Plot): plot object for which we wanna plot
         | 
| 104 | 
             
                    llm (_type_): LLM
         | 
| @@ -106,19 +107,21 @@ def detect_relevant_tables(db: str, user_question: str, plot: Plot, llm) -> list | |
| 106 | 
             
                Returns:
         | 
| 107 | 
             
                    list[str]: list of table names
         | 
| 108 | 
             
                """
         | 
| 109 | 
            -
                conn = sqlite3.connect(db)
         | 
| 110 | 
            -
                cursor = conn.cursor()
         | 
| 111 |  | 
| 112 | 
             
                # Get all table names
         | 
| 113 | 
            -
                 | 
| 114 | 
            -
                table_names_list = cursor.fetchall()
         | 
| 115 |  | 
| 116 | 
             
                prompt = (
         | 
| 117 | 
             
                    f"You are helping to build a plot following this description : {plot['description']}."
         | 
|  | |
| 118 | 
             
                    f"Based on the description of the plot, which table are appropriate for that kind of plot."
         | 
| 119 | 
            -
                    f" | 
| 120 | 
            -
                    f" | 
|  | |
|  | |
| 121 | 
             
                )
         | 
|  | |
|  | |
| 122 | 
             
                table_names = ast.literal_eval(
         | 
| 123 | 
             
                    llm.invoke(prompt).content.strip("```python\n").strip()
         | 
| 124 | 
             
                )
         | 
| @@ -141,17 +144,28 @@ def detect_relevant_plots(user_question: str, llm): | |
| 141 | 
             
                    plots_description += " - Description: " + plot["description"] + "\n"
         | 
| 142 |  | 
| 143 | 
             
                prompt = (
         | 
| 144 | 
            -
                    f"You are helping to answer a  | 
| 145 | 
            -
                    f" | 
| 146 | 
            -
                    f" | 
| 147 | 
            -
                    f" | 
| 148 | 
            -
                    f" | 
| 149 | 
            -
                    f" | 
| 150 | 
            -
                    f" | 
| 151 | 
             
                )
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                 | 
| 154 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 155 |  | 
| 156 |  | 
| 157 | 
             
            # Next Version
         | 
|  | |
| 1 | 
             
            import re
         | 
| 2 | 
             
            from typing import Annotated, TypedDict
         | 
| 3 | 
            +
            import duckdb
         | 
|  | |
| 4 | 
             
            from geopy.geocoders import Nominatim
         | 
|  | |
| 5 | 
             
            import ast
         | 
| 6 | 
             
            from climateqa.engine.llm import get_llm
         | 
| 7 | 
            +
            from climateqa.engine.talk_to_data.config import DRIAS_TABLES
         | 
| 8 | 
             
            from climateqa.engine.talk_to_data.plot import PLOTS, Plot
         | 
| 9 | 
             
            from langchain_core.prompts import ChatPromptTemplate
         | 
| 10 |  | 
|  | |
| 34 |  | 
| 35 | 
             
                array: Annotated[str, ..., "Syntactically valid python array."]
         | 
| 36 |  | 
| 37 | 
            +
            def detect_year_with_openai(sentence: str) -> str:
         | 
| 38 | 
             
                """
         | 
| 39 | 
             
                Detects years in a sentence using OpenAI's API via LangChain.
         | 
| 40 | 
             
                """
         | 
|  | |
| 55 | 
             
                if len(years_list) > 0:
         | 
| 56 | 
             
                    return years_list[0]
         | 
| 57 | 
             
                else:
         | 
| 58 | 
            +
                    return ""
         | 
| 59 |  | 
| 60 |  | 
| 61 | 
             
            def detectTable(sql_query):
         | 
|  | |
| 80 | 
             
                    return "Unknown Location"
         | 
| 81 |  | 
| 82 |  | 
| 83 | 
            +
            def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
         | 
|  | |
| 84 | 
             
                long = round(location[1], 3)
         | 
| 85 | 
             
                lat = round(location[0], 3)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                results = duckdb.sql(
         | 
| 90 | 
             
                    f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
         | 
| 91 | 
            +
                ).fetchdf()
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                if len(results) == 0:
         | 
| 94 | 
            +
                    return "", ""
         | 
| 95 | 
             
                # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
         | 
| 96 | 
            +
                return results['latitude'].iloc[0], results['longitude'].iloc[0]
         | 
|  | |
| 97 |  | 
| 98 |  | 
| 99 | 
            +
            def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
         | 
| 100 | 
             
                """Detect relevant tables regarding the plot and the user input
         | 
| 101 |  | 
| 102 | 
             
                Args:
         | 
|  | |
| 103 | 
             
                    user_question (str): initial user input
         | 
| 104 | 
             
                    plot (Plot): plot object for which we wanna plot
         | 
| 105 | 
             
                    llm (_type_): LLM
         | 
|  | |
| 107 | 
             
                Returns:
         | 
| 108 | 
             
                    list[str]: list of table names
         | 
| 109 | 
             
                """
         | 
|  | |
|  | |
| 110 |  | 
| 111 | 
             
                # Get all table names
         | 
| 112 | 
            +
                table_names_list = DRIAS_TABLES
         | 
|  | |
| 113 |  | 
| 114 | 
             
                prompt = (
         | 
| 115 | 
             
                    f"You are helping to build a plot following this description : {plot['description']}."
         | 
| 116 | 
            +
                    f"You are given a list of tables and a user question."
         | 
| 117 | 
             
                    f"Based on the description of the plot, which table are appropriate for that kind of plot."
         | 
| 118 | 
            +
                    f"Write the 3 most relevant tables to use. Answer only a python list of table name."
         | 
| 119 | 
            +
                    f"### List of tables : {table_names_list}"
         | 
| 120 | 
            +
                    f"### User question : {user_question}"
         | 
| 121 | 
            +
                    f"### List of table name : "
         | 
| 122 | 
             
                )
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
             
                table_names = ast.literal_eval(
         | 
| 126 | 
             
                    llm.invoke(prompt).content.strip("```python\n").strip()
         | 
| 127 | 
             
                )
         | 
|  | |
| 144 | 
             
                    plots_description += " - Description: " + plot["description"] + "\n"
         | 
| 145 |  | 
| 146 | 
             
                prompt = (
         | 
| 147 | 
            +
                    f"You are helping to answer a quesiton with insightful visualizations."
         | 
| 148 | 
            +
                    f"You are given an user question and a list of plots with their name and description."
         | 
| 149 | 
            +
                    f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
         | 
| 150 | 
            +
                    f"Write the most relevant tables to use. Answer only a python list of plot name."
         | 
| 151 | 
            +
                    f"### Descriptions of the plots : {plots_description}"
         | 
| 152 | 
            +
                    f"### User question : {user_question}"
         | 
| 153 | 
            +
                    f"### Name of the plot : "
         | 
| 154 | 
             
                )
         | 
| 155 | 
            +
                # prompt = (
         | 
| 156 | 
            +
                #     f"You are helping to answer a question with insightful visualizations. "
         | 
| 157 | 
            +
                #     f"Given a list of plots with their name and description: "
         | 
| 158 | 
            +
                #     f"{plots_description} "
         | 
| 159 | 
            +
                #     f"The user question is: {user_question}. "
         | 
| 160 | 
            +
                #     f"Choose the most relevant plots to answer the question. "
         | 
| 161 | 
            +
                #     f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
         | 
| 162 | 
            +
                #     f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
         | 
| 163 | 
            +
                # )
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                plot_names = ast.literal_eval(
         | 
| 166 | 
            +
                    llm.invoke(prompt).content.strip("```python\n").strip()
         | 
| 167 | 
            +
                )
         | 
| 168 | 
            +
                return plot_names
         | 
| 169 |  | 
| 170 |  | 
| 171 | 
             
            # Next Version
         | 
    	
        climateqa/engine/talk_to_data/workflow.py
    CHANGED
    
    | @@ -5,6 +5,7 @@ import pandas as pd | |
| 5 |  | 
| 6 | 
             
            from plotly.graph_objects import Figure
         | 
| 7 | 
             
            from climateqa.engine.llm import get_llm
         | 
|  | |
| 8 | 
             
            from climateqa.engine.talk_to_data.plot import PLOTS, Plot
         | 
| 9 | 
             
            from climateqa.engine.talk_to_data.sql_query import execute_sql_query
         | 
| 10 | 
             
            from climateqa.engine.talk_to_data.utils import (
         | 
| @@ -37,12 +38,12 @@ class State(TypedDict): | |
| 37 | 
             
                user_input: str
         | 
| 38 | 
             
                plots: list[str]
         | 
| 39 | 
             
                plot_states: dict[str, PlotState]
         | 
|  | |
| 40 |  | 
| 41 | 
            -
            def drias_workflow( | 
| 42 | 
             
                """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
         | 
| 43 |  | 
| 44 | 
             
                Args:
         | 
| 45 | 
            -
                    db_drias_path (str): path to the drias database
         | 
| 46 | 
             
                    user_input (str): initial user input
         | 
| 47 |  | 
| 48 | 
             
                Returns:
         | 
| @@ -60,8 +61,12 @@ def drias_workflow(db_drias_path: str, user_input: str, model: str) -> State: | |
| 60 | 
             
                state['plots'] = plots
         | 
| 61 |  | 
| 62 | 
             
                if not state['plots']:
         | 
|  | |
| 63 | 
             
                    return state
         | 
| 64 |  | 
|  | |
|  | |
|  | |
| 65 | 
             
                for plot_name in state['plots']:
         | 
| 66 |  | 
| 67 | 
             
                    plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
         | 
| @@ -76,21 +81,23 @@ def drias_workflow(db_drias_path: str, user_input: str, model: str) -> State: | |
| 76 |  | 
| 77 | 
             
                    plot_state['plot_name'] = plot_name
         | 
| 78 |  | 
| 79 | 
            -
                    relevant_tables = find_relevant_tables_per_plot(state, plot,  | 
|  | |
|  | |
| 80 |  | 
| 81 | 
             
                    plot_state['tables'] = relevant_tables
         | 
| 82 |  | 
| 83 | 
            -
                    for table in plot_state['tables']:
         | 
|  | |
|  | |
|  | |
| 84 | 
             
                        table_state: TableState = {
         | 
| 85 | 
             
                            'table_name': table,
         | 
| 86 | 
             
                            'params': {},
         | 
| 87 | 
             
                            'status': 'OK'
         | 
| 88 | 
             
                        } 
         | 
| 89 | 
            -
                        table_state['params'] = {
         | 
| 90 | 
            -
                            'model': model
         | 
| 91 | 
            -
                        }
         | 
| 92 | 
             
                        for param_name in plot['params']:
         | 
| 93 | 
            -
                            param = find_param(state, param_name, table | 
| 94 | 
             
                            if param:
         | 
| 95 | 
             
                                table_state['params'].update(param)
         | 
| 96 |  | 
| @@ -99,17 +106,30 @@ def drias_workflow(db_drias_path: str, user_input: str, model: str) -> State: | |
| 99 | 
             
                        if sql_query == "":
         | 
| 100 | 
             
                            table_state['status'] = 'ERROR'
         | 
| 101 | 
             
                            continue
         | 
|  | |
|  | |
| 102 |  | 
| 103 | 
             
                        table_state['sql_query'] = sql_query
         | 
| 104 | 
            -
                         | 
|  | |
|  | |
|  | |
| 105 |  | 
| 106 | 
            -
                        df = pd.DataFrame(results['data'], columns=results['labels'])
         | 
| 107 | 
             
                        figure = plot['plot_function'](table_state['params'])
         | 
| 108 | 
             
                        table_state['dataframe'] = df
         | 
| 109 | 
             
                        table_state['figure'] = figure
         | 
| 110 | 
             
                        plot_state['table_states'][table] = table_state
         | 
| 111 |  | 
| 112 | 
             
                    state['plot_states'][plot_name] = plot_state
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 113 | 
             
                return state
         | 
| 114 |  | 
| 115 |  | 
| @@ -118,26 +138,25 @@ def find_relevant_plots(state: State, llm) -> list[str]: | |
| 118 | 
             
                relevant_plots = detect_relevant_plots(state['user_input'], llm)
         | 
| 119 | 
             
                return relevant_plots
         | 
| 120 |  | 
| 121 | 
            -
            def find_relevant_tables_per_plot(state: State, plot: Plot,  | 
| 122 | 
             
                print(f"---- Find relevant tables for {plot['name']} ----")
         | 
| 123 | 
            -
                relevant_tables = detect_relevant_tables( | 
| 124 | 
             
                return relevant_tables
         | 
| 125 |  | 
| 126 |  | 
| 127 | 
            -
            def find_param(state: State, param_name:str, table: str | 
| 128 | 
             
                """Perform the good method to retrieve the desired parameter
         | 
| 129 |  | 
| 130 | 
             
                Args:
         | 
| 131 | 
             
                    state (State): state of the workflow
         | 
| 132 | 
             
                    param_name (str): name of the desired parameter
         | 
| 133 | 
             
                    table (str): name of the table
         | 
| 134 | 
            -
                    db_path (str): path to the databse
         | 
| 135 |  | 
| 136 | 
             
                Returns:
         | 
| 137 | 
             
                    dict[str, Any] | None: 
         | 
| 138 | 
             
                """
         | 
| 139 | 
             
                if param_name == 'location':
         | 
| 140 | 
            -
                    location = find_location(state['user_input'], table | 
| 141 | 
             
                    return location
         | 
| 142 | 
             
                if param_name == 'indicator_column':
         | 
| 143 | 
             
                    indicator_column = find_indicator_column(table)
         | 
| @@ -153,13 +172,13 @@ class Location(TypedDict): | |
| 153 | 
             
                latitude: NotRequired[str]
         | 
| 154 | 
             
                longitude: NotRequired[str]
         | 
| 155 |  | 
| 156 | 
            -
            def find_location(user_input: str, table: str | 
| 157 | 
             
                print(f"---- Find location in table {table} ----")
         | 
| 158 | 
             
                location = detect_location_with_openai(user_input)
         | 
| 159 | 
             
                output: Location = {'location' : location}
         | 
| 160 | 
             
                if location:
         | 
| 161 | 
             
                    coords = loc2coords(location)
         | 
| 162 | 
            -
                    neighbour = nearestNeighbourSQL( | 
| 163 | 
             
                    output.update({
         | 
| 164 | 
             
                        "latitude": neighbour[0],
         | 
| 165 | 
             
                        "longitude": neighbour[1],
         | 
| @@ -182,23 +201,8 @@ def find_indicator_column(table: str) -> str: | |
| 182 | 
             
                """
         | 
| 183 |  | 
| 184 | 
             
                print(f"---- Find indicator column in table {table} ----")
         | 
| 185 | 
            -
             | 
| 186 | 
            -
             | 
| 187 | 
            -
                    "total_summer_precipiation": "total_summer_precipitation",
         | 
| 188 | 
            -
                    "total_annual_precipitation": "total_annual_precipitation",
         | 
| 189 | 
            -
                    "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
         | 
| 190 | 
            -
                    "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
         | 
| 191 | 
            -
                    "extreme_precipitation_intensity": "extreme_precipitation_intensity",
         | 
| 192 | 
            -
                    "mean_winter_temperature": "mean_winter_temperature",
         | 
| 193 | 
            -
                    "mean_summer_temperature": "mean_summer_temperature",
         | 
| 194 | 
            -
                    "mean_annual_temperature": "mean_annual_temperature",
         | 
| 195 | 
            -
                    "number_of_tropical_nights": "number_tropical_nights",
         | 
| 196 | 
            -
                    "maximum_summer_temperature": "maximum_summer_temperature",
         | 
| 197 | 
            -
                    "number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
         | 
| 198 | 
            -
                    "number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
         | 
| 199 | 
            -
                    "number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
         | 
| 200 | 
            -
                }
         | 
| 201 | 
            -
                return indicator_columns_per_table[table]
         | 
| 202 |  | 
| 203 |  | 
| 204 | 
             
            # def make_write_query_node():
         | 
|  | |
| 5 |  | 
| 6 | 
             
            from plotly.graph_objects import Figure
         | 
| 7 | 
             
            from climateqa.engine.llm import get_llm
         | 
| 8 | 
            +
            from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
         | 
| 9 | 
             
            from climateqa.engine.talk_to_data.plot import PLOTS, Plot
         | 
| 10 | 
             
            from climateqa.engine.talk_to_data.sql_query import execute_sql_query
         | 
| 11 | 
             
            from climateqa.engine.talk_to_data.utils import (
         | 
|  | |
| 38 | 
             
                user_input: str
         | 
| 39 | 
             
                plots: list[str]
         | 
| 40 | 
             
                plot_states: dict[str, PlotState]
         | 
| 41 | 
            +
                error: NotRequired[str]
         | 
| 42 |  | 
| 43 | 
            +
            def drias_workflow(user_input: str) -> State:
         | 
| 44 | 
             
                """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
         | 
| 45 |  | 
| 46 | 
             
                Args:
         | 
|  | |
| 47 | 
             
                    user_input (str): initial user input
         | 
| 48 |  | 
| 49 | 
             
                Returns:
         | 
|  | |
| 61 | 
             
                state['plots'] = plots
         | 
| 62 |  | 
| 63 | 
             
                if not state['plots']:
         | 
| 64 | 
            +
                    state['error'] = 'There is no plot to answer to the question'
         | 
| 65 | 
             
                    return state
         | 
| 66 |  | 
| 67 | 
            +
                have_relevant_table = False
         | 
| 68 | 
            +
                have_sql_query = False
         | 
| 69 | 
            +
                have_dataframe = False
         | 
| 70 | 
             
                for plot_name in state['plots']:
         | 
| 71 |  | 
| 72 | 
             
                    plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
         | 
|  | |
| 81 |  | 
| 82 | 
             
                    plot_state['plot_name'] = plot_name
         | 
| 83 |  | 
| 84 | 
            +
                    relevant_tables = find_relevant_tables_per_plot(state, plot, llm)
         | 
| 85 | 
            +
                    if len(relevant_tables) > 0 :
         | 
| 86 | 
            +
                        have_relevant_table = True
         | 
| 87 |  | 
| 88 | 
             
                    plot_state['tables'] = relevant_tables
         | 
| 89 |  | 
| 90 | 
            +
                    for n, table in enumerate(plot_state['tables']):
         | 
| 91 | 
            +
                        if n > 2:
         | 
| 92 | 
            +
                            break
         | 
| 93 | 
            +
             | 
| 94 | 
             
                        table_state: TableState = {
         | 
| 95 | 
             
                            'table_name': table,
         | 
| 96 | 
             
                            'params': {},
         | 
| 97 | 
             
                            'status': 'OK'
         | 
| 98 | 
             
                        } 
         | 
|  | |
|  | |
|  | |
| 99 | 
             
                        for param_name in plot['params']:
         | 
| 100 | 
            +
                            param = find_param(state, param_name, table)
         | 
| 101 | 
             
                            if param:
         | 
| 102 | 
             
                                table_state['params'].update(param)
         | 
| 103 |  | 
|  | |
| 106 | 
             
                        if sql_query == "":
         | 
| 107 | 
             
                            table_state['status'] = 'ERROR'
         | 
| 108 | 
             
                            continue
         | 
| 109 | 
            +
                        else : 
         | 
| 110 | 
            +
                            have_sql_query = True
         | 
| 111 |  | 
| 112 | 
             
                        table_state['sql_query'] = sql_query
         | 
| 113 | 
            +
                        df = execute_sql_query(sql_query)
         | 
| 114 | 
            +
                        
         | 
| 115 | 
            +
                        if len(df) > 0:
         | 
| 116 | 
            +
                            have_dataframe = True
         | 
| 117 |  | 
|  | |
| 118 | 
             
                        figure = plot['plot_function'](table_state['params'])
         | 
| 119 | 
             
                        table_state['dataframe'] = df
         | 
| 120 | 
             
                        table_state['figure'] = figure
         | 
| 121 | 
             
                        plot_state['table_states'][table] = table_state
         | 
| 122 |  | 
| 123 | 
             
                    state['plot_states'][plot_name] = plot_state
         | 
| 124 | 
            +
                
         | 
| 125 | 
            +
                if not have_relevant_table:
         | 
| 126 | 
            +
                    state['error'] = "There is no relevant table in the our database to answer your question"
         | 
| 127 | 
            +
                elif not have_sql_query:
         | 
| 128 | 
            +
                    state['error'] = "There is no relevant sql query on our database that can help to answer your question"
         | 
| 129 | 
            +
                elif not have_dataframe:
         | 
| 130 | 
            +
                    state['error'] = "There is no data in our table that can answer to your question"
         | 
| 131 | 
            +
                
         | 
| 132 | 
            +
             | 
| 133 | 
             
                return state
         | 
| 134 |  | 
| 135 |  | 
|  | |
| 138 | 
             
                relevant_plots = detect_relevant_plots(state['user_input'], llm)
         | 
| 139 | 
             
                return relevant_plots
         | 
| 140 |  | 
| 141 | 
            +
            def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
         | 
| 142 | 
             
                print(f"---- Find relevant tables for {plot['name']} ----")
         | 
| 143 | 
            +
                relevant_tables = detect_relevant_tables(state['user_input'], plot, llm)
         | 
| 144 | 
             
                return relevant_tables
         | 
| 145 |  | 
| 146 |  | 
| 147 | 
            +
            def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
         | 
| 148 | 
             
                """Perform the good method to retrieve the desired parameter
         | 
| 149 |  | 
| 150 | 
             
                Args:
         | 
| 151 | 
             
                    state (State): state of the workflow
         | 
| 152 | 
             
                    param_name (str): name of the desired parameter
         | 
| 153 | 
             
                    table (str): name of the table
         | 
|  | |
| 154 |  | 
| 155 | 
             
                Returns:
         | 
| 156 | 
             
                    dict[str, Any] | None: 
         | 
| 157 | 
             
                """
         | 
| 158 | 
             
                if param_name == 'location':
         | 
| 159 | 
            +
                    location = find_location(state['user_input'], table)
         | 
| 160 | 
             
                    return location
         | 
| 161 | 
             
                if param_name == 'indicator_column':
         | 
| 162 | 
             
                    indicator_column = find_indicator_column(table)
         | 
|  | |
| 172 | 
             
                latitude: NotRequired[str]
         | 
| 173 | 
             
                longitude: NotRequired[str]
         | 
| 174 |  | 
| 175 | 
            +
            def find_location(user_input: str, table: str) -> Location:
         | 
| 176 | 
             
                print(f"---- Find location in table {table} ----")
         | 
| 177 | 
             
                location = detect_location_with_openai(user_input)
         | 
| 178 | 
             
                output: Location = {'location' : location}
         | 
| 179 | 
             
                if location:
         | 
| 180 | 
             
                    coords = loc2coords(location)
         | 
| 181 | 
            +
                    neighbour = nearestNeighbourSQL(coords, table)
         | 
| 182 | 
             
                    output.update({
         | 
| 183 | 
             
                        "latitude": neighbour[0],
         | 
| 184 | 
             
                        "longitude": neighbour[1],
         | 
|  | |
| 201 | 
             
                """
         | 
| 202 |  | 
| 203 | 
             
                print(f"---- Find indicator column in table {table} ----")
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                return INDICATOR_COLUMNS_PER_TABLE[table]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 206 |  | 
| 207 |  | 
| 208 | 
             
            # def make_write_query_node():
         | 
    	
        style.css
    CHANGED
    
    | @@ -644,17 +644,23 @@ a { | |
| 644 | 
             
                overflow-y:scroll;
         | 
| 645 | 
             
            }
         | 
| 646 |  | 
|  | |
|  | |
|  | |
|  | |
| 647 | 
             
            #sql-query span{
         | 
| 648 | 
             
                display: none;
         | 
| 649 | 
             
            }
         | 
| 650 | 
             
            div#tab-vanna{
         | 
| 651 | 
             
                max-height: 100¨vh;
         | 
| 652 | 
            -
                overflow-y: | 
| 653 | 
             
            } 
         | 
| 654 | 
             
            #vanna-plot{
         | 
| 655 | 
             
                max-height:500px
         | 
| 656 | 
             
            }
         | 
| 657 |  | 
| 658 | 
            -
            # | 
| 659 | 
            -
                 | 
|  | |
|  | |
| 660 | 
             
            }
         | 
|  | |
| 644 | 
             
                overflow-y:scroll;
         | 
| 645 | 
             
            }
         | 
| 646 |  | 
| 647 | 
            +
            #sql-query textarea{
         | 
| 648 | 
            +
                min-height: 100px !important;
         | 
| 649 | 
            +
            }
         | 
| 650 | 
            +
             | 
| 651 | 
             
            #sql-query span{
         | 
| 652 | 
             
                display: none;
         | 
| 653 | 
             
            }
         | 
| 654 | 
             
            div#tab-vanna{
         | 
| 655 | 
             
                max-height: 100¨vh;
         | 
| 656 | 
            +
                overflow-y: hidden;
         | 
| 657 | 
             
            } 
         | 
| 658 | 
             
            #vanna-plot{
         | 
| 659 | 
             
                max-height:500px
         | 
| 660 | 
             
            }
         | 
| 661 |  | 
| 662 | 
            +
            #pagination-display{
         | 
| 663 | 
            +
                text-align: center;
         | 
| 664 | 
            +
                font-weight: bold;
         | 
| 665 | 
            +
                font-size: 16px;
         | 
| 666 | 
             
            }
         | 
