kimyechan commited on
Commit
0b929da
·
1 Parent(s): c43d5d1

fix: 비트코인 로직 수정

Browse files
Files changed (1) hide show
  1. app.py +45 -21
app.py CHANGED
@@ -15,26 +15,44 @@ def get_pipeline(model_id: str, device: str = "cpu"):
15
  if key not in _PIPELINE_CACHE:
16
  _PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained(
17
  model_id,
18
- device_map=device, # "cpu" / "cuda" (Spaces 기본은 cpu)
19
  torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
20
  )
21
  return _PIPELINE_CACHE[key]
22
 
23
- # ---- 주가 데이터 로딩 (yfinance) ----
24
  def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
25
- # 한국 주식은 예: 005930.KS (삼성전자)
26
- df = yf.download(ticker, start=start, end=end, interval=interval, progress=False)
27
- if df.empty or "Close" not in df:
28
- raise ValueError("데이터가 없거나 'Close' 열을 찾을 수 없습니다. 티커/날짜를 확인하세요.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  s = df["Close"].dropna().astype(float)
 
 
30
  return s
31
 
32
  # ---- 예측 함수 (Gradio가 호출) ----
33
  def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval):
34
  try:
35
- series = load_close_series(ticker, start_date, end_date, interval)
36
  except Exception as e:
37
- return gr.Plot.update(None), pd.DataFrame(), f"데이터 로딩 오류: {e}"
 
38
 
39
  pipe = get_pipeline(model_id, device)
40
  H = int(horizon)
@@ -42,8 +60,7 @@ def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interv
42
  # Chronos 입력: 1D 텐서 (float)
43
  context = torch.tensor(series.values, dtype=torch.float32)
44
 
45
- # 출력: (num_series=1, num_quantiles=3, H)
46
- # 보통 q=[0.1, 0.5, 0.9]
47
  preds = pipe.predict(context=context, prediction_length=H)[0]
48
  q10, q50, q90 = preds[0], preds[1], preds[2]
49
 
@@ -53,15 +70,18 @@ def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interv
53
  index=pd.RangeIndex(1, H + 1, name="step"),
54
  )
55
 
56
- # 그래프
57
  import matplotlib.pyplot as plt
 
 
 
 
 
58
  fig = plt.figure(figsize=(10, 4))
59
  plt.plot(series.index, series.values, label="history")
60
- # 미래 구간 x축 만들기: 종가가 일 단위라 'D' 주기 사용
61
- future_index = pd.date_range(series.index[-1], periods=H + 1, freq="D")[1:]
62
  plt.plot(future_index, q50.numpy(), label="forecast(q50)")
63
  plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10–q90")
64
- plt.title(f"{ticker} forecast by Chronos-Bolt")
65
  plt.legend()
66
  plt.tight_layout()
67
 
@@ -69,14 +89,14 @@ def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interv
69
  return fig, df_fcst, note
70
 
71
  # ---- Gradio UI ----
72
- with gr.Blocks(title="Chronos Stock Forecast") as demo:
73
- gr.Markdown("# Chronos 주가 예측 데모")
74
  with gr.Row():
75
- ticker = gr.Textbox(value="AAPL", label="티커 (예: AAPL, MSFT, 005930.KS)")
76
- horizon = gr.Slider(5, 60, value=20, step=1, label="예측 길이 H ()")
77
  with gr.Row():
78
- start = gr.Textbox(value=(dt.date.today()-dt.timedelta(days=365)).isoformat(), label="시작일 (YYYY-MM-DD)")
79
- end = gr.Textbox(value=dt.date.today().isoformat(), label="종료일 (YYYY-MM-DD)")
80
  with gr.Row():
81
  model_id = gr.Dropdown(
82
  choices=[
@@ -89,7 +109,11 @@ with gr.Blocks(title="Chronos Stock Forecast") as demo:
89
  label="모델"
90
  )
91
  device = gr.Dropdown(choices=["cpu"], value="cpu", label="디바이스")
92
- interval = gr.Dropdown(choices=["1d"], value="1d", label="간격")
 
 
 
 
93
  btn = gr.Button("예측 실행")
94
 
95
  plot = gr.Plot(label="History + Forecast")
 
15
  if key not in _PIPELINE_CACHE:
16
  _PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained(
17
  model_id,
18
+ device_map=device, # "cpu" / "cuda"
19
  torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
20
  )
21
  return _PIPELINE_CACHE[key]
22
 
23
+ # ---- 주가/크립토 데이터 로딩 (yfinance, 견고화) ----
24
  def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
25
+ """
26
+ BTC-USD 크립토 심볼에서 간헐적으로 timezone/파싱 오류가 나므로
27
+ history() 경로를 우선 사용하고, 실패 한 번 재시도.
28
+ """
29
+ # 기본값 보정: 너무 최근만 고르면 빈 데이터가 나올 수 있어 일봉은 과거부터 권장
30
+ _start = start or "2014-09-17" # BTC-USD 최초 상장일 근처
31
+ _end = end or dt.date.today().isoformat()
32
+
33
+ tk = yf.Ticker(ticker)
34
+ try:
35
+ df = tk.history(start=_start, end=_end, interval=interval, auto_adjust=True, actions=False)
36
+ if df.empty or "Close" not in df:
37
+ raise ValueError("empty")
38
+ except Exception:
39
+ # fallback: download() 경로 시도
40
+ df = yf.download(ticker, start=_start, end=_end, interval=interval, progress=False, threads=False)
41
+ if df.empty or "Close" not in df:
42
+ raise ValueError("데이터가 없거나 'Close' 열이 없습니다. 티커/날짜/간격을 확인하세요.")
43
+
44
  s = df["Close"].dropna().astype(float)
45
+ if s.empty:
46
+ raise ValueError("다운로드 결과가 비어 있습니다. 기간/간격을 줄이거나 다시 시도하세요.")
47
  return s
48
 
49
  # ---- 예측 함수 (Gradio가 호출) ----
50
  def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval):
51
  try:
52
+ series = load_close_series(ticker.strip(), start_date, end_date, interval)
53
  except Exception as e:
54
+ # Gradio v4에서는 Plot.update 없음 None 반환으로 정리
55
+ return None, pd.DataFrame(), f"데이터 로딩 오류: {e}"
56
 
57
  pipe = get_pipeline(model_id, device)
58
  H = int(horizon)
 
60
  # Chronos 입력: 1D 텐서 (float)
61
  context = torch.tensor(series.values, dtype=torch.float32)
62
 
63
+ # 예측: (num_series=1, num_quantiles=3, H) with q=[0.1, 0.5, 0.9]
 
64
  preds = pipe.predict(context=context, prediction_length=H)[0]
65
  q10, q50, q90 = preds[0], preds[1], preds[2]
66
 
 
70
  index=pd.RangeIndex(1, H + 1, name="step"),
71
  )
72
 
73
+ # 미래 x축: interval에 맞는 pandas 주기
74
  import matplotlib.pyplot as plt
75
+ freq_map = {"1d": "D", "1h": "H", "30m": "30T", "15m": "15T", "5m": "5T"}
76
+ freq = freq_map.get(interval, "D")
77
+ future_index = pd.date_range(series.index[-1], periods=H + 1, freq=freq)[1:]
78
+
79
+ # 그래프
80
  fig = plt.figure(figsize=(10, 4))
81
  plt.plot(series.index, series.values, label="history")
 
 
82
  plt.plot(future_index, q50.numpy(), label="forecast(q50)")
83
  plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10–q90")
84
+ plt.title(f"{ticker} forecast by Chronos-Bolt ({interval}, H={H})")
85
  plt.legend()
86
  plt.tight_layout()
87
 
 
89
  return fig, df_fcst, note
90
 
91
  # ---- Gradio UI ----
92
+ with gr.Blocks(title="Chronos Stock/Crypto Forecast") as demo:
93
+ gr.Markdown("# Chronos 주가·크립토 예측 데모")
94
  with gr.Row():
95
+ ticker = gr.Textbox(value="BTC-USD", label="티커 (예: AAPL, MSFT, 005930.KS, BTC-USD)")
96
+ horizon = gr.Slider(5, 365, value=90, step=1, label="예측 스텝 H (간격 단위�� 동일)")
97
  with gr.Row():
98
+ start = gr.Textbox(value="2014-09-17", label="시작일 (YYYY-MM-DD, 예: 2014-09-17)")
99
+ end = gr.Textbox(value=dt.date.today().isoformat(), label="종료일 (YYYY-MM-DD, 비워두면 오늘)")
100
  with gr.Row():
101
  model_id = gr.Dropdown(
102
  choices=[
 
109
  label="모델"
110
  )
111
  device = gr.Dropdown(choices=["cpu"], value="cpu", label="디바이스")
112
+ interval = gr.Dropdown(
113
+ choices=["1d", "1h", "30m", "15m", "5m"],
114
+ value="1d",
115
+ label="간격"
116
+ )
117
  btn = gr.Button("예측 실행")
118
 
119
  plot = gr.Plot(label="History + Forecast")