kimyechan commited on
Commit
7355d65
·
1 Parent(s): 0b16dd0

fix:수정

Browse files
Files changed (1) hide show
  1. app.py +108 -23
app.py CHANGED
@@ -3,6 +3,7 @@ import pandas as pd
3
  import torch
4
  import gradio as gr
5
  import yfinance as yf
 
6
 
7
  from chronos import BaseChronosPipeline # pip: chronos-forecasting
8
 
@@ -24,17 +25,98 @@ def get_pipeline(model_id: str, device: str = "cpu"):
24
 
25
 
26
  # =============================
27
- # yfinance 전용 견고 로더
28
  # =============================
29
- def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  """
31
- yfinance만 사용.
32
- 1) Ticker().history(start/end)
33
- 2) download(start/end, repair=True)
34
- 3) period 기반 폴백:
35
- - 1d → ["max", "10y", "5y", "2y", "1y"]
36
- - 1h → ["730d", "365d", "60d"]
37
- - 30m/15m/5m → ["60d", "30d", "14d"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  """
39
  ticker = ticker.strip().upper()
40
 
@@ -45,12 +127,12 @@ def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
45
  sdt = pd.to_datetime(_start)
46
  edt = pd.to_datetime(_end)
47
  if edt < sdt:
48
- sdt, edt = edt, sdt # 뒤바뀐 경우 교환
49
  _start, _end = sdt.date().isoformat(), edt.date().isoformat()
50
  except Exception:
51
- pass # 파싱 실패해도 밑의 period 폴백이 커버
52
 
53
- def _extract_close(df):
54
  if df is None or df.empty:
55
  return None
56
  c = df.get("Close")
@@ -85,12 +167,11 @@ def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
85
  if interval == "1d":
86
  period_candidates = ["max", "10y", "5y", "2y", "1y"]
87
  elif interval == "1h":
88
- period_candidates = ["730d", "365d", "60d"] # 1시간봉은 과거 제한 큼
89
- else: # 30m/15m/5m 등 분봉
90
- period_candidates = ["60d", "30d", "14d"] # 분봉은 보통 60~30일 이내만 가능
91
 
92
  for per in period_candidates:
93
- # Ticker().history(period=…)
94
  try:
95
  df = tk.history(period=per, interval=interval, auto_adjust=True, actions=False)
96
  s = _extract_close(df)
@@ -98,7 +179,6 @@ def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
98
  return s
99
  except Exception:
100
  pass
101
- # download(period=…)
102
  try:
103
  df = yf.download(
104
  ticker, period=per, interval=interval,
@@ -110,9 +190,15 @@ def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
110
  except Exception:
111
  pass
112
 
 
 
 
 
 
 
 
113
  raise ValueError(
114
- "yfinance에서 데이터를 가져오지 못했습니다. "
115
- "간격(interval)이나 기간(start/end 혹은 period)을 조정해 다시 시도해 보세요."
116
  )
117
 
118
 
@@ -123,26 +209,25 @@ def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interv
123
  try:
124
  series = load_close_series(ticker, start_date, end_date, interval)
125
  except Exception as e:
126
- # Gradio v4: Plot.update 없음 → None 반환
127
  return None, pd.DataFrame(), f"데이터 로딩 오류: {e}"
128
 
129
  pipe = get_pipeline(model_id, device)
130
  H = int(horizon)
131
 
132
- # Chronos 입력: 1D 텐서 (float)
133
  context = torch.tensor(series.values, dtype=torch.float32)
134
 
135
  # 예측: (num_series=1, num_quantiles=3, H) with q=[0.1, 0.5, 0.9]
136
  preds = pipe.predict(context=context, prediction_length=H)[0]
137
  q10, q50, q90 = preds[0], preds[1], preds[2]
138
 
139
- # 표 데이터
140
  df_fcst = pd.DataFrame(
141
  {"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()},
142
  index=pd.RangeIndex(1, H + 1, name="step"),
143
  )
144
 
145
- # 미래 x축: interval→pandas freq 매핑
146
  import matplotlib.pyplot as plt
147
  freq_map = {"1d": "D", "1h": "H", "30m": "30T", "15m": "15T", "5m": "5T"}
148
  freq = freq_map.get(interval, "D")
 
3
  import torch
4
  import gradio as gr
5
  import yfinance as yf
6
+ import requests
7
 
8
  from chronos import BaseChronosPipeline # pip: chronos-forecasting
9
 
 
25
 
26
 
27
  # =============================
28
+ # Binance (무인증) 폴백
29
  # =============================
30
+ _BINANCE_INTERVAL = {"1d": "1d", "1h": "1h", "30m": "30m", "15m": "15m", "5m": "5m"}
31
+
32
+ def _yf_to_binance_symbol(ticker: str) -> str | None:
33
+ """
34
+ BTC-USD -> BTCUSDT, ETH-USD -> ETHUSDT ...
35
+ 주식/원화/기타 심볼은 None (Binance 폴백하지 않음)
36
+ """
37
+ t = ticker.upper()
38
+ if t.endswith("-USD") and len(t) >= 6:
39
+ base = t[:-4] # remove "-USD"
40
+ return f"{base}USDT"
41
+ return None
42
+
43
+ def _fetch_binance_klines(ticker: str, interval: str, start: str | None, end: str | None) -> pd.Series:
44
  """
45
+ Binance Klines (무인증)
46
+ https://api.binance.com/api/v3/klines
47
+ 반환: pandas.Series(index=datetime, values=float close)
48
+ """
49
+ if interval not in _BINANCE_INTERVAL:
50
+ raise ValueError("Binance는 해당 interval을 지원하지 않습니다.")
51
+
52
+ symbol = _yf_to_binance_symbol(ticker)
53
+ if not symbol:
54
+ raise ValueError("이 티커는 Binance 폴백 대상이 아닙니다.")
55
+
56
+ base = "https://api.binance.com/api/v3/klines"
57
+
58
+ def to_ms(s: str) -> int:
59
+ return int(pd.to_datetime(s).timestamp() * 1000)
60
+
61
+ start_ms = to_ms(start) if start else None
62
+ end_ms = to_ms(end) if end else None
63
+
64
+ rows = []
65
+ cur_start = start_ms
66
+ while True:
67
+ params = {"symbol": symbol, "interval": _BINANCE_INTERVAL[interval], "limit": 1000}
68
+ if cur_start is not None:
69
+ params["startTime"] = cur_start
70
+ if end_ms is not None:
71
+ params["endTime"] = end_ms
72
+
73
+ r = requests.get(base, params=params, timeout=30)
74
+ r.raise_for_status()
75
+ data = r.json()
76
+ if not data:
77
+ break
78
+
79
+ rows.extend(data)
80
+
81
+ last_close_time = data[-1][6] # closeTime (ms)
82
+ next_start = last_close_time + 1
83
+ if cur_start is not None and next_start <= cur_start:
84
+ break
85
+ cur_start = next_start
86
+
87
+ if len(data) < 1000:
88
+ break
89
+
90
+ if not rows:
91
+ raise ValueError("Binance에서 데이터가 비어 있습니다.")
92
+
93
+ df = pd.DataFrame(rows, columns=[
94
+ "openTime","open","high","low","close","volume","closeTime",
95
+ "quoteAssetVolume","numTrades","takerBuyBase","takerBuyQuote","ignore"
96
+ ])
97
+ df["ts"] = pd.to_datetime(df["closeTime"], unit="ms")
98
+ s = df.set_index("ts")["close"].astype(float).sort_index()
99
+
100
+ if start:
101
+ s = s[s.index >= pd.to_datetime(start)]
102
+ if end:
103
+ s = s[s.index <= pd.to_datetime(end)]
104
+ if s.empty:
105
+ raise ValueError("Binance 시리즈가 비어 있습니다.")
106
+ return s
107
+
108
+
109
+ # =============================
110
+ # yfinance 전용 견고 로더 (+Binance 폴백)
111
+ # =============================
112
+ def load_close_series(ticker: str, start: str, end: str, interval: str = "1d") -> pd.Series:
113
+ """
114
+ yfinance만으로 먼저 시도:
115
+ 1) Ticker().history(start/end)
116
+ 2) download(start/end, repair=True)
117
+ 3) period 폴백 (interval별 후보 순회)
118
+ 그래도 실패할 경우:
119
+ 4) Binance (무인증) 폴백 — BTC-USD 같은 암호화폐만 대상
120
  """
121
  ticker = ticker.strip().upper()
122
 
 
127
  sdt = pd.to_datetime(_start)
128
  edt = pd.to_datetime(_end)
129
  if edt < sdt:
130
+ sdt, edt = edt, sdt
131
  _start, _end = sdt.date().isoformat(), edt.date().isoformat()
132
  except Exception:
133
+ pass
134
 
135
+ def _extract_close(df: pd.DataFrame | None) -> pd.Series | None:
136
  if df is None or df.empty:
137
  return None
138
  c = df.get("Close")
 
167
  if interval == "1d":
168
  period_candidates = ["max", "10y", "5y", "2y", "1y"]
169
  elif interval == "1h":
170
+ period_candidates = ["730d", "365d", "60d"]
171
+ else: # 30m/15m/5m
172
+ period_candidates = ["60d", "30d", "14d"]
173
 
174
  for per in period_candidates:
 
175
  try:
176
  df = tk.history(period=per, interval=interval, auto_adjust=True, actions=False)
177
  s = _extract_close(df)
 
179
  return s
180
  except Exception:
181
  pass
 
182
  try:
183
  df = yf.download(
184
  ticker, period=per, interval=interval,
 
190
  except Exception:
191
  pass
192
 
193
+ # 4) Binance 폴백 (암호화폐만)
194
+ try:
195
+ s = _fetch_binance_klines(ticker, interval, _start, _end)
196
+ return s
197
+ except Exception:
198
+ pass
199
+
200
  raise ValueError(
201
+ "데이터를 가져오지 못했습니다. 간격(interval)이나 기간(start/end 혹은 period)을 조정해 다시 시도해 보세요."
 
202
  )
203
 
204
 
 
209
  try:
210
  series = load_close_series(ticker, start_date, end_date, interval)
211
  except Exception as e:
 
212
  return None, pd.DataFrame(), f"데이터 로딩 오류: {e}"
213
 
214
  pipe = get_pipeline(model_id, device)
215
  H = int(horizon)
216
 
217
+ # Chronos 입력
218
  context = torch.tensor(series.values, dtype=torch.float32)
219
 
220
  # 예측: (num_series=1, num_quantiles=3, H) with q=[0.1, 0.5, 0.9]
221
  preds = pipe.predict(context=context, prediction_length=H)[0]
222
  q10, q50, q90 = preds[0], preds[1], preds[2]
223
 
224
+ # 표
225
  df_fcst = pd.DataFrame(
226
  {"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()},
227
  index=pd.RangeIndex(1, H + 1, name="step"),
228
  )
229
 
230
+ # 미래 x
231
  import matplotlib.pyplot as plt
232
  freq_map = {"1d": "D", "1h": "H", "30m": "30T", "15m": "15T", "5m": "5T"}
233
  freq = freq_map.get(interval, "D")