Chronos / app.py
kimyechan
fix: ๋น„ํŠธ์ฝ”์ธ ๋กœ์ง ์ˆ˜์ •
0b929da
raw
history blame
5.24 kB
import os
import datetime as dt
import pandas as pd
import torch
import gradio as gr
import yfinance as yf
from chronos import BaseChronosPipeline # from 'chronos-forecasting'
# ---- ์ „์—ญ ์บ์‹œ: ๋ชจ๋ธ์„ ํ•œ ๋ฒˆ๋งŒ ๋กœ๋“œํ•ด ์žฌ์‚ฌ์šฉ ----
_PIPELINE_CACHE = {}
def get_pipeline(model_id: str, device: str = "cpu"):
key = (model_id, device)
if key not in _PIPELINE_CACHE:
_PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained(
model_id,
device_map=device, # "cpu" / "cuda"
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
)
return _PIPELINE_CACHE[key]
# ---- ์ฃผ๊ฐ€/ํฌ๋ฆฝํ†  ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ (yfinance, ๊ฒฌ๊ณ ํ™”) ----
def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
"""
BTC-USD ๋“ฑ ํฌ๋ฆฝํ†  ์‹ฌ๋ณผ์—์„œ ๊ฐ„ํ—์ ์œผ๋กœ timezone/ํŒŒ์‹ฑ ์˜ค๋ฅ˜๊ฐ€ ๋‚˜๋ฏ€๋กœ
history() ๊ฒฝ๋กœ๋ฅผ ์šฐ์„  ์‚ฌ์šฉํ•˜๊ณ , ์‹คํŒจ ์‹œ ํ•œ ๋ฒˆ ์žฌ์‹œ๋„.
"""
# ๊ธฐ๋ณธ๊ฐ’ ๋ณด์ •: ๋„ˆ๋ฌด ์ตœ๊ทผ๋งŒ ๊ณ ๋ฅด๋ฉด ๋นˆ ๋ฐ์ดํ„ฐ๊ฐ€ ๋‚˜์˜ฌ ์ˆ˜ ์žˆ์–ด ์ผ๋ด‰์€ ๊ณผ๊ฑฐ๋ถ€ํ„ฐ ๊ถŒ์žฅ
_start = start or "2014-09-17" # BTC-USD ์ตœ์ดˆ ์ƒ์žฅ์ผ ๊ทผ์ฒ˜
_end = end or dt.date.today().isoformat()
tk = yf.Ticker(ticker)
try:
df = tk.history(start=_start, end=_end, interval=interval, auto_adjust=True, actions=False)
if df.empty or "Close" not in df:
raise ValueError("empty")
except Exception:
# fallback: download() ๊ฒฝ๋กœ ์‹œ๋„
df = yf.download(ticker, start=_start, end=_end, interval=interval, progress=False, threads=False)
if df.empty or "Close" not in df:
raise ValueError("๋ฐ์ดํ„ฐ๊ฐ€ ์—†๊ฑฐ๋‚˜ 'Close' ์—ด์ด ์—†์Šต๋‹ˆ๋‹ค. ํ‹ฐ์ปค/๋‚ ์งœ/๊ฐ„๊ฒฉ์„ ํ™•์ธํ•˜์„ธ์š”.")
s = df["Close"].dropna().astype(float)
if s.empty:
raise ValueError("๋‹ค์šด๋กœ๋“œ ๊ฒฐ๊ณผ๊ฐ€ ๋น„์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ๊ธฐ๊ฐ„/๊ฐ„๊ฒฉ์„ ์ค„์ด๊ฑฐ๋‚˜ ๋‹ค์‹œ ์‹œ๋„ํ•˜์„ธ์š”.")
return s
# ---- ์˜ˆ์ธก ํ•จ์ˆ˜ (Gradio๊ฐ€ ํ˜ธ์ถœ) ----
def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval):
try:
series = load_close_series(ticker.strip(), start_date, end_date, interval)
except Exception as e:
# Gradio v4์—์„œ๋Š” Plot.update๊ฐ€ ์—†์Œ โ†’ None ๋ฐ˜ํ™˜์œผ๋กœ ์ •๋ฆฌ
return None, pd.DataFrame(), f"๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ์˜ค๋ฅ˜: {e}"
pipe = get_pipeline(model_id, device)
H = int(horizon)
# Chronos ์ž…๋ ฅ: 1D ํ…์„œ (float)
context = torch.tensor(series.values, dtype=torch.float32)
# ์˜ˆ์ธก: (num_series=1, num_quantiles=3, H) with q=[0.1, 0.5, 0.9]
preds = pipe.predict(context=context, prediction_length=H)[0]
q10, q50, q90 = preds[0], preds[1], preds[2]
# ํ‘œ ๋ฐ์ดํ„ฐ
df_fcst = pd.DataFrame(
{"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()},
index=pd.RangeIndex(1, H + 1, name="step"),
)
# ๋ฏธ๋ž˜ x์ถ•: interval์— ๋งž๋Š” pandas ์ฃผ๊ธฐ
import matplotlib.pyplot as plt
freq_map = {"1d": "D", "1h": "H", "30m": "30T", "15m": "15T", "5m": "5T"}
freq = freq_map.get(interval, "D")
future_index = pd.date_range(series.index[-1], periods=H + 1, freq=freq)[1:]
# ๊ทธ๋ž˜ํ”„
fig = plt.figure(figsize=(10, 4))
plt.plot(series.index, series.values, label="history")
plt.plot(future_index, q50.numpy(), label="forecast(q50)")
plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10โ€“q90")
plt.title(f"{ticker} forecast by Chronos-Bolt ({interval}, H={H})")
plt.legend()
plt.tight_layout()
note = "โ€ป ๋ฐ๋ชจ ๋ชฉ์ ์ž…๋‹ˆ๋‹ค. ํˆฌ์ž ํŒ๋‹จ์˜ ์ฑ…์ž„์€ ๋ณธ์ธ์—๊ฒŒ ์žˆ์Šต๋‹ˆ๋‹ค."
return fig, df_fcst, note
# ---- Gradio UI ----
with gr.Blocks(title="Chronos Stock/Crypto Forecast") as demo:
gr.Markdown("# Chronos ์ฃผ๊ฐ€ยทํฌ๋ฆฝํ†  ์˜ˆ์ธก ๋ฐ๋ชจ")
with gr.Row():
ticker = gr.Textbox(value="BTC-USD", label="ํ‹ฐ์ปค (์˜ˆ: AAPL, MSFT, 005930.KS, BTC-USD)")
horizon = gr.Slider(5, 365, value=90, step=1, label="์˜ˆ์ธก ์Šคํ… H (๊ฐ„๊ฒฉ ๋‹จ์œ„์™€ ๋™์ผ)")
with gr.Row():
start = gr.Textbox(value="2014-09-17", label="์‹œ์ž‘์ผ (YYYY-MM-DD, ์˜ˆ: 2014-09-17)")
end = gr.Textbox(value=dt.date.today().isoformat(), label="์ข…๋ฃŒ์ผ (YYYY-MM-DD, ๋น„์›Œ๋‘๋ฉด ์˜ค๋Š˜)")
with gr.Row():
model_id = gr.Dropdown(
choices=[
"amazon/chronos-bolt-tiny",
"amazon/chronos-bolt-mini",
"amazon/chronos-bolt-small",
"amazon/chronos-bolt-base",
],
value="amazon/chronos-bolt-small",
label="๋ชจ๋ธ"
)
device = gr.Dropdown(choices=["cpu"], value="cpu", label="๋””๋ฐ”์ด์Šค")
interval = gr.Dropdown(
choices=["1d", "1h", "30m", "15m", "5m"],
value="1d",
label="๊ฐ„๊ฒฉ"
)
btn = gr.Button("์˜ˆ์ธก ์‹คํ–‰")
plot = gr.Plot(label="History + Forecast")
table = gr.Dataframe(label="์˜ˆ์ธก ๊ฒฐ๊ณผ (๋ถ„์œ„์ˆ˜)")
note = gr.Markdown()
btn.click(
fn=run_forecast,
inputs=[ticker, start, end, horizon, model_id, device, interval],
outputs=[plot, table, note]
)
if __name__ == "__main__":
demo.launch()