kimyechan
commited on
Commit
·
0b929da
1
Parent(s):
c43d5d1
fix: 비트코인 로직 수정
Browse files
app.py
CHANGED
|
@@ -15,26 +15,44 @@ def get_pipeline(model_id: str, device: str = "cpu"):
|
|
| 15 |
if key not in _PIPELINE_CACHE:
|
| 16 |
_PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained(
|
| 17 |
model_id,
|
| 18 |
-
device_map=device, # "cpu" / "cuda"
|
| 19 |
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
|
| 20 |
)
|
| 21 |
return _PIPELINE_CACHE[key]
|
| 22 |
|
| 23 |
-
# ----
|
| 24 |
def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
s = df["Close"].dropna().astype(float)
|
|
|
|
|
|
|
| 30 |
return s
|
| 31 |
|
| 32 |
# ---- 예측 함수 (Gradio가 호출) ----
|
| 33 |
def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval):
|
| 34 |
try:
|
| 35 |
-
series = load_close_series(ticker, start_date, end_date, interval)
|
| 36 |
except Exception as e:
|
| 37 |
-
|
|
|
|
| 38 |
|
| 39 |
pipe = get_pipeline(model_id, device)
|
| 40 |
H = int(horizon)
|
|
@@ -42,8 +60,7 @@ def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interv
|
|
| 42 |
# Chronos 입력: 1D 텐서 (float)
|
| 43 |
context = torch.tensor(series.values, dtype=torch.float32)
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
# 보통 q=[0.1, 0.5, 0.9]
|
| 47 |
preds = pipe.predict(context=context, prediction_length=H)[0]
|
| 48 |
q10, q50, q90 = preds[0], preds[1], preds[2]
|
| 49 |
|
|
@@ -53,15 +70,18 @@ def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interv
|
|
| 53 |
index=pd.RangeIndex(1, H + 1, name="step"),
|
| 54 |
)
|
| 55 |
|
| 56 |
-
#
|
| 57 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
fig = plt.figure(figsize=(10, 4))
|
| 59 |
plt.plot(series.index, series.values, label="history")
|
| 60 |
-
# 미래 구간 x축 만들기: 종가가 일 단위라 'D' 주기 사용
|
| 61 |
-
future_index = pd.date_range(series.index[-1], periods=H + 1, freq="D")[1:]
|
| 62 |
plt.plot(future_index, q50.numpy(), label="forecast(q50)")
|
| 63 |
plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10–q90")
|
| 64 |
-
plt.title(f"{ticker} forecast by Chronos-Bolt")
|
| 65 |
plt.legend()
|
| 66 |
plt.tight_layout()
|
| 67 |
|
|
@@ -69,14 +89,14 @@ def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interv
|
|
| 69 |
return fig, df_fcst, note
|
| 70 |
|
| 71 |
# ---- Gradio UI ----
|
| 72 |
-
with gr.Blocks(title="Chronos Stock Forecast") as demo:
|
| 73 |
-
gr.Markdown("# Chronos
|
| 74 |
with gr.Row():
|
| 75 |
-
ticker = gr.Textbox(value="
|
| 76 |
-
horizon = gr.Slider(5,
|
| 77 |
with gr.Row():
|
| 78 |
-
start = gr.Textbox(value=
|
| 79 |
-
end = gr.Textbox(value=dt.date.today().isoformat(), label="종료일 (YYYY-MM-DD)")
|
| 80 |
with gr.Row():
|
| 81 |
model_id = gr.Dropdown(
|
| 82 |
choices=[
|
|
@@ -89,7 +109,11 @@ with gr.Blocks(title="Chronos Stock Forecast") as demo:
|
|
| 89 |
label="모델"
|
| 90 |
)
|
| 91 |
device = gr.Dropdown(choices=["cpu"], value="cpu", label="디바이스")
|
| 92 |
-
interval = gr.Dropdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
btn = gr.Button("예측 실행")
|
| 94 |
|
| 95 |
plot = gr.Plot(label="History + Forecast")
|
|
|
|
| 15 |
if key not in _PIPELINE_CACHE:
|
| 16 |
_PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained(
|
| 17 |
model_id,
|
| 18 |
+
device_map=device, # "cpu" / "cuda"
|
| 19 |
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
|
| 20 |
)
|
| 21 |
return _PIPELINE_CACHE[key]
|
| 22 |
|
| 23 |
+
# ---- 주가/크립토 데이터 로딩 (yfinance, 견고화) ----
|
| 24 |
def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
|
| 25 |
+
"""
|
| 26 |
+
BTC-USD 등 크립토 심볼에서 간헐적으로 timezone/파싱 오류가 나므로
|
| 27 |
+
history() 경로를 우선 사용하고, 실패 시 한 번 재시도.
|
| 28 |
+
"""
|
| 29 |
+
# 기본값 보정: 너무 최근만 고르면 빈 데이터가 나올 수 있어 일봉은 과거부터 권장
|
| 30 |
+
_start = start or "2014-09-17" # BTC-USD 최초 상장일 근처
|
| 31 |
+
_end = end or dt.date.today().isoformat()
|
| 32 |
+
|
| 33 |
+
tk = yf.Ticker(ticker)
|
| 34 |
+
try:
|
| 35 |
+
df = tk.history(start=_start, end=_end, interval=interval, auto_adjust=True, actions=False)
|
| 36 |
+
if df.empty or "Close" not in df:
|
| 37 |
+
raise ValueError("empty")
|
| 38 |
+
except Exception:
|
| 39 |
+
# fallback: download() 경로 시도
|
| 40 |
+
df = yf.download(ticker, start=_start, end=_end, interval=interval, progress=False, threads=False)
|
| 41 |
+
if df.empty or "Close" not in df:
|
| 42 |
+
raise ValueError("데이터가 없거나 'Close' 열이 없습니다. 티커/날짜/간격을 확인하세요.")
|
| 43 |
+
|
| 44 |
s = df["Close"].dropna().astype(float)
|
| 45 |
+
if s.empty:
|
| 46 |
+
raise ValueError("다운로드 결과가 비어 있습니다. 기간/간격을 줄이거나 다시 시도하세요.")
|
| 47 |
return s
|
| 48 |
|
| 49 |
# ---- 예측 함수 (Gradio가 호출) ----
|
| 50 |
def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval):
|
| 51 |
try:
|
| 52 |
+
series = load_close_series(ticker.strip(), start_date, end_date, interval)
|
| 53 |
except Exception as e:
|
| 54 |
+
# Gradio v4에서는 Plot.update가 없음 → None 반환으로 정리
|
| 55 |
+
return None, pd.DataFrame(), f"데이터 로딩 오류: {e}"
|
| 56 |
|
| 57 |
pipe = get_pipeline(model_id, device)
|
| 58 |
H = int(horizon)
|
|
|
|
| 60 |
# Chronos 입력: 1D 텐서 (float)
|
| 61 |
context = torch.tensor(series.values, dtype=torch.float32)
|
| 62 |
|
| 63 |
+
# 예측: (num_series=1, num_quantiles=3, H) with q=[0.1, 0.5, 0.9]
|
|
|
|
| 64 |
preds = pipe.predict(context=context, prediction_length=H)[0]
|
| 65 |
q10, q50, q90 = preds[0], preds[1], preds[2]
|
| 66 |
|
|
|
|
| 70 |
index=pd.RangeIndex(1, H + 1, name="step"),
|
| 71 |
)
|
| 72 |
|
| 73 |
+
# 미래 x축: interval에 맞는 pandas 주기
|
| 74 |
import matplotlib.pyplot as plt
|
| 75 |
+
freq_map = {"1d": "D", "1h": "H", "30m": "30T", "15m": "15T", "5m": "5T"}
|
| 76 |
+
freq = freq_map.get(interval, "D")
|
| 77 |
+
future_index = pd.date_range(series.index[-1], periods=H + 1, freq=freq)[1:]
|
| 78 |
+
|
| 79 |
+
# 그래프
|
| 80 |
fig = plt.figure(figsize=(10, 4))
|
| 81 |
plt.plot(series.index, series.values, label="history")
|
|
|
|
|
|
|
| 82 |
plt.plot(future_index, q50.numpy(), label="forecast(q50)")
|
| 83 |
plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10–q90")
|
| 84 |
+
plt.title(f"{ticker} forecast by Chronos-Bolt ({interval}, H={H})")
|
| 85 |
plt.legend()
|
| 86 |
plt.tight_layout()
|
| 87 |
|
|
|
|
| 89 |
return fig, df_fcst, note
|
| 90 |
|
| 91 |
# ---- Gradio UI ----
|
| 92 |
+
with gr.Blocks(title="Chronos Stock/Crypto Forecast") as demo:
|
| 93 |
+
gr.Markdown("# Chronos 주가·크립토 예측 데모")
|
| 94 |
with gr.Row():
|
| 95 |
+
ticker = gr.Textbox(value="BTC-USD", label="티커 (예: AAPL, MSFT, 005930.KS, BTC-USD)")
|
| 96 |
+
horizon = gr.Slider(5, 365, value=90, step=1, label="예측 스텝 H (간격 단위�� 동일)")
|
| 97 |
with gr.Row():
|
| 98 |
+
start = gr.Textbox(value="2014-09-17", label="시작일 (YYYY-MM-DD, 예: 2014-09-17)")
|
| 99 |
+
end = gr.Textbox(value=dt.date.today().isoformat(), label="종료일 (YYYY-MM-DD, 비워두면 오늘)")
|
| 100 |
with gr.Row():
|
| 101 |
model_id = gr.Dropdown(
|
| 102 |
choices=[
|
|
|
|
| 109 |
label="모델"
|
| 110 |
)
|
| 111 |
device = gr.Dropdown(choices=["cpu"], value="cpu", label="디바이스")
|
| 112 |
+
interval = gr.Dropdown(
|
| 113 |
+
choices=["1d", "1h", "30m", "15m", "5m"],
|
| 114 |
+
value="1d",
|
| 115 |
+
label="간격"
|
| 116 |
+
)
|
| 117 |
btn = gr.Button("예측 실행")
|
| 118 |
|
| 119 |
plot = gr.Plot(label="History + Forecast")
|