Merge hf-origin/main into pr/30
Browse filesResolved conflicts by keeping pr/30 changes which include:
- Azure AI Search implementation (replacing Pinecone)
- Updated talk_to_data functionality
- New dependencies and vectorstore wrapper
- Enhanced IPCC and DRIAS workflows
New files from main:
- climateqa/engine/talk_to_data/myVanna.py
- climateqa/engine/talk_to_data/plot.py
- climateqa/engine/talk_to_data/sql_query.py
- climateqa/engine/talk_to_data/talk_to_drias.py
- climateqa/engine/talk_to_data/utils.py
- climateqa/engine/talk_to_data/vanna_class.py
    	
        climateqa/engine/talk_to_data/myVanna.py
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dotenv import load_dotenv
         | 
| 2 | 
            +
            from climateqa.engine.talk_to_data.vanna_class import MyCustomVectorDB
         | 
| 3 | 
            +
            from vanna.openai import OpenAI_Chat
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            load_dotenv()
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            OPENAI_API_KEY = os.getenv('THEO_API_KEY')
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            class MyVanna(MyCustomVectorDB, OpenAI_Chat):
         | 
| 11 | 
            +
                def __init__(self, config=None):
         | 
| 12 | 
            +
                    MyCustomVectorDB.__init__(self, config=config)
         | 
| 13 | 
            +
                    OpenAI_Chat.__init__(self, config=config)
         | 
    	
        climateqa/engine/talk_to_data/plot.py
    ADDED
    
    | @@ -0,0 +1,418 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Callable, TypedDict
         | 
| 2 | 
            +
            from matplotlib.figure import figaspect
         | 
| 3 | 
            +
            import pandas as pd
         | 
| 4 | 
            +
            from plotly.graph_objects import Figure
         | 
| 5 | 
            +
            import plotly.graph_objects as go
         | 
| 6 | 
            +
            import plotly.express as px
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from climateqa.engine.talk_to_data.sql_query import (
         | 
| 9 | 
            +
                indicator_for_given_year_query,
         | 
| 10 | 
            +
                indicator_per_year_at_location_query,
         | 
| 11 | 
            +
            )
         | 
| 12 | 
            +
            from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class Plot(TypedDict):
         | 
| 18 | 
            +
                """Represents a plot configuration in the DRIAS system.
         | 
| 19 | 
            +
                
         | 
| 20 | 
            +
                This class defines the structure for configuring different types of plots
         | 
| 21 | 
            +
                that can be generated from climate data.
         | 
| 22 | 
            +
                
         | 
| 23 | 
            +
                Attributes:
         | 
| 24 | 
            +
                    name (str): The name of the plot type
         | 
| 25 | 
            +
                    description (str): A description of what the plot shows
         | 
| 26 | 
            +
                    params (list[str]): List of required parameters for the plot
         | 
| 27 | 
            +
                    plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
         | 
| 28 | 
            +
                    sql_query (Callable[..., str]): Function to generate the SQL query for the plot
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                name: str
         | 
| 31 | 
            +
                description: str
         | 
| 32 | 
            +
                params: list[str]
         | 
| 33 | 
            +
                plot_function: Callable[..., Callable[..., Figure]]
         | 
| 34 | 
            +
                sql_query: Callable[..., str]
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
         | 
| 38 | 
            +
                """Generates a function to plot indicator evolution over time at a location.
         | 
| 39 | 
            +
                
         | 
| 40 | 
            +
                This function creates a line plot showing how a climate indicator changes
         | 
| 41 | 
            +
                over time at a specific location. It handles temperature, precipitation,
         | 
| 42 | 
            +
                and other climate indicators.
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                Args:
         | 
| 45 | 
            +
                    params (dict): Dictionary containing:
         | 
| 46 | 
            +
                        - indicator_column (str): The column name for the indicator
         | 
| 47 | 
            +
                        - location (str): The location to plot
         | 
| 48 | 
            +
                        - model (str): The climate model to use
         | 
| 49 | 
            +
                        
         | 
| 50 | 
            +
                Returns:
         | 
| 51 | 
            +
                    Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
         | 
| 52 | 
            +
                    
         | 
| 53 | 
            +
                Example:
         | 
| 54 | 
            +
                    >>> plot_func = plot_indicator_evolution_at_location({
         | 
| 55 | 
            +
                    ...     'indicator_column': 'mean_temperature',
         | 
| 56 | 
            +
                    ...     'location': 'Paris',
         | 
| 57 | 
            +
                    ...     'model': 'ALL'
         | 
| 58 | 
            +
                    ... })
         | 
| 59 | 
            +
                    >>> fig = plot_func(df)
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                indicator = params["indicator_column"]
         | 
| 62 | 
            +
                location = params["location"]
         | 
| 63 | 
            +
                indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
         | 
| 64 | 
            +
                unit = INDICATOR_TO_UNIT.get(indicator, "")
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def plot_data(df: pd.DataFrame) -> Figure:
         | 
| 67 | 
            +
                    """Generates the actual plot from the data.
         | 
| 68 | 
            +
                    
         | 
| 69 | 
            +
                    Args:
         | 
| 70 | 
            +
                        df (pd.DataFrame): DataFrame containing the data to plot
         | 
| 71 | 
            +
                        
         | 
| 72 | 
            +
                    Returns:
         | 
| 73 | 
            +
                        Figure: A plotly Figure object showing the indicator evolution
         | 
| 74 | 
            +
                    """
         | 
| 75 | 
            +
                    fig = go.Figure()
         | 
| 76 | 
            +
                    if df['model'].nunique() != 1:
         | 
| 77 | 
            +
                        df_avg = df.groupby("year", as_index=False)[indicator].mean()
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                        # Transform to list to avoid pandas encoding
         | 
| 80 | 
            +
                        indicators = df_avg[indicator].astype(float).tolist()
         | 
| 81 | 
            +
                        years = df_avg["year"].astype(int).tolist()
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                        # Compute the 10-year rolling average
         | 
| 84 | 
            +
                        rolling_window = 10
         | 
| 85 | 
            +
                        sliding_averages = (
         | 
| 86 | 
            +
                            df_avg[indicator]
         | 
| 87 | 
            +
                            .rolling(window=rolling_window, min_periods=rolling_window)
         | 
| 88 | 
            +
                            .mean()
         | 
| 89 | 
            +
                            .astype(float)
         | 
| 90 | 
            +
                            .tolist()
         | 
| 91 | 
            +
                        )
         | 
| 92 | 
            +
                        model_label = "Model Average"
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                        # Only add rolling average if we have enough data points
         | 
| 95 | 
            +
                        if len([x for x in sliding_averages if pd.notna(x)]) > 0:
         | 
| 96 | 
            +
                            # Sliding average dashed line
         | 
| 97 | 
            +
                            fig.add_scatter(
         | 
| 98 | 
            +
                                x=years,
         | 
| 99 | 
            +
                                y=sliding_averages,
         | 
| 100 | 
            +
                                mode="lines",
         | 
| 101 | 
            +
                                name="10 years rolling average",
         | 
| 102 | 
            +
                                line=dict(dash="dash"),
         | 
| 103 | 
            +
                                marker=dict(color="#d62728"),
         | 
| 104 | 
            +
                                hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
         | 
| 105 | 
            +
                            )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    else:
         | 
| 108 | 
            +
                        df_model = df
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                        # Transform to list to avoid pandas encoding
         | 
| 111 | 
            +
                        indicators = df_model[indicator].astype(float).tolist()
         | 
| 112 | 
            +
                        years = df_model["year"].astype(int).tolist()
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                        # Compute the 10-year rolling average
         | 
| 115 | 
            +
                        rolling_window = 10
         | 
| 116 | 
            +
                        sliding_averages = (
         | 
| 117 | 
            +
                            df_model[indicator]
         | 
| 118 | 
            +
                            .rolling(window=rolling_window, min_periods=rolling_window)
         | 
| 119 | 
            +
                            .mean()
         | 
| 120 | 
            +
                            .astype(float)
         | 
| 121 | 
            +
                            .tolist()
         | 
| 122 | 
            +
                        )
         | 
| 123 | 
            +
                        model_label = f"Model : {df['model'].unique()[0]}"
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                        # Only add rolling average if we have enough data points
         | 
| 126 | 
            +
                        if len([x for x in sliding_averages if pd.notna(x)]) > 0:
         | 
| 127 | 
            +
                            # Sliding average dashed line
         | 
| 128 | 
            +
                            fig.add_scatter(
         | 
| 129 | 
            +
                                x=years,
         | 
| 130 | 
            +
                                y=sliding_averages,
         | 
| 131 | 
            +
                                mode="lines",
         | 
| 132 | 
            +
                                name="10 years rolling average",
         | 
| 133 | 
            +
                                line=dict(dash="dash"),
         | 
| 134 | 
            +
                                marker=dict(color="#d62728"),
         | 
| 135 | 
            +
                                hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
         | 
| 136 | 
            +
                            )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    # Indicator per year plot
         | 
| 139 | 
            +
                    fig.add_scatter(
         | 
| 140 | 
            +
                        x=years,
         | 
| 141 | 
            +
                        y=indicators,
         | 
| 142 | 
            +
                        name=f"Yearly {indicator_label}",
         | 
| 143 | 
            +
                        mode="lines",
         | 
| 144 | 
            +
                        marker=dict(color="#1f77b4"),
         | 
| 145 | 
            +
                        hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
         | 
| 146 | 
            +
                    )
         | 
| 147 | 
            +
                    fig.update_layout(
         | 
| 148 | 
            +
                        title=f"Plot of {indicator_label} in {location} ({model_label})",
         | 
| 149 | 
            +
                        xaxis_title="Year",
         | 
| 150 | 
            +
                        yaxis_title=f"{indicator_label} ({unit})",
         | 
| 151 | 
            +
                        template="plotly_white",
         | 
| 152 | 
            +
                    )
         | 
| 153 | 
            +
                    return fig
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                return plot_data
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            indicator_evolution_at_location: Plot = {
         | 
| 159 | 
            +
                "name": "Indicator evolution at location",
         | 
| 160 | 
            +
                "description": "Plot an evolution of the indicator at a certain location",
         | 
| 161 | 
            +
                "params": ["indicator_column", "location", "model"],
         | 
| 162 | 
            +
                "plot_function": plot_indicator_evolution_at_location,
         | 
| 163 | 
            +
                "sql_query": indicator_per_year_at_location_query,
         | 
| 164 | 
            +
            }
         | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
            def plot_indicator_number_of_days_per_year_at_location(
         | 
| 168 | 
            +
                params: dict,
         | 
| 169 | 
            +
            ) -> Callable[..., Figure]:
         | 
| 170 | 
            +
                """Generates a function to plot the number of days per year for an indicator.
         | 
| 171 | 
            +
                
         | 
| 172 | 
            +
                This function creates a bar chart showing the frequency of certain climate
         | 
| 173 | 
            +
                events (like days above a temperature threshold) per year at a specific location.
         | 
| 174 | 
            +
                
         | 
| 175 | 
            +
                Args:
         | 
| 176 | 
            +
                    params (dict): Dictionary containing:
         | 
| 177 | 
            +
                        - indicator_column (str): The column name for the indicator
         | 
| 178 | 
            +
                        - location (str): The location to plot
         | 
| 179 | 
            +
                        - model (str): The climate model to use
         | 
| 180 | 
            +
                        
         | 
| 181 | 
            +
                Returns:
         | 
| 182 | 
            +
                    Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
         | 
| 183 | 
            +
                """
         | 
| 184 | 
            +
                indicator = params["indicator_column"]
         | 
| 185 | 
            +
                location = params["location"]
         | 
| 186 | 
            +
                indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
         | 
| 187 | 
            +
                unit = INDICATOR_TO_UNIT.get(indicator, "")
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                def plot_data(df: pd.DataFrame) -> Figure:
         | 
| 190 | 
            +
                    """Generate the figure thanks to the dataframe
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    Args:
         | 
| 193 | 
            +
                        df (pd.DataFrame): pandas dataframe with the required data
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    Returns:
         | 
| 196 | 
            +
                        Figure: Plotly figure
         | 
| 197 | 
            +
                    """
         | 
| 198 | 
            +
                    fig = go.Figure()
         | 
| 199 | 
            +
                    if df['model'].nunique() != 1:
         | 
| 200 | 
            +
                        df_avg = df.groupby("year", as_index=False)[indicator].mean()
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                        # Transform to list to avoid pandas encoding
         | 
| 203 | 
            +
                        indicators = df_avg[indicator].astype(float).tolist()
         | 
| 204 | 
            +
                        years = df_avg["year"].astype(int).tolist()
         | 
| 205 | 
            +
                        model_label = "Model Average"
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    else:
         | 
| 208 | 
            +
                        df_model = df
         | 
| 209 | 
            +
                        # Transform to list to avoid pandas encoding
         | 
| 210 | 
            +
                        indicators = df_model[indicator].astype(float).tolist()
         | 
| 211 | 
            +
                        years = df_model["year"].astype(int).tolist()
         | 
| 212 | 
            +
                        model_label = f"Model : {df['model'].unique()[0]}"
         | 
| 213 | 
            +
             | 
| 214 | 
            +
             | 
| 215 | 
            +
                    # Bar plot
         | 
| 216 | 
            +
                    fig.add_trace(
         | 
| 217 | 
            +
                        go.Bar(
         | 
| 218 | 
            +
                            x=years,
         | 
| 219 | 
            +
                            y=indicators,
         | 
| 220 | 
            +
                            width=0.5,
         | 
| 221 | 
            +
                            marker=dict(color="#1f77b4"),
         | 
| 222 | 
            +
                            hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
         | 
| 223 | 
            +
                        )
         | 
| 224 | 
            +
                    )
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    fig.update_layout(
         | 
| 227 | 
            +
                        title=f"{indicator_label} in {location} ({model_label})",
         | 
| 228 | 
            +
                        xaxis_title="Year",
         | 
| 229 | 
            +
                        yaxis_title=f"{indicator_label} ({unit})",
         | 
| 230 | 
            +
                        yaxis=dict(range=[0, max(indicators)]),
         | 
| 231 | 
            +
                        bargap=0.5,
         | 
| 232 | 
            +
                        template="plotly_white",
         | 
| 233 | 
            +
                    )
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    return fig
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                return plot_data
         | 
| 238 | 
            +
             | 
| 239 | 
            +
             | 
| 240 | 
            +
            indicator_number_of_days_per_year_at_location: Plot = {
         | 
| 241 | 
            +
                "name": "Indicator number of days per year at location",
         | 
| 242 | 
            +
                "description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
         | 
| 243 | 
            +
                "params": ["indicator_column", "location", "model"],
         | 
| 244 | 
            +
                "plot_function": plot_indicator_number_of_days_per_year_at_location,
         | 
| 245 | 
            +
                "sql_query": indicator_per_year_at_location_query,
         | 
| 246 | 
            +
            }
         | 
| 247 | 
            +
             | 
| 248 | 
            +
             | 
| 249 | 
            +
            def plot_distribution_of_indicator_for_given_year(
         | 
| 250 | 
            +
                params: dict,
         | 
| 251 | 
            +
            ) -> Callable[..., Figure]:
         | 
| 252 | 
            +
                """Generates a function to plot the distribution of an indicator for a year.
         | 
| 253 | 
            +
                
         | 
| 254 | 
            +
                This function creates a histogram showing the distribution of a climate
         | 
| 255 | 
            +
                indicator across different locations for a specific year.
         | 
| 256 | 
            +
                
         | 
| 257 | 
            +
                Args:
         | 
| 258 | 
            +
                    params (dict): Dictionary containing:
         | 
| 259 | 
            +
                        - indicator_column (str): The column name for the indicator
         | 
| 260 | 
            +
                        - year (str): The year to plot
         | 
| 261 | 
            +
                        - model (str): The climate model to use
         | 
| 262 | 
            +
                        
         | 
| 263 | 
            +
                Returns:
         | 
| 264 | 
            +
                    Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
         | 
| 265 | 
            +
                """
         | 
| 266 | 
            +
                indicator = params["indicator_column"]
         | 
| 267 | 
            +
                year = params["year"]
         | 
| 268 | 
            +
                indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
         | 
| 269 | 
            +
                unit = INDICATOR_TO_UNIT.get(indicator, "")
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                def plot_data(df: pd.DataFrame) -> Figure:
         | 
| 272 | 
            +
                    """Generate the figure thanks to the dataframe
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    Args:
         | 
| 275 | 
            +
                        df (pd.DataFrame): pandas dataframe with the required data
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    Returns:
         | 
| 278 | 
            +
                        Figure: Plotly figure
         | 
| 279 | 
            +
                    """
         | 
| 280 | 
            +
                    fig = go.Figure()
         | 
| 281 | 
            +
                    if df['model'].nunique() != 1:
         | 
| 282 | 
            +
                        df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
         | 
| 283 | 
            +
                            indicator
         | 
| 284 | 
            +
                        ].mean()
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                        # Transform to list to avoid pandas encoding
         | 
| 287 | 
            +
                        indicators = df_avg[indicator].astype(float).tolist()
         | 
| 288 | 
            +
                        model_label = "Model Average"
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    else:
         | 
| 291 | 
            +
                        df_model = df
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                        # Transform to list to avoid pandas encoding
         | 
| 294 | 
            +
                        indicators = df_model[indicator].astype(float).tolist()
         | 
| 295 | 
            +
                        model_label = f"Model : {df['model'].unique()[0]}"
         | 
| 296 | 
            +
             | 
| 297 | 
            +
             | 
| 298 | 
            +
                    fig.add_trace(
         | 
| 299 | 
            +
                        go.Histogram(
         | 
| 300 | 
            +
                            x=indicators,
         | 
| 301 | 
            +
                            opacity=0.8,
         | 
| 302 | 
            +
                            histnorm="percent",
         | 
| 303 | 
            +
                            marker=dict(color="#1f77b4"),
         | 
| 304 | 
            +
                            hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
         | 
| 305 | 
            +
                        )
         | 
| 306 | 
            +
                    )
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    fig.update_layout(
         | 
| 309 | 
            +
                        title=f"Distribution of {indicator_label} in {year} ({model_label})",
         | 
| 310 | 
            +
                        xaxis_title=f"{indicator_label} ({unit})",
         | 
| 311 | 
            +
                        yaxis_title="Frequency (%)",
         | 
| 312 | 
            +
                        plot_bgcolor="rgba(0, 0, 0, 0)",
         | 
| 313 | 
            +
                        showlegend=False,
         | 
| 314 | 
            +
                    )
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    return fig
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                return plot_data
         | 
| 319 | 
            +
             | 
| 320 | 
            +
             | 
| 321 | 
            +
            distribution_of_indicator_for_given_year: Plot = {
         | 
| 322 | 
            +
                "name": "Distribution of an indicator for a given year",
         | 
| 323 | 
            +
                "description": "Plot an histogram of the distribution for a given year of the values of an indicator",
         | 
| 324 | 
            +
                "params": ["indicator_column", "model", "year"],
         | 
| 325 | 
            +
                "plot_function": plot_distribution_of_indicator_for_given_year,
         | 
| 326 | 
            +
                "sql_query": indicator_for_given_year_query,
         | 
| 327 | 
            +
            }
         | 
| 328 | 
            +
             | 
| 329 | 
            +
             | 
| 330 | 
            +
            def plot_map_of_france_of_indicator_for_given_year(
         | 
| 331 | 
            +
                params: dict,
         | 
| 332 | 
            +
            ) -> Callable[..., Figure]:
         | 
| 333 | 
            +
                """Generates a function to plot a map of France for an indicator.
         | 
| 334 | 
            +
                
         | 
| 335 | 
            +
                This function creates a choropleth map of France showing the spatial
         | 
| 336 | 
            +
                distribution of a climate indicator for a specific year.
         | 
| 337 | 
            +
                
         | 
| 338 | 
            +
                Args:
         | 
| 339 | 
            +
                    params (dict): Dictionary containing:
         | 
| 340 | 
            +
                        - indicator_column (str): The column name for the indicator
         | 
| 341 | 
            +
                        - year (str): The year to plot
         | 
| 342 | 
            +
                        - model (str): The climate model to use
         | 
| 343 | 
            +
                        
         | 
| 344 | 
            +
                Returns:
         | 
| 345 | 
            +
                    Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
         | 
| 346 | 
            +
                """
         | 
| 347 | 
            +
                indicator = params["indicator_column"]
         | 
| 348 | 
            +
                year = params["year"]
         | 
| 349 | 
            +
                indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
         | 
| 350 | 
            +
                unit = INDICATOR_TO_UNIT.get(indicator, "")
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def plot_data(df: pd.DataFrame) -> Figure:
         | 
| 353 | 
            +
                    fig = go.Figure()
         | 
| 354 | 
            +
                    if df['model'].nunique() != 1:
         | 
| 355 | 
            +
                        df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
         | 
| 356 | 
            +
                            indicator
         | 
| 357 | 
            +
                        ].mean()
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                        indicators = df_avg[indicator].astype(float).tolist()
         | 
| 360 | 
            +
                        latitudes = df_avg["latitude"].astype(float).tolist()
         | 
| 361 | 
            +
                        longitudes = df_avg["longitude"].astype(float).tolist()
         | 
| 362 | 
            +
                        model_label = "Model Average"
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    else:
         | 
| 365 | 
            +
                        df_model = df
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                        # Transform to list to avoid pandas encoding
         | 
| 368 | 
            +
                        indicators = df_model[indicator].astype(float).tolist()
         | 
| 369 | 
            +
                        latitudes = df_model["latitude"].astype(float).tolist()
         | 
| 370 | 
            +
                        longitudes = df_model["longitude"].astype(float).tolist()
         | 
| 371 | 
            +
                        model_label = f"Model : {df['model'].unique()[0]}"
         | 
| 372 | 
            +
             | 
| 373 | 
            +
             | 
| 374 | 
            +
                    fig.add_trace(
         | 
| 375 | 
            +
                        go.Scattermapbox(
         | 
| 376 | 
            +
                            lat=latitudes,
         | 
| 377 | 
            +
                            lon=longitudes,
         | 
| 378 | 
            +
                            mode="markers",
         | 
| 379 | 
            +
                            marker=dict(
         | 
| 380 | 
            +
                                size=10,
         | 
| 381 | 
            +
                                color=indicators,  # Color mapped to values
         | 
| 382 | 
            +
                                colorscale="Turbo",  # Color scale (can be 'Plasma', 'Jet', etc.)
         | 
| 383 | 
            +
                                cmin=min(indicators),  # Minimum color range
         | 
| 384 | 
            +
                                cmax=max(indicators),  # Maximum color range
         | 
| 385 | 
            +
                                showscale=True,  # Show colorbar
         | 
| 386 | 
            +
                            ),
         | 
| 387 | 
            +
                            text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators],  # Add hover text showing the indicator value
         | 
| 388 | 
            +
                            hoverinfo="text"  # Only show the custom text on hover
         | 
| 389 | 
            +
                        )
         | 
| 390 | 
            +
                    )
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    fig.update_layout(
         | 
| 393 | 
            +
                        mapbox_style="open-street-map",  # Use OpenStreetMap
         | 
| 394 | 
            +
                        mapbox_zoom=3,
         | 
| 395 | 
            +
                        mapbox_center={"lat": 46.6, "lon": 2.0},
         | 
| 396 | 
            +
                        coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"),  # Add legend
         | 
| 397 | 
            +
                        title=f"{indicator_label} in {year} in France ({model_label}) " # Title
         | 
| 398 | 
            +
                    )
         | 
| 399 | 
            +
                    return fig
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                return plot_data
         | 
| 402 | 
            +
             | 
| 403 | 
            +
             | 
| 404 | 
            +
            map_of_france_of_indicator_for_given_year: Plot = {
         | 
| 405 | 
            +
                "name": "Map of France of an indicator for a given year",
         | 
| 406 | 
            +
                "description": "Heatmap on the map of France of the values of an in indicator for a given year",
         | 
| 407 | 
            +
                "params": ["indicator_column", "year", "model"],
         | 
| 408 | 
            +
                "plot_function": plot_map_of_france_of_indicator_for_given_year,
         | 
| 409 | 
            +
                "sql_query": indicator_for_given_year_query,
         | 
| 410 | 
            +
            }
         | 
| 411 | 
            +
             | 
| 412 | 
            +
             | 
| 413 | 
            +
            PLOTS = [
         | 
| 414 | 
            +
                indicator_evolution_at_location,
         | 
| 415 | 
            +
                indicator_number_of_days_per_year_at_location,
         | 
| 416 | 
            +
                distribution_of_indicator_for_given_year,
         | 
| 417 | 
            +
                map_of_france_of_indicator_for_given_year,
         | 
| 418 | 
            +
            ]
         | 
    	
        climateqa/engine/talk_to_data/sql_query.py
    ADDED
    
    | @@ -0,0 +1,114 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import asyncio
         | 
| 2 | 
            +
            from concurrent.futures import ThreadPoolExecutor
         | 
| 3 | 
            +
            from typing import TypedDict
         | 
| 4 | 
            +
            import duckdb
         | 
| 5 | 
            +
            import pandas as pd
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            async def execute_sql_query(sql_query: str) -> pd.DataFrame:
         | 
| 8 | 
            +
                """Executes a SQL query on the DRIAS database and returns the results.
         | 
| 9 | 
            +
                
         | 
| 10 | 
            +
                This function connects to the DuckDB database containing DRIAS climate data
         | 
| 11 | 
            +
                and executes the provided SQL query. It handles the database connection and
         | 
| 12 | 
            +
                returns the results as a pandas DataFrame.
         | 
| 13 | 
            +
                
         | 
| 14 | 
            +
                Args:
         | 
| 15 | 
            +
                    sql_query (str): The SQL query to execute
         | 
| 16 | 
            +
                    
         | 
| 17 | 
            +
                Returns:
         | 
| 18 | 
            +
                    pd.DataFrame: A DataFrame containing the query results
         | 
| 19 | 
            +
                    
         | 
| 20 | 
            +
                Raises:
         | 
| 21 | 
            +
                    duckdb.Error: If there is an error executing the SQL query
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                def _execute_query():
         | 
| 24 | 
            +
                    # Execute the query
         | 
| 25 | 
            +
                    con = duckdb.connect()
         | 
| 26 | 
            +
                    results = con.sql(sql_query).fetchdf()
         | 
| 27 | 
            +
                    # return fetched data
         | 
| 28 | 
            +
                    return results
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                # Run the query in a thread pool to avoid blocking
         | 
| 31 | 
            +
                loop = asyncio.get_event_loop()
         | 
| 32 | 
            +
                with ThreadPoolExecutor() as executor:
         | 
| 33 | 
            +
                    return await loop.run_in_executor(executor, _execute_query)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
         | 
| 37 | 
            +
                """Parameters for querying an indicator's values over time at a location.
         | 
| 38 | 
            +
                
         | 
| 39 | 
            +
                This class defines the parameters needed to query climate indicator data
         | 
| 40 | 
            +
                for a specific location over multiple years.
         | 
| 41 | 
            +
                
         | 
| 42 | 
            +
                Attributes:
         | 
| 43 | 
            +
                    indicator_column (str): The column name for the climate indicator
         | 
| 44 | 
            +
                    latitude (str): The latitude coordinate of the location
         | 
| 45 | 
            +
                    longitude (str): The longitude coordinate of the location
         | 
| 46 | 
            +
                    model (str): The climate model to use (optional)
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                indicator_column: str
         | 
| 49 | 
            +
                latitude: str
         | 
| 50 | 
            +
                longitude: str
         | 
| 51 | 
            +
                model: str
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def indicator_per_year_at_location_query(
         | 
| 55 | 
            +
                table: str, params: IndicatorPerYearAtLocationQueryParams
         | 
| 56 | 
            +
            ) -> str:
         | 
| 57 | 
            +
                """SQL Query to get the evolution of an indicator per year at a certain location
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                Args:
         | 
| 60 | 
            +
                    table (str): sql table of the indicator
         | 
| 61 | 
            +
                    params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                Returns:
         | 
| 64 | 
            +
                    str: the sql query
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
                indicator_column = params.get("indicator_column")
         | 
| 67 | 
            +
                latitude = params.get("latitude")
         | 
| 68 | 
            +
                longitude = params.get("longitude")
         | 
| 69 | 
            +
                
         | 
| 70 | 
            +
                if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
         | 
| 71 | 
            +
                    return ""
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                return sql_query
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            class IndicatorForGivenYearQueryParams(TypedDict, total=False):
         | 
| 80 | 
            +
                """Parameters for querying an indicator's values across locations for a year.
         | 
| 81 | 
            +
                
         | 
| 82 | 
            +
                This class defines the parameters needed to query climate indicator data
         | 
| 83 | 
            +
                across different locations for a specific year.
         | 
| 84 | 
            +
                
         | 
| 85 | 
            +
                Attributes:
         | 
| 86 | 
            +
                    indicator_column (str): The column name for the climate indicator
         | 
| 87 | 
            +
                    year (str): The year to query
         | 
| 88 | 
            +
                    model (str): The climate model to use (optional)
         | 
| 89 | 
            +
                """
         | 
| 90 | 
            +
                indicator_column: str
         | 
| 91 | 
            +
                year: str
         | 
| 92 | 
            +
                model: str
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            def indicator_for_given_year_query(
         | 
| 95 | 
            +
                    table:str, params: IndicatorForGivenYearQueryParams 
         | 
| 96 | 
            +
            ) -> str:
         | 
| 97 | 
            +
                """SQL Query to get the values of an indicator with their latitudes, longitudes and models for a given year
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                Args:
         | 
| 100 | 
            +
                    table (str): sql table of the indicator
         | 
| 101 | 
            +
                    params (IndicatorForGivenYearQueryParams): dictionarry with the required params for the query
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                Returns:
         | 
| 104 | 
            +
                    str: the sql query
         | 
| 105 | 
            +
                """
         | 
| 106 | 
            +
                indicator_column = params.get("indicator_column")
         | 
| 107 | 
            +
                year = params.get('year')
         | 
| 108 | 
            +
                if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
         | 
| 109 | 
            +
                    return ""
         | 
| 110 | 
            +
                
         | 
| 111 | 
            +
                table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
         | 
| 114 | 
            +
                return sql_query    
         | 
    	
        climateqa/engine/talk_to_data/talk_to_drias.py
    ADDED
    
    | @@ -0,0 +1,317 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from typing import Any, Callable, TypedDict, Optional
         | 
| 4 | 
            +
            from numpy import sort
         | 
| 5 | 
            +
            import pandas as pd
         | 
| 6 | 
            +
            import asyncio
         | 
| 7 | 
            +
            from plotly.graph_objects import Figure
         | 
| 8 | 
            +
            from climateqa.engine.llm import get_llm
         | 
| 9 | 
            +
            from climateqa.engine.talk_to_data import sql_query
         | 
| 10 | 
            +
            from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
         | 
| 11 | 
            +
            from climateqa.engine.talk_to_data.plot import PLOTS, Plot
         | 
| 12 | 
            +
            from climateqa.engine.talk_to_data.sql_query import execute_sql_query
         | 
| 13 | 
            +
            from climateqa.engine.talk_to_data.utils import (
         | 
| 14 | 
            +
                detect_relevant_plots,
         | 
| 15 | 
            +
                detect_year_with_openai,
         | 
| 16 | 
            +
                loc2coords,
         | 
| 17 | 
            +
                detect_location_with_openai,
         | 
| 18 | 
            +
                nearestNeighbourSQL,
         | 
| 19 | 
            +
                detect_relevant_tables,
         | 
| 20 | 
            +
            )
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            class TableState(TypedDict):
         | 
| 26 | 
            +
                """Represents the state of a table in the DRIAS workflow.
         | 
| 27 | 
            +
                
         | 
| 28 | 
            +
                This class defines the structure for tracking the state of a table during the
         | 
| 29 | 
            +
                data processing workflow, including its name, parameters, SQL query, and results.
         | 
| 30 | 
            +
                
         | 
| 31 | 
            +
                Attributes:
         | 
| 32 | 
            +
                    table_name (str): The name of the table in the database
         | 
| 33 | 
            +
                    params (dict[str, Any]): Parameters used for querying the table
         | 
| 34 | 
            +
                    sql_query (str, optional): The SQL query used to fetch data
         | 
| 35 | 
            +
                    dataframe (pd.DataFrame | None, optional): The resulting data
         | 
| 36 | 
            +
                    figure (Callable[..., Figure], optional): Function to generate visualization
         | 
| 37 | 
            +
                    status (str): The current status of the table processing ('OK' or 'ERROR')
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                table_name: str
         | 
| 40 | 
            +
                params: dict[str, Any]
         | 
| 41 | 
            +
                sql_query: Optional[str]
         | 
| 42 | 
            +
                dataframe: Optional[pd.DataFrame | None]
         | 
| 43 | 
            +
                figure: Optional[Callable[..., Figure]]
         | 
| 44 | 
            +
                status: str
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            class PlotState(TypedDict):
         | 
| 47 | 
            +
                """Represents the state of a plot in the DRIAS workflow.
         | 
| 48 | 
            +
                
         | 
| 49 | 
            +
                This class defines the structure for tracking the state of a plot during the
         | 
| 50 | 
            +
                data processing workflow, including its name and associated tables.
         | 
| 51 | 
            +
                
         | 
| 52 | 
            +
                Attributes:
         | 
| 53 | 
            +
                    plot_name (str): The name of the plot
         | 
| 54 | 
            +
                    tables (list[str]): List of tables used in the plot
         | 
| 55 | 
            +
                    table_states (dict[str, TableState]): States of the tables used in the plot
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                plot_name: str
         | 
| 58 | 
            +
                tables: list[str]
         | 
| 59 | 
            +
                table_states: dict[str, TableState]
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            class State(TypedDict):
         | 
| 62 | 
            +
                user_input: str
         | 
| 63 | 
            +
                plots: list[str]
         | 
| 64 | 
            +
                plot_states: dict[str, PlotState]
         | 
| 65 | 
            +
                error: Optional[str]
         | 
| 66 | 
            +
                
         | 
| 67 | 
            +
            async def find_relevant_plots(state: State, llm) -> list[str]:
         | 
| 68 | 
            +
                print("---- Find relevant plots ----")
         | 
| 69 | 
            +
                relevant_plots = await detect_relevant_plots(state['user_input'], llm)
         | 
| 70 | 
            +
                return relevant_plots
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
         | 
| 73 | 
            +
                print(f"---- Find relevant tables for {plot['name']} ----")
         | 
| 74 | 
            +
                relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
         | 
| 75 | 
            +
                return relevant_tables
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
         | 
| 78 | 
            +
                """Perform the good method to retrieve the desired parameter
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                Args:
         | 
| 81 | 
            +
                    state (State): state of the workflow
         | 
| 82 | 
            +
                    param_name (str): name of the desired parameter
         | 
| 83 | 
            +
                    table (str): name of the table
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                Returns:
         | 
| 86 | 
            +
                    dict[str, Any] | None: 
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                if param_name == 'location':
         | 
| 89 | 
            +
                    location = await find_location(state['user_input'], table)
         | 
| 90 | 
            +
                    return location
         | 
| 91 | 
            +
                if param_name == 'year':
         | 
| 92 | 
            +
                    year = await find_year(state['user_input'])
         | 
| 93 | 
            +
                    return {'year': year}
         | 
| 94 | 
            +
                return None
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            class Location(TypedDict):
         | 
| 97 | 
            +
                location: str
         | 
| 98 | 
            +
                latitude: Optional[str]
         | 
| 99 | 
            +
                longitude: Optional[str]
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            async def find_location(user_input: str, table: str) -> Location:
         | 
| 102 | 
            +
                print(f"---- Find location in table {table} ----")
         | 
| 103 | 
            +
                location = await detect_location_with_openai(user_input)
         | 
| 104 | 
            +
                output: Location = {'location' : location}
         | 
| 105 | 
            +
                if location:
         | 
| 106 | 
            +
                    coords = loc2coords(location)
         | 
| 107 | 
            +
                    neighbour = nearestNeighbourSQL(coords, table)
         | 
| 108 | 
            +
                    output.update({
         | 
| 109 | 
            +
                        "latitude": neighbour[0],
         | 
| 110 | 
            +
                        "longitude": neighbour[1],
         | 
| 111 | 
            +
                    })
         | 
| 112 | 
            +
                return output
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            async def find_year(user_input: str) -> str:
         | 
| 115 | 
            +
                """Extracts year information from user input using LLM.
         | 
| 116 | 
            +
                
         | 
| 117 | 
            +
                This function uses an LLM to identify and extract year information from the
         | 
| 118 | 
            +
                user's query, which is used to filter data in subsequent queries.
         | 
| 119 | 
            +
                
         | 
| 120 | 
            +
                Args:
         | 
| 121 | 
            +
                    user_input (str): The user's query text
         | 
| 122 | 
            +
                    
         | 
| 123 | 
            +
                Returns:
         | 
| 124 | 
            +
                    str: The extracted year, or empty string if no year found
         | 
| 125 | 
            +
                """
         | 
| 126 | 
            +
                print(f"---- Find year ---")
         | 
| 127 | 
            +
                year = await detect_year_with_openai(user_input)
         | 
| 128 | 
            +
                return year
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            def find_indicator_column(table: str) -> str:
         | 
| 131 | 
            +
                """Retrieves the name of the indicator column within a table.
         | 
| 132 | 
            +
                
         | 
| 133 | 
            +
                This function maps table names to their corresponding indicator columns
         | 
| 134 | 
            +
                using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
         | 
| 135 | 
            +
                
         | 
| 136 | 
            +
                Args:
         | 
| 137 | 
            +
                    table (str): Name of the table in the database
         | 
| 138 | 
            +
                    
         | 
| 139 | 
            +
                Returns:
         | 
| 140 | 
            +
                    str: Name of the indicator column for the specified table
         | 
| 141 | 
            +
                    
         | 
| 142 | 
            +
                Raises:
         | 
| 143 | 
            +
                    KeyError: If the table name is not found in the mapping
         | 
| 144 | 
            +
                """
         | 
| 145 | 
            +
                print(f"---- Find indicator column in table {table} ----")
         | 
| 146 | 
            +
                return INDICATOR_COLUMNS_PER_TABLE[table]
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            async def process_table(
         | 
| 150 | 
            +
                table: str,
         | 
| 151 | 
            +
                params: dict[str, Any],
         | 
| 152 | 
            +
                plot: Plot,
         | 
| 153 | 
            +
            ) -> TableState:
         | 
| 154 | 
            +
                """Processes a table to extract relevant data and generate visualizations.
         | 
| 155 | 
            +
                
         | 
| 156 | 
            +
                This function retrieves the SQL query for the specified table, executes it,
         | 
| 157 | 
            +
                and generates a visualization based on the results.
         | 
| 158 | 
            +
                
         | 
| 159 | 
            +
                Args:
         | 
| 160 | 
            +
                    table (str): The name of the table to process
         | 
| 161 | 
            +
                    params (dict[str, Any]): Parameters used for querying the table
         | 
| 162 | 
            +
                    plot (Plot): The plot object containing SQL query and visualization function
         | 
| 163 | 
            +
                    
         | 
| 164 | 
            +
                Returns:
         | 
| 165 | 
            +
                    TableState: The state of the processed table
         | 
| 166 | 
            +
                """
         | 
| 167 | 
            +
                table_state: TableState = {
         | 
| 168 | 
            +
                    'table_name': table,
         | 
| 169 | 
            +
                    'params': params.copy(),
         | 
| 170 | 
            +
                    'status': 'OK',
         | 
| 171 | 
            +
                    'dataframe': None,
         | 
| 172 | 
            +
                    'sql_query': None,
         | 
| 173 | 
            +
                    'figure': None
         | 
| 174 | 
            +
                }
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                table_state['params']['indicator_column'] = find_indicator_column(table)
         | 
| 177 | 
            +
                sql_query = plot['sql_query'](table, table_state['params'])
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                if sql_query == "":
         | 
| 180 | 
            +
                    table_state['status'] = 'ERROR'
         | 
| 181 | 
            +
                    return table_state
         | 
| 182 | 
            +
                table_state['sql_query'] = sql_query
         | 
| 183 | 
            +
                df = await execute_sql_query(sql_query)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                table_state['dataframe'] = df
         | 
| 186 | 
            +
                table_state['figure'] = plot['plot_function'](table_state['params'])
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                return table_state
         | 
| 189 | 
            +
             | 
| 190 | 
            +
            async def drias_workflow(user_input: str) -> State:
         | 
| 191 | 
            +
                """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                Args:
         | 
| 194 | 
            +
                    user_input (str): initial user input
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                Returns:
         | 
| 197 | 
            +
                    State: Final state with all the results
         | 
| 198 | 
            +
                """
         | 
| 199 | 
            +
                state: State = {
         | 
| 200 | 
            +
                    'user_input': user_input,
         | 
| 201 | 
            +
                    'plots': [],
         | 
| 202 | 
            +
                    'plot_states': {},
         | 
| 203 | 
            +
                    'error': ''
         | 
| 204 | 
            +
                }
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                llm = get_llm(provider="openai")
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                plots = await find_relevant_plots(state, llm)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                state['plots'] = plots
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                if len(state['plots']) < 1:
         | 
| 213 | 
            +
                    state['error'] = 'There is no plot to answer to the question'
         | 
| 214 | 
            +
                    return state
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                have_relevant_table = False
         | 
| 217 | 
            +
                have_sql_query = False
         | 
| 218 | 
            +
                have_dataframe = False
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                for plot_name in state['plots']:
         | 
| 221 | 
            +
                    
         | 
| 222 | 
            +
                    plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
         | 
| 223 | 
            +
                    if plot is None:
         | 
| 224 | 
            +
                        continue
         | 
| 225 | 
            +
                    
         | 
| 226 | 
            +
                    plot_state: PlotState = {
         | 
| 227 | 
            +
                        'plot_name': plot_name,
         | 
| 228 | 
            +
                        'tables': [],
         | 
| 229 | 
            +
                        'table_states': {}
         | 
| 230 | 
            +
                    }
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    plot_state['plot_name'] = plot_name
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    if len(relevant_tables) > 0 :
         | 
| 237 | 
            +
                        have_relevant_table = True
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    plot_state['tables'] = relevant_tables
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    params = {}
         | 
| 242 | 
            +
                    for param_name in plot['params']:
         | 
| 243 | 
            +
                        param = await find_param(state, param_name, relevant_tables[0])
         | 
| 244 | 
            +
                        if param:
         | 
| 245 | 
            +
                            params.update(param)
         | 
| 246 | 
            +
                    
         | 
| 247 | 
            +
                    tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
         | 
| 248 | 
            +
                    results = await asyncio.gather(*tasks)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    # Store results back in plot_state
         | 
| 251 | 
            +
                    have_dataframe = False
         | 
| 252 | 
            +
                    have_sql_query = False
         | 
| 253 | 
            +
                    for table_state in results:
         | 
| 254 | 
            +
                        if table_state['sql_query']:
         | 
| 255 | 
            +
                            have_sql_query = True
         | 
| 256 | 
            +
                        if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
         | 
| 257 | 
            +
                            have_dataframe = True
         | 
| 258 | 
            +
                        plot_state['table_states'][table_state['table_name']] = table_state
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    state['plot_states'][plot_name] = plot_state
         | 
| 261 | 
            +
                            
         | 
| 262 | 
            +
                if not have_relevant_table:
         | 
| 263 | 
            +
                    state['error'] = "There is no relevant table in our database to answer your question"
         | 
| 264 | 
            +
                elif not have_sql_query:
         | 
| 265 | 
            +
                    state['error'] = "There is no relevant sql query on our database that can help to answer your question"
         | 
| 266 | 
            +
                elif not have_dataframe:
         | 
| 267 | 
            +
                    state['error'] = "There is no data in our table that can answer to your question"
         | 
| 268 | 
            +
                
         | 
| 269 | 
            +
                return state
         | 
| 270 | 
            +
             | 
| 271 | 
            +
            # def make_write_query_node():
         | 
| 272 | 
            +
             | 
| 273 | 
            +
            #     def write_query(state):
         | 
| 274 | 
            +
            #         print("---- Write query ----")
         | 
| 275 | 
            +
            #         for table in state["tables"]:
         | 
| 276 | 
            +
            #             sql_query = QUERIES[state[table]['query_type']](
         | 
| 277 | 
            +
            #                 table=table,
         | 
| 278 | 
            +
            #                 indicator_column=state[table]["columns"],
         | 
| 279 | 
            +
            #                 longitude=state[table]["longitude"],
         | 
| 280 | 
            +
            #                 latitude=state[table]["latitude"],
         | 
| 281 | 
            +
            #             )
         | 
| 282 | 
            +
            #             state[table].update({"sql_query": sql_query})
         | 
| 283 | 
            +
             | 
| 284 | 
            +
            #         return state
         | 
| 285 | 
            +
             | 
| 286 | 
            +
            #     return write_query
         | 
| 287 | 
            +
             | 
| 288 | 
            +
            # def make_fetch_data_node(db_path):
         | 
| 289 | 
            +
             | 
| 290 | 
            +
            #     def fetch_data(state):
         | 
| 291 | 
            +
            #         print("---- Fetch data ----")
         | 
| 292 | 
            +
            #         for table in state["tables"]:
         | 
| 293 | 
            +
            #             results = execute_sql_query(db_path, state[table]['sql_query'])
         | 
| 294 | 
            +
            #             state[table].update(results)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
            #         return state
         | 
| 297 | 
            +
                    
         | 
| 298 | 
            +
            #     return fetch_data
         | 
| 299 | 
            +
             | 
| 300 | 
            +
             | 
| 301 | 
            +
             | 
| 302 | 
            +
            ## V2
         | 
| 303 | 
            +
             | 
| 304 | 
            +
             | 
| 305 | 
            +
            # def make_fetch_data_node(db_path: str, llm):
         | 
| 306 | 
            +
            #     def fetch_data(state):
         | 
| 307 | 
            +
            #         print("---- Fetch data ----")
         | 
| 308 | 
            +
            #         db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
         | 
| 309 | 
            +
            #         output = {}
         | 
| 310 | 
            +
            #         sql_query = write_sql_query(state["query"], db, state["tables"], llm)
         | 
| 311 | 
            +
            #         # TO DO : Add query checker
         | 
| 312 | 
            +
            #         print(f"SQL query  : {sql_query}")
         | 
| 313 | 
            +
            #         output["sql_query"] = sql_query
         | 
| 314 | 
            +
            #         output.update(fetch_data_from_sql_query(db_path, sql_query))
         | 
| 315 | 
            +
            #         return output
         | 
| 316 | 
            +
             | 
| 317 | 
            +
            #     return fetch_data
         | 
    	
        climateqa/engine/talk_to_data/utils.py
    ADDED
    
    | @@ -0,0 +1,281 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
            from typing import Annotated, TypedDict
         | 
| 3 | 
            +
            import duckdb
         | 
| 4 | 
            +
            from geopy.geocoders import Nominatim
         | 
| 5 | 
            +
            import ast
         | 
| 6 | 
            +
            from climateqa.engine.llm import get_llm
         | 
| 7 | 
            +
            from climateqa.engine.talk_to_data.config import DRIAS_TABLES
         | 
| 8 | 
            +
            from climateqa.engine.talk_to_data.plot import PLOTS, Plot
         | 
| 9 | 
            +
            from langchain_core.prompts import ChatPromptTemplate
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            async def detect_location_with_openai(sentence):
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                Detects locations in a sentence using OpenAI's API via LangChain.
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                llm = get_llm()
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                prompt = f"""
         | 
| 19 | 
            +
                Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
         | 
| 20 | 
            +
                Return the result as a Python list. If no locations are mentioned, return an empty list.
         | 
| 21 | 
            +
                
         | 
| 22 | 
            +
                Sentence: "{sentence}"
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                response = await llm.ainvoke(prompt)
         | 
| 26 | 
            +
                location_list = ast.literal_eval(response.content.strip("```python\n").strip())
         | 
| 27 | 
            +
                if location_list:
         | 
| 28 | 
            +
                    return location_list[0]
         | 
| 29 | 
            +
                else:
         | 
| 30 | 
            +
                    return ""
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            class ArrayOutput(TypedDict):
         | 
| 33 | 
            +
                """Represents the output of a function that returns an array.
         | 
| 34 | 
            +
                
         | 
| 35 | 
            +
                This class is used to type-hint functions that return arrays,
         | 
| 36 | 
            +
                ensuring consistent return types across the codebase.
         | 
| 37 | 
            +
                
         | 
| 38 | 
            +
                Attributes:
         | 
| 39 | 
            +
                    array (str): A syntactically valid Python array string
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
                array: Annotated[str, "Syntactically valid python array."]
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            async def detect_year_with_openai(sentence: str) -> str:
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
                Detects years in a sentence using OpenAI's API via LangChain.
         | 
| 46 | 
            +
                """
         | 
| 47 | 
            +
                llm = get_llm()
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                prompt = """
         | 
| 50 | 
            +
                Extract all years mentioned in the following sentence.
         | 
| 51 | 
            +
                Return the result as a Python list. If no year are mentioned, return an empty list.
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
                Sentence: "{sentence}"
         | 
| 54 | 
            +
                """
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                prompt = ChatPromptTemplate.from_template(prompt)
         | 
| 57 | 
            +
                structured_llm = llm.with_structured_output(ArrayOutput)
         | 
| 58 | 
            +
                chain = prompt | structured_llm
         | 
| 59 | 
            +
                response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
         | 
| 60 | 
            +
                years_list = eval(response['array'])
         | 
| 61 | 
            +
                if len(years_list) > 0:
         | 
| 62 | 
            +
                    return years_list[0]
         | 
| 63 | 
            +
                else:
         | 
| 64 | 
            +
                    return ""
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def detectTable(sql_query: str) -> list[str]:
         | 
| 68 | 
            +
                """Extracts table names from a SQL query.
         | 
| 69 | 
            +
                
         | 
| 70 | 
            +
                This function uses regular expressions to find all table names
         | 
| 71 | 
            +
                referenced in a SQL query's FROM clause.
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                Args:
         | 
| 74 | 
            +
                    sql_query (str): The SQL query to analyze
         | 
| 75 | 
            +
                    
         | 
| 76 | 
            +
                Returns:
         | 
| 77 | 
            +
                    list[str]: A list of table names found in the query
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                Example:
         | 
| 80 | 
            +
                    >>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
         | 
| 81 | 
            +
                    ['temperature_data']
         | 
| 82 | 
            +
                """
         | 
| 83 | 
            +
                pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
         | 
| 84 | 
            +
                matches = re.findall(pattern, sql_query)
         | 
| 85 | 
            +
                return matches
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def loc2coords(location: str) -> tuple[float, float]:
         | 
| 89 | 
            +
                """Converts a location name to geographic coordinates.
         | 
| 90 | 
            +
                
         | 
| 91 | 
            +
                This function uses the Nominatim geocoding service to convert
         | 
| 92 | 
            +
                a location name (e.g., city name) to its latitude and longitude.
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                Args:
         | 
| 95 | 
            +
                    location (str): The name of the location to geocode
         | 
| 96 | 
            +
                    
         | 
| 97 | 
            +
                Returns:
         | 
| 98 | 
            +
                    tuple[float, float]: A tuple containing (latitude, longitude)
         | 
| 99 | 
            +
                    
         | 
| 100 | 
            +
                Raises:
         | 
| 101 | 
            +
                    AttributeError: If the location cannot be found
         | 
| 102 | 
            +
                """
         | 
| 103 | 
            +
                geolocator = Nominatim(user_agent="city_to_latlong")
         | 
| 104 | 
            +
                coords = geolocator.geocode(location)
         | 
| 105 | 
            +
                return (coords.latitude, coords.longitude)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            def coords2loc(coords: tuple[float, float]) -> str:
         | 
| 109 | 
            +
                """Converts geographic coordinates to a location name.
         | 
| 110 | 
            +
                
         | 
| 111 | 
            +
                This function uses the Nominatim reverse geocoding service to convert
         | 
| 112 | 
            +
                latitude and longitude coordinates to a human-readable location name.
         | 
| 113 | 
            +
                
         | 
| 114 | 
            +
                Args:
         | 
| 115 | 
            +
                    coords (tuple[float, float]): A tuple containing (latitude, longitude)
         | 
| 116 | 
            +
                    
         | 
| 117 | 
            +
                Returns:
         | 
| 118 | 
            +
                    str: The address of the location, or "Unknown Location" if not found
         | 
| 119 | 
            +
                    
         | 
| 120 | 
            +
                Example:
         | 
| 121 | 
            +
                    >>> coords2loc((48.8566, 2.3522))
         | 
| 122 | 
            +
                    'Paris, France'
         | 
| 123 | 
            +
                """
         | 
| 124 | 
            +
                geolocator = Nominatim(user_agent="coords_to_city")
         | 
| 125 | 
            +
                try:
         | 
| 126 | 
            +
                    location = geolocator.reverse(coords)
         | 
| 127 | 
            +
                    return location.address
         | 
| 128 | 
            +
                except Exception as e:
         | 
| 129 | 
            +
                    print(f"Error: {e}")
         | 
| 130 | 
            +
                    return "Unknown Location"
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
         | 
| 134 | 
            +
                long = round(location[1], 3)
         | 
| 135 | 
            +
                lat = round(location[0], 3)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                results = duckdb.sql(
         | 
| 140 | 
            +
                    f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
         | 
| 141 | 
            +
                ).fetchdf()
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                if len(results) == 0:
         | 
| 144 | 
            +
                    return "", ""
         | 
| 145 | 
            +
                # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
         | 
| 146 | 
            +
                return results['latitude'].iloc[0], results['longitude'].iloc[0]
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
         | 
| 150 | 
            +
                """Identifies relevant tables for a plot based on user input.
         | 
| 151 | 
            +
                
         | 
| 152 | 
            +
                This function uses an LLM to analyze the user's question and the plot
         | 
| 153 | 
            +
                description to determine which tables in the DRIAS database would be
         | 
| 154 | 
            +
                most relevant for generating the requested visualization.
         | 
| 155 | 
            +
                
         | 
| 156 | 
            +
                Args:
         | 
| 157 | 
            +
                    user_question (str): The user's question about climate data
         | 
| 158 | 
            +
                    plot (Plot): The plot configuration object
         | 
| 159 | 
            +
                    llm: The language model instance to use for analysis
         | 
| 160 | 
            +
                    
         | 
| 161 | 
            +
                Returns:
         | 
| 162 | 
            +
                    list[str]: A list of table names that are relevant for the plot
         | 
| 163 | 
            +
                    
         | 
| 164 | 
            +
                Example:
         | 
| 165 | 
            +
                    >>> detect_relevant_tables(
         | 
| 166 | 
            +
                    ...     "What will the temperature be like in Paris?",
         | 
| 167 | 
            +
                    ...     indicator_evolution_at_location,
         | 
| 168 | 
            +
                    ...     llm
         | 
| 169 | 
            +
                    ... )
         | 
| 170 | 
            +
                    ['mean_annual_temperature', 'mean_summer_temperature']
         | 
| 171 | 
            +
                """
         | 
| 172 | 
            +
                # Get all table names
         | 
| 173 | 
            +
                table_names_list = DRIAS_TABLES
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                prompt = (
         | 
| 176 | 
            +
                    f"You are helping to build a plot following this description : {plot['description']}."
         | 
| 177 | 
            +
                    f"You are given a list of tables and a user question."
         | 
| 178 | 
            +
                    f"Based on the description of the plot, which table are appropriate for that kind of plot."
         | 
| 179 | 
            +
                    f"Write the 3 most relevant tables to use. Answer only a python list of table name."
         | 
| 180 | 
            +
                    f"### List of tables : {table_names_list}"
         | 
| 181 | 
            +
                    f"### User question : {user_question}"
         | 
| 182 | 
            +
                    f"### List of table name : "
         | 
| 183 | 
            +
                )
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                table_names = ast.literal_eval(
         | 
| 186 | 
            +
                    (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
         | 
| 187 | 
            +
                )
         | 
| 188 | 
            +
                return table_names
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            def replace_coordonates(coords, query, coords_tables):
         | 
| 192 | 
            +
                n = query.count(str(coords[0]))
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                for i in range(n):
         | 
| 195 | 
            +
                    query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
         | 
| 196 | 
            +
                    query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
         | 
| 197 | 
            +
                return query
         | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            async def detect_relevant_plots(user_question: str, llm):
         | 
| 201 | 
            +
                plots_description = ""
         | 
| 202 | 
            +
                for plot in PLOTS:
         | 
| 203 | 
            +
                    plots_description += "Name: " + plot["name"]
         | 
| 204 | 
            +
                    plots_description += " - Description: " + plot["description"] + "\n"
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                prompt = (
         | 
| 207 | 
            +
                    f"You are helping to answer a quesiton with insightful visualizations."
         | 
| 208 | 
            +
                    f"You are given an user question and a list of plots with their name and description."
         | 
| 209 | 
            +
                    f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
         | 
| 210 | 
            +
                    f"Write the most relevant tables to use. Answer only a python list of plot name."
         | 
| 211 | 
            +
                    f"### Descriptions of the plots : {plots_description}"
         | 
| 212 | 
            +
                    f"### User question : {user_question}"
         | 
| 213 | 
            +
                    f"### Name of the plot : "
         | 
| 214 | 
            +
                )
         | 
| 215 | 
            +
                # prompt = (
         | 
| 216 | 
            +
                #     f"You are helping to answer a question with insightful visualizations. "
         | 
| 217 | 
            +
                #     f"Given a list of plots with their name and description: "
         | 
| 218 | 
            +
                #     f"{plots_description} "
         | 
| 219 | 
            +
                #     f"The user question is: {user_question}. "
         | 
| 220 | 
            +
                #     f"Choose the most relevant plots to answer the question. "
         | 
| 221 | 
            +
                #     f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
         | 
| 222 | 
            +
                #     f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
         | 
| 223 | 
            +
                # )
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                plot_names = ast.literal_eval(
         | 
| 226 | 
            +
                    (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
         | 
| 227 | 
            +
                )
         | 
| 228 | 
            +
                return plot_names
         | 
| 229 | 
            +
             | 
| 230 | 
            +
             | 
| 231 | 
            +
            # Next Version
         | 
| 232 | 
            +
            # class QueryOutput(TypedDict):
         | 
| 233 | 
            +
            #     """Generated SQL query."""
         | 
| 234 | 
            +
             | 
| 235 | 
            +
            #     query: Annotated[str, ..., "Syntactically valid SQL query."]
         | 
| 236 | 
            +
             | 
| 237 | 
            +
             | 
| 238 | 
            +
            # class PlotlyCodeOutput(TypedDict):
         | 
| 239 | 
            +
            #     """Generated Plotly code"""
         | 
| 240 | 
            +
             | 
| 241 | 
            +
            #     code: Annotated[str, ..., "Synatically valid Plotly python code."]
         | 
| 242 | 
            +
            # def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
         | 
| 243 | 
            +
            #     """Generate SQL query to fetch information."""
         | 
| 244 | 
            +
            #     prompt_params = {
         | 
| 245 | 
            +
            #         "dialect": db.dialect,
         | 
| 246 | 
            +
            #         "table_info": db.get_table_info(),
         | 
| 247 | 
            +
            #         "input": user_input,
         | 
| 248 | 
            +
            #         "relevant_tables": relevant_tables,
         | 
| 249 | 
            +
            #         "model": "ALADIN63_CNRM-CM5",
         | 
| 250 | 
            +
            #     }
         | 
| 251 | 
            +
             | 
| 252 | 
            +
            #     prompt = ChatPromptTemplate.from_template(query_prompt_template)
         | 
| 253 | 
            +
            #     structured_llm = llm.with_structured_output(QueryOutput)
         | 
| 254 | 
            +
            #     chain = prompt | structured_llm
         | 
| 255 | 
            +
            #     result = chain.invoke(prompt_params)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
            #     return result["query"]
         | 
| 258 | 
            +
             | 
| 259 | 
            +
             | 
| 260 | 
            +
            # def fetch_data_from_sql_query(db: str, sql_query: str):
         | 
| 261 | 
            +
            #     conn = sqlite3.connect(db)
         | 
| 262 | 
            +
            #     cursor = conn.cursor()
         | 
| 263 | 
            +
            #     cursor.execute(sql_query)
         | 
| 264 | 
            +
            #     column_names = [desc[0] for desc in cursor.description]
         | 
| 265 | 
            +
            #     values = cursor.fetchall()
         | 
| 266 | 
            +
            #     return {"column_names": column_names, "data": values}
         | 
| 267 | 
            +
             | 
| 268 | 
            +
             | 
| 269 | 
            +
            # def generate_chart_code(user_input: str, sql_query: list[str], llm):
         | 
| 270 | 
            +
            #     """ "Generate plotly python code for the chart based on the sql query and the user question"""
         | 
| 271 | 
            +
             | 
| 272 | 
            +
            #     class PlotlyCodeOutput(TypedDict):
         | 
| 273 | 
            +
            #         """Generated Plotly code"""
         | 
| 274 | 
            +
             | 
| 275 | 
            +
            #         code: Annotated[str, ..., "Synatically valid Plotly python code."]
         | 
| 276 | 
            +
             | 
| 277 | 
            +
            #     prompt = ChatPromptTemplate.from_template(plot_prompt_template)
         | 
| 278 | 
            +
            #     structured_llm = llm.with_structured_output(PlotlyCodeOutput)
         | 
| 279 | 
            +
            #     chain = prompt | structured_llm
         | 
| 280 | 
            +
            #     result = chain.invoke({"input": user_input, "sql_query": sql_query})
         | 
| 281 | 
            +
            #     return result["code"]
         | 
    	
        climateqa/engine/talk_to_data/vanna_class.py
    ADDED
    
    | @@ -0,0 +1,325 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from vanna.base import VannaBase
         | 
| 2 | 
            +
            from pinecone import Pinecone
         | 
| 3 | 
            +
            from climateqa.engine.embeddings import get_embeddings_function
         | 
| 4 | 
            +
            import pandas as pd
         | 
| 5 | 
            +
            import hashlib
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            class MyCustomVectorDB(VannaBase):
         | 
| 8 | 
            +
                
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
                VectorDB class for storing and retrieving vectors from Pinecone.
         | 
| 11 | 
            +
                
         | 
| 12 | 
            +
                args : 
         | 
| 13 | 
            +
                    config (dict) : Configuration dictionary containing the Pinecone API key and the index name :
         | 
| 14 | 
            +
                        - pc_api_key (str) : Pinecone API key
         | 
| 15 | 
            +
                        - index_name (str) : Pinecone index name
         | 
| 16 | 
            +
                        - top_k (int) : Number of top results to return (default = 2)
         | 
| 17 | 
            +
                        
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                
         | 
| 20 | 
            +
                def __init__(self,config):
         | 
| 21 | 
            +
                    super().__init__(config = config)
         | 
| 22 | 
            +
                    try : 
         | 
| 23 | 
            +
                        self.api_key = config.get('pc_api_key')
         | 
| 24 | 
            +
                        self.index_name = config.get('index_name')
         | 
| 25 | 
            +
                    except : 
         | 
| 26 | 
            +
                        raise Exception("Please provide the Pinecone API key and the index name")
         | 
| 27 | 
            +
                    
         | 
| 28 | 
            +
                    self.pc = Pinecone(api_key = self.api_key)
         | 
| 29 | 
            +
                    self.index = self.pc.Index(self.index_name)
         | 
| 30 | 
            +
                    self.top_k = config.get('top_k', 2)
         | 
| 31 | 
            +
                    self.embeddings = get_embeddings_function()
         | 
| 32 | 
            +
                    
         | 
| 33 | 
            +
                    
         | 
| 34 | 
            +
                def check_embedding(self, id, namespace):
         | 
| 35 | 
            +
                    fetched = self.index.fetch(ids = [id], namespace = namespace)
         | 
| 36 | 
            +
                    if fetched['vectors'] == {}: 
         | 
| 37 | 
            +
                        return False
         | 
| 38 | 
            +
                    return True
         | 
| 39 | 
            +
                
         | 
| 40 | 
            +
                def generate_hash_id(self, data: str) -> str:
         | 
| 41 | 
            +
                    """
         | 
| 42 | 
            +
                    Generate a unique hash ID for the given data.
         | 
| 43 | 
            +
                    
         | 
| 44 | 
            +
                    Args:
         | 
| 45 | 
            +
                        data (str): The input data to hash (e.g., a concatenated string of user attributes).
         | 
| 46 | 
            +
                    
         | 
| 47 | 
            +
                    Returns:
         | 
| 48 | 
            +
                        str: A unique hash ID as a hexadecimal string.
         | 
| 49 | 
            +
                    """
         | 
| 50 | 
            +
                    
         | 
| 51 | 
            +
                    data_bytes = data.encode('utf-8')
         | 
| 52 | 
            +
                    hash_object = hashlib.sha256(data_bytes)
         | 
| 53 | 
            +
                    hash_id = hash_object.hexdigest()
         | 
| 54 | 
            +
                    
         | 
| 55 | 
            +
                    return hash_id
         | 
| 56 | 
            +
                
         | 
| 57 | 
            +
                def add_ddl(self, ddl: str, **kwargs) -> str:
         | 
| 58 | 
            +
                    id = self.generate_hash_id(ddl) + '_ddl'
         | 
| 59 | 
            +
                    
         | 
| 60 | 
            +
                    if self.check_embedding(id, 'ddl'):
         | 
| 61 | 
            +
                        print(f"DDL having id {id} already exists")
         | 
| 62 | 
            +
                        return id
         | 
| 63 | 
            +
                    
         | 
| 64 | 
            +
                    self.index.upsert(
         | 
| 65 | 
            +
                        vectors = [(id, self.embeddings.embed_query(ddl), {'ddl': ddl})],
         | 
| 66 | 
            +
                        namespace = 'ddl'
         | 
| 67 | 
            +
                    )
         | 
| 68 | 
            +
                    
         | 
| 69 | 
            +
                    return id
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def add_documentation(self, doc: str, **kwargs) -> str:
         | 
| 72 | 
            +
                    id = self.generate_hash_id(doc) + '_doc'
         | 
| 73 | 
            +
                    
         | 
| 74 | 
            +
                    if self.check_embedding(id, 'documentation'):
         | 
| 75 | 
            +
                        print(f"Documentation having id {id} already exists")
         | 
| 76 | 
            +
                        return id
         | 
| 77 | 
            +
                    
         | 
| 78 | 
            +
                    self.index.upsert(
         | 
| 79 | 
            +
                        vectors = [(id, self.embeddings.embed_query(doc), {'doc': doc})],
         | 
| 80 | 
            +
                        namespace = 'documentation'
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
                    
         | 
| 83 | 
            +
                    return id
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
         | 
| 86 | 
            +
                    id = self.generate_hash_id(question) + '_sql'
         | 
| 87 | 
            +
                    
         | 
| 88 | 
            +
                    if self.check_embedding(id, 'question_sql'):
         | 
| 89 | 
            +
                        print(f"Question-SQL pair having id {id} already exists")
         | 
| 90 | 
            +
                        return id
         | 
| 91 | 
            +
                    
         | 
| 92 | 
            +
                    self.index.upsert(
         | 
| 93 | 
            +
                        vectors = [(id, self.embeddings.embed_query(question + sql), {'question': question, 'sql': sql})],
         | 
| 94 | 
            +
                        namespace = 'question_sql'
         | 
| 95 | 
            +
                    )
         | 
| 96 | 
            +
                    
         | 
| 97 | 
            +
                    return id
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def get_related_ddl(self, question: str, **kwargs) -> list:
         | 
| 100 | 
            +
                    res = self.index.query(
         | 
| 101 | 
            +
                        vector=self.embeddings.embed_query(question),
         | 
| 102 | 
            +
                        top_k=self.top_k,
         | 
| 103 | 
            +
                        namespace='ddl',
         | 
| 104 | 
            +
                        include_metadata=True
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
                    
         | 
| 107 | 
            +
                    return [match['metadata']['ddl'] for match in res['matches']]
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def get_related_documentation(self, question: str, **kwargs) -> list:
         | 
| 110 | 
            +
                    res = self.index.query(
         | 
| 111 | 
            +
                        vector=self.embeddings.embed_query(question),
         | 
| 112 | 
            +
                        top_k=self.top_k,
         | 
| 113 | 
            +
                        namespace='documentation',
         | 
| 114 | 
            +
                        include_metadata=True
         | 
| 115 | 
            +
                    )
         | 
| 116 | 
            +
                    
         | 
| 117 | 
            +
                    return [match['metadata']['doc'] for match in res['matches']]
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                def get_similar_question_sql(self, question: str, **kwargs) -> list:
         | 
| 120 | 
            +
                    res = self.index.query(
         | 
| 121 | 
            +
                        vector=self.embeddings.embed_query(question),
         | 
| 122 | 
            +
                        top_k=self.top_k,
         | 
| 123 | 
            +
                        namespace='question_sql',
         | 
| 124 | 
            +
                        include_metadata=True
         | 
| 125 | 
            +
                    )
         | 
| 126 | 
            +
                    
         | 
| 127 | 
            +
                    return [(match['metadata']['question'], match['metadata']['sql']) for match in res['matches']]
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def get_training_data(self, **kwargs) -> pd.DataFrame:
         | 
| 130 | 
            +
                    
         | 
| 131 | 
            +
                    list_of_data = []
         | 
| 132 | 
            +
                    
         | 
| 133 | 
            +
                    namespaces = ['ddl', 'documentation', 'question_sql']
         | 
| 134 | 
            +
                    
         | 
| 135 | 
            +
                    for namespace in namespaces:
         | 
| 136 | 
            +
                        
         | 
| 137 | 
            +
                        data = self.index.query(
         | 
| 138 | 
            +
                        top_k=10000,
         | 
| 139 | 
            +
                        namespace=namespace,
         | 
| 140 | 
            +
                        include_metadata=True,
         | 
| 141 | 
            +
                        include_values=False
         | 
| 142 | 
            +
                        )
         | 
| 143 | 
            +
                        
         | 
| 144 | 
            +
                        for match in data['matches']:
         | 
| 145 | 
            +
                            list_of_data.append(match['metadata'])
         | 
| 146 | 
            +
                            
         | 
| 147 | 
            +
                    return pd.DataFrame(list_of_data)
         | 
| 148 | 
            +
                        
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
                def remove_training_data(self, id: str, **kwargs) -> bool:
         | 
| 152 | 
            +
                    if id.endswith("_ddl"):
         | 
| 153 | 
            +
                        self.Index.delete(ids=[id], namespace="_ddl")
         | 
| 154 | 
            +
                        return True
         | 
| 155 | 
            +
                    if id.endswith("_sql"):
         | 
| 156 | 
            +
                        self.index.delete(ids=[id], namespace="_sql")
         | 
| 157 | 
            +
                        return True
         | 
| 158 | 
            +
                    
         | 
| 159 | 
            +
                    if id.endswith("_doc"):
         | 
| 160 | 
            +
                        self.Index.delete(ids=[id], namespace="_doc")
         | 
| 161 | 
            +
                        return True
         | 
| 162 | 
            +
                    
         | 
| 163 | 
            +
                    return False
         | 
| 164 | 
            +
                
         | 
| 165 | 
            +
                def generate_embedding(self, text, **kwargs):
         | 
| 166 | 
            +
                    # Implement the method here
         | 
| 167 | 
            +
                    pass
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
                def get_sql_prompt(
         | 
| 171 | 
            +
                        self,
         | 
| 172 | 
            +
                        initial_prompt : str,
         | 
| 173 | 
            +
                        question: str,
         | 
| 174 | 
            +
                        question_sql_list: list,
         | 
| 175 | 
            +
                        ddl_list: list,
         | 
| 176 | 
            +
                        doc_list: list,
         | 
| 177 | 
            +
                        **kwargs,
         | 
| 178 | 
            +
                    ):
         | 
| 179 | 
            +
                        """
         | 
| 180 | 
            +
                        Example:
         | 
| 181 | 
            +
                        ```python
         | 
| 182 | 
            +
                        vn.get_sql_prompt(
         | 
| 183 | 
            +
                            question="What are the top 10 customers by sales?",
         | 
| 184 | 
            +
                            question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
         | 
| 185 | 
            +
                            ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
         | 
| 186 | 
            +
                            doc_list=["The customers table contains information about customers and their sales."],
         | 
| 187 | 
            +
                        )
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                        ```
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                        This method is used to generate a prompt for the LLM to generate SQL.
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                        Args:
         | 
| 194 | 
            +
                            question (str): The question to generate SQL for.
         | 
| 195 | 
            +
                            question_sql_list (list): A list of questions and their corresponding SQL statements.
         | 
| 196 | 
            +
                            ddl_list (list): A list of DDL statements.
         | 
| 197 | 
            +
                            doc_list (list): A list of documentation.
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                        Returns:
         | 
| 200 | 
            +
                            any: The prompt for the LLM to generate SQL.
         | 
| 201 | 
            +
                        """
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                        if initial_prompt is None:
         | 
| 204 | 
            +
                            initial_prompt = f"You are a {self.dialect} expert. " + \
         | 
| 205 | 
            +
                            "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                        initial_prompt = self.add_ddl_to_prompt(
         | 
| 208 | 
            +
                            initial_prompt, ddl_list, max_tokens=self.max_tokens
         | 
| 209 | 
            +
                        )
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                        if self.static_documentation != "":
         | 
| 212 | 
            +
                            doc_list.append(self.static_documentation)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                        initial_prompt = self.add_documentation_to_prompt(
         | 
| 215 | 
            +
                            initial_prompt, doc_list, max_tokens=self.max_tokens
         | 
| 216 | 
            +
                        )
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                        # initial_prompt = self.add_sql_to_prompt(
         | 
| 219 | 
            +
                        #     initial_prompt, question_sql_list, max_tokens=self.max_tokens
         | 
| 220 | 
            +
                        # )
         | 
| 221 | 
            +
             | 
| 222 | 
            +
             | 
| 223 | 
            +
                        initial_prompt += (
         | 
| 224 | 
            +
                            "===Response Guidelines \n"
         | 
| 225 | 
            +
                            "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
         | 
| 226 | 
            +
                            "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
         | 
| 227 | 
            +
                            "3. If the provided context is insufficient, please give a sql query based on your knowledge and the context provided. \n"
         | 
| 228 | 
            +
                            "4. Please use the most relevant table(s). \n"
         | 
| 229 | 
            +
                            "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
         | 
| 230 | 
            +
                            f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
         | 
| 231 | 
            +
                            f"7. Add a description of the table in the result of the sql query, if relevant. \n"
         | 
| 232 | 
            +
                            "8 Make sure to include the relevant KPI in the SQL query. The query should return impactfull data \n"
         | 
| 233 | 
            +
                            # f"8. If a set of latitude,longitude is provided, make a intermediate query to find the nearest value in the table and replace the coordinates in the sql query. \n"
         | 
| 234 | 
            +
                            # "7. Add a description of the table in the result of the sql query."
         | 
| 235 | 
            +
                            # "7. If the question is about a specific latitude, longitude, query an interval of 0.3 and keep only the first set of coordinate. \n"
         | 
| 236 | 
            +
                            # "7. Table names should be included in the result of the sql query. Use for example Mean_winter_temperature AS table_name in the query \n"
         | 
| 237 | 
            +
                        )
         | 
| 238 | 
            +
             | 
| 239 | 
            +
             | 
| 240 | 
            +
                        message_log = [self.system_message(initial_prompt)]
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                        for example in question_sql_list:
         | 
| 243 | 
            +
                            if example is None:
         | 
| 244 | 
            +
                                print("example is None")
         | 
| 245 | 
            +
                            else:
         | 
| 246 | 
            +
                                if example is not None and "question" in example and "sql" in example:
         | 
| 247 | 
            +
                                    message_log.append(self.user_message(example["question"]))
         | 
| 248 | 
            +
                                    message_log.append(self.assistant_message(example["sql"]))
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                        message_log.append(self.user_message(question))
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                        return message_log
         | 
| 253 | 
            +
                    
         | 
| 254 | 
            +
             | 
| 255 | 
            +
            # def get_sql_prompt(
         | 
| 256 | 
            +
            #         self,
         | 
| 257 | 
            +
            #         initial_prompt : str,
         | 
| 258 | 
            +
            #         question: str,
         | 
| 259 | 
            +
            #         question_sql_list: list,
         | 
| 260 | 
            +
            #         ddl_list: list,
         | 
| 261 | 
            +
            #         doc_list: list,
         | 
| 262 | 
            +
            #         **kwargs,
         | 
| 263 | 
            +
            #     ):
         | 
| 264 | 
            +
            #         """
         | 
| 265 | 
            +
            #         Example:
         | 
| 266 | 
            +
            #         ```python
         | 
| 267 | 
            +
            #         vn.get_sql_prompt(
         | 
| 268 | 
            +
            #             question="What are the top 10 customers by sales?",
         | 
| 269 | 
            +
            #             question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
         | 
| 270 | 
            +
            #             ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
         | 
| 271 | 
            +
            #             doc_list=["The customers table contains information about customers and their sales."],
         | 
| 272 | 
            +
            #         )
         | 
| 273 | 
            +
             | 
| 274 | 
            +
            #         ```
         | 
| 275 | 
            +
             | 
| 276 | 
            +
            #         This method is used to generate a prompt for the LLM to generate SQL.
         | 
| 277 | 
            +
             | 
| 278 | 
            +
            #         Args:
         | 
| 279 | 
            +
            #             question (str): The question to generate SQL for.
         | 
| 280 | 
            +
            #             question_sql_list (list): A list of questions and their corresponding SQL statements.
         | 
| 281 | 
            +
            #             ddl_list (list): A list of DDL statements.
         | 
| 282 | 
            +
            #             doc_list (list): A list of documentation.
         | 
| 283 | 
            +
             | 
| 284 | 
            +
            #         Returns:
         | 
| 285 | 
            +
            #             any: The prompt for the LLM to generate SQL.
         | 
| 286 | 
            +
            #         """
         | 
| 287 | 
            +
             | 
| 288 | 
            +
            #         if initial_prompt is None:
         | 
| 289 | 
            +
            #             initial_prompt = f"You are a {self.dialect} expert. " + \
         | 
| 290 | 
            +
            #             "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
         | 
| 291 | 
            +
             | 
| 292 | 
            +
            #         initial_prompt = self.add_ddl_to_prompt(
         | 
| 293 | 
            +
            #             initial_prompt, ddl_list, max_tokens=self.max_tokens
         | 
| 294 | 
            +
            #         )
         | 
| 295 | 
            +
             | 
| 296 | 
            +
            #         if self.static_documentation != "":
         | 
| 297 | 
            +
            #             doc_list.append(self.static_documentation)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
            #         initial_prompt = self.add_documentation_to_prompt(
         | 
| 300 | 
            +
            #             initial_prompt, doc_list, max_tokens=self.max_tokens
         | 
| 301 | 
            +
            #         )
         | 
| 302 | 
            +
             | 
| 303 | 
            +
            #         initial_prompt += (
         | 
| 304 | 
            +
            #             "===Response Guidelines \n"
         | 
| 305 | 
            +
            #             "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
         | 
| 306 | 
            +
            #             "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
         | 
| 307 | 
            +
            #             "3. If the provided context is insufficient, please explain why it can't be generated. \n"
         | 
| 308 | 
            +
            #             "4. Please use the most relevant table(s). \n"
         | 
| 309 | 
            +
            #             "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
         | 
| 310 | 
            +
            #             f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
         | 
| 311 | 
            +
            #         )
         | 
| 312 | 
            +
             | 
| 313 | 
            +
            #         message_log = [self.system_message(initial_prompt)]
         | 
| 314 | 
            +
             | 
| 315 | 
            +
            #         for example in question_sql_list:
         | 
| 316 | 
            +
            #             if example is None:
         | 
| 317 | 
            +
            #                 print("example is None")
         | 
| 318 | 
            +
            #             else:
         | 
| 319 | 
            +
            #                 if example is not None and "question" in example and "sql" in example:
         | 
| 320 | 
            +
            #                     message_log.append(self.user_message(example["question"]))
         | 
| 321 | 
            +
            #                     message_log.append(self.assistant_message(example["sql"]))
         | 
| 322 | 
            +
             | 
| 323 | 
            +
            #         message_log.append(self.user_message(question))
         | 
| 324 | 
            +
             | 
| 325 | 
            +
            #         return message_log
         | 

