Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from typing import Iterable | |
| import gradio as gr | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| import math | |
| import torch | |
| from chronos import ChronosPipeline | |
| import warnings | |
| from seafoam import Seafoam | |
| warnings.filterwarnings("ignore") | |
| import numpy as np | |
| import matplotlib.ticker as ticker | |
| os.makedirs("example_files", exist_ok=True) | |
| def process_csv(file): | |
| if file is None: | |
| return None, gr.Dropdown(choices=[]) | |
| if not file.name.endswith('.csv'): | |
| raise gr.Error("Please upload a CSV file only") | |
| df = pd.read_csv(file.name) | |
| columns = df.columns.tolist() | |
| transformed_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), columns)) | |
| data_columns = gr.Dropdown(choices=transformed_columns, value=None) | |
| return df, data_columns, data_columns | |
| def process_data(csv_file, date_column_value, target_column_value): | |
| try: | |
| if not csv_file: | |
| return "Error: Upload Csv File" | |
| if not date_column_value or not target_column_value: | |
| return "Error: Both date and target columns must be selected" | |
| date_column = date_column_value.lower().replace(" ", "_") | |
| target_column = target_column_value.lower().replace(" ", "_") | |
| # Read the CSV file | |
| df = pd.read_csv(csv_file.name) | |
| numeric_mask = df[date_column].apply(lambda x: isinstance(x, (int, float))) | |
| if numeric_mask.any(): | |
| return "Error: Found numeric values in column '{date_column}'. Please provide dates in string format like 'YYYY-MM-DD'." | |
| df['date'] = pd.to_datetime(df[date_column]) | |
| df['month'] = df['date'].dt.month | |
| df['year'] = df['date'].dt.year | |
| df['sold_qty'] = df[target_column] | |
| monthly_sales = df.groupby(['year', 'month'])['sold_qty'].sum().reset_index() | |
| monthly_sales = monthly_sales.rename(columns={'year': 'year', 'month': 'month', 'sold_qty': 'y'}) | |
| pipeline = ChronosPipeline.from_pretrained( | |
| "amazon/chronos-t5-base", | |
| device_map="cpu", | |
| torch_dtype=torch.float32, | |
| ) | |
| context = torch.tensor(monthly_sales["y"]) | |
| prediction_length = 12 | |
| forecast = pipeline.predict(context, prediction_length) | |
| # Prepare forecast data | |
| forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length) | |
| low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) | |
| df['month_name'] = df['date'].dt.month_name() | |
| month_order = [ | |
| 'January', 'February', 'March', 'April', 'May', 'June', | |
| 'July', 'August', 'September', 'October', 'November', 'December' | |
| ] | |
| df['month_name'] = pd.Categorical(df['month_name'], categories=month_order, ordered=True) | |
| expanded_df = df.copy() | |
| year_month_sum = expanded_df.groupby(['year', 'month_name'])['sold_qty'].sum().reset_index() | |
| # Create a pivot table: sum of units sold per year and month | |
| pivot_table = year_month_sum.pivot(index='year', columns='month_name', values='sold_qty') | |
| new_data_list = [math.ceil(x) for x in median] | |
| # Add the new data list for the next year (incrementing the year by 1) | |
| next_year = pivot_table.index[-1] + 1 # Increment the year by 1 | |
| pivot_table.loc[next_year] = new_data_list # Add the new row for the next year | |
| # Visualization: Pivot Table Data (Second Plot) | |
| fig3, ax3 = plt.subplots(figsize=(18, 6)) | |
| # Create a table inside the plot | |
| ax3.axis('off') # Turn off the axis | |
| table = ax3.table(cellText=pivot_table.values, colLabels=pivot_table.columns, rowLabels=pivot_table.index, loc='center', cellLoc='center') | |
| # Style the table | |
| table.auto_set_font_size(False) | |
| table.set_fontsize(12) | |
| table.scale(1.2, 1.2) # Scale the table for better visibility | |
| # Adjust table colors (optional) | |
| for (i, j), cell in table.get_celld().items(): | |
| if i == 0: | |
| cell.set_text_props(weight='bold') | |
| cell.set_facecolor('#f2f2f2') | |
| elif j == 0: | |
| cell.set_text_props(weight='bold') | |
| cell.set_facecolor('#f2f2f2') | |
| else: | |
| cell.set_facecolor('white') | |
| # Visualization | |
| plt.figure(figsize=(30, 10)) | |
| plt.plot(monthly_sales["y"], color="royalblue", label="Historical Data", linewidth=2) | |
| plt.plot(forecast_index, median, color="tomato", label="Median Forecast", linewidth=2) | |
| plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval") | |
| plt.title("Sales Forecasting Visualization", fontsize=16) | |
| plt.xlabel("Months", fontsize=20) | |
| plt.ylabel("Sold Qty", fontsize=20) | |
| plt.xticks(fontsize=18) | |
| plt.yticks(fontsize=18) | |
| ax = plt.gca() | |
| ax.xaxis.set_major_locator(ticker.MultipleLocator(3)) | |
| ax.yaxis.set_major_locator(ticker.MultipleLocator(5)) | |
| ax.grid(which='major', linestyle='--', linewidth=1.2, color='gray', alpha=0.7) | |
| plt.legend(fontsize=18) | |
| plt.grid(linestyle='--', linewidth=1.2, color='gray', alpha=0.7) | |
| plt.tight_layout() | |
| return plt.gcf(), fig3 | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| return None | |
| # Create Gradio interface | |
| with gr.Blocks(theme=Seafoam()) as demo: | |
| gr.Markdown("# Chronos Forecasting - Tops infosolutions Pvt Ltd") | |
| gr.Markdown("Upload a CSV file and click 'Forecast' to generate sales forecast for next 12 months .") | |
| df_state = gr.State() | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload CSV File", file_types=[".csv"]) | |
| with gr.Row(): | |
| date_column = gr.Dropdown( | |
| choices=[], | |
| label="Select Date column", | |
| multiselect=False, | |
| value=None | |
| ) | |
| target_column = gr.Dropdown( | |
| choices=[], | |
| label="Select Target column", | |
| multiselect=False, | |
| value=None | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["example_files/test_tops_product_id_1.csv"], | |
| ["example_files/test_tops_product_id_2.csv"], | |
| ["example_files/test_tops_product_id_3.csv"], | |
| ["example_files/test_tops_product_id_4.csv"] | |
| ], | |
| inputs=file_input, | |
| outputs=[df_state, date_column, target_column], | |
| fn=process_csv, | |
| cache_examples=True | |
| ) | |
| with gr.Row(): | |
| visualize_btn = gr.Button("Forecast", variant="primary") | |
| with gr.Row(): | |
| plot_output = gr.Plot(label="Chronos Forecasting Visualization") | |
| with gr.Row(): | |
| pivot_plot_output = gr.Plot(label="Monthly Sales Pivot Table") | |
| file_input.upload( | |
| process_csv, | |
| inputs=[file_input], | |
| outputs=[df_state, date_column, target_column] | |
| ) | |
| # Column selection handler | |
| date_column.change( | |
| lambda x: x if x else "", | |
| inputs=[date_column], | |
| outputs=[] | |
| ) | |
| target_column.change( | |
| lambda x: x if x else "", | |
| inputs=[target_column], | |
| outputs=[] | |
| ) | |
| visualize_btn.click( | |
| fn=process_data, | |
| inputs=[file_input, date_column, target_column], | |
| outputs=[plot_output, pivot_plot_output] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |