kimyechan commited on
Commit
d24798b
ยท
0 Parent(s):

clean commit: only app.py, README.md, requirements.txt

Browse files
Files changed (3) hide show
  1. README.md +9 -0
  2. app.py +106 -0
  3. requirements.txt +6 -0
README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Chronos Stock Forecast (Gradio on Hugging Face Spaces)
2
+
3
+ - Zero-shot time series forecasting with amazon/chronos-bolt-*
4
+ - UI: Gradio
5
+ - Data: yfinance Close prices
6
+
7
+ ## Local run
8
+ pip install -r requirements.txt
9
+ python app.py
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datetime as dt
3
+ import pandas as pd
4
+ import torch
5
+ import gradio as gr
6
+ import yfinance as yf
7
+
8
+ from chronos import BaseChronosPipeline # from 'chronos-forecasting'
9
+
10
+ # ---- ์ „์—ญ ์บ์‹œ: ๋ชจ๋ธ์„ ํ•œ ๋ฒˆ๋งŒ ๋กœ๋“œํ•ด ์žฌ์‚ฌ์šฉ ----
11
+ _PIPELINE_CACHE = {}
12
+
13
+ def get_pipeline(model_id: str, device: str = "cpu"):
14
+ key = (model_id, device)
15
+ if key not in _PIPELINE_CACHE:
16
+ _PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained(
17
+ model_id,
18
+ device_map=device, # "cpu" / "cuda" (Spaces ๊ธฐ๋ณธ์€ cpu)
19
+ torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
20
+ )
21
+ return _PIPELINE_CACHE[key]
22
+
23
+ # ---- ์ฃผ๊ฐ€ ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ (yfinance) ----
24
+ def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
25
+ # ํ•œ๊ตญ ์ฃผ์‹์€ ์˜ˆ: 005930.KS (์‚ผ์„ฑ์ „์ž)
26
+ df = yf.download(ticker, start=start, end=end, interval=interval, progress=False)
27
+ if df.empty or "Close" not in df:
28
+ raise ValueError("๋ฐ์ดํ„ฐ๊ฐ€ ์—†๊ฑฐ๋‚˜ 'Close' ์—ด์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ‹ฐ์ปค/๋‚ ์งœ๋ฅผ ํ™•์ธํ•˜์„ธ์š”.")
29
+ s = df["Close"].dropna().astype(float)
30
+ return s
31
+
32
+ # ---- ์˜ˆ์ธก ํ•จ์ˆ˜ (Gradio๊ฐ€ ํ˜ธ์ถœ) ----
33
+ def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval):
34
+ try:
35
+ series = load_close_series(ticker, start_date, end_date, interval)
36
+ except Exception as e:
37
+ return gr.Plot.update(None), pd.DataFrame(), f"๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ์˜ค๋ฅ˜: {e}"
38
+
39
+ pipe = get_pipeline(model_id, device)
40
+ H = int(horizon)
41
+
42
+ # Chronos ์ž…๋ ฅ: 1D ํ…์„œ (float)
43
+ context = torch.tensor(series.values, dtype=torch.float32)
44
+
45
+ # ์ถœ๋ ฅ: (num_series=1, num_quantiles=3, H)
46
+ # ๋ณดํ†ต q=[0.1, 0.5, 0.9]
47
+ preds = pipe.predict(context=context, prediction_length=H)[0]
48
+ q10, q50, q90 = preds[0], preds[1], preds[2]
49
+
50
+ # ํ‘œ ๋ฐ์ดํ„ฐ
51
+ df_fcst = pd.DataFrame(
52
+ {"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()},
53
+ index=pd.RangeIndex(1, H + 1, name="step"),
54
+ )
55
+
56
+ # ๊ทธ๋ž˜ํ”„
57
+ import matplotlib.pyplot as plt
58
+ fig = plt.figure(figsize=(10, 4))
59
+ plt.plot(series.index, series.values, label="history")
60
+ # ๋ฏธ๋ž˜ ๊ตฌ๊ฐ„ x์ถ• ๋งŒ๋“ค๊ธฐ: ์ข…๊ฐ€๊ฐ€ ์ผ ๋‹จ์œ„๋ผ 'D' ์ฃผ๊ธฐ ์‚ฌ์šฉ
61
+ future_index = pd.date_range(series.index[-1], periods=H + 1, freq="D")[1:]
62
+ plt.plot(future_index, q50.numpy(), label="forecast(q50)")
63
+ plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10โ€“q90")
64
+ plt.title(f"{ticker} forecast by Chronos-Bolt")
65
+ plt.legend()
66
+ plt.tight_layout()
67
+
68
+ note = "โ€ป ๋ฐ๋ชจ ๋ชฉ์ ์ž…๋‹ˆ๋‹ค. ํˆฌ์ž ํŒ๋‹จ์˜ ์ฑ…์ž„์€ ๋ณธ์ธ์—๊ฒŒ ์žˆ์Šต๋‹ˆ๋‹ค."
69
+ return fig, df_fcst, note
70
+
71
+ # ---- Gradio UI ----
72
+ with gr.Blocks(title="Chronos Stock Forecast") as demo:
73
+ gr.Markdown("# Chronos ์ฃผ๊ฐ€ ์˜ˆ์ธก ๋ฐ๋ชจ")
74
+ with gr.Row():
75
+ ticker = gr.Textbox(value="AAPL", label="ํ‹ฐ์ปค (์˜ˆ: AAPL, MSFT, 005930.KS)")
76
+ horizon = gr.Slider(5, 60, value=20, step=1, label="์˜ˆ์ธก ๊ธธ์ด H (์ผ)")
77
+ with gr.Row():
78
+ start = gr.Textbox(value=(dt.date.today()-dt.timedelta(days=365)).isoformat(), label="์‹œ์ž‘์ผ (YYYY-MM-DD)")
79
+ end = gr.Textbox(value=dt.date.today().isoformat(), label="์ข…๋ฃŒ์ผ (YYYY-MM-DD)")
80
+ with gr.Row():
81
+ model_id = gr.Dropdown(
82
+ choices=[
83
+ "amazon/chronos-bolt-tiny",
84
+ "amazon/chronos-bolt-mini",
85
+ "amazon/chronos-bolt-small",
86
+ "amazon/chronos-bolt-base",
87
+ ],
88
+ value="amazon/chronos-bolt-small",
89
+ label="๋ชจ๋ธ"
90
+ )
91
+ device = gr.Dropdown(choices=["cpu"], value="cpu", label="๋””๋ฐ”์ด์Šค")
92
+ interval = gr.Dropdown(choices=["1d"], value="1d", label="๊ฐ„๊ฒฉ")
93
+ btn = gr.Button("์˜ˆ์ธก ์‹คํ–‰")
94
+
95
+ plot = gr.Plot(label="History + Forecast")
96
+ table = gr.Dataframe(label="์˜ˆ์ธก ๊ฒฐ๊ณผ (๋ถ„์œ„์ˆ˜)")
97
+ note = gr.Markdown()
98
+
99
+ btn.click(
100
+ fn=run_forecast,
101
+ inputs=[ticker, start, end, horizon, model_id, device, interval],
102
+ outputs=[plot, table, note]
103
+ )
104
+
105
+ if __name__ == "__main__":
106
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.44
2
+ pandas>=2.2
3
+ yfinance>=0.2
4
+ torch>=2.2 ; platform_system != "Darwin"
5
+ chronos-forecasting>=1.0
6
+ matplotlib>=3.8