feature/drias_parallelization (#25)
Browse files- log to huggingface (f9c4c84a71d320c6db05ee099b9e35492ba7b184)
- Merged in feat/logs_on_huggingface (pull request #5) (261632833e7d939c321f208edfd6576e68947b4b)
- feat: added multithreading to run sql queries in talk to drias (705ccece7775c65a9c7b73b091cdddbc4246f2e7)
- chore: remove prints in talk to drias workflow (a967134f90c70d87bb3d786f2feb81f2e56fdb9f)
- Merged in feat/improve_drias_exeuction_time (pull request #6) (7c38528636ac19c1efa240a122e523ed0c34706a)
- fix import (05b8df9c9b74926459da70797b6852ff07a4d838)
- Merge branch 'main' into dev (8fb231c8beabf8a6406f05cf4cac564c5d81c7ce)
- Merged in dev (pull request #7) (6b9f71b1cf216eef0fd7973412f675d8633a5f4a)
- fix import (b35df2a8160723e43f74a040aa94983069066213)
- Merge branch 'main' of https://bitbucket.org/ekimetrics/climate_qa (f96cfd0715ec2b1ed7a78775ea7f8722f5793d8f)
Co-authored-by: Armand Demasson <armanddemasson@users.noreply.huggingface.co>
- climateqa/chat.py +5 -54
- climateqa/engine/talk_to_data/main.py +5 -2
- climateqa/engine/talk_to_data/sql_query.py +3 -2
- climateqa/engine/talk_to_data/{workflow.py → talk_to_drias.py} +126 -96
- climateqa/handle_stream_events.py +1 -1
- climateqa/logging.py +194 -0
- data/drias/drias.db +0 -3
- front/tabs/chat_interface.py +1 -1
- front/tabs/tab_drias.py +6 -31
- front/utils.py +0 -11
- requirements.txt +1 -0
| @@ -12,15 +12,11 @@ from .handle_stream_events import ( | |
| 12 | 
             
                convert_to_docs_to_html,
         | 
| 13 | 
             
                stream_answer,
         | 
| 14 | 
             
                handle_retrieved_owid_graphs,
         | 
| 15 | 
            -
                serialize_docs,
         | 
| 16 | 
             
            )
         | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
                file_client = share_client.get_file_client(file)
         | 
| 22 | 
            -
                file_client.upload_file(logs)
         | 
| 23 | 
            -
             | 
| 24 | 
             
            # Chat functions
         | 
| 25 | 
             
            def start_chat(query, history, search_only):
         | 
| 26 | 
             
                history = history + [ChatMessage(role="user", content=query)]
         | 
| @@ -32,28 +28,6 @@ def start_chat(query, history, search_only): | |
| 32 | 
             
            def finish_chat():
         | 
| 33 | 
             
                return gr.update(interactive=True, value="")
         | 
| 34 |  | 
| 35 | 
            -
            def log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id):
         | 
| 36 | 
            -
                try:
         | 
| 37 | 
            -
                    # Log interaction to Azure if not in local environment
         | 
| 38 | 
            -
                    if os.getenv("GRADIO_ENV") != "local":
         | 
| 39 | 
            -
                        timestamp = str(datetime.now().timestamp())
         | 
| 40 | 
            -
                        prompt = history[1]["content"]
         | 
| 41 | 
            -
                        logs = {
         | 
| 42 | 
            -
                            "user_id": str(user_id),
         | 
| 43 | 
            -
                            "prompt": prompt,
         | 
| 44 | 
            -
                            "query": prompt,
         | 
| 45 | 
            -
                            "question": output_query,
         | 
| 46 | 
            -
                            "sources": sources,
         | 
| 47 | 
            -
                            "docs": serialize_docs(docs),
         | 
| 48 | 
            -
                            "answer": history[-1].content,
         | 
| 49 | 
            -
                            "time": timestamp,
         | 
| 50 | 
            -
                        }
         | 
| 51 | 
            -
                        log_on_azure(f"{timestamp}.json", logs, share_client)
         | 
| 52 | 
            -
                except Exception as e:
         | 
| 53 | 
            -
                    print(f"Error logging on Azure Blob Storage: {e}")
         | 
| 54 | 
            -
                    error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
         | 
| 55 | 
            -
                    raise gr.Error(error_msg)
         | 
| 56 | 
            -
             | 
| 57 | 
             
            def handle_numerical_data(event):
         | 
| 58 | 
             
                if event["name"] == "retrieve_drias_data" and event["event"] == "on_chain_end":
         | 
| 59 | 
             
                    numerical_data = event["data"]["output"]["drias_data"]
         | 
| @@ -61,27 +35,6 @@ def handle_numerical_data(event): | |
| 61 | 
             
                    return numerical_data, sql_query
         | 
| 62 | 
             
                return None, None
         | 
| 63 |  | 
| 64 | 
            -
            def log_drias_interaction_to_azure(query, sql_query, data, share_client, user_id):
         | 
| 65 | 
            -
                try:
         | 
| 66 | 
            -
                    # Log interaction to Azure if not in local environment
         | 
| 67 | 
            -
                    if os.getenv("GRADIO_ENV") != "local":
         | 
| 68 | 
            -
                        timestamp = str(datetime.now().timestamp())
         | 
| 69 | 
            -
                        logs = {
         | 
| 70 | 
            -
                            "user_id": str(user_id),
         | 
| 71 | 
            -
                            "query": query,
         | 
| 72 | 
            -
                            "sql_query": sql_query,
         | 
| 73 | 
            -
                            # "data": data.to_dict() if data is not None else None,
         | 
| 74 | 
            -
                            "time": timestamp,
         | 
| 75 | 
            -
                        }
         | 
| 76 | 
            -
                        log_on_azure(f"drias_{timestamp}.json", logs, share_client)
         | 
| 77 | 
            -
                        print(f"Logged Drias interaction to Azure Blob Storage: {logs}")
         | 
| 78 | 
            -
                    else:
         | 
| 79 | 
            -
                        print("share_client or user_id is None, or GRADIO_ENV is local")
         | 
| 80 | 
            -
                except Exception as e:
         | 
| 81 | 
            -
                    print(f"Error logging Drias interaction on Azure Blob Storage: {e}")
         | 
| 82 | 
            -
                    error_msg = f"Drias Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
         | 
| 83 | 
            -
                    raise gr.Error(error_msg)
         | 
| 84 | 
            -
             | 
| 85 | 
             
            # Main chat function
         | 
| 86 | 
             
            async def chat_stream(
         | 
| 87 | 
             
                agent : CompiledStateGraph,
         | 
| @@ -235,9 +188,7 @@ async def chat_stream( | |
| 235 | 
             
                    print(f"Event {event} has failed")
         | 
| 236 | 
             
                    raise gr.Error(str(e))
         | 
| 237 |  | 
| 238 | 
            -
             
         | 
| 239 | 
            -
             | 
| 240 | 
             
                # Call the function to log interaction
         | 
| 241 | 
            -
                 | 
| 242 |  | 
| 243 | 
             
                yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
         | 
|  | |
| 12 | 
             
                convert_to_docs_to_html,
         | 
| 13 | 
             
                stream_answer,
         | 
| 14 | 
             
                handle_retrieved_owid_graphs,
         | 
|  | |
| 15 | 
             
            )
         | 
| 16 | 
            +
            from .logging import (
         | 
| 17 | 
            +
                log_interaction_to_huggingface
         | 
| 18 | 
            +
            )
         | 
| 19 | 
            +
                    
         | 
|  | |
|  | |
|  | |
| 20 | 
             
            # Chat functions
         | 
| 21 | 
             
            def start_chat(query, history, search_only):
         | 
| 22 | 
             
                history = history + [ChatMessage(role="user", content=query)]
         | 
|  | |
| 28 | 
             
            def finish_chat():
         | 
| 29 | 
             
                return gr.update(interactive=True, value="")
         | 
| 30 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 31 | 
             
            def handle_numerical_data(event):
         | 
| 32 | 
             
                if event["name"] == "retrieve_drias_data" and event["event"] == "on_chain_end":
         | 
| 33 | 
             
                    numerical_data = event["data"]["output"]["drias_data"]
         | 
|  | |
| 35 | 
             
                    return numerical_data, sql_query
         | 
| 36 | 
             
                return None, None
         | 
| 37 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 38 | 
             
            # Main chat function
         | 
| 39 | 
             
            async def chat_stream(
         | 
| 40 | 
             
                agent : CompiledStateGraph,
         | 
|  | |
| 188 | 
             
                    print(f"Event {event} has failed")
         | 
| 189 | 
             
                    raise gr.Error(str(e))
         | 
| 190 |  | 
|  | |
|  | |
| 191 | 
             
                # Call the function to log interaction
         | 
| 192 | 
            +
                log_interaction_to_huggingface(history, output_query, sources, docs, share_client, user_id)
         | 
| 193 |  | 
| 194 | 
             
                yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
         | 
| @@ -1,5 +1,6 @@ | |
| 1 | 
            -
            from climateqa.engine.talk_to_data. | 
| 2 | 
             
            from climateqa.engine.llm import get_llm
         | 
|  | |
| 3 | 
             
            import ast
         | 
| 4 |  | 
| 5 | 
             
            llm = get_llm(provider="openai")
         | 
| @@ -37,7 +38,7 @@ def ask_llm_column_names(sql_query: str, llm) -> list[str]: | |
| 37 | 
             
                columns_list = ast.literal_eval(columns.strip("```python\n").strip())
         | 
| 38 | 
             
                return columns_list
         | 
| 39 |  | 
| 40 | 
            -
            async def ask_drias(query: str, index_state: int = 0) -> tuple:
         | 
| 41 | 
             
                """Main function to process a DRIAS query and return results.
         | 
| 42 |  | 
| 43 | 
             
                This function orchestrates the DRIAS workflow, processing a user query to generate
         | 
| @@ -85,6 +86,8 @@ async def ask_drias(query: str, index_state: int = 0) -> tuple: | |
| 85 | 
             
                sql_query = sql_queries[index_state]
         | 
| 86 | 
             
                dataframe = result_dataframes[index_state]
         | 
| 87 | 
             
                figure = figures[index_state](dataframe)
         | 
|  | |
|  | |
| 88 |  | 
| 89 | 
             
                return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, ""
         | 
| 90 |  | 
|  | |
| 1 | 
            +
            from climateqa.engine.talk_to_data.talk_to_drias import drias_workflow
         | 
| 2 | 
             
            from climateqa.engine.llm import get_llm
         | 
| 3 | 
            +
            from climateqa.logging import log_drias_interaction_to_huggingface
         | 
| 4 | 
             
            import ast
         | 
| 5 |  | 
| 6 | 
             
            llm = get_llm(provider="openai")
         | 
|  | |
| 38 | 
             
                columns_list = ast.literal_eval(columns.strip("```python\n").strip())
         | 
| 39 | 
             
                return columns_list
         | 
| 40 |  | 
| 41 | 
            +
            async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tuple:
         | 
| 42 | 
             
                """Main function to process a DRIAS query and return results.
         | 
| 43 |  | 
| 44 | 
             
                This function orchestrates the DRIAS workflow, processing a user query to generate
         | 
|  | |
| 86 | 
             
                sql_query = sql_queries[index_state]
         | 
| 87 | 
             
                dataframe = result_dataframes[index_state]
         | 
| 88 | 
             
                figure = figures[index_state](dataframe)
         | 
| 89 | 
            +
                
         | 
| 90 | 
            +
                log_drias_interaction_to_huggingface(query, sql_query, user_id)
         | 
| 91 |  | 
| 92 | 
             
                return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, ""
         | 
| 93 |  | 
| @@ -22,9 +22,10 @@ async def execute_sql_query(sql_query: str) -> pd.DataFrame: | |
| 22 | 
             
                """
         | 
| 23 | 
             
                def _execute_query():
         | 
| 24 | 
             
                    # Execute the query
         | 
| 25 | 
            -
                     | 
|  | |
| 26 | 
             
                    # return fetched data
         | 
| 27 | 
            -
                    return results | 
| 28 |  | 
| 29 | 
             
                # Run the query in a thread pool to avoid blocking
         | 
| 30 | 
             
                loop = asyncio.get_event_loop()
         | 
|  | |
| 22 | 
             
                """
         | 
| 23 | 
             
                def _execute_query():
         | 
| 24 | 
             
                    # Execute the query
         | 
| 25 | 
            +
                    con = duckdb.connect()
         | 
| 26 | 
            +
                    results = con.sql(sql_query).fetchdf()
         | 
| 27 | 
             
                    # return fetched data
         | 
| 28 | 
            +
                    return results
         | 
| 29 |  | 
| 30 | 
             
                # Run the query in a thread pool to avoid blocking
         | 
| 31 | 
             
                loop = asyncio.get_event_loop()
         | 
| @@ -1,10 +1,12 @@ | |
| 1 | 
             
            import os
         | 
| 2 |  | 
| 3 | 
             
            from typing import Any, Callable, TypedDict, Optional
         | 
|  | |
| 4 | 
             
            import pandas as pd
         | 
| 5 | 
            -
             | 
| 6 | 
             
            from plotly.graph_objects import Figure
         | 
| 7 | 
             
            from climateqa.engine.llm import get_llm
         | 
|  | |
| 8 | 
             
            from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
         | 
| 9 | 
             
            from climateqa.engine.talk_to_data.plot import PLOTS, Plot
         | 
| 10 | 
             
            from climateqa.engine.talk_to_data.sql_query import execute_sql_query
         | 
| @@ -17,6 +19,7 @@ from climateqa.engine.talk_to_data.utils import ( | |
| 17 | 
             
                detect_relevant_tables,
         | 
| 18 | 
             
            )
         | 
| 19 |  | 
|  | |
| 20 | 
             
            ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
         | 
| 21 |  | 
| 22 | 
             
            class TableState(TypedDict):
         | 
| @@ -61,101 +64,6 @@ class State(TypedDict): | |
| 61 | 
             
                plot_states: dict[str, PlotState]
         | 
| 62 | 
             
                error: Optional[str]
         | 
| 63 |  | 
| 64 | 
            -
            async def drias_workflow(user_input: str) -> State:
         | 
| 65 | 
            -
                """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
         | 
| 66 | 
            -
             | 
| 67 | 
            -
                Args:
         | 
| 68 | 
            -
                    user_input (str): initial user input
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                Returns:
         | 
| 71 | 
            -
                    State: Final state with all the results
         | 
| 72 | 
            -
                """
         | 
| 73 | 
            -
                state: State = {
         | 
| 74 | 
            -
                    'user_input': user_input,
         | 
| 75 | 
            -
                    'plots': [],
         | 
| 76 | 
            -
                    'plot_states': {}
         | 
| 77 | 
            -
                }
         | 
| 78 | 
            -
             | 
| 79 | 
            -
                llm = get_llm(provider="openai")
         | 
| 80 | 
            -
             | 
| 81 | 
            -
                plots = await find_relevant_plots(state, llm)
         | 
| 82 | 
            -
                state['plots'] = plots
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                if not state['plots']:
         | 
| 85 | 
            -
                    state['error'] = 'There is no plot to answer to the question'
         | 
| 86 | 
            -
                    return state
         | 
| 87 | 
            -
             | 
| 88 | 
            -
                have_relevant_table = False
         | 
| 89 | 
            -
                have_sql_query = False
         | 
| 90 | 
            -
                have_dataframe = False
         | 
| 91 | 
            -
                for plot_name in state['plots']:
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                    plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
         | 
| 94 | 
            -
                    if plot is None:
         | 
| 95 | 
            -
                        continue
         | 
| 96 | 
            -
                    
         | 
| 97 | 
            -
                    plot_state: PlotState = {
         | 
| 98 | 
            -
                        'plot_name': plot_name,
         | 
| 99 | 
            -
                        'tables': [],
         | 
| 100 | 
            -
                        'table_states': {}
         | 
| 101 | 
            -
                    }
         | 
| 102 | 
            -
             | 
| 103 | 
            -
                    plot_state['plot_name'] = plot_name
         | 
| 104 | 
            -
             | 
| 105 | 
            -
                    relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
         | 
| 106 | 
            -
                    if len(relevant_tables) > 0 :
         | 
| 107 | 
            -
                        have_relevant_table = True
         | 
| 108 | 
            -
             | 
| 109 | 
            -
                    plot_state['tables'] = relevant_tables
         | 
| 110 | 
            -
             | 
| 111 | 
            -
                    params = {}
         | 
| 112 | 
            -
                    for param_name in plot['params']:
         | 
| 113 | 
            -
                        param = await find_param(state, param_name, relevant_tables[0])
         | 
| 114 | 
            -
                        if param:
         | 
| 115 | 
            -
                            params.update(param)
         | 
| 116 | 
            -
             | 
| 117 | 
            -
                    for n, table in enumerate(plot_state['tables']):
         | 
| 118 | 
            -
                        if n > 2:
         | 
| 119 | 
            -
                            break
         | 
| 120 | 
            -
             | 
| 121 | 
            -
                        table_state: TableState = {
         | 
| 122 | 
            -
                            'table_name': table,
         | 
| 123 | 
            -
                            'params': params,
         | 
| 124 | 
            -
                            'status': 'OK'
         | 
| 125 | 
            -
                        } 
         | 
| 126 | 
            -
             | 
| 127 | 
            -
                        table_state["params"]['indicator_column'] = find_indicator_column(table)
         | 
| 128 | 
            -
                        
         | 
| 129 | 
            -
                        sql_query = plot['sql_query'](table, table_state['params'])
         | 
| 130 | 
            -
             | 
| 131 | 
            -
                        if sql_query == "":
         | 
| 132 | 
            -
                            table_state['status'] = 'ERROR'
         | 
| 133 | 
            -
                            continue
         | 
| 134 | 
            -
                        else : 
         | 
| 135 | 
            -
                            have_sql_query = True
         | 
| 136 | 
            -
             | 
| 137 | 
            -
                        table_state['sql_query'] = sql_query
         | 
| 138 | 
            -
                        df = await execute_sql_query(sql_query)
         | 
| 139 | 
            -
             | 
| 140 | 
            -
                        if len(df) > 0:
         | 
| 141 | 
            -
                            have_dataframe = True
         | 
| 142 | 
            -
             | 
| 143 | 
            -
                        figure = plot['plot_function'](table_state['params'])
         | 
| 144 | 
            -
                        table_state['dataframe'] = df
         | 
| 145 | 
            -
                        table_state['figure'] = figure
         | 
| 146 | 
            -
                        plot_state['table_states'][table] = table_state
         | 
| 147 | 
            -
             | 
| 148 | 
            -
                    state['plot_states'][plot_name] = plot_state
         | 
| 149 | 
            -
                
         | 
| 150 | 
            -
                if not have_relevant_table:
         | 
| 151 | 
            -
                    state['error'] = "There is no relevant table in the our database to answer your question"
         | 
| 152 | 
            -
                elif not have_sql_query:
         | 
| 153 | 
            -
                    state['error'] = "There is no relevant sql query on our database that can help to answer your question"
         | 
| 154 | 
            -
                elif not have_dataframe:
         | 
| 155 | 
            -
                    state['error'] = "There is no data in our table that can answer to your question"
         | 
| 156 | 
            -
                
         | 
| 157 | 
            -
                return state
         | 
| 158 | 
            -
             | 
| 159 | 
             
            async def find_relevant_plots(state: State, llm) -> list[str]:
         | 
| 160 | 
             
                print("---- Find relevant plots ----")
         | 
| 161 | 
             
                relevant_plots = await detect_relevant_plots(state['user_input'], llm)
         | 
| @@ -238,6 +146,128 @@ def find_indicator_column(table: str) -> str: | |
| 238 | 
             
                return INDICATOR_COLUMNS_PER_TABLE[table]
         | 
| 239 |  | 
| 240 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 241 | 
             
            # def make_write_query_node():
         | 
| 242 |  | 
| 243 | 
             
            #     def write_query(state):
         | 
|  | |
| 1 | 
             
            import os
         | 
| 2 |  | 
| 3 | 
             
            from typing import Any, Callable, TypedDict, Optional
         | 
| 4 | 
            +
            from numpy import sort
         | 
| 5 | 
             
            import pandas as pd
         | 
| 6 | 
            +
            import asyncio
         | 
| 7 | 
             
            from plotly.graph_objects import Figure
         | 
| 8 | 
             
            from climateqa.engine.llm import get_llm
         | 
| 9 | 
            +
            from climateqa.engine.talk_to_data import sql_query
         | 
| 10 | 
             
            from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
         | 
| 11 | 
             
            from climateqa.engine.talk_to_data.plot import PLOTS, Plot
         | 
| 12 | 
             
            from climateqa.engine.talk_to_data.sql_query import execute_sql_query
         | 
|  | |
| 19 | 
             
                detect_relevant_tables,
         | 
| 20 | 
             
            )
         | 
| 21 |  | 
| 22 | 
            +
             | 
| 23 | 
             
            ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
         | 
| 24 |  | 
| 25 | 
             
            class TableState(TypedDict):
         | 
|  | |
| 64 | 
             
                plot_states: dict[str, PlotState]
         | 
| 65 | 
             
                error: Optional[str]
         | 
| 66 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 67 | 
             
            async def find_relevant_plots(state: State, llm) -> list[str]:
         | 
| 68 | 
             
                print("---- Find relevant plots ----")
         | 
| 69 | 
             
                relevant_plots = await detect_relevant_plots(state['user_input'], llm)
         | 
|  | |
| 146 | 
             
                return INDICATOR_COLUMNS_PER_TABLE[table]
         | 
| 147 |  | 
| 148 |  | 
| 149 | 
            +
            async def process_table(
         | 
| 150 | 
            +
                table: str,
         | 
| 151 | 
            +
                params: dict[str, Any],
         | 
| 152 | 
            +
                plot: Plot,
         | 
| 153 | 
            +
            ) -> TableState:
         | 
| 154 | 
            +
                """Processes a table to extract relevant data and generate visualizations.
         | 
| 155 | 
            +
                
         | 
| 156 | 
            +
                This function retrieves the SQL query for the specified table, executes it,
         | 
| 157 | 
            +
                and generates a visualization based on the results.
         | 
| 158 | 
            +
                
         | 
| 159 | 
            +
                Args:
         | 
| 160 | 
            +
                    table (str): The name of the table to process
         | 
| 161 | 
            +
                    params (dict[str, Any]): Parameters used for querying the table
         | 
| 162 | 
            +
                    plot (Plot): The plot object containing SQL query and visualization function
         | 
| 163 | 
            +
                    
         | 
| 164 | 
            +
                Returns:
         | 
| 165 | 
            +
                    TableState: The state of the processed table
         | 
| 166 | 
            +
                """
         | 
| 167 | 
            +
                table_state: TableState = {
         | 
| 168 | 
            +
                    'table_name': table,
         | 
| 169 | 
            +
                    'params': params.copy(),
         | 
| 170 | 
            +
                    'status': 'OK',
         | 
| 171 | 
            +
                    'dataframe': None,
         | 
| 172 | 
            +
                    'sql_query': None,
         | 
| 173 | 
            +
                    'figure': None
         | 
| 174 | 
            +
                }
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                table_state['params']['indicator_column'] = find_indicator_column(table)
         | 
| 177 | 
            +
                sql_query = plot['sql_query'](table, table_state['params'])
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                if sql_query == "":
         | 
| 180 | 
            +
                    table_state['status'] = 'ERROR'
         | 
| 181 | 
            +
                    return table_state
         | 
| 182 | 
            +
                table_state['sql_query'] = sql_query
         | 
| 183 | 
            +
                df = await execute_sql_query(sql_query)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                table_state['dataframe'] = df
         | 
| 186 | 
            +
                table_state['figure'] = plot['plot_function'](table_state['params'])
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                return table_state
         | 
| 189 | 
            +
             | 
| 190 | 
            +
            async def drias_workflow(user_input: str) -> State:
         | 
| 191 | 
            +
                """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                Args:
         | 
| 194 | 
            +
                    user_input (str): initial user input
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                Returns:
         | 
| 197 | 
            +
                    State: Final state with all the results
         | 
| 198 | 
            +
                """
         | 
| 199 | 
            +
                state: State = {
         | 
| 200 | 
            +
                    'user_input': user_input,
         | 
| 201 | 
            +
                    'plots': [],
         | 
| 202 | 
            +
                    'plot_states': {},
         | 
| 203 | 
            +
                    'error': ''
         | 
| 204 | 
            +
                }
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                llm = get_llm(provider="openai")
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                plots = await find_relevant_plots(state, llm)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                state['plots'] = plots
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                if len(state['plots']) < 1:
         | 
| 213 | 
            +
                    state['error'] = 'There is no plot to answer to the question'
         | 
| 214 | 
            +
                    return state
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                have_relevant_table = False
         | 
| 217 | 
            +
                have_sql_query = False
         | 
| 218 | 
            +
                have_dataframe = False
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                for plot_name in state['plots']:
         | 
| 221 | 
            +
                    
         | 
| 222 | 
            +
                    plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
         | 
| 223 | 
            +
                    if plot is None:
         | 
| 224 | 
            +
                        continue
         | 
| 225 | 
            +
                    
         | 
| 226 | 
            +
                    plot_state: PlotState = {
         | 
| 227 | 
            +
                        'plot_name': plot_name,
         | 
| 228 | 
            +
                        'tables': [],
         | 
| 229 | 
            +
                        'table_states': {}
         | 
| 230 | 
            +
                    }
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    plot_state['plot_name'] = plot_name
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    if len(relevant_tables) > 0 :
         | 
| 237 | 
            +
                        have_relevant_table = True
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    plot_state['tables'] = relevant_tables
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    params = {}
         | 
| 242 | 
            +
                    for param_name in plot['params']:
         | 
| 243 | 
            +
                        param = await find_param(state, param_name, relevant_tables[0])
         | 
| 244 | 
            +
                        if param:
         | 
| 245 | 
            +
                            params.update(param)
         | 
| 246 | 
            +
                    
         | 
| 247 | 
            +
                    tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
         | 
| 248 | 
            +
                    results = await asyncio.gather(*tasks)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    # Store results back in plot_state
         | 
| 251 | 
            +
                    have_dataframe = False
         | 
| 252 | 
            +
                    have_sql_query = False
         | 
| 253 | 
            +
                    for table_state in results:
         | 
| 254 | 
            +
                        if table_state['sql_query']:
         | 
| 255 | 
            +
                            have_sql_query = True
         | 
| 256 | 
            +
                        if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
         | 
| 257 | 
            +
                            have_dataframe = True
         | 
| 258 | 
            +
                        plot_state['table_states'][table_state['table_name']] = table_state
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    state['plot_states'][plot_name] = plot_state
         | 
| 261 | 
            +
                            
         | 
| 262 | 
            +
                if not have_relevant_table:
         | 
| 263 | 
            +
                    state['error'] = "There is no relevant table in our database to answer your question"
         | 
| 264 | 
            +
                elif not have_sql_query:
         | 
| 265 | 
            +
                    state['error'] = "There is no relevant sql query on our database that can help to answer your question"
         | 
| 266 | 
            +
                elif not have_dataframe:
         | 
| 267 | 
            +
                    state['error'] = "There is no data in our table that can answer to your question"
         | 
| 268 | 
            +
                
         | 
| 269 | 
            +
                return state
         | 
| 270 | 
            +
             | 
| 271 | 
             
            # def make_write_query_node():
         | 
| 272 |  | 
| 273 | 
             
            #     def write_query(state):
         | 
| @@ -1,7 +1,7 @@ | |
| 1 | 
             
            from langchain_core.runnables.schema import StreamEvent
         | 
| 2 | 
             
            from gradio import ChatMessage
         | 
| 3 | 
             
            from climateqa.engine.chains.prompts import audience_prompts
         | 
| 4 | 
            -
            from front.utils import make_html_source,parse_output_llm_with_sources | 
| 5 | 
             
            import numpy as np
         | 
| 6 |  | 
| 7 | 
             
            def init_audience(audience :str) -> str:
         | 
|  | |
| 1 | 
             
            from langchain_core.runnables.schema import StreamEvent
         | 
| 2 | 
             
            from gradio import ChatMessage
         | 
| 3 | 
             
            from climateqa.engine.chains.prompts import audience_prompts
         | 
| 4 | 
            +
            from front.utils import make_html_source,parse_output_llm_with_sources
         | 
| 5 | 
             
            import numpy as np
         | 
| 6 |  | 
| 7 | 
             
            def init_audience(audience :str) -> str:
         | 
| @@ -0,0 +1,194 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from datetime import datetime
         | 
| 3 | 
            +
            import json
         | 
| 4 | 
            +
            from huggingface_hub import HfApi
         | 
| 5 | 
            +
            import gradio as gr
         | 
| 6 | 
            +
            import csv
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            def serialize_docs(docs:list)->list:
         | 
| 9 | 
            +
                new_docs = []
         | 
| 10 | 
            +
                for doc in docs:
         | 
| 11 | 
            +
                    new_doc = {}
         | 
| 12 | 
            +
                    new_doc["page_content"] = doc.page_content
         | 
| 13 | 
            +
                    new_doc["metadata"] = doc.metadata
         | 
| 14 | 
            +
                    new_docs.append(new_doc)
         | 
| 15 | 
            +
                return new_docs
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            ## AZURE LOGGING - DEPRECATED
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # def log_on_azure(file, logs, share_client):
         | 
| 20 | 
            +
            #     """Log data to Azure Blob Storage.
         | 
| 21 | 
            +
                
         | 
| 22 | 
            +
            #     Args:
         | 
| 23 | 
            +
            #         file (str): Name of the file to store logs
         | 
| 24 | 
            +
            #         logs (dict): Log data to store
         | 
| 25 | 
            +
            #         share_client: Azure share client instance
         | 
| 26 | 
            +
            #     """
         | 
| 27 | 
            +
            #     logs = json.dumps(logs)
         | 
| 28 | 
            +
            #     file_client = share_client.get_file_client(file)
         | 
| 29 | 
            +
            #     file_client.upload_file(logs)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            # def log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id):
         | 
| 33 | 
            +
            #     """Log chat interaction to Azure and Hugging Face.
         | 
| 34 | 
            +
                
         | 
| 35 | 
            +
            #     Args:
         | 
| 36 | 
            +
            #         history (list): Chat message history
         | 
| 37 | 
            +
            #         output_query (str): Processed query
         | 
| 38 | 
            +
            #         sources (list): Knowledge base sources used
         | 
| 39 | 
            +
            #         docs (list): Retrieved documents
         | 
| 40 | 
            +
            #         share_client: Azure share client instance
         | 
| 41 | 
            +
            #         user_id (str): User identifier
         | 
| 42 | 
            +
            #     """
         | 
| 43 | 
            +
            #     try:
         | 
| 44 | 
            +
            #         # Log interaction to Azure if not in local environment
         | 
| 45 | 
            +
            #         if os.getenv("GRADIO_ENV") != "local":
         | 
| 46 | 
            +
            #             timestamp = str(datetime.now().timestamp())
         | 
| 47 | 
            +
            #             prompt = history[1]["content"]
         | 
| 48 | 
            +
            #             logs = {
         | 
| 49 | 
            +
            #                 "user_id": str(user_id),
         | 
| 50 | 
            +
            #                 "prompt": prompt,
         | 
| 51 | 
            +
            #                 "query": prompt,
         | 
| 52 | 
            +
            #                 "question": output_query,
         | 
| 53 | 
            +
            #                 "sources": sources,
         | 
| 54 | 
            +
            #                 "docs": serialize_docs(docs),
         | 
| 55 | 
            +
            #                 "answer": history[-1].content,
         | 
| 56 | 
            +
            #                 "time": timestamp,
         | 
| 57 | 
            +
            #             }
         | 
| 58 | 
            +
            #             # Log to Azure
         | 
| 59 | 
            +
            #             log_on_azure(f"{timestamp}.json", logs, share_client)
         | 
| 60 | 
            +
            #     except Exception as e:
         | 
| 61 | 
            +
            #         print(f"Error logging on Azure Blob Storage: {e}")
         | 
| 62 | 
            +
            #         error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
         | 
| 63 | 
            +
            #         raise gr.Error(error_msg)
         | 
| 64 | 
            +
                
         | 
| 65 | 
            +
            # def log_drias_interaction_to_azure(query, sql_query, data, share_client, user_id):
         | 
| 66 | 
            +
            #     """Log Drias data interaction to Azure and Hugging Face.
         | 
| 67 | 
            +
                
         | 
| 68 | 
            +
            #     Args:
         | 
| 69 | 
            +
            #         query (str): User query
         | 
| 70 | 
            +
            #         sql_query (str): SQL query used
         | 
| 71 | 
            +
            #         data: Retrieved data
         | 
| 72 | 
            +
            #         share_client: Azure share client instance
         | 
| 73 | 
            +
            #         user_id (str): User identifier
         | 
| 74 | 
            +
            #     """
         | 
| 75 | 
            +
            #     try:
         | 
| 76 | 
            +
            #         # Log interaction to Azure if not in local environment
         | 
| 77 | 
            +
            #         if os.getenv("GRADIO_ENV") != "local":
         | 
| 78 | 
            +
            #             timestamp = str(datetime.now().timestamp())
         | 
| 79 | 
            +
            #             logs = {
         | 
| 80 | 
            +
            #                 "user_id": str(user_id),
         | 
| 81 | 
            +
            #                 "query": query,
         | 
| 82 | 
            +
            #                 "sql_query": sql_query,
         | 
| 83 | 
            +
            #                 "time": timestamp,
         | 
| 84 | 
            +
            #             }
         | 
| 85 | 
            +
            #             log_on_azure(f"drias_{timestamp}.json", logs, share_client)
         | 
| 86 | 
            +
            #             print(f"Logged Drias interaction to Azure Blob Storage: {logs}")
         | 
| 87 | 
            +
            #         else:
         | 
| 88 | 
            +
            #             print("share_client or user_id is None, or GRADIO_ENV is local")
         | 
| 89 | 
            +
            #     except Exception as e:
         | 
| 90 | 
            +
            #         print(f"Error logging Drias interaction on Azure Blob Storage: {e}")
         | 
| 91 | 
            +
            #         error_msg = f"Drias Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
         | 
| 92 | 
            +
            #         raise gr.Error(error_msg)    
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
            ## HUGGING FACE LOGGING
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            def log_on_huggingface(log_filename, logs):
         | 
| 97 | 
            +
                """Log data to Hugging Face dataset repository.
         | 
| 98 | 
            +
                
         | 
| 99 | 
            +
                Args:
         | 
| 100 | 
            +
                    log_filename (str): Name of the file to store logs
         | 
| 101 | 
            +
                    logs (dict): Log data to store
         | 
| 102 | 
            +
                """
         | 
| 103 | 
            +
                try:
         | 
| 104 | 
            +
                    # Get Hugging Face token from environment
         | 
| 105 | 
            +
                    hf_token = os.getenv("HF_LOGS_TOKEN")
         | 
| 106 | 
            +
                    if not hf_token:
         | 
| 107 | 
            +
                        print("HF_LOGS_TOKEN not found in environment variables")
         | 
| 108 | 
            +
                        return
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    # Get repository name from environment or use default
         | 
| 111 | 
            +
                    repo_id = os.getenv("HF_DATASET_REPO", "timeki/climateqa_logs")
         | 
| 112 | 
            +
                    
         | 
| 113 | 
            +
                    # Initialize HfApi
         | 
| 114 | 
            +
                    api = HfApi(token=hf_token)
         | 
| 115 | 
            +
                    
         | 
| 116 | 
            +
                    # Add timestamp to the log data
         | 
| 117 | 
            +
                    logs["timestamp"] = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
         | 
| 118 | 
            +
                    
         | 
| 119 | 
            +
                    # Convert logs to JSON string
         | 
| 120 | 
            +
                    logs_json = json.dumps(logs)
         | 
| 121 | 
            +
                    
         | 
| 122 | 
            +
                    # Upload directly from memory
         | 
| 123 | 
            +
                    api.upload_file(
         | 
| 124 | 
            +
                        path_or_fileobj=logs_json.encode('utf-8'),
         | 
| 125 | 
            +
                        path_in_repo=log_filename,
         | 
| 126 | 
            +
                        repo_id=repo_id,
         | 
| 127 | 
            +
                        repo_type="dataset"
         | 
| 128 | 
            +
                    )
         | 
| 129 | 
            +
                        
         | 
| 130 | 
            +
                except Exception as e:
         | 
| 131 | 
            +
                    print(f"Error logging to Hugging Face: {e}")
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                
         | 
| 134 | 
            +
            def log_interaction_to_huggingface(history, output_query, sources, docs, share_client, user_id):
         | 
| 135 | 
            +
                """Log chat interaction to Hugging Face.
         | 
| 136 | 
            +
                
         | 
| 137 | 
            +
                Args:
         | 
| 138 | 
            +
                    history (list): Chat message history
         | 
| 139 | 
            +
                    output_query (str): Processed query
         | 
| 140 | 
            +
                    sources (list): Knowledge base sources used
         | 
| 141 | 
            +
                    docs (list): Retrieved documents
         | 
| 142 | 
            +
                    share_client: Azure share client instance (unused in this function)
         | 
| 143 | 
            +
                    user_id (str): User identifier
         | 
| 144 | 
            +
                """
         | 
| 145 | 
            +
                try:
         | 
| 146 | 
            +
                    # Log interaction if not in local environment
         | 
| 147 | 
            +
                    if os.getenv("GRADIO_ENV") != "local":
         | 
| 148 | 
            +
                        timestamp = str(datetime.now().timestamp())
         | 
| 149 | 
            +
                        prompt = history[1]["content"]
         | 
| 150 | 
            +
                        logs = {
         | 
| 151 | 
            +
                            "user_id": str(user_id),
         | 
| 152 | 
            +
                            "prompt": prompt,
         | 
| 153 | 
            +
                            "query": prompt,
         | 
| 154 | 
            +
                            "question": output_query,
         | 
| 155 | 
            +
                            "sources": sources,
         | 
| 156 | 
            +
                            "docs": serialize_docs(docs),
         | 
| 157 | 
            +
                            "answer": history[-1].content,
         | 
| 158 | 
            +
                            "time": timestamp,
         | 
| 159 | 
            +
                        }
         | 
| 160 | 
            +
                        # Log to Hugging Face
         | 
| 161 | 
            +
                        log_on_huggingface(f"chat/{timestamp}.json", logs)
         | 
| 162 | 
            +
                except Exception as e:
         | 
| 163 | 
            +
                    print(f"Error logging to Hugging Face: {e}")
         | 
| 164 | 
            +
                    error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
         | 
| 165 | 
            +
                    raise gr.Error(error_msg)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
            def log_drias_interaction_to_huggingface(query, sql_query, user_id):
         | 
| 168 | 
            +
                """Log Drias data interaction to Hugging Face.
         | 
| 169 | 
            +
                
         | 
| 170 | 
            +
                Args:
         | 
| 171 | 
            +
                    query (str): User query
         | 
| 172 | 
            +
                    sql_query (str): SQL query used
         | 
| 173 | 
            +
                    data: Retrieved data
         | 
| 174 | 
            +
                    user_id (str): User identifier
         | 
| 175 | 
            +
                """
         | 
| 176 | 
            +
                try:
         | 
| 177 | 
            +
                    if os.getenv("GRADIO_ENV") != "local":
         | 
| 178 | 
            +
                        timestamp = str(datetime.now().timestamp())
         | 
| 179 | 
            +
                        logs = {
         | 
| 180 | 
            +
                            "user_id": str(user_id),
         | 
| 181 | 
            +
                            "query": query,
         | 
| 182 | 
            +
                            "sql_query": sql_query,
         | 
| 183 | 
            +
                            "time": timestamp,
         | 
| 184 | 
            +
                        }
         | 
| 185 | 
            +
                        log_on_huggingface(f"drias/drias_{timestamp}.json", logs)
         | 
| 186 | 
            +
                        print(f"Logged Drias interaction to Hugging Face: {logs}")
         | 
| 187 | 
            +
                    else:
         | 
| 188 | 
            +
                        print("share_client or user_id is None, or GRADIO_ENV is local")
         | 
| 189 | 
            +
                except Exception as e:
         | 
| 190 | 
            +
                    print(f"Error logging Drias interaction to Hugging Face: {e}")
         | 
| 191 | 
            +
                    error_msg = f"Drias Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
         | 
| 192 | 
            +
                    raise gr.Error(error_msg)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
             | 
| @@ -1,3 +0,0 @@ | |
| 1 | 
            -
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            -
            oid sha256:1e29ba55d0122dc034b76113941769b44214355d4528bcc5b3d8f71f3c50bf59
         | 
| 3 | 
            -
            size 280621056
         | 
|  | |
|  | |
|  | |
|  | 
| @@ -39,7 +39,7 @@ What do you want to learn ? | |
| 39 | 
             
            # """
         | 
| 40 |  | 
| 41 | 
             
            init_prompt_poc = """
         | 
| 42 | 
            -
            Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports, the Paris Climate Action Plan (PCAET), the Biodiversity Plan 2018-2024, and the Acclimaterra reports from the Nouvelle-Aquitaine Region**.
         | 
| 43 |  | 
| 44 | 
             
            ❓ How to use
         | 
| 45 | 
             
            - **Language**: You can ask me your questions in any language. 
         | 
|  | |
| 39 | 
             
            # """
         | 
| 40 |  | 
| 41 | 
             
            init_prompt_poc = """
         | 
| 42 | 
            +
            Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports, the Paris Climate Action Plan (PCAET), the Paris Biodiversity Plan 2018-2024, and the Acclimaterra reports from the Nouvelle-Aquitaine Region**.
         | 
| 43 |  | 
| 44 | 
             
            ❓ How to use
         | 
| 45 | 
             
            - **Language**: You can ask me your questions in any language. 
         | 
| @@ -5,8 +5,6 @@ import pandas as pd | |
| 5 |  | 
| 6 | 
             
            from climateqa.engine.talk_to_data.main import ask_drias
         | 
| 7 | 
             
            from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
         | 
| 8 | 
            -
            from climateqa.chat import log_drias_interaction_to_azure
         | 
| 9 | 
            -
             | 
| 10 |  | 
| 11 | 
             
            class DriasUIElements(TypedDict):
         | 
| 12 | 
             
                tab: gr.Tab
         | 
| @@ -28,8 +26,8 @@ class DriasUIElements(TypedDict): | |
| 28 | 
             
                next_button: gr.Button
         | 
| 29 |  | 
| 30 |  | 
| 31 | 
            -
            async def ask_drias_query(query: str, index_state: int):
         | 
| 32 | 
            -
                result = await ask_drias(query, index_state)
         | 
| 33 | 
             
                return result
         | 
| 34 |  | 
| 35 |  | 
| @@ -196,19 +194,7 @@ def create_drias_ui() -> DriasUIElements: | |
| 196 | 
             
                        next_button=next_button
         | 
| 197 | 
             
                    )
         | 
| 198 |  | 
| 199 | 
            -
             | 
| 200 | 
            -
                """Log Drias interaction to Azure storage."""
         | 
| 201 | 
            -
                print("log_drias_to_azure")
         | 
| 202 | 
            -
                if share_client is not None and user_id is not None:
         | 
| 203 | 
            -
                    log_drias_interaction_to_azure(
         | 
| 204 | 
            -
                        query=query,
         | 
| 205 | 
            -
                        sql_query=sql_query,
         | 
| 206 | 
            -
                        data=data,
         | 
| 207 | 
            -
                        share_client=share_client,
         | 
| 208 | 
            -
                        user_id=user_id
         | 
| 209 | 
            -
                    )
         | 
| 210 | 
            -
                else:
         | 
| 211 | 
            -
                    print("share_client or user_id is None")
         | 
| 212 |  | 
| 213 | 
             
            def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=None) -> None:
         | 
| 214 | 
             
                """Set up all event handlers for the DRIAS tab."""
         | 
| @@ -218,10 +204,7 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id= | |
| 218 | 
             
                plots_state = gr.State([])
         | 
| 219 | 
             
                index_state = gr.State(0)
         | 
| 220 | 
             
                table_names_list = gr.State([])
         | 
| 221 | 
            -
                
         | 
| 222 | 
            -
                def log_drias_interaction(query: str, sql_query: str, data: pd.DataFrame):
         | 
| 223 | 
            -
                    log_drias_to_azure(query, sql_query, data, share_client, user_id)
         | 
| 224 | 
            -
                
         | 
| 225 |  | 
| 226 | 
             
                # Handle example selection
         | 
| 227 | 
             
                ui_elements["examples_hidden"].change(
         | 
| @@ -230,7 +213,7 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id= | |
| 230 | 
             
                    outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
         | 
| 231 | 
             
                ).then(
         | 
| 232 | 
             
                    ask_drias_query,
         | 
| 233 | 
            -
                    inputs=[ui_elements["examples_hidden"], index_state],
         | 
| 234 | 
             
                    outputs=[
         | 
| 235 | 
             
                        ui_elements["drias_sql_query"],
         | 
| 236 | 
             
                        ui_elements["drias_table"],
         | 
| @@ -242,10 +225,6 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id= | |
| 242 | 
             
                        table_names_list,
         | 
| 243 | 
             
                        ui_elements["result_text"],
         | 
| 244 | 
             
                    ],
         | 
| 245 | 
            -
                ).then(
         | 
| 246 | 
            -
                    log_drias_interaction,
         | 
| 247 | 
            -
                    inputs=[ui_elements["examples_hidden"], ui_elements["drias_sql_query"], ui_elements["drias_table"]],
         | 
| 248 | 
            -
                    outputs=[],
         | 
| 249 | 
             
                ).then(
         | 
| 250 | 
             
                    show_results,
         | 
| 251 | 
             
                    inputs=[sql_queries_state, dataframes_state, plots_state],
         | 
| @@ -276,7 +255,7 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id= | |
| 276 | 
             
                    outputs=[ui_elements["details_accordion"]]
         | 
| 277 | 
             
                ).then(
         | 
| 278 | 
             
                    ask_drias_query,
         | 
| 279 | 
            -
                    inputs=[ui_elements["drias_direct_question"], index_state],
         | 
| 280 | 
             
                    outputs=[
         | 
| 281 | 
             
                        ui_elements["drias_sql_query"],
         | 
| 282 | 
             
                        ui_elements["drias_table"],
         | 
| @@ -288,10 +267,6 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id= | |
| 288 | 
             
                        table_names_list,
         | 
| 289 | 
             
                        ui_elements["result_text"],
         | 
| 290 | 
             
                    ],
         | 
| 291 | 
            -
                ).then(
         | 
| 292 | 
            -
                    log_drias_interaction,
         | 
| 293 | 
            -
                    inputs=[ui_elements["drias_direct_question"], ui_elements["drias_sql_query"], ui_elements["drias_table"]],
         | 
| 294 | 
            -
                    outputs=[],
         | 
| 295 | 
             
                ).then(
         | 
| 296 | 
             
                    show_results,
         | 
| 297 | 
             
                    inputs=[sql_queries_state, dataframes_state, plots_state],
         | 
|  | |
| 5 |  | 
| 6 | 
             
            from climateqa.engine.talk_to_data.main import ask_drias
         | 
| 7 | 
             
            from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
         | 
|  | |
|  | |
| 8 |  | 
| 9 | 
             
            class DriasUIElements(TypedDict):
         | 
| 10 | 
             
                tab: gr.Tab
         | 
|  | |
| 26 | 
             
                next_button: gr.Button
         | 
| 27 |  | 
| 28 |  | 
| 29 | 
            +
            async def ask_drias_query(query: str, index_state: int, user_id: str):
         | 
| 30 | 
            +
                result = await ask_drias(query, index_state, user_id)
         | 
| 31 | 
             
                return result
         | 
| 32 |  | 
| 33 |  | 
|  | |
| 194 | 
             
                        next_button=next_button
         | 
| 195 | 
             
                    )
         | 
| 196 |  | 
| 197 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 198 |  | 
| 199 | 
             
            def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=None) -> None:
         | 
| 200 | 
             
                """Set up all event handlers for the DRIAS tab."""
         | 
|  | |
| 204 | 
             
                plots_state = gr.State([])
         | 
| 205 | 
             
                index_state = gr.State(0)
         | 
| 206 | 
             
                table_names_list = gr.State([])
         | 
| 207 | 
            +
                user_id = gr.State(user_id)
         | 
|  | |
|  | |
|  | |
| 208 |  | 
| 209 | 
             
                # Handle example selection
         | 
| 210 | 
             
                ui_elements["examples_hidden"].change(
         | 
|  | |
| 213 | 
             
                    outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
         | 
| 214 | 
             
                ).then(
         | 
| 215 | 
             
                    ask_drias_query,
         | 
| 216 | 
            +
                    inputs=[ui_elements["examples_hidden"], index_state, user_id],
         | 
| 217 | 
             
                    outputs=[
         | 
| 218 | 
             
                        ui_elements["drias_sql_query"],
         | 
| 219 | 
             
                        ui_elements["drias_table"],
         | 
|  | |
| 225 | 
             
                        table_names_list,
         | 
| 226 | 
             
                        ui_elements["result_text"],
         | 
| 227 | 
             
                    ],
         | 
|  | |
|  | |
|  | |
|  | |
| 228 | 
             
                ).then(
         | 
| 229 | 
             
                    show_results,
         | 
| 230 | 
             
                    inputs=[sql_queries_state, dataframes_state, plots_state],
         | 
|  | |
| 255 | 
             
                    outputs=[ui_elements["details_accordion"]]
         | 
| 256 | 
             
                ).then(
         | 
| 257 | 
             
                    ask_drias_query,
         | 
| 258 | 
            +
                    inputs=[ui_elements["drias_direct_question"], index_state, user_id],
         | 
| 259 | 
             
                    outputs=[
         | 
| 260 | 
             
                        ui_elements["drias_sql_query"],
         | 
| 261 | 
             
                        ui_elements["drias_table"],
         | 
|  | |
| 267 | 
             
                        table_names_list,
         | 
| 268 | 
             
                        ui_elements["result_text"],
         | 
| 269 | 
             
                    ],
         | 
|  | |
|  | |
|  | |
|  | |
| 270 | 
             
                ).then(
         | 
| 271 | 
             
                    show_results,
         | 
| 272 | 
             
                    inputs=[sql_queries_state, dataframes_state, plots_state],
         | 
| @@ -13,17 +13,6 @@ def make_pairs(lst:list)->list: | |
| 13 | 
             
                return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
         | 
| 14 |  | 
| 15 |  | 
| 16 | 
            -
            def serialize_docs(docs:list)->list:
         | 
| 17 | 
            -
                new_docs = []
         | 
| 18 | 
            -
                for doc in docs:
         | 
| 19 | 
            -
                    new_doc = {}
         | 
| 20 | 
            -
                    new_doc["page_content"] = doc.page_content
         | 
| 21 | 
            -
                    new_doc["metadata"] = doc.metadata
         | 
| 22 | 
            -
                    new_docs.append(new_doc)
         | 
| 23 | 
            -
                return new_docs
         | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
             
            def parse_output_llm_with_sources(output:str)->str:
         | 
| 28 | 
             
                # Split the content into a list of text and "[Doc X]" references
         | 
| 29 | 
             
                content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
         | 
|  | |
| 13 | 
             
                return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
         | 
| 14 |  | 
| 15 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 16 | 
             
            def parse_output_llm_with_sources(output:str)->str:
         | 
| 17 | 
             
                # Split the content into a list of text and "[Doc X]" references
         | 
| 18 | 
             
                content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
         | 
| @@ -8,6 +8,7 @@ langgraph==0.2.70 | |
| 8 | 
             
            pinecone-client==4.1.0
         | 
| 9 | 
             
            sentence-transformers==2.6.0
         | 
| 10 | 
             
            huggingface-hub==0.25.2
         | 
|  | |
| 11 | 
             
            pyalex==0.13
         | 
| 12 | 
             
            networkx==3.2.1
         | 
| 13 | 
             
            pyvis==0.3.2
         | 
|  | |
| 8 | 
             
            pinecone-client==4.1.0
         | 
| 9 | 
             
            sentence-transformers==2.6.0
         | 
| 10 | 
             
            huggingface-hub==0.25.2
         | 
| 11 | 
            +
            datasets==3.5.0
         | 
| 12 | 
             
            pyalex==0.13
         | 
| 13 | 
             
            networkx==3.2.1
         | 
| 14 | 
             
            pyvis==0.3.2
         | 

