syedtamim2020's picture
Create app.py
10b50ee verified
raw
history blame
6.66 kB
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
import os
# --- Constants ---
TIME_STEP = 100
MODEL_PATH = 'stock_prediction_model.h5'
# --- Model Architecture Definition ---
def create_lstm_model():
"""Defines the Bidirectional LSTM model architecture used for training."""
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences=True), input_shape=(TIME_STEP, 1)))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True)))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(1))
model.compile(loss='mse', optimizer='adam')
return model
# --- Model Loading ---
try:
# Attempt to load the pre-trained model
model = tf.keras.models.load_model(MODEL_PATH)
print(f"Successfully loaded model from {MODEL_PATH}")
except Exception as e:
# If loading fails (e.g., in a local test where the file is missing),
# create a dummy model for structure checking, but warn the user.
print(f"Warning: Could not load {MODEL_PATH}. Error: {e}")
print("Initializing a dummy model. Please ensure your 'stock_prediction_model.h5' is uploaded.")
model = create_lstm_model() # Create architecture
# NOTE: The dummy model has random weights and will give poor predictions until the H5 file is provided.
# --- Prediction Logic Function ---
def forecast_stock(csv_file, days_to_predict=30):
"""
Takes an uploaded CSV file containing stock data, extracts 'Close' prices,
and forecasts the next 'days_to_predict' using the loaded LSTM model.
"""
if csv_file is None:
return None, "Error: Please upload a CSV file.", None
try:
# Load the uploaded data
df = pd.read_csv(csv_file.name)
# Ensure 'Close' column exists
if 'Close' not in df.columns:
return None, "Error: CSV must contain a 'Close' price column.", None
# Extract and scale the 'Close' prices (fit the scaler on the entire provided dataset)
ds_close = df.reset_index()['Close'].values.reshape(-1, 1)
scaler = MinMaxScaler(feature_range=(0, 1))
ds_close_scaled = scaler.fit_transform(ds_close)
# Get the last TIME_STEP (100) values to use as the initial input for forecasting
if len(ds_close_scaled) < TIME_STEP:
return None, f"Error: Dataset must contain at least {TIME_STEP} entries for initial prediction.", None
# The input data is the last 100 scaled data points
x_input = ds_close_scaled[-TIME_STEP:].reshape(1, -1)
temp_input = list(x_input[0])
lst_output = []
i = 0
# Iterative prediction loop (predict 1 day, append, use result for next prediction)
while i < days_to_predict:
if len(temp_input) > TIME_STEP:
# Get the last 100 steps
x_input = np.array(temp_input[1:])
x_input = x_input.reshape(1, TIME_STEP, 1)
temp_input = temp_input[1:]
else:
x_input = np.array(temp_input).reshape(1, TIME_STEP, 1)
# Predict the next step
yhat = model.predict(x_input, verbose=0)
# Append prediction to the input sequence and to the output list
temp_input.extend(yhat[0].tolist())
lst_output.extend(yhat.tolist())
i = i + 1
# Inverse transform the forecasted data to get actual prices
predicted_prices = scaler.inverse_transform(lst_output)
# Create a plot for visualization
plt.figure(figsize=(10, 6))
# Actual Data Plot
actual_prices = scaler.inverse_transform(ds_close_scaled)
day_actual = np.arange(len(actual_prices) - TIME_STEP, len(actual_prices))
plt.plot(day_actual, actual_prices[-TIME_STEP:], label='Last 100 Actual Days', color='blue')
# Predicted Data Plot
day_pred = np.arange(TIME_STEP, TIME_STEP + days_to_predict)
plt.plot(day_pred, predicted_prices, label=f'Forecasted {days_to_predict} Days', color='red', linestyle='--')
# Connect the last actual point to the first predicted point
plt.plot([day_actual[-1], day_pred[0]], [actual_prices[-1], predicted_prices[0]], color='red', linestyle='--')
plt.title('Stock Price Forecast (LSTM)')
plt.xlabel('Days')
plt.ylabel('Close Price')
plt.legend()
plt.grid(True)
plot_output = plt
# Create a DataFrame for output
# Get the date of the last actual entry
last_date = pd.to_datetime(df.iloc[-1]['Date']) if 'Date' in df.columns else pd.to_datetime(df.index[-1])
# Generate future dates
future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=days_to_predict)
df_forecast = pd.DataFrame({
'Date': future_dates.strftime('%Y-%m-%d'),
'Forecasted Price': np.round(predicted_prices.flatten(), 2)
})
return plot_output, "Forecast successful!", df_forecast
except Exception as e:
return None, f"An unexpected error occurred: {e}", None
# --- Gradio Interface ---
# Define the inputs
input_csv = gr.File(label="1. Upload Historical Stock CSV (must include 'Date' and 'Close' columns)")
input_days = gr.Slider(minimum=10, maximum=180, value=30, step=1, label="2. Number of days to forecast")
# Define the outputs
output_plot = gr.Plot(label="Price Forecast Visualization")
output_message = gr.Textbox(label="Status / Notes", value="Waiting for file upload...")
output_df = gr.Dataframe(label="Forecasted Prices Table")
# Create the Gradio Interface
iface = gr.Interface(
fn=forecast_stock,
inputs=[input_csv, input_days],
outputs=[output_plot, output_message, output_df],
title="LSTM Stock Price Prediction",
description="Upload a CSV file of historical stock data and use the pre-trained Bidirectional LSTM model to forecast future closing prices. The model requires the latest 100 data points to make the initial forecast.",
allow_flagging='never'
)
# Launch the app for local testing (Hugging Face Spaces will ignore this and use 'iface.launch()' internally)
if __name__ == "__main__":
iface.launch()