import datetime as dt import pandas as pd import torch import gradio as gr import yfinance as yf from chronos import BaseChronosPipeline # pip: chronos-forecasting # ============================= # Chronos 모델 캐시/로더 # ============================= _PIPELINE_CACHE = {} def get_pipeline(model_id: str, device: str = "cpu"): key = (model_id, device) if key not in _PIPELINE_CACHE: _PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained( model_id, device_map=device, # "cpu" / "cuda" torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16, ) return _PIPELINE_CACHE[key] # ============================= # yfinance 전용 견고 로더 # ============================= def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"): """ yfinance만 사용. 1) Ticker().history(start/end) 2) download(start/end, repair=True) 3) period 기반 폴백: - 1d → ["max", "10y", "5y", "2y", "1y"] - 1h → ["730d", "365d", "60d"] - 30m/15m/5m → ["60d", "30d", "14d"] """ ticker = ticker.strip().upper() # 날짜 보정 _start = start or "2014-09-17" # BTC-USD 히스토리 시작 근처 _end = end or dt.date.today().isoformat() try: sdt = pd.to_datetime(_start) edt = pd.to_datetime(_end) if edt < sdt: sdt, edt = edt, sdt # 뒤바뀐 경우 교환 _start, _end = sdt.date().isoformat(), edt.date().isoformat() except Exception: pass # 파싱 실패해도 밑의 period 폴백이 커버 def _extract_close(df): if df is None or df.empty: return None c = df.get("Close") if c is None: return None c = c.dropna().astype(float) return c if not c.empty else None # 1) history(start/end) try: tk = yf.Ticker(ticker) df = tk.history(start=_start, end=_end, interval=interval, auto_adjust=True, actions=False) s = _extract_close(df) if s is not None: return s except Exception: pass # 2) download(start/end) + repair=True try: df = yf.download( ticker, start=_start, end=_end, interval=interval, progress=False, threads=False, repair=True ) s = _extract_close(df) if s is not None: return s except Exception: pass # 3) period 폴백 if interval == "1d": period_candidates = ["max", "10y", "5y", "2y", "1y"] elif interval == "1h": period_candidates = ["730d", "365d", "60d"] # 1시간봉은 과거 제한 큼 else: # 30m/15m/5m 등 분봉 period_candidates = ["60d", "30d", "14d"] # 분봉은 보통 60~30일 이내만 가능 for per in period_candidates: # Ticker().history(period=…) try: df = tk.history(period=per, interval=interval, auto_adjust=True, actions=False) s = _extract_close(df) if s is not None: return s except Exception: pass # download(period=…) try: df = yf.download( ticker, period=per, interval=interval, progress=False, threads=False, repair=True ) s = _extract_close(df) if s is not None: return s except Exception: pass raise ValueError( "yfinance에서 데이터를 가져오지 못했습니다. " "간격(interval)이나 기간(start/end 혹은 period)을 조정해 다시 시도해 보세요." ) # ============================= # 예측 함수 (Gradio 핸들러) # ============================= def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval): try: series = load_close_series(ticker, start_date, end_date, interval) except Exception as e: # Gradio v4: Plot.update 없음 → None 반환 return None, pd.DataFrame(), f"데이터 로딩 오류: {e}" pipe = get_pipeline(model_id, device) H = int(horizon) # Chronos 입력: 1D 텐서 (float) context = torch.tensor(series.values, dtype=torch.float32) # 예측: (num_series=1, num_quantiles=3, H) with q=[0.1, 0.5, 0.9] preds = pipe.predict(context=context, prediction_length=H)[0] q10, q50, q90 = preds[0], preds[1], preds[2] # 표 데이터 df_fcst = pd.DataFrame( {"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()}, index=pd.RangeIndex(1, H + 1, name="step"), ) # 미래 x축: interval→pandas freq 매핑 import matplotlib.pyplot as plt freq_map = {"1d": "D", "1h": "H", "30m": "30T", "15m": "15T", "5m": "5T"} freq = freq_map.get(interval, "D") future_index = pd.date_range(series.index[-1], periods=H + 1, freq=freq)[1:] # 그래프 fig = plt.figure(figsize=(10, 4)) plt.plot(series.index, series.values, label="history") plt.plot(future_index, q50.numpy(), label="forecast(q50)") plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10–q90") plt.title(f"{ticker} forecast by Chronos-Bolt ({interval}, H={H})") plt.legend() plt.tight_layout() note = "※ 데모 목적입니다. 투자 판단의 책임은 본인에게 있습니다." return fig, df_fcst, note # ============================= # Gradio UI # ============================= with gr.Blocks(title="Chronos Stock/Crypto Forecast") as demo: gr.Markdown("# Chronos 주가·크립토 예측 데모") with gr.Row(): ticker = gr.Textbox(value="BTC-USD", label="티커 (예: AAPL, MSFT, 005930.KS, BTC-USD)") horizon = gr.Slider(5, 365, value=90, step=1, label="예측 스텝 H (간격 단위와 동일)") with gr.Row(): start = gr.Textbox(value="2014-09-17", label="시작일 (YYYY-MM-DD)") end = gr.Textbox(value=dt.date.today().isoformat(), label="종료일 (YYYY-MM-DD)") with gr.Row(): model_id = gr.Dropdown( choices=[ "amazon/chronos-bolt-tiny", "amazon/chronos-bolt-mini", "amazon/chronos-bolt-small", "amazon/chronos-bolt-base", ], value="amazon/chronos-bolt-small", label="모델" ) device = gr.Dropdown(choices=["cpu"], value="cpu", label="디바이스") interval = gr.Dropdown( choices=["1d", "1h", "30m", "15m", "5m"], value="1d", label="간격" ) btn = gr.Button("예측 실행") plot = gr.Plot(label="History + Forecast") table = gr.Dataframe(label="예측 결과 (분위수)") note = gr.Markdown() btn.click( fn=run_forecast, inputs=[ticker, start, end, horizon, model_id, device, interval], outputs=[plot, table, note] ) if __name__ == "__main__": demo.launch()