kimyechan
commited on
Commit
ยท
d24798b
0
Parent(s):
clean commit: only app.py, README.md, requirements.txt
Browse files- README.md +9 -0
- app.py +106 -0
- requirements.txt +6 -0
README.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Chronos Stock Forecast (Gradio on Hugging Face Spaces)
|
| 2 |
+
|
| 3 |
+
- Zero-shot time series forecasting with amazon/chronos-bolt-*
|
| 4 |
+
- UI: Gradio
|
| 5 |
+
- Data: yfinance Close prices
|
| 6 |
+
|
| 7 |
+
## Local run
|
| 8 |
+
pip install -r requirements.txt
|
| 9 |
+
python app.py
|
app.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import datetime as dt
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import yfinance as yf
|
| 7 |
+
|
| 8 |
+
from chronos import BaseChronosPipeline # from 'chronos-forecasting'
|
| 9 |
+
|
| 10 |
+
# ---- ์ ์ญ ์บ์: ๋ชจ๋ธ์ ํ ๋ฒ๋ง ๋ก๋ํด ์ฌ์ฌ์ฉ ----
|
| 11 |
+
_PIPELINE_CACHE = {}
|
| 12 |
+
|
| 13 |
+
def get_pipeline(model_id: str, device: str = "cpu"):
|
| 14 |
+
key = (model_id, device)
|
| 15 |
+
if key not in _PIPELINE_CACHE:
|
| 16 |
+
_PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained(
|
| 17 |
+
model_id,
|
| 18 |
+
device_map=device, # "cpu" / "cuda" (Spaces ๊ธฐ๋ณธ์ cpu)
|
| 19 |
+
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
|
| 20 |
+
)
|
| 21 |
+
return _PIPELINE_CACHE[key]
|
| 22 |
+
|
| 23 |
+
# ---- ์ฃผ๊ฐ ๋ฐ์ดํฐ ๋ก๋ฉ (yfinance) ----
|
| 24 |
+
def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
|
| 25 |
+
# ํ๊ตญ ์ฃผ์์ ์: 005930.KS (์ผ์ฑ์ ์)
|
| 26 |
+
df = yf.download(ticker, start=start, end=end, interval=interval, progress=False)
|
| 27 |
+
if df.empty or "Close" not in df:
|
| 28 |
+
raise ValueError("๋ฐ์ดํฐ๊ฐ ์๊ฑฐ๋ 'Close' ์ด์ ์ฐพ์ ์ ์์ต๋๋ค. ํฐ์ปค/๋ ์ง๋ฅผ ํ์ธํ์ธ์.")
|
| 29 |
+
s = df["Close"].dropna().astype(float)
|
| 30 |
+
return s
|
| 31 |
+
|
| 32 |
+
# ---- ์์ธก ํจ์ (Gradio๊ฐ ํธ์ถ) ----
|
| 33 |
+
def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval):
|
| 34 |
+
try:
|
| 35 |
+
series = load_close_series(ticker, start_date, end_date, interval)
|
| 36 |
+
except Exception as e:
|
| 37 |
+
return gr.Plot.update(None), pd.DataFrame(), f"๋ฐ์ดํฐ ๋ก๋ฉ ์ค๋ฅ: {e}"
|
| 38 |
+
|
| 39 |
+
pipe = get_pipeline(model_id, device)
|
| 40 |
+
H = int(horizon)
|
| 41 |
+
|
| 42 |
+
# Chronos ์
๋ ฅ: 1D ํ
์ (float)
|
| 43 |
+
context = torch.tensor(series.values, dtype=torch.float32)
|
| 44 |
+
|
| 45 |
+
# ์ถ๋ ฅ: (num_series=1, num_quantiles=3, H)
|
| 46 |
+
# ๋ณดํต q=[0.1, 0.5, 0.9]
|
| 47 |
+
preds = pipe.predict(context=context, prediction_length=H)[0]
|
| 48 |
+
q10, q50, q90 = preds[0], preds[1], preds[2]
|
| 49 |
+
|
| 50 |
+
# ํ ๋ฐ์ดํฐ
|
| 51 |
+
df_fcst = pd.DataFrame(
|
| 52 |
+
{"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()},
|
| 53 |
+
index=pd.RangeIndex(1, H + 1, name="step"),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# ๊ทธ๋ํ
|
| 57 |
+
import matplotlib.pyplot as plt
|
| 58 |
+
fig = plt.figure(figsize=(10, 4))
|
| 59 |
+
plt.plot(series.index, series.values, label="history")
|
| 60 |
+
# ๋ฏธ๋ ๊ตฌ๊ฐ x์ถ ๋ง๋ค๊ธฐ: ์ข
๊ฐ๊ฐ ์ผ ๋จ์๋ผ 'D' ์ฃผ๊ธฐ ์ฌ์ฉ
|
| 61 |
+
future_index = pd.date_range(series.index[-1], periods=H + 1, freq="D")[1:]
|
| 62 |
+
plt.plot(future_index, q50.numpy(), label="forecast(q50)")
|
| 63 |
+
plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10โq90")
|
| 64 |
+
plt.title(f"{ticker} forecast by Chronos-Bolt")
|
| 65 |
+
plt.legend()
|
| 66 |
+
plt.tight_layout()
|
| 67 |
+
|
| 68 |
+
note = "โป ๋ฐ๋ชจ ๋ชฉ์ ์
๋๋ค. ํฌ์ ํ๋จ์ ์ฑ
์์ ๋ณธ์ธ์๊ฒ ์์ต๋๋ค."
|
| 69 |
+
return fig, df_fcst, note
|
| 70 |
+
|
| 71 |
+
# ---- Gradio UI ----
|
| 72 |
+
with gr.Blocks(title="Chronos Stock Forecast") as demo:
|
| 73 |
+
gr.Markdown("# Chronos ์ฃผ๊ฐ ์์ธก ๋ฐ๋ชจ")
|
| 74 |
+
with gr.Row():
|
| 75 |
+
ticker = gr.Textbox(value="AAPL", label="ํฐ์ปค (์: AAPL, MSFT, 005930.KS)")
|
| 76 |
+
horizon = gr.Slider(5, 60, value=20, step=1, label="์์ธก ๊ธธ์ด H (์ผ)")
|
| 77 |
+
with gr.Row():
|
| 78 |
+
start = gr.Textbox(value=(dt.date.today()-dt.timedelta(days=365)).isoformat(), label="์์์ผ (YYYY-MM-DD)")
|
| 79 |
+
end = gr.Textbox(value=dt.date.today().isoformat(), label="์ข
๋ฃ์ผ (YYYY-MM-DD)")
|
| 80 |
+
with gr.Row():
|
| 81 |
+
model_id = gr.Dropdown(
|
| 82 |
+
choices=[
|
| 83 |
+
"amazon/chronos-bolt-tiny",
|
| 84 |
+
"amazon/chronos-bolt-mini",
|
| 85 |
+
"amazon/chronos-bolt-small",
|
| 86 |
+
"amazon/chronos-bolt-base",
|
| 87 |
+
],
|
| 88 |
+
value="amazon/chronos-bolt-small",
|
| 89 |
+
label="๋ชจ๋ธ"
|
| 90 |
+
)
|
| 91 |
+
device = gr.Dropdown(choices=["cpu"], value="cpu", label="๋๋ฐ์ด์ค")
|
| 92 |
+
interval = gr.Dropdown(choices=["1d"], value="1d", label="๊ฐ๊ฒฉ")
|
| 93 |
+
btn = gr.Button("์์ธก ์คํ")
|
| 94 |
+
|
| 95 |
+
plot = gr.Plot(label="History + Forecast")
|
| 96 |
+
table = gr.Dataframe(label="์์ธก ๊ฒฐ๊ณผ (๋ถ์์)")
|
| 97 |
+
note = gr.Markdown()
|
| 98 |
+
|
| 99 |
+
btn.click(
|
| 100 |
+
fn=run_forecast,
|
| 101 |
+
inputs=[ticker, start, end, horizon, model_id, device, interval],
|
| 102 |
+
outputs=[plot, table, note]
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.44
|
| 2 |
+
pandas>=2.2
|
| 3 |
+
yfinance>=0.2
|
| 4 |
+
torch>=2.2 ; platform_system != "Darwin"
|
| 5 |
+
chronos-forecasting>=1.0
|
| 6 |
+
matplotlib>=3.8
|