Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import gradio as gr | |
| from pathlib import Path | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import numpy as np | |
| import torch | |
| from chronos import ChronosPipeline | |
| from datetime import datetime | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as ticker | |
| def filter_data(start, end, df_state, select_product_column, date_column, target_column): | |
| if not date_column: | |
| raise gr.Error("Please select a Date column") | |
| if not target_column: | |
| raise gr.Error("Please select a target column") | |
| start_datetime = pd.to_datetime(datetime.utcfromtimestamp(start)) | |
| end_datetime = pd.to_datetime(datetime.utcfromtimestamp(end)) | |
| original_date_column = None | |
| original_target_column = None | |
| column_mapping = { | |
| ' '.join([word.capitalize() for word in col.split('_')]): col | |
| for col in df_state.columns | |
| } | |
| if date_column in column_mapping: | |
| original_date_column = column_mapping[date_column] | |
| if target_column in column_mapping: | |
| original_target_column = column_mapping[target_column] | |
| df_state[original_date_column] = pd.to_datetime(df_state[original_date_column]) | |
| filtered_df = df_state[(df_state[original_date_column] >= start_datetime) & (df_state[original_date_column] <= end_datetime)] | |
| filtered_df = filtered_df.groupby(original_date_column)[original_target_column].sum().reset_index() | |
| filtered_df = filtered_df.sort_values(by=original_date_column) | |
| fig = px.line(filtered_df, x=original_date_column, y=original_target_column, title="Historical Sales Data") | |
| return [filtered_df, fig] | |
| def upload_file(filepath): | |
| name = Path(filepath).name | |
| df = pd.read_csv(filepath.name) | |
| datetime_columns = [] | |
| numeric_columns = [] | |
| for col in df.columns: | |
| try: | |
| if all(isinstance(float(x), float) for x in df[col].head(3)): | |
| numeric_columns.append(col) | |
| except ValueError: | |
| continue | |
| for col in df.columns: | |
| if df[col].dtype == 'object': | |
| try: | |
| df[col] = pd.to_datetime(df[col]) | |
| except: | |
| pass | |
| if df[col].dtype == 'datetime64[ns]': | |
| datetime_columns.append(col) | |
| datetime_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), datetime_columns)) | |
| columns = df.columns.tolist() | |
| transformed_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), columns)) | |
| target_col = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), numeric_columns)) | |
| transformed_columns.insert(0, "") | |
| data_columns = gr.Dropdown(choices=transformed_columns, value=None) | |
| date_columns = gr.Dropdown(choices=datetime_columns, value=None) | |
| target_columns = gr.Dropdown(choices=target_col, value=None) | |
| return [df, data_columns, date_columns, target_columns] | |
| def download_file(): | |
| return [gr.UploadButton(visible=True), gr.DownloadButton(visible=False)] | |
| def set_products(selected_column, df_state): | |
| column_mapping = { | |
| ' '.join([word.capitalize() for word in col.split('_')]): col | |
| for col in df_state.columns | |
| } | |
| if selected_column in column_mapping: | |
| original_column = column_mapping[selected_column] | |
| unique_values = df_state[original_column].dropna().unique().tolist() | |
| return unique_values | |
| return [] | |
| def set_dates(selected_column, df_state): | |
| column_mapping = { | |
| ' '.join([word.capitalize() for word in col.split('_')]): col | |
| for col in df_state.columns | |
| } | |
| if selected_column in column_mapping: | |
| original_column = column_mapping[selected_column] | |
| min_date = df_state[original_column].min() | |
| max_date = df_state[original_column].max() | |
| return min_date, max_date | |
| return None, None | |
| def forecast_chronos_data(df_state, date_column, target_column, select_period, forecasting_type): | |
| if not date_column: | |
| raise gr.Error("Please select a Date column") | |
| if not target_column: | |
| raise gr.Error("Please select a target column") | |
| original_date_column = None | |
| original_target_column = None | |
| column_mapping = { | |
| ' '.join([word.capitalize() for word in col.split('_')]): col | |
| for col in df_state.columns | |
| } | |
| if date_column in column_mapping: | |
| original_date_column = column_mapping[date_column] | |
| if target_column in column_mapping: | |
| original_target_column = column_mapping[target_column] | |
| df_forecast = pd.DataFrame() | |
| df_forecast['date'] = df_state[original_date_column] | |
| df_forecast['month'] = df_forecast['date'].dt.month | |
| df_forecast['year'] = df_forecast['date'].dt.year | |
| df_forecast['sold_qty'] = df_state[original_target_column] | |
| monthly_sales = df_forecast.groupby(['year', 'month'])['sold_qty'].sum().reset_index() | |
| monthly_sales = monthly_sales.rename(columns={'year': 'year', 'month': 'month', 'sold_qty': 'y'}) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipeline = ChronosPipeline.from_pretrained( | |
| "amazon/chronos-t5-base", | |
| device_map=device, | |
| torch_dtype=torch.float32, | |
| ) | |
| context = torch.tensor(monthly_sales["y"]) | |
| prediction_length = select_period | |
| forecast = pipeline.predict(context, prediction_length) | |
| forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length) | |
| low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) | |
| low, median, high = np.ceil(low).astype(int), np.ceil(median).astype(int), np.ceil(high).astype(int) | |
| forecast_index = list(forecast_index) | |
| fig = px.line( | |
| x=monthly_sales.index, | |
| y=monthly_sales["y"], | |
| title="Sales Forecasting Visualization", | |
| labels={"x": "Months", "y": f"{target_column}"}, | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=forecast_index, | |
| y=median, | |
| name="Median Forecast", | |
| line=dict(color="tomato", width=2) | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=forecast_index, | |
| y=high, | |
| name="80% Prediction Interval", | |
| mode='lines', | |
| line=dict(width=2, color='rgba(50, 205, 50, 1)'), | |
| showlegend=False | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=forecast_index, | |
| y=low, | |
| name="10% Prediction Interval", | |
| mode='lines', | |
| line=dict(width=1, color='rgba(255, 255, 0, 1)'), | |
| showlegend=False, | |
| fillcolor='rgba(255, 99, 71, 0.3)', | |
| fill='tonexty', | |
| ) | |
| ) | |
| fig.update_layout( | |
| title_font_size=20, | |
| xaxis_title_font_size=16, | |
| yaxis_title_font_size=16, | |
| legend_font_size=16, | |
| xaxis_tickfont_size=14, | |
| yaxis_tickfont_size=14, | |
| showlegend=True, | |
| width=1600, # Equivalent to figsize=(30, 10) | |
| height=400, | |
| xaxis=dict( | |
| title="Months", | |
| tickfont=dict(size=14), | |
| gridcolor='rgba(128, 128, 128, 0.7)', | |
| gridwidth=1.2, | |
| dtick=3, | |
| griddash='dash', | |
| rangeslider=dict(visible=True), | |
| rangeselector=dict( | |
| buttons=list([ | |
| dict(count=6, label="6m", step="month", stepmode="backward"), | |
| dict(count=12, label="1y", step="month", stepmode="backward"), | |
| dict(count=24, label="2y", step="month", stepmode="backward"), | |
| dict(step="all", label="All") | |
| ]) | |
| ) | |
| ), | |
| yaxis=dict( | |
| gridcolor='rgba(128, 128, 128, 0.7)', | |
| gridwidth=1.2, | |
| dtick=5, # Set tick interval to 5 units | |
| griddash='dash' | |
| ), | |
| plot_bgcolor='white' | |
| # margin=dict(l=50, r=50, t=50, b=50) | |
| ) | |
| fig.update_traces( | |
| line=dict(color="royalblue", width=2), | |
| selector=dict(name="y") # Updates only the historical data line | |
| ) | |
| # Bar Chart | |
| bar_chart = go.Figure() | |
| bar_chart.add_trace( | |
| go.Bar( | |
| x=monthly_sales.index, | |
| y=monthly_sales["y"], | |
| name="Historical Sales", | |
| marker_color='rgba(50, 150, 250, 0.6)', # Light blue color | |
| opacity=0.8 | |
| ) | |
| ) | |
| bar_chart.add_trace( | |
| go.Bar( | |
| x=forecast_index, | |
| y=median, | |
| name="Median Forecast", | |
| marker_color='rgba(255, 99, 71, 0.9)', | |
| opacity=0.8 | |
| ) | |
| ) | |
| bar_chart.update_layout( | |
| title="Sales Forecasting Visualization (Bar Chart)", | |
| xaxis_title="Months", | |
| yaxis_title=f"{target_column}", | |
| title_font_size=20, | |
| xaxis_title_font_size=16, | |
| yaxis_title_font_size=16, | |
| legend_font_size=16, | |
| width=1800, | |
| height=600, | |
| plot_bgcolor='white' | |
| ) | |
| return fig, bar_chart | |
| # plt.figure(figsize=(30, 10)) | |
| # plt.plot(monthly_sales["y"], color="royalblue", label="Historical Data", linewidth=2) | |
| # plt.plot(forecast_index, median, color="tomato", label="Median Forecast", linewidth=2) | |
| # plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval") | |
| # plt.title("Sales Forecasting Visualization", fontsize=16) | |
| # plt.xlabel("Months", fontsize=20) | |
| # plt.ylabel("Sold Qty", fontsize=20) | |
| # plt.xticks(fontsize=18) | |
| # plt.yticks(fontsize=18) | |
| # ax = plt.gca() | |
| # ax.xaxis.set_major_locator(ticker.MultipleLocator(3)) | |
| # ax.yaxis.set_major_locator(ticker.MultipleLocator(5)) | |
| # ax.grid(which='major', linestyle='--', linewidth=1.2, color='gray', alpha=0.7) | |
| # plt.legend(fontsize=18) | |
| # plt.grid(linestyle='--', linewidth=1.2, color='gray', alpha=0.7) | |
| # plt.tight_layout() | |
| # return plt.gcf() | |
| def home_page(): | |
| content = """ | |
| ### **Sales Forecasting with Chronos** | |
| Welcome to the future of sales optimization with **Chronos**. | |
| Say goodbye to guesswork and unlock the power of **data-driven insights** with our advanced forecasting platform. | |
| - **Seamless CSV Upload**: Quickly upload your sales data in CSV formatβno technical expertise needed. | |
| - **AI-Powered Predictions**: Harness the power of state-of-the-art machine learning models to uncover trends and forecast future sales performance. | |
| - **Interactive Visualizations**: Gain actionable insights with intuitive charts and graphs that make data easy to understand. | |
| Start making smarter, data-backed business decisions today with **Chronos**! | |
| """ | |
| return content | |
| def about_page(): | |
| content = """ | |
| ### π§ **Contact Us:** | |
| - **Email**: contact@topsinfosolutions.com βοΈ | |
| - **Website**: [https://www.topsinfosolutions.com/](https://www.topsinfosolutions.com/) π | |
| ### π **What We Offer:** | |
| - **Custom AI Solutions**: Tailored to your business needs π€ | |
| - **Chatbot Development**: Build intelligent conversational agents π¬ | |
| - **Vision Models**: Computer vision solutions for various applications πΌοΈ | |
| - **AI Agents**: Personalized agents powered by advanced LLMs π€ | |
| ### π€ **How We Can Help:** | |
| Reach out to us for bespoke AI services. Whether you need chatbots, vision models, or AI-powered agents, weβre here to build solutions that make a difference! π | |
| ### π¬ **Get in Touch:** | |
| If you have any questions or need a custom solution, click the button below to schedule a consultation with us. π | |
| """ | |
| return content | |
| with gr.Blocks(theme=gr.themes.Default()) as demo: | |
| with gr.Tabs(): | |
| with gr.TabItem("Home"): | |
| df_state = gr.State() | |
| # gr.Image("/content/chronos-logo.png", interactive=False) | |
| home_output = gr.Markdown(value=home_page(), label="Playground") | |
| gr.Markdown("## Step 1: Historical/Training Data (currently supports *.csv only)") | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload Historical (Training Data) Sales Data", file_types=[".csv"]) | |
| with gr.Row(): | |
| date_column = gr.Dropdown(choices=[], label="Select Date column (*Required)", multiselect=False, value=None) | |
| target_column = gr.Dropdown(choices=[], label="Select Target column (*Required)", multiselect=False, value=None) | |
| select_product_column = gr.Dropdown(choices=[], label="Select Product column (Optional)", multiselect=False, value=None) | |
| select_product = gr.Dropdown(choices=[], label="Select Product (Optional)", multiselect=False, value=None) | |
| with gr.Row(): | |
| start = gr.DateTime("2021-01-01 00:00:00", label="Training data Start date") | |
| end = gr.DateTime("2021-01-05 00:00:00", label="Training data End date") | |
| apply_btn = gr.Button("Visualize Data", scale=0) | |
| gr.Examples( | |
| examples=[ | |
| ["example_files/test_tops_product_id_1.csv"], | |
| ["example_files/test_tops_product_id_2.csv"], | |
| ["example_files/test_tops_product_id_3.csv"], | |
| ["example_files/test_tops_product_id_4.csv"] | |
| ], | |
| inputs=file_input, | |
| outputs=[df_state, select_product_column, date_column, target_column], | |
| fn=upload_file, | |
| ) | |
| with gr.Row(): | |
| historical_data_plot = gr.Plot() | |
| apply_btn.click( | |
| filter_data, | |
| inputs=[start, end, df_state, select_product_column, date_column, target_column], | |
| outputs=[df_state, historical_data_plot] | |
| ) | |
| gr.Markdown("## Step 2: Forecast") | |
| with gr.Row(): | |
| forecasting_type = gr.Radio(["day", "monthly", "year"], value="monthly", label="Forecasting Type", interactive=False) | |
| select_period = gr.Slider(2, 60, value=12, label="Select Period", info="Check Selected Forecast Type", interactive =True, step=1) | |
| forecast_btn = gr.Button("Forecast") | |
| with gr.Tabs(): | |
| with gr.TabItem("Line Chart"): | |
| with gr.Row(): | |
| plot_forecast_output = gr.Plot(label="Chronos Forecasting Visualization (Line)") | |
| with gr.TabItem("Bar Chart"): | |
| with gr.Row(): | |
| bar_plot_forecast_output = gr.Plot(label="Chronos Forecasting Visualization (Bar)") | |
| forecast_btn.click( | |
| forecast_chronos_data, | |
| inputs=[df_state, date_column, target_column, select_period], | |
| outputs=[plot_forecast_output, bar_plot_forecast_output] | |
| ) | |
| file_input.change( | |
| upload_file, | |
| inputs=[file_input], | |
| outputs=[df_state, select_product_column, date_column, target_column] | |
| ) | |
| select_product_column.change( | |
| set_products, | |
| inputs=[select_product_column, df_state], | |
| outputs=[] | |
| ) | |
| date_column.change( | |
| set_dates, | |
| inputs=[date_column, df_state], | |
| outputs=[start, end] | |
| ) | |
| target_column.change( | |
| lambda x: x if x else [], | |
| inputs=[target_column], | |
| outputs=[] | |
| ) | |
| with gr.TabItem("About Tops"): | |
| df_state = gr.State() | |
| # gr.Image("/content/chronos-logo.png", interactive=False) | |
| about_output = gr.Markdown(value=about_page(), label="About Tops") | |
| if __name__ == "__main__": | |
| demo.launch() | |