|
|
import datetime as dt |
|
|
import pandas as pd |
|
|
import torch |
|
|
import gradio as gr |
|
|
import yfinance as yf |
|
|
|
|
|
from chronos import BaseChronosPipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_PIPELINE_CACHE = {} |
|
|
|
|
|
def get_pipeline(model_id: str, device: str = "cpu"): |
|
|
key = (model_id, device) |
|
|
if key not in _PIPELINE_CACHE: |
|
|
_PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained( |
|
|
model_id, |
|
|
device_map=device, |
|
|
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16, |
|
|
) |
|
|
return _PIPELINE_CACHE[key] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"): |
|
|
""" |
|
|
yfinance๋ง ์ฌ์ฉ. |
|
|
1) Ticker().history(start/end) |
|
|
2) download(start/end, repair=True) |
|
|
3) period ๊ธฐ๋ฐ ํด๋ฐฑ: |
|
|
- 1d โ ["max", "10y", "5y", "2y", "1y"] |
|
|
- 1h โ ["730d", "365d", "60d"] |
|
|
- 30m/15m/5m โ ["60d", "30d", "14d"] |
|
|
""" |
|
|
ticker = ticker.strip().upper() |
|
|
|
|
|
|
|
|
_start = start or "2014-09-17" |
|
|
_end = end or dt.date.today().isoformat() |
|
|
try: |
|
|
sdt = pd.to_datetime(_start) |
|
|
edt = pd.to_datetime(_end) |
|
|
if edt < sdt: |
|
|
sdt, edt = edt, sdt |
|
|
_start, _end = sdt.date().isoformat(), edt.date().isoformat() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def _extract_close(df): |
|
|
if df is None or df.empty: |
|
|
return None |
|
|
c = df.get("Close") |
|
|
if c is None: |
|
|
return None |
|
|
c = c.dropna().astype(float) |
|
|
return c if not c.empty else None |
|
|
|
|
|
|
|
|
try: |
|
|
tk = yf.Ticker(ticker) |
|
|
df = tk.history(start=_start, end=_end, interval=interval, auto_adjust=True, actions=False) |
|
|
s = _extract_close(df) |
|
|
if s is not None: |
|
|
return s |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
df = yf.download( |
|
|
ticker, start=_start, end=_end, interval=interval, |
|
|
progress=False, threads=False, repair=True |
|
|
) |
|
|
s = _extract_close(df) |
|
|
if s is not None: |
|
|
return s |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
if interval == "1d": |
|
|
period_candidates = ["max", "10y", "5y", "2y", "1y"] |
|
|
elif interval == "1h": |
|
|
period_candidates = ["730d", "365d", "60d"] |
|
|
else: |
|
|
period_candidates = ["60d", "30d", "14d"] |
|
|
|
|
|
for per in period_candidates: |
|
|
|
|
|
try: |
|
|
df = tk.history(period=per, interval=interval, auto_adjust=True, actions=False) |
|
|
s = _extract_close(df) |
|
|
if s is not None: |
|
|
return s |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
try: |
|
|
df = yf.download( |
|
|
ticker, period=per, interval=interval, |
|
|
progress=False, threads=False, repair=True |
|
|
) |
|
|
s = _extract_close(df) |
|
|
if s is not None: |
|
|
return s |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
raise ValueError( |
|
|
"yfinance์์ ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ ธ์ค์ง ๋ชปํ์ต๋๋ค. " |
|
|
"๊ฐ๊ฒฉ(interval)์ด๋ ๊ธฐ๊ฐ(start/end ํน์ period)์ ์กฐ์ ํด ๋ค์ ์๋ํด ๋ณด์ธ์." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval): |
|
|
try: |
|
|
series = load_close_series(ticker, start_date, end_date, interval) |
|
|
except Exception as e: |
|
|
|
|
|
return None, pd.DataFrame(), f"๋ฐ์ดํฐ ๋ก๋ฉ ์ค๋ฅ: {e}" |
|
|
|
|
|
pipe = get_pipeline(model_id, device) |
|
|
H = int(horizon) |
|
|
|
|
|
|
|
|
context = torch.tensor(series.values, dtype=torch.float32) |
|
|
|
|
|
|
|
|
preds = pipe.predict(context=context, prediction_length=H)[0] |
|
|
q10, q50, q90 = preds[0], preds[1], preds[2] |
|
|
|
|
|
|
|
|
df_fcst = pd.DataFrame( |
|
|
{"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()}, |
|
|
index=pd.RangeIndex(1, H + 1, name="step"), |
|
|
) |
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
freq_map = {"1d": "D", "1h": "H", "30m": "30T", "15m": "15T", "5m": "5T"} |
|
|
freq = freq_map.get(interval, "D") |
|
|
future_index = pd.date_range(series.index[-1], periods=H + 1, freq=freq)[1:] |
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(10, 4)) |
|
|
plt.plot(series.index, series.values, label="history") |
|
|
plt.plot(future_index, q50.numpy(), label="forecast(q50)") |
|
|
plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10โq90") |
|
|
plt.title(f"{ticker} forecast by Chronos-Bolt ({interval}, H={H})") |
|
|
plt.legend() |
|
|
plt.tight_layout() |
|
|
|
|
|
note = "โป ๋ฐ๋ชจ ๋ชฉ์ ์
๋๋ค. ํฌ์ ํ๋จ์ ์ฑ
์์ ๋ณธ์ธ์๊ฒ ์์ต๋๋ค." |
|
|
return fig, df_fcst, note |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Chronos Stock/Crypto Forecast") as demo: |
|
|
gr.Markdown("# Chronos ์ฃผ๊ฐยทํฌ๋ฆฝํ ์์ธก ๋ฐ๋ชจ") |
|
|
with gr.Row(): |
|
|
ticker = gr.Textbox(value="BTC-USD", label="ํฐ์ปค (์: AAPL, MSFT, 005930.KS, BTC-USD)") |
|
|
horizon = gr.Slider(5, 365, value=90, step=1, label="์์ธก ์คํ
H (๊ฐ๊ฒฉ ๋จ์์ ๋์ผ)") |
|
|
with gr.Row(): |
|
|
start = gr.Textbox(value="2014-09-17", label="์์์ผ (YYYY-MM-DD)") |
|
|
end = gr.Textbox(value=dt.date.today().isoformat(), label="์ข
๋ฃ์ผ (YYYY-MM-DD)") |
|
|
with gr.Row(): |
|
|
model_id = gr.Dropdown( |
|
|
choices=[ |
|
|
"amazon/chronos-bolt-tiny", |
|
|
"amazon/chronos-bolt-mini", |
|
|
"amazon/chronos-bolt-small", |
|
|
"amazon/chronos-bolt-base", |
|
|
], |
|
|
value="amazon/chronos-bolt-small", |
|
|
label="๋ชจ๋ธ" |
|
|
) |
|
|
device = gr.Dropdown(choices=["cpu"], value="cpu", label="๋๋ฐ์ด์ค") |
|
|
interval = gr.Dropdown( |
|
|
choices=["1d", "1h", "30m", "15m", "5m"], |
|
|
value="1d", |
|
|
label="๊ฐ๊ฒฉ" |
|
|
) |
|
|
btn = gr.Button("์์ธก ์คํ") |
|
|
|
|
|
plot = gr.Plot(label="History + Forecast") |
|
|
table = gr.Dataframe(label="์์ธก ๊ฒฐ๊ณผ (๋ถ์์)") |
|
|
note = gr.Markdown() |
|
|
|
|
|
btn.click( |
|
|
fn=run_forecast, |
|
|
inputs=[ticker, start, end, horizon, model_id, device, interval], |
|
|
outputs=[plot, table, note] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|