syedtamim2020 commited on
Commit
10b50ee
·
verified ·
1 Parent(s): 30be9a3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import tensorflow as tf
6
+ from sklearn.preprocessing import MinMaxScaler
7
+ import os
8
+
9
+ # --- Constants ---
10
+ TIME_STEP = 100
11
+ MODEL_PATH = 'stock_prediction_model.h5'
12
+
13
+ # --- Model Architecture Definition ---
14
+ def create_lstm_model():
15
+ """Defines the Bidirectional LSTM model architecture used for training."""
16
+ model = tf.keras.models.Sequential()
17
+ model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences=True), input_shape=(TIME_STEP, 1)))
18
+ model.add(tf.keras.layers.Dropout(0.3))
19
+ model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True)))
20
+ model.add(tf.keras.layers.Dropout(0.3))
21
+ model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)))
22
+ model.add(tf.keras.layers.Dense(64, activation='relu'))
23
+ model.add(tf.keras.layers.Dense(1))
24
+ model.compile(loss='mse', optimizer='adam')
25
+ return model
26
+
27
+ # --- Model Loading ---
28
+ try:
29
+ # Attempt to load the pre-trained model
30
+ model = tf.keras.models.load_model(MODEL_PATH)
31
+ print(f"Successfully loaded model from {MODEL_PATH}")
32
+ except Exception as e:
33
+ # If loading fails (e.g., in a local test where the file is missing),
34
+ # create a dummy model for structure checking, but warn the user.
35
+ print(f"Warning: Could not load {MODEL_PATH}. Error: {e}")
36
+ print("Initializing a dummy model. Please ensure your 'stock_prediction_model.h5' is uploaded.")
37
+ model = create_lstm_model() # Create architecture
38
+ # NOTE: The dummy model has random weights and will give poor predictions until the H5 file is provided.
39
+
40
+
41
+ # --- Prediction Logic Function ---
42
+ def forecast_stock(csv_file, days_to_predict=30):
43
+ """
44
+ Takes an uploaded CSV file containing stock data, extracts 'Close' prices,
45
+ and forecasts the next 'days_to_predict' using the loaded LSTM model.
46
+ """
47
+ if csv_file is None:
48
+ return None, "Error: Please upload a CSV file.", None
49
+
50
+ try:
51
+ # Load the uploaded data
52
+ df = pd.read_csv(csv_file.name)
53
+
54
+ # Ensure 'Close' column exists
55
+ if 'Close' not in df.columns:
56
+ return None, "Error: CSV must contain a 'Close' price column.", None
57
+
58
+ # Extract and scale the 'Close' prices (fit the scaler on the entire provided dataset)
59
+ ds_close = df.reset_index()['Close'].values.reshape(-1, 1)
60
+ scaler = MinMaxScaler(feature_range=(0, 1))
61
+ ds_close_scaled = scaler.fit_transform(ds_close)
62
+
63
+ # Get the last TIME_STEP (100) values to use as the initial input for forecasting
64
+ if len(ds_close_scaled) < TIME_STEP:
65
+ return None, f"Error: Dataset must contain at least {TIME_STEP} entries for initial prediction.", None
66
+
67
+ # The input data is the last 100 scaled data points
68
+ x_input = ds_close_scaled[-TIME_STEP:].reshape(1, -1)
69
+ temp_input = list(x_input[0])
70
+
71
+ lst_output = []
72
+ i = 0
73
+
74
+ # Iterative prediction loop (predict 1 day, append, use result for next prediction)
75
+ while i < days_to_predict:
76
+ if len(temp_input) > TIME_STEP:
77
+ # Get the last 100 steps
78
+ x_input = np.array(temp_input[1:])
79
+ x_input = x_input.reshape(1, TIME_STEP, 1)
80
+ temp_input = temp_input[1:]
81
+ else:
82
+ x_input = np.array(temp_input).reshape(1, TIME_STEP, 1)
83
+
84
+ # Predict the next step
85
+ yhat = model.predict(x_input, verbose=0)
86
+
87
+ # Append prediction to the input sequence and to the output list
88
+ temp_input.extend(yhat[0].tolist())
89
+ lst_output.extend(yhat.tolist())
90
+ i = i + 1
91
+
92
+ # Inverse transform the forecasted data to get actual prices
93
+ predicted_prices = scaler.inverse_transform(lst_output)
94
+
95
+ # Create a plot for visualization
96
+ plt.figure(figsize=(10, 6))
97
+
98
+ # Actual Data Plot
99
+ actual_prices = scaler.inverse_transform(ds_close_scaled)
100
+ day_actual = np.arange(len(actual_prices) - TIME_STEP, len(actual_prices))
101
+
102
+ plt.plot(day_actual, actual_prices[-TIME_STEP:], label='Last 100 Actual Days', color='blue')
103
+
104
+ # Predicted Data Plot
105
+ day_pred = np.arange(TIME_STEP, TIME_STEP + days_to_predict)
106
+ plt.plot(day_pred, predicted_prices, label=f'Forecasted {days_to_predict} Days', color='red', linestyle='--')
107
+
108
+ # Connect the last actual point to the first predicted point
109
+ plt.plot([day_actual[-1], day_pred[0]], [actual_prices[-1], predicted_prices[0]], color='red', linestyle='--')
110
+
111
+ plt.title('Stock Price Forecast (LSTM)')
112
+ plt.xlabel('Days')
113
+ plt.ylabel('Close Price')
114
+ plt.legend()
115
+ plt.grid(True)
116
+ plot_output = plt
117
+
118
+ # Create a DataFrame for output
119
+ # Get the date of the last actual entry
120
+ last_date = pd.to_datetime(df.iloc[-1]['Date']) if 'Date' in df.columns else pd.to_datetime(df.index[-1])
121
+
122
+ # Generate future dates
123
+ future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=days_to_predict)
124
+
125
+ df_forecast = pd.DataFrame({
126
+ 'Date': future_dates.strftime('%Y-%m-%d'),
127
+ 'Forecasted Price': np.round(predicted_prices.flatten(), 2)
128
+ })
129
+
130
+ return plot_output, "Forecast successful!", df_forecast
131
+
132
+ except Exception as e:
133
+ return None, f"An unexpected error occurred: {e}", None
134
+
135
+
136
+ # --- Gradio Interface ---
137
+
138
+ # Define the inputs
139
+ input_csv = gr.File(label="1. Upload Historical Stock CSV (must include 'Date' and 'Close' columns)")
140
+ input_days = gr.Slider(minimum=10, maximum=180, value=30, step=1, label="2. Number of days to forecast")
141
+
142
+ # Define the outputs
143
+ output_plot = gr.Plot(label="Price Forecast Visualization")
144
+ output_message = gr.Textbox(label="Status / Notes", value="Waiting for file upload...")
145
+ output_df = gr.Dataframe(label="Forecasted Prices Table")
146
+
147
+ # Create the Gradio Interface
148
+ iface = gr.Interface(
149
+ fn=forecast_stock,
150
+ inputs=[input_csv, input_days],
151
+ outputs=[output_plot, output_message, output_df],
152
+ title="LSTM Stock Price Prediction",
153
+ description="Upload a CSV file of historical stock data and use the pre-trained Bidirectional LSTM model to forecast future closing prices. The model requires the latest 100 data points to make the initial forecast.",
154
+ allow_flagging='never'
155
+ )
156
+
157
+ # Launch the app for local testing (Hugging Face Spaces will ignore this and use 'iface.launch()' internally)
158
+ if __name__ == "__main__":
159
+ iface.launch()