kimyechan commited on
Commit
0b16dd0
·
1 Parent(s): 98b6850

fix:수정

Browse files
Files changed (2) hide show
  1. app.py +100 -88
  2. requirements.txt +1 -2
app.py CHANGED
@@ -1,14 +1,15 @@
1
- import os
2
  import datetime as dt
3
  import pandas as pd
4
  import torch
5
  import gradio as gr
6
  import yfinance as yf
7
- import requests # ← 추가
8
 
9
- from chronos import BaseChronosPipeline # from 'chronos-forecasting'
10
 
11
- # ---- 전역 캐시: 모델을 한 번만 로드해 재사용 ----
 
 
 
12
  _PIPELINE_CACHE = {}
13
 
14
  def get_pipeline(model_id: str, device: str = "cpu"):
@@ -20,125 +21,134 @@ def get_pipeline(model_id: str, device: str = "cpu"):
20
  torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
21
  )
22
  return _PIPELINE_CACHE[key]
23
- # ---- 심볼 매핑: 'BTC-USD' → 'bitcoin' (Coingecko id)
24
- _CG_MAP = {
25
- "BTC-USD": "bitcoin",
26
- "ETH-USD": "ethereum",
27
- "SOL-USD": "solana",
28
- "XRP-USD": "ripple",
29
- "ADA-USD": "cardano",
30
- }
31
-
32
- def _fetch_coingecko_daily(ticker: str, start: str, end: str):
33
- """
34
- Coingecko: /coins/{id}/market_chart?vs_currency=usd&days=max
35
- 반환: (date, price) 일별 데이터프레임
36
- """
37
- coin_id = _CG_MAP.get(ticker.upper())
38
- if not coin_id:
39
- raise ValueError("해당 티커는 Coingecko 매핑이 없습니다. (예: BTC-USD, ETH-USD)")
40
-
41
- url = f"https://api.coingecko.com/api/v3/coins/{coin_id}/market_chart"
42
- # days=max 로 전체 일봉 받아온 뒤, 날짜 필터링
43
- resp = requests.get(url, params={"vs_currency": "usd", "days": "max"}, timeout=30)
44
- resp.raise_for_status()
45
- data = resp.json()
46
- prices = data.get("prices", [])
47
- if not prices:
48
- raise ValueError("Coingecko 응답에 prices가 없습니다.")
49
-
50
- # prices: [[timestamp_ms, price], ...]
51
- df = pd.DataFrame(prices, columns=["ts", "close"])
52
- df["ts"] = pd.to_datetime(df["ts"], unit="ms", utc=True).dt.tz_convert(None)
53
- df = df.set_index("ts").sort_index()
54
- # 날짜 범위 적용
55
- s = df["close"].astype(float)
56
- if start:
57
- s = s[s.index >= pd.to_datetime(start)]
58
- if end:
59
- s = s[s.index <= pd.to_datetime(end)]
60
- return s
61
 
 
 
 
 
62
  def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
63
  """
64
- 1) yfinance(history → download)로 시도
65
- 2) 실패 시 Coingecko 일봉으로 대체 (BTC-USD/ETH-USD 등)
 
 
 
 
 
66
  """
67
  ticker = ticker.strip().upper()
68
- _start = start or "2014-09-17"
69
- _end = end or dt.date.today().isoformat()
70
 
71
- # ---- 1차: yfinance 시도
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  try:
73
  tk = yf.Ticker(ticker)
74
  df = tk.history(start=_start, end=_end, interval=interval, auto_adjust=True, actions=False)
75
- if df.empty or "Close" not in df:
76
- raise ValueError("empty history")
77
-
78
- s = df["Close"].dropna().astype(float)
79
- if s.empty:
80
- raise ValueError("empty close after dropna")
81
- return s
82
  except Exception:
83
  pass
84
 
 
85
  try:
86
- df = yf.download(ticker, start=_start, end=_end, interval=interval, progress=False, threads=False)
87
- if not df.empty and "Close" in df and not df["Close"].dropna().empty:
88
- return df["Close"].dropna().astype(float)
 
 
 
 
89
  except Exception:
90
  pass
91
 
92
- # ---- 2차: Coingecko fallback (일봉만)
93
- try:
94
- s = _fetch_coingecko_daily(ticker, _start, _end)
95
- if s.empty:
96
- raise ValueError("Coingecko 데이터가 비어 있습니다.")
97
- # interval이 일봉이 아니면 일봉으로 강제 전환 안내 (호출 측에서 메시지로 보여줌)
98
- if interval != "1d":
99
- raise RuntimeError("FALLBACK_DAILY_ONLY")
100
- return s
101
- except RuntimeError as r:
102
- if str(r) == "FALLBACK_DAILY_ONLY":
103
- # 호출부에서 메시지 처리할 있게 예외를 다시 던짐
104
- raise RuntimeError("FALLBACK_DAILY_ONLY")
105
- raise
106
- except Exception as e:
107
- raise ValueError(f"데이터를 가져오지 못했습니다 (yfinance/Coingecko 실패): {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
 
 
 
109
  def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval):
110
  try:
111
  series = load_close_series(ticker, start_date, end_date, interval)
112
- fallback_note = ""
113
- except RuntimeError as r:
114
- if str(r) == "FALLBACK_DAILY_ONLY":
115
- # 일봉으로 재시도
116
- series = load_close_series(ticker, start_date, end_date, "1d")
117
- interval = "1d"
118
- fallback_note = "※ Coingecko 대체 소스 사용으로 간격을 '1d(일봉)'로 자동 전환했습니다."
119
- else:
120
- return None, pd.DataFrame(), f"데이터 로딩 오류: {r}"
121
  except Exception as e:
 
122
  return None, pd.DataFrame(), f"데이터 로딩 오류: {e}"
123
 
124
  pipe = get_pipeline(model_id, device)
125
  H = int(horizon)
126
 
127
- import numpy as np
128
  context = torch.tensor(series.values, dtype=torch.float32)
 
 
129
  preds = pipe.predict(context=context, prediction_length=H)[0]
130
  q10, q50, q90 = preds[0], preds[1], preds[2]
131
 
 
132
  df_fcst = pd.DataFrame(
133
  {"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()},
134
  index=pd.RangeIndex(1, H + 1, name="step"),
135
  )
136
 
 
137
  import matplotlib.pyplot as plt
138
  freq_map = {"1d": "D", "1h": "H", "30m": "30T", "15m": "15T", "5m": "5T"}
139
  freq = freq_map.get(interval, "D")
140
  future_index = pd.date_range(series.index[-1], periods=H + 1, freq=freq)[1:]
141
 
 
142
  fig = plt.figure(figsize=(10, 4))
143
  plt.plot(series.index, series.values, label="history")
144
  plt.plot(future_index, q50.numpy(), label="forecast(q50)")
@@ -147,19 +157,21 @@ def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interv
147
  plt.legend()
148
  plt.tight_layout()
149
 
150
- base_note = "※ 데모 목적입니다. 투자 판단의 책임은 본인에게 있습니다."
151
- note = (fallback_note + " " + base_note).strip()
152
  return fig, df_fcst, note
153
 
154
- # ---- Gradio UI ----
 
 
 
155
  with gr.Blocks(title="Chronos Stock/Crypto Forecast") as demo:
156
  gr.Markdown("# Chronos 주가·크립토 예측 데모")
157
  with gr.Row():
158
  ticker = gr.Textbox(value="BTC-USD", label="티커 (예: AAPL, MSFT, 005930.KS, BTC-USD)")
159
  horizon = gr.Slider(5, 365, value=90, step=1, label="예측 스텝 H (간격 단위와 동일)")
160
  with gr.Row():
161
- start = gr.Textbox(value="2014-09-17", label="시작일 (YYYY-MM-DD, 예: 2014-09-17)")
162
- end = gr.Textbox(value=dt.date.today().isoformat(), label="종료일 (YYYY-MM-DD, 비워두면 오늘)")
163
  with gr.Row():
164
  model_id = gr.Dropdown(
165
  choices=[
 
 
1
  import datetime as dt
2
  import pandas as pd
3
  import torch
4
  import gradio as gr
5
  import yfinance as yf
 
6
 
7
+ from chronos import BaseChronosPipeline # pip: chronos-forecasting
8
 
9
+
10
+ # =============================
11
+ # Chronos 모델 캐시/로더
12
+ # =============================
13
  _PIPELINE_CACHE = {}
14
 
15
  def get_pipeline(model_id: str, device: str = "cpu"):
 
21
  torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
22
  )
23
  return _PIPELINE_CACHE[key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+
26
+ # =============================
27
+ # yfinance 전용 견고 로더
28
+ # =============================
29
  def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
30
  """
31
+ yfinance 사용.
32
+ 1) Ticker().history(start/end)
33
+ 2) download(start/end, repair=True)
34
+ 3) period 기반 폴백:
35
+ - 1d → ["max", "10y", "5y", "2y", "1y"]
36
+ - 1h → ["730d", "365d", "60d"]
37
+ - 30m/15m/5m → ["60d", "30d", "14d"]
38
  """
39
  ticker = ticker.strip().upper()
 
 
40
 
41
+ # 날짜 보정
42
+ _start = start or "2014-09-17" # BTC-USD 히스토리 시작 근처
43
+ _end = end or dt.date.today().isoformat()
44
+ try:
45
+ sdt = pd.to_datetime(_start)
46
+ edt = pd.to_datetime(_end)
47
+ if edt < sdt:
48
+ sdt, edt = edt, sdt # 뒤바뀐 경우 교환
49
+ _start, _end = sdt.date().isoformat(), edt.date().isoformat()
50
+ except Exception:
51
+ pass # 파싱 실패해도 밑의 period 폴백이 커버
52
+
53
+ def _extract_close(df):
54
+ if df is None or df.empty:
55
+ return None
56
+ c = df.get("Close")
57
+ if c is None:
58
+ return None
59
+ c = c.dropna().astype(float)
60
+ return c if not c.empty else None
61
+
62
+ # 1) history(start/end)
63
  try:
64
  tk = yf.Ticker(ticker)
65
  df = tk.history(start=_start, end=_end, interval=interval, auto_adjust=True, actions=False)
66
+ s = _extract_close(df)
67
+ if s is not None:
68
+ return s
 
 
 
 
69
  except Exception:
70
  pass
71
 
72
+ # 2) download(start/end) + repair=True
73
  try:
74
+ df = yf.download(
75
+ ticker, start=_start, end=_end, interval=interval,
76
+ progress=False, threads=False, repair=True
77
+ )
78
+ s = _extract_close(df)
79
+ if s is not None:
80
+ return s
81
  except Exception:
82
  pass
83
 
84
+ # 3) period 폴백
85
+ if interval == "1d":
86
+ period_candidates = ["max", "10y", "5y", "2y", "1y"]
87
+ elif interval == "1h":
88
+ period_candidates = ["730d", "365d", "60d"] # 1시간봉은 과거 제한 큼
89
+ else: # 30m/15m/5m 분봉
90
+ period_candidates = ["60d", "30d", "14d"] # 분봉은 보통 60~30일 이내만 가능
91
+
92
+ for per in period_candidates:
93
+ # Ticker().history(period=…)
94
+ try:
95
+ df = tk.history(period=per, interval=interval, auto_adjust=True, actions=False)
96
+ s = _extract_close(df)
97
+ if s is not None:
98
+ return s
99
+ except Exception:
100
+ pass
101
+ # download(period=…)
102
+ try:
103
+ df = yf.download(
104
+ ticker, period=per, interval=interval,
105
+ progress=False, threads=False, repair=True
106
+ )
107
+ s = _extract_close(df)
108
+ if s is not None:
109
+ return s
110
+ except Exception:
111
+ pass
112
+
113
+ raise ValueError(
114
+ "yfinance에서 데이터를 가져오지 못했습니다. "
115
+ "간격(interval)이나 기간(start/end 혹은 period)을 조정해 다시 시도해 보세요."
116
+ )
117
+
118
 
119
+ # =============================
120
+ # 예측 함수 (Gradio 핸들러)
121
+ # =============================
122
  def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval):
123
  try:
124
  series = load_close_series(ticker, start_date, end_date, interval)
 
 
 
 
 
 
 
 
 
125
  except Exception as e:
126
+ # Gradio v4: Plot.update 없음 → None 반환
127
  return None, pd.DataFrame(), f"데이터 로딩 오류: {e}"
128
 
129
  pipe = get_pipeline(model_id, device)
130
  H = int(horizon)
131
 
132
+ # Chronos 입력: 1D 텐서 (float)
133
  context = torch.tensor(series.values, dtype=torch.float32)
134
+
135
+ # 예측: (num_series=1, num_quantiles=3, H) with q=[0.1, 0.5, 0.9]
136
  preds = pipe.predict(context=context, prediction_length=H)[0]
137
  q10, q50, q90 = preds[0], preds[1], preds[2]
138
 
139
+ # 표 데이터
140
  df_fcst = pd.DataFrame(
141
  {"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()},
142
  index=pd.RangeIndex(1, H + 1, name="step"),
143
  )
144
 
145
+ # 미래 x축: interval→pandas freq 매핑
146
  import matplotlib.pyplot as plt
147
  freq_map = {"1d": "D", "1h": "H", "30m": "30T", "15m": "15T", "5m": "5T"}
148
  freq = freq_map.get(interval, "D")
149
  future_index = pd.date_range(series.index[-1], periods=H + 1, freq=freq)[1:]
150
 
151
+ # 그래프
152
  fig = plt.figure(figsize=(10, 4))
153
  plt.plot(series.index, series.values, label="history")
154
  plt.plot(future_index, q50.numpy(), label="forecast(q50)")
 
157
  plt.legend()
158
  plt.tight_layout()
159
 
160
+ note = "※ 데모 목적입니다. 투자 판단의 책임은 본인에게 있습니다."
 
161
  return fig, df_fcst, note
162
 
163
+
164
+ # =============================
165
+ # Gradio UI
166
+ # =============================
167
  with gr.Blocks(title="Chronos Stock/Crypto Forecast") as demo:
168
  gr.Markdown("# Chronos 주가·크립토 예측 데모")
169
  with gr.Row():
170
  ticker = gr.Textbox(value="BTC-USD", label="티커 (예: AAPL, MSFT, 005930.KS, BTC-USD)")
171
  horizon = gr.Slider(5, 365, value=90, step=1, label="예측 스텝 H (간격 단위와 동일)")
172
  with gr.Row():
173
+ start = gr.Textbox(value="2014-09-17", label="시작일 (YYYY-MM-DD)")
174
+ end = gr.Textbox(value=dt.date.today().isoformat(), label="종료일 (YYYY-MM-DD)")
175
  with gr.Row():
176
  model_id = gr.Dropdown(
177
  choices=[
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
  gradio>=4.44
2
  pandas>=2.2
3
- yfinance==0.2.40 # (우린 실시간 안 써서 이대로 안전)
4
- requests>=2.31
5
  matplotlib>=3.8
6
  torch>=2.2 ; platform_system != "Darwin"
7
  chronos-forecasting>=1.0
 
1
  gradio>=4.44
2
  pandas>=2.2
3
+ yfinance==0.2.40
 
4
  matplotlib>=3.8
5
  torch>=2.2 ; platform_system != "Darwin"
6
  chronos-forecasting>=1.0