Chronos / app.py
kimyechan
clean commit: only app.py, README.md, requirements.txt
d24798b
raw
history blame
4.13 kB
import os
import datetime as dt
import pandas as pd
import torch
import gradio as gr
import yfinance as yf
from chronos import BaseChronosPipeline # from 'chronos-forecasting'
# ---- ์ „์—ญ ์บ์‹œ: ๋ชจ๋ธ์„ ํ•œ ๋ฒˆ๋งŒ ๋กœ๋“œํ•ด ์žฌ์‚ฌ์šฉ ----
_PIPELINE_CACHE = {}
def get_pipeline(model_id: str, device: str = "cpu"):
key = (model_id, device)
if key not in _PIPELINE_CACHE:
_PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained(
model_id,
device_map=device, # "cpu" / "cuda" (Spaces ๊ธฐ๋ณธ์€ cpu)
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
)
return _PIPELINE_CACHE[key]
# ---- ์ฃผ๊ฐ€ ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ (yfinance) ----
def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
# ํ•œ๊ตญ ์ฃผ์‹์€ ์˜ˆ: 005930.KS (์‚ผ์„ฑ์ „์ž)
df = yf.download(ticker, start=start, end=end, interval=interval, progress=False)
if df.empty or "Close" not in df:
raise ValueError("๋ฐ์ดํ„ฐ๊ฐ€ ์—†๊ฑฐ๋‚˜ 'Close' ์—ด์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ‹ฐ์ปค/๋‚ ์งœ๋ฅผ ํ™•์ธํ•˜์„ธ์š”.")
s = df["Close"].dropna().astype(float)
return s
# ---- ์˜ˆ์ธก ํ•จ์ˆ˜ (Gradio๊ฐ€ ํ˜ธ์ถœ) ----
def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval):
try:
series = load_close_series(ticker, start_date, end_date, interval)
except Exception as e:
return gr.Plot.update(None), pd.DataFrame(), f"๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ์˜ค๋ฅ˜: {e}"
pipe = get_pipeline(model_id, device)
H = int(horizon)
# Chronos ์ž…๋ ฅ: 1D ํ…์„œ (float)
context = torch.tensor(series.values, dtype=torch.float32)
# ์ถœ๋ ฅ: (num_series=1, num_quantiles=3, H)
# ๋ณดํ†ต q=[0.1, 0.5, 0.9]
preds = pipe.predict(context=context, prediction_length=H)[0]
q10, q50, q90 = preds[0], preds[1], preds[2]
# ํ‘œ ๋ฐ์ดํ„ฐ
df_fcst = pd.DataFrame(
{"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()},
index=pd.RangeIndex(1, H + 1, name="step"),
)
# ๊ทธ๋ž˜ํ”„
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 4))
plt.plot(series.index, series.values, label="history")
# ๋ฏธ๋ž˜ ๊ตฌ๊ฐ„ x์ถ• ๋งŒ๋“ค๊ธฐ: ์ข…๊ฐ€๊ฐ€ ์ผ ๋‹จ์œ„๋ผ 'D' ์ฃผ๊ธฐ ์‚ฌ์šฉ
future_index = pd.date_range(series.index[-1], periods=H + 1, freq="D")[1:]
plt.plot(future_index, q50.numpy(), label="forecast(q50)")
plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10โ€“q90")
plt.title(f"{ticker} forecast by Chronos-Bolt")
plt.legend()
plt.tight_layout()
note = "โ€ป ๋ฐ๋ชจ ๋ชฉ์ ์ž…๋‹ˆ๋‹ค. ํˆฌ์ž ํŒ๋‹จ์˜ ์ฑ…์ž„์€ ๋ณธ์ธ์—๊ฒŒ ์žˆ์Šต๋‹ˆ๋‹ค."
return fig, df_fcst, note
# ---- Gradio UI ----
with gr.Blocks(title="Chronos Stock Forecast") as demo:
gr.Markdown("# Chronos ์ฃผ๊ฐ€ ์˜ˆ์ธก ๋ฐ๋ชจ")
with gr.Row():
ticker = gr.Textbox(value="AAPL", label="ํ‹ฐ์ปค (์˜ˆ: AAPL, MSFT, 005930.KS)")
horizon = gr.Slider(5, 60, value=20, step=1, label="์˜ˆ์ธก ๊ธธ์ด H (์ผ)")
with gr.Row():
start = gr.Textbox(value=(dt.date.today()-dt.timedelta(days=365)).isoformat(), label="์‹œ์ž‘์ผ (YYYY-MM-DD)")
end = gr.Textbox(value=dt.date.today().isoformat(), label="์ข…๋ฃŒ์ผ (YYYY-MM-DD)")
with gr.Row():
model_id = gr.Dropdown(
choices=[
"amazon/chronos-bolt-tiny",
"amazon/chronos-bolt-mini",
"amazon/chronos-bolt-small",
"amazon/chronos-bolt-base",
],
value="amazon/chronos-bolt-small",
label="๋ชจ๋ธ"
)
device = gr.Dropdown(choices=["cpu"], value="cpu", label="๋””๋ฐ”์ด์Šค")
interval = gr.Dropdown(choices=["1d"], value="1d", label="๊ฐ„๊ฒฉ")
btn = gr.Button("์˜ˆ์ธก ์‹คํ–‰")
plot = gr.Plot(label="History + Forecast")
table = gr.Dataframe(label="์˜ˆ์ธก ๊ฒฐ๊ณผ (๋ถ„์œ„์ˆ˜)")
note = gr.Markdown()
btn.click(
fn=run_forecast,
inputs=[ticker, start, end, horizon, model_id, device, interval],
outputs=[plot, table, note]
)
if __name__ == "__main__":
demo.launch()