Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from chronos import ChronosPipeline | |
| import yfinance as yf | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.dates as mdates | |
| from sklearn.metrics import mean_absolute_error, mean_squared_error | |
| import tempfile | |
| def get_popular_tickers(): | |
| return [ | |
| "AAPL", "MSFT", "GOOGL", "AMZN", "META", "TSLA", "NVDA", "JPM", | |
| "JNJ", "V", "PG", "WMT", "BAC", "DIS", "NFLX", "INTC" | |
| ] | |
| def predict_stock(ticker, train_data_points, prediction_days): | |
| try: | |
| # Asegurar que los par谩metros sean enteros | |
| train_data_points = int(train_data_points) | |
| prediction_days = int(prediction_days) | |
| # Configurar el pipeline | |
| pipeline = ChronosPipeline.from_pretrained( | |
| "amazon/chronos-t5-mini", | |
| device_map="cpu", | |
| torch_dtype=torch.float32 | |
| ) | |
| # Obtener datos hist贸ricos | |
| stock = yf.Ticker(ticker) | |
| hist = stock.history(period="max") | |
| if hist.empty: | |
| raise ValueError(f"No hay datos disponibles para {ticker}") | |
| stock_prices = hist[['Close']].reset_index() | |
| df = stock_prices.rename(columns={'Date': 'Date', 'Close': f'{ticker}_Close'}) | |
| total_points = len(df) | |
| if total_points < 50: | |
| raise ValueError(f"Datos insuficientes para {ticker}") | |
| # Asegurar que el n煤mero de datos de entrenamiento no exceda el total disponible | |
| train_data_points = min(train_data_points, total_points) | |
| # Crear el contexto para entrenamiento | |
| context = torch.tensor(df[f'{ticker}_Close'][:train_data_points].values, dtype=torch.float32) | |
| # Realizar predicci贸n | |
| forecast = pipeline.predict(context, prediction_days, limit_prediction_length=False) | |
| low, median, high = np.quantile(forecast[0].numpy(), [0.01, 0.5, 0.99], axis=0) | |
| plt.figure(figsize=(20, 10)) | |
| plt.clf() | |
| # Determinar el rango de fechas para mostrar | |
| context_days = min(10, train_data_points) | |
| start_index = max(0, train_data_points - context_days) | |
| end_index = min(train_data_points + prediction_days, total_points) | |
| # Plotear datos hist贸ricos | |
| historical_dates = df['Date'][start_index:end_index] | |
| historical_data = df[f'{ticker}_Close'][start_index:end_index].values | |
| plt.plot(historical_dates, | |
| historical_data, | |
| color='blue', | |
| linewidth=2, | |
| label='Datos Reales') | |
| # Crear fechas para la predicci贸n | |
| if train_data_points < total_points: | |
| prediction_start_date = df['Date'].iloc[train_data_points] | |
| else: | |
| last_date = df['Date'].iloc[-1] | |
| prediction_start_date = last_date + pd.Timedelta(days=1) | |
| prediction_dates = pd.date_range(start=prediction_start_date, periods=prediction_days, freq='B') | |
| # Plotear predicci贸n | |
| plt.plot(prediction_dates, | |
| median, | |
| color='black', | |
| linewidth=2, | |
| linestyle='-', | |
| label='Predicci贸n') | |
| # 脕rea de confianza | |
| plt.fill_between(prediction_dates, low, high, | |
| color='gray', alpha=0.2, | |
| label='Intervalo de Confianza') | |
| # Calcular m茅tricas si hay datos reales para comparar | |
| overlap_end_index = train_data_points + prediction_days | |
| if overlap_end_index <= total_points: | |
| real_future_dates = df['Date'][train_data_points:overlap_end_index] | |
| real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values | |
| matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)] | |
| matching_indices = matching_dates.index - train_data_points | |
| plt.plot(matching_dates, | |
| real_future_data[matching_indices], | |
| color='red', | |
| linewidth=2, | |
| linestyle='--', | |
| label='Datos Reales de Validaci贸n') | |
| predicted_data = median[:len(matching_indices)] | |
| mae = mean_absolute_error(real_future_data[matching_indices], predicted_data) | |
| rmse = np.sqrt(mean_squared_error(real_future_data[matching_indices], predicted_data)) | |
| mape = np.mean(np.abs((real_future_data[matching_indices] - predicted_data) / real_future_data[matching_indices])) * 100 | |
| plt.title(f"Predicci贸n del Precio de {ticker}\nMAE: {mae:.2f} | RMSE: {rmse:.2f} | MAPE: {mape:.2f}%", | |
| fontsize=14, pad=20) | |
| else: | |
| plt.title(f"Predicci贸n Futura del Precio de {ticker}", | |
| fontsize=14, pad=20) | |
| plt.legend(loc="upper left", fontsize=12) | |
| plt.xlabel("Fecha", fontsize=12) | |
| plt.ylabel("Precio", fontsize=12) | |
| plt.grid(True, which='both', axis='x', linestyle='--', linewidth=0.5) | |
| ax = plt.gca() | |
| locator = mdates.DayLocator() | |
| formatter = mdates.DateFormatter('%Y-%m-%d') | |
| ax.xaxis.set_major_locator(locator) | |
| ax.xaxis.set_major_formatter(formatter) | |
| plt.setp(ax.get_xticklabels(), rotation=45, ha='right') | |
| plt.tight_layout() | |
| # Crear archivo CSV temporal | |
| temp_csv = tempfile.NamedTemporaryFile(delete=False, suffix='.csv') | |
| prediction_df = pd.DataFrame({ | |
| 'Date': prediction_dates, | |
| 'Predicted_Price': median, | |
| 'Lower_Bound': low, | |
| 'Upper_Bound': high | |
| }) | |
| if overlap_end_index <= total_points: | |
| real_future_dates = df['Date'][train_data_points:overlap_end_index] | |
| real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values | |
| matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)] | |
| prediction_df = prediction_df[prediction_df['Date'].isin(matching_dates)] | |
| prediction_df['Real_Price'] = real_future_data[:len(prediction_df)] | |
| prediction_df.to_csv(temp_csv.name, index=False) | |
| temp_csv.close() | |
| return plt, temp_csv.name | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| raise gr.Error(f"Error al procesar {ticker}: {str(e)}") | |
| def update_train_data_points(ticker): | |
| if not ticker: | |
| return gr.Slider.update(value=1000, maximum=5000) | |
| try: | |
| stock = yf.Ticker(ticker) | |
| hist = stock.history(period="max") | |
| if hist.empty: | |
| raise ValueError(f"No hay datos disponibles para {ticker}") | |
| total_points = len(hist) | |
| if total_points < 50: | |
| raise ValueError(f"Datos insuficientes para {ticker}") | |
| return gr.Slider.update( | |
| maximum=total_points, | |
| value=min(1000, total_points), | |
| minimum=50, | |
| step=1, | |
| interactive=True | |
| ) | |
| except Exception as e: | |
| print(f"Error al actualizar datos para {ticker}: {str(e)}") | |
| return gr.Slider.update(value=1000, maximum=5000, minimum=50, step=1) | |
| # Interfaz de Gradio | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Aplicaci贸n de Predicci贸n de Precios de Acciones") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ticker = gr.Dropdown( | |
| choices=get_popular_tickers(), | |
| value="AAPL", | |
| label="Selecciona el S铆mbolo de la Acci贸n", | |
| interactive=True | |
| ) | |
| with gr.Column(): | |
| train_data_points = gr.Slider( | |
| minimum=50, | |
| maximum=5000, | |
| value=1000, | |
| step=1, | |
| label="N煤mero de Datos para Entrenamiento", | |
| interactive=True | |
| ) | |
| prediction_days = gr.Slider( | |
| minimum=1, | |
| maximum=60, | |
| value=5, | |
| step=1, | |
| label="N煤mero de D铆as a Predecir", | |
| interactive=True | |
| ) | |
| predict_btn = gr.Button("Predecir", interactive=True) | |
| with gr.Column(): | |
| error_output = gr.Textbox(label="Estado", visible=False) | |
| plot_output = gr.Plot(label="Gr谩fico de Predicci贸n") | |
| download_btn = gr.File(label="Descargar Predicciones") | |
| # Eventos | |
| ticker.change( | |
| fn=update_train_data_points, | |
| inputs=[ticker], | |
| outputs=[train_data_points], | |
| api_name="update_data" | |
| ) | |
| predict_btn.click( | |
| fn=predict_stock, | |
| inputs=[ticker, train_data_points, prediction_days], | |
| outputs=[plot_output, download_btn] | |
| ) | |
| demo.launch() |