Commit 
							
							·
						
						e7bde89
	
1
								Parent(s):
							
							2beb115
								
feat: added multithreading to run sql queries in talk to drias
Browse files
    	
        climateqa/engine/talk_to_data/sql_query.py
    CHANGED
    
    | @@ -22,9 +22,10 @@ async def execute_sql_query(sql_query: str) -> pd.DataFrame: | |
| 22 | 
             
                """
         | 
| 23 | 
             
                def _execute_query():
         | 
| 24 | 
             
                    # Execute the query
         | 
| 25 | 
            -
                     | 
|  | |
| 26 | 
             
                    # return fetched data
         | 
| 27 | 
            -
                    return results | 
| 28 |  | 
| 29 | 
             
                # Run the query in a thread pool to avoid blocking
         | 
| 30 | 
             
                loop = asyncio.get_event_loop()
         | 
|  | |
| 22 | 
             
                """
         | 
| 23 | 
             
                def _execute_query():
         | 
| 24 | 
             
                    # Execute the query
         | 
| 25 | 
            +
                    con = duckdb.connect()
         | 
| 26 | 
            +
                    results = con.sql(sql_query).fetchdf()
         | 
| 27 | 
             
                    # return fetched data
         | 
| 28 | 
            +
                    return results
         | 
| 29 |  | 
| 30 | 
             
                # Run the query in a thread pool to avoid blocking
         | 
| 31 | 
             
                loop = asyncio.get_event_loop()
         | 
    	
        climateqa/engine/talk_to_data/{workflow.py → talk_to_drias.py}
    RENAMED
    
    | @@ -1,10 +1,12 @@ | |
| 1 | 
             
            import os
         | 
| 2 |  | 
| 3 | 
             
            from typing import Any, Callable, TypedDict, Optional
         | 
|  | |
| 4 | 
             
            import pandas as pd
         | 
| 5 | 
            -
             | 
| 6 | 
             
            from plotly.graph_objects import Figure
         | 
| 7 | 
             
            from climateqa.engine.llm import get_llm
         | 
|  | |
| 8 | 
             
            from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
         | 
| 9 | 
             
            from climateqa.engine.talk_to_data.plot import PLOTS, Plot
         | 
| 10 | 
             
            from climateqa.engine.talk_to_data.sql_query import execute_sql_query
         | 
| @@ -17,6 +19,7 @@ from climateqa.engine.talk_to_data.utils import ( | |
| 17 | 
             
                detect_relevant_tables,
         | 
| 18 | 
             
            )
         | 
| 19 |  | 
|  | |
| 20 | 
             
            ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
         | 
| 21 |  | 
| 22 | 
             
            class TableState(TypedDict):
         | 
| @@ -61,101 +64,6 @@ class State(TypedDict): | |
| 61 | 
             
                plot_states: dict[str, PlotState]
         | 
| 62 | 
             
                error: Optional[str]
         | 
| 63 |  | 
| 64 | 
            -
            async def drias_workflow(user_input: str) -> State:
         | 
| 65 | 
            -
                """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
         | 
| 66 | 
            -
             | 
| 67 | 
            -
                Args:
         | 
| 68 | 
            -
                    user_input (str): initial user input
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                Returns:
         | 
| 71 | 
            -
                    State: Final state with all the results
         | 
| 72 | 
            -
                """
         | 
| 73 | 
            -
                state: State = {
         | 
| 74 | 
            -
                    'user_input': user_input,
         | 
| 75 | 
            -
                    'plots': [],
         | 
| 76 | 
            -
                    'plot_states': {}
         | 
| 77 | 
            -
                }
         | 
| 78 | 
            -
             | 
| 79 | 
            -
                llm = get_llm(provider="openai")
         | 
| 80 | 
            -
             | 
| 81 | 
            -
                plots = await find_relevant_plots(state, llm)
         | 
| 82 | 
            -
                state['plots'] = plots
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                if not state['plots']:
         | 
| 85 | 
            -
                    state['error'] = 'There is no plot to answer to the question'
         | 
| 86 | 
            -
                    return state
         | 
| 87 | 
            -
             | 
| 88 | 
            -
                have_relevant_table = False
         | 
| 89 | 
            -
                have_sql_query = False
         | 
| 90 | 
            -
                have_dataframe = False
         | 
| 91 | 
            -
                for plot_name in state['plots']:
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                    plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
         | 
| 94 | 
            -
                    if plot is None:
         | 
| 95 | 
            -
                        continue
         | 
| 96 | 
            -
                    
         | 
| 97 | 
            -
                    plot_state: PlotState = {
         | 
| 98 | 
            -
                        'plot_name': plot_name,
         | 
| 99 | 
            -
                        'tables': [],
         | 
| 100 | 
            -
                        'table_states': {}
         | 
| 101 | 
            -
                    }
         | 
| 102 | 
            -
             | 
| 103 | 
            -
                    plot_state['plot_name'] = plot_name
         | 
| 104 | 
            -
             | 
| 105 | 
            -
                    relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
         | 
| 106 | 
            -
                    if len(relevant_tables) > 0 :
         | 
| 107 | 
            -
                        have_relevant_table = True
         | 
| 108 | 
            -
             | 
| 109 | 
            -
                    plot_state['tables'] = relevant_tables
         | 
| 110 | 
            -
             | 
| 111 | 
            -
                    params = {}
         | 
| 112 | 
            -
                    for param_name in plot['params']:
         | 
| 113 | 
            -
                        param = await find_param(state, param_name, relevant_tables[0])
         | 
| 114 | 
            -
                        if param:
         | 
| 115 | 
            -
                            params.update(param)
         | 
| 116 | 
            -
             | 
| 117 | 
            -
                    for n, table in enumerate(plot_state['tables']):
         | 
| 118 | 
            -
                        if n > 2:
         | 
| 119 | 
            -
                            break
         | 
| 120 | 
            -
             | 
| 121 | 
            -
                        table_state: TableState = {
         | 
| 122 | 
            -
                            'table_name': table,
         | 
| 123 | 
            -
                            'params': params,
         | 
| 124 | 
            -
                            'status': 'OK'
         | 
| 125 | 
            -
                        } 
         | 
| 126 | 
            -
             | 
| 127 | 
            -
                        table_state["params"]['indicator_column'] = find_indicator_column(table)
         | 
| 128 | 
            -
                        
         | 
| 129 | 
            -
                        sql_query = plot['sql_query'](table, table_state['params'])
         | 
| 130 | 
            -
             | 
| 131 | 
            -
                        if sql_query == "":
         | 
| 132 | 
            -
                            table_state['status'] = 'ERROR'
         | 
| 133 | 
            -
                            continue
         | 
| 134 | 
            -
                        else : 
         | 
| 135 | 
            -
                            have_sql_query = True
         | 
| 136 | 
            -
             | 
| 137 | 
            -
                        table_state['sql_query'] = sql_query
         | 
| 138 | 
            -
                        df = await execute_sql_query(sql_query)
         | 
| 139 | 
            -
             | 
| 140 | 
            -
                        if len(df) > 0:
         | 
| 141 | 
            -
                            have_dataframe = True
         | 
| 142 | 
            -
             | 
| 143 | 
            -
                        figure = plot['plot_function'](table_state['params'])
         | 
| 144 | 
            -
                        table_state['dataframe'] = df
         | 
| 145 | 
            -
                        table_state['figure'] = figure
         | 
| 146 | 
            -
                        plot_state['table_states'][table] = table_state
         | 
| 147 | 
            -
             | 
| 148 | 
            -
                    state['plot_states'][plot_name] = plot_state
         | 
| 149 | 
            -
                
         | 
| 150 | 
            -
                if not have_relevant_table:
         | 
| 151 | 
            -
                    state['error'] = "There is no relevant table in the our database to answer your question"
         | 
| 152 | 
            -
                elif not have_sql_query:
         | 
| 153 | 
            -
                    state['error'] = "There is no relevant sql query on our database that can help to answer your question"
         | 
| 154 | 
            -
                elif not have_dataframe:
         | 
| 155 | 
            -
                    state['error'] = "There is no data in our table that can answer to your question"
         | 
| 156 | 
            -
                
         | 
| 157 | 
            -
                return state
         | 
| 158 | 
            -
             | 
| 159 | 
             
            async def find_relevant_plots(state: State, llm) -> list[str]:
         | 
| 160 | 
             
                print("---- Find relevant plots ----")
         | 
| 161 | 
             
                relevant_plots = await detect_relevant_plots(state['user_input'], llm)
         | 
| @@ -238,6 +146,130 @@ def find_indicator_column(table: str) -> str: | |
| 238 | 
             
                return INDICATOR_COLUMNS_PER_TABLE[table]
         | 
| 239 |  | 
| 240 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 241 | 
             
            # def make_write_query_node():
         | 
| 242 |  | 
| 243 | 
             
            #     def write_query(state):
         | 
|  | |
| 1 | 
             
            import os
         | 
| 2 |  | 
| 3 | 
             
            from typing import Any, Callable, TypedDict, Optional
         | 
| 4 | 
            +
            from numpy import sort
         | 
| 5 | 
             
            import pandas as pd
         | 
| 6 | 
            +
            import asyncio
         | 
| 7 | 
             
            from plotly.graph_objects import Figure
         | 
| 8 | 
             
            from climateqa.engine.llm import get_llm
         | 
| 9 | 
            +
            from climateqa.engine.talk_to_data import sql_query
         | 
| 10 | 
             
            from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
         | 
| 11 | 
             
            from climateqa.engine.talk_to_data.plot import PLOTS, Plot
         | 
| 12 | 
             
            from climateqa.engine.talk_to_data.sql_query import execute_sql_query
         | 
|  | |
| 19 | 
             
                detect_relevant_tables,
         | 
| 20 | 
             
            )
         | 
| 21 |  | 
| 22 | 
            +
             | 
| 23 | 
             
            ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
         | 
| 24 |  | 
| 25 | 
             
            class TableState(TypedDict):
         | 
|  | |
| 64 | 
             
                plot_states: dict[str, PlotState]
         | 
| 65 | 
             
                error: Optional[str]
         | 
| 66 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 67 | 
             
            async def find_relevant_plots(state: State, llm) -> list[str]:
         | 
| 68 | 
             
                print("---- Find relevant plots ----")
         | 
| 69 | 
             
                relevant_plots = await detect_relevant_plots(state['user_input'], llm)
         | 
|  | |
| 146 | 
             
                return INDICATOR_COLUMNS_PER_TABLE[table]
         | 
| 147 |  | 
| 148 |  | 
| 149 | 
            +
            async def process_table(
         | 
| 150 | 
            +
                table: str,
         | 
| 151 | 
            +
                params: dict[str, Any],
         | 
| 152 | 
            +
                plot: Plot,
         | 
| 153 | 
            +
            ) -> TableState:
         | 
| 154 | 
            +
                """Processes a table to extract relevant data and generate visualizations.
         | 
| 155 | 
            +
                
         | 
| 156 | 
            +
                This function retrieves the SQL query for the specified table, executes it,
         | 
| 157 | 
            +
                and generates a visualization based on the results.
         | 
| 158 | 
            +
                
         | 
| 159 | 
            +
                Args:
         | 
| 160 | 
            +
                    table (str): The name of the table to process
         | 
| 161 | 
            +
                    params (dict[str, Any]): Parameters used for querying the table
         | 
| 162 | 
            +
                    plot (Plot): The plot object containing SQL query and visualization function
         | 
| 163 | 
            +
                    
         | 
| 164 | 
            +
                Returns:
         | 
| 165 | 
            +
                    TableState: The state of the processed table
         | 
| 166 | 
            +
                """
         | 
| 167 | 
            +
                table_state: TableState = {
         | 
| 168 | 
            +
                    'table_name': table,
         | 
| 169 | 
            +
                    'params': params.copy(),
         | 
| 170 | 
            +
                    'status': 'OK',
         | 
| 171 | 
            +
                    'dataframe': None,
         | 
| 172 | 
            +
                    'sql_query': None,
         | 
| 173 | 
            +
                    'figure': None
         | 
| 174 | 
            +
                }
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                table_state['params']['indicator_column'] = find_indicator_column(table)
         | 
| 177 | 
            +
                sql_query = plot['sql_query'](table, table_state['params'])
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                if sql_query == "":
         | 
| 180 | 
            +
                    table_state['status'] = 'ERROR'
         | 
| 181 | 
            +
                    return table_state
         | 
| 182 | 
            +
                table_state['sql_query'] = sql_query
         | 
| 183 | 
            +
                print(sql_query)
         | 
| 184 | 
            +
                df = await execute_sql_query(sql_query)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                table_state['dataframe'] = df
         | 
| 187 | 
            +
                table_state['figure'] = plot['plot_function'](table_state['params'])
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                return table_state
         | 
| 190 | 
            +
             | 
| 191 | 
            +
            async def drias_workflow(user_input: str) -> State:
         | 
| 192 | 
            +
                """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                Args:
         | 
| 195 | 
            +
                    user_input (str): initial user input
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                Returns:
         | 
| 198 | 
            +
                    State: Final state with all the results
         | 
| 199 | 
            +
                """
         | 
| 200 | 
            +
                state: State = {
         | 
| 201 | 
            +
                    'user_input': user_input,
         | 
| 202 | 
            +
                    'plots': [],
         | 
| 203 | 
            +
                    'plot_states': {},
         | 
| 204 | 
            +
                    'error': ''
         | 
| 205 | 
            +
                }
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                llm = get_llm(provider="openai")
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                plots = await find_relevant_plots(state, llm)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                state['plots'] = plots
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                if len(state['plots']) < 1:
         | 
| 214 | 
            +
                    state['error'] = 'There is no plot to answer to the question'
         | 
| 215 | 
            +
                    return state
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                have_relevant_table = False
         | 
| 218 | 
            +
                have_sql_query = False
         | 
| 219 | 
            +
                have_dataframe = False
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                for plot_name in state['plots']:
         | 
| 222 | 
            +
                    
         | 
| 223 | 
            +
                    plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
         | 
| 224 | 
            +
                    if plot is None:
         | 
| 225 | 
            +
                        continue
         | 
| 226 | 
            +
                    
         | 
| 227 | 
            +
                    plot_state: PlotState = {
         | 
| 228 | 
            +
                        'plot_name': plot_name,
         | 
| 229 | 
            +
                        'tables': [],
         | 
| 230 | 
            +
                        'table_states': {}
         | 
| 231 | 
            +
                    }
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    plot_state['plot_name'] = plot_name
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    if len(relevant_tables) > 0 :
         | 
| 238 | 
            +
                        have_relevant_table = True
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    plot_state['tables'] = relevant_tables
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    params = {}
         | 
| 243 | 
            +
                    for param_name in plot['params']:
         | 
| 244 | 
            +
                        param = await find_param(state, param_name, relevant_tables[0])
         | 
| 245 | 
            +
                        if param:
         | 
| 246 | 
            +
                            params.update(param)
         | 
| 247 | 
            +
                    
         | 
| 248 | 
            +
                    tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
         | 
| 249 | 
            +
                    results = await asyncio.gather(*tasks)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    # Store results back in plot_state
         | 
| 252 | 
            +
                    have_dataframe = False
         | 
| 253 | 
            +
                    have_sql_query = False
         | 
| 254 | 
            +
                    for table_state in results:
         | 
| 255 | 
            +
                        print(table_state)
         | 
| 256 | 
            +
                        if table_state['sql_query']:
         | 
| 257 | 
            +
                            have_sql_query = True
         | 
| 258 | 
            +
                        if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
         | 
| 259 | 
            +
                            have_dataframe = True
         | 
| 260 | 
            +
                        plot_state['table_states'][table_state['table_name']] = table_state
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    state['plot_states'][plot_name] = plot_state
         | 
| 263 | 
            +
                            
         | 
| 264 | 
            +
                if not have_relevant_table:
         | 
| 265 | 
            +
                    state['error'] = "There is no relevant table in our database to answer your question"
         | 
| 266 | 
            +
                elif not have_sql_query:
         | 
| 267 | 
            +
                    state['error'] = "There is no relevant sql query on our database that can help to answer your question"
         | 
| 268 | 
            +
                elif not have_dataframe:
         | 
| 269 | 
            +
                    state['error'] = "There is no data in our table that can answer to your question"
         | 
| 270 | 
            +
                
         | 
| 271 | 
            +
                return state
         | 
| 272 | 
            +
             | 
| 273 | 
             
            # def make_write_query_node():
         | 
| 274 |  | 
| 275 | 
             
            #     def write_query(state):
         | 
