|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import streamlit as st |
|
|
import torch |
|
|
from chronos import BaseChronosPipeline |
|
|
|
|
|
st.set_page_config(page_title="Chronos-Bolt Zero-Shot Forecast", layout="centered") |
|
|
st.title("Chronos-Bolt Zero-Shot Forecast") |
|
|
st.caption("Zero-shot probabilistic forecasting (q10/q50/q90) using amazon/chronos-bolt-* models.") |
|
|
|
|
|
|
|
|
def ema(series, length=20): |
|
|
s = pd.Series(series).astype("float64") |
|
|
return s.ewm(span=length, adjust=False).mean() |
|
|
|
|
|
def rsi(series, length=14): |
|
|
s = pd.Series(series).astype("float64") |
|
|
delta = s.diff() |
|
|
gain = delta.clip(lower=0).ewm(alpha=1/length, adjust=False).mean() |
|
|
loss = (-delta.clip(upper=0)).ewm(alpha=1/length, adjust=False).mean() |
|
|
rs = gain / loss.replace(0, np.nan) |
|
|
return 100 - (100 / (1 + rs)) |
|
|
|
|
|
def stochastic_kd(high, low, close, k=14, d=3, smooth_k=3): |
|
|
h = pd.Series(high).astype("float64") |
|
|
l = pd.Series(low).astype("float64") |
|
|
c = pd.Series(close).astype("float64") |
|
|
hh = h.rolling(k).max() |
|
|
ll = l.rolling(k).min() |
|
|
raw_k = 100 * (c - ll) / (hh - ll) |
|
|
k_smoothed = raw_k.rolling(smooth_k).mean() |
|
|
d_line = k_smoothed.rolling(d).mean() |
|
|
return k_smoothed, d_line |
|
|
|
|
|
|
|
|
|
|
|
MODEL_CHOICES = { |
|
|
"Bolt Mini (CPU-friendly)": "amazon/chronos-bolt-mini", |
|
|
"Bolt Small (better; GPU if available)": "amazon/chronos-bolt-small", |
|
|
} |
|
|
|
|
|
@st.cache_resource(show_spinner=True) |
|
|
def load_pipeline(model_id: str): |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.bfloat16 if device == "cuda" else torch.float32 |
|
|
return BaseChronosPipeline.from_pretrained(model_id, device_map=device, torch_dtype=dtype) |
|
|
|
|
|
|
|
|
def _force_1d(a): |
|
|
a = pd.Series(a, dtype="float32").replace([np.inf, -np.inf], np.nan).dropna() |
|
|
return a.to_numpy().reshape(-1) |
|
|
|
|
|
@st.cache_data(show_spinner=False) |
|
|
def load_ticker_series(ticker: str, period: str = "2y"): |
|
|
import yfinance as yf |
|
|
df = yf.download(ticker, period=period, interval="1d", auto_adjust=True, progress=False) |
|
|
if df.empty: |
|
|
return np.asarray([], dtype="float32") |
|
|
close = df["Close"] |
|
|
if isinstance(close, pd.DataFrame): |
|
|
close = close.iloc[:, 0] |
|
|
return _force_1d(close) |
|
|
|
|
|
def parse_pasted_series(txt: str): |
|
|
import re |
|
|
toks = re.split(r"[,\s]+", txt.strip()) |
|
|
vals = [] |
|
|
for t in toks: |
|
|
if not t: |
|
|
continue |
|
|
try: |
|
|
vals.append(float(t)) |
|
|
except: |
|
|
pass |
|
|
return _force_1d(vals) |
|
|
|
|
|
def load_csv_series(file, column=None): |
|
|
df = pd.read_csv(file) |
|
|
if column is None: |
|
|
num_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.number)] |
|
|
column = num_cols[0] if num_cols else None |
|
|
if column is None: |
|
|
return np.asarray([], dtype="float32"), df, None |
|
|
return _force_1d(df[column]), df, column |
|
|
|
|
|
|
|
|
c1, c2 = st.columns(2) |
|
|
with c1: |
|
|
model_label = st.selectbox("Model", list(MODEL_CHOICES.keys()), index=0) |
|
|
with c2: |
|
|
pred_len = st.number_input("Prediction length (steps)", 1, 365, 30) |
|
|
|
|
|
src = st.radio("Data source", ["Ticker (yfinance)", "Paste numbers", "Upload CSV"], horizontal=True) |
|
|
|
|
|
series = None |
|
|
if src == "Ticker (yfinance)": |
|
|
t1, t2 = st.columns([2, 1]) |
|
|
with t1: |
|
|
ticker = st.text_input("Ticker (e.g., AAPL, SPY, BTC-USD)", "AAPL") |
|
|
with t2: |
|
|
period = st.selectbox("History window", ["6mo", "1y", "2y", "5y"], index=2) |
|
|
if st.button("Load data"): |
|
|
series = load_ticker_series(ticker.strip(), period) |
|
|
if series.size == 0: |
|
|
st.error("No data returned. Try another ticker/window.") |
|
|
elif src == "Paste numbers": |
|
|
txt = st.text_area("One value per line (or comma/space separated)", "1\n2\n3\n4\n5\n6\n7\n8\n9\n10") |
|
|
if st.button("Use pasted data"): |
|
|
series = parse_pasted_series(txt) |
|
|
else: |
|
|
uploaded = st.file_uploader("Upload CSV", type=["csv"]) |
|
|
if uploaded is not None: |
|
|
df = pd.read_csv(uploaded) |
|
|
numeric_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.number)] |
|
|
col = st.selectbox("Pick numeric column", numeric_cols) if numeric_cols else None |
|
|
if st.button("Load CSV column") and col: |
|
|
series, _, _ = load_csv_series(uploaded, column=col) |
|
|
elif uploaded and not numeric_cols: |
|
|
st.error("No numeric columns found in CSV.") |
|
|
|
|
|
|
|
|
if series is not None and series.size > 5: |
|
|
st.write(f"Loaded {series.size} points.") |
|
|
st.line_chart(pd.DataFrame(series, columns=["value"])) |
|
|
|
|
|
if st.button("Forecast"): |
|
|
with st.spinner("Running Chronos-Bolt..."): |
|
|
pipe = load_pipeline(MODEL_CHOICES[model_label]) |
|
|
ctx = torch.tensor(series, dtype=torch.float32) |
|
|
q_levels = [0.10, 0.50, 0.90] |
|
|
|
|
|
quantiles, mean = pipe.predict_quantiles( |
|
|
context=ctx, |
|
|
prediction_length=int(pred_len), |
|
|
quantile_levels=q_levels, |
|
|
) |
|
|
|
|
|
q_np = quantiles[0].cpu().numpy() |
|
|
lo, med, hi = q_np[:, 0], q_np[:, 1], q_np[:, 2] |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
hist_x = np.arange(len(series)) |
|
|
fut_x = np.arange(len(series), len(series) + int(pred_len)) |
|
|
|
|
|
fig = plt.figure(figsize=(9, 4.5)) |
|
|
plt.plot(hist_x, series, label="history") |
|
|
plt.plot(fut_x, med, label="median forecast") |
|
|
plt.fill_between(fut_x, lo, hi, alpha=0.3, label="q10–q90 band") |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
st.pyplot(fig) |
|
|
|
|
|
out = pd.DataFrame({"t": fut_x, "q10": lo, "q50": med, "q90": hi}) |
|
|
st.download_button( |
|
|
"Download forecast CSV", |
|
|
out.to_csv(index=False).encode("utf-8"), |
|
|
file_name="chronos_forecast.csv", |
|
|
mime="text/csv", |
|
|
) |
|
|
else: |
|
|
st.info("Load a ticker, paste values, or upload a CSV to begin.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.expander("Train with Indicators (RSI, EMA, Stochastic)"): |
|
|
st.write("Fine-tune Chronos-Bolt on one ticker using indicator covariates (past-only).") |
|
|
tcol1, tcol2, tcol3 = st.columns([2, 1, 1]) |
|
|
with tcol1: |
|
|
ft_ticker = st.text_input("Ticker", "SPY") |
|
|
with tcol3: |
|
|
ft_interval = st.selectbox("Interval", ["1d", "60m", "30m", "15m"], index=0) |
|
|
|
|
|
|
|
|
if ft_interval == "1d": |
|
|
allowed_periods = ["6mo", "1y", "2y", "5y"] |
|
|
default_idx = 2 |
|
|
else: |
|
|
allowed_periods = ["5d", "30d", "60d"] |
|
|
default_idx = 1 |
|
|
with tcol2: |
|
|
ft_period = st.selectbox("Lookback", allowed_periods, index=default_idx) |
|
|
|
|
|
ft_steps = st.slider("Fine-tune steps", 100, 1500, 300, step=50) |
|
|
run_ft = st.button("Train fine-tuned model") |
|
|
|
|
|
if run_ft: |
|
|
with st.spinner("Downloading & computing indicators…"): |
|
|
import yfinance as yf |
|
|
from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame |
|
|
|
|
|
|
|
|
df = yf.download( |
|
|
ft_ticker.strip(), |
|
|
period=ft_period, |
|
|
interval=ft_interval, |
|
|
auto_adjust=True, |
|
|
progress=False, |
|
|
) |
|
|
|
|
|
if df.empty: |
|
|
alt_period = "60d" if ft_interval != "1d" else "1y" |
|
|
if alt_period != ft_period: |
|
|
df = yf.download( |
|
|
ft_ticker.strip(), |
|
|
period=alt_period, |
|
|
interval=ft_interval, |
|
|
auto_adjust=True, |
|
|
progress=False, |
|
|
) |
|
|
if df.empty: |
|
|
st.error("No data returned. Try a shorter lookback for intraday (e.g., 30d/60d) or use Interval=1d.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
freq_alias = {"1d": "B", "60m": "60min", "30m": "30min", "15m": "15min"}.get(ft_interval, "B") |
|
|
df.index = pd.DatetimeIndex(df.index).tz_localize(None) |
|
|
|
|
|
|
|
|
if isinstance(df.columns, pd.MultiIndex): |
|
|
try: |
|
|
sym = df.columns.get_level_values(1).unique()[0] |
|
|
df = df.xs(sym, axis=1, level=1) |
|
|
except Exception: |
|
|
|
|
|
df.columns = [c[0] for c in df.columns.to_flat_index()] |
|
|
|
|
|
|
|
|
df = df[["Close", "High", "Low"]].copy() |
|
|
|
|
|
|
|
|
for _c in ["Close", "High", "Low"]: |
|
|
if isinstance(df[_c], pd.DataFrame): |
|
|
df[_c] = df[_c].iloc[:, 0] |
|
|
df[_c] = pd.Series(np.asarray(df[_c]).reshape(-1), index=df.index) |
|
|
|
|
|
df = df.dropna() |
|
|
|
|
|
|
|
|
df["rsi14"] = rsi(df["Close"], 14) |
|
|
df["ema20"] = ema(df["Close"], 20) |
|
|
df["stoch_k"], df["stoch_d"] = stochastic_kd(df["High"], df["Low"], df["Close"], 14, 3, 3) |
|
|
|
|
|
df = df.dropna().astype("float32") |
|
|
if df.shape[0] < 200: |
|
|
st.warning("Very short history after indicators; results may be noisy.") |
|
|
|
|
|
|
|
|
ts = df[["Close", "rsi14", "ema20", "stoch_k", "stoch_d"]].copy() |
|
|
ts["item_id"] = ft_ticker.upper() |
|
|
ts["timestamp"] = ts.index |
|
|
ts = ts.rename(columns={"Close": "target"}) |
|
|
|
|
|
tsdf = TimeSeriesDataFrame.from_data_frame( |
|
|
ts, id_column="item_id", timestamp_column="timestamp" |
|
|
) |
|
|
|
|
|
try: |
|
|
tsdf = tsdf.convert_frequency(freq=freq_alias) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
with st.spinner("Fine-tuning Chronos-Bolt (small demo)…"): |
|
|
|
|
|
predictor = TimeSeriesPredictor( |
|
|
prediction_length=int(pred_len), |
|
|
eval_metric="WQL", |
|
|
quantile_levels=[0.1, 0.5, 0.9], |
|
|
freq=freq_alias, |
|
|
).fit( |
|
|
train_data=tsdf, |
|
|
enable_ensemble=False, |
|
|
time_limit=300, |
|
|
hyperparameters={ |
|
|
"Chronos": { |
|
|
"model_path": "bolt_mini", |
|
|
"fine_tune": True, |
|
|
"fine_tune_steps": int(ft_steps), |
|
|
"fine_tune_lr": 1e-5, |
|
|
} |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
preds = predictor.predict(tsdf) |
|
|
item = ft_ticker.upper() |
|
|
yhist = tsdf.loc[item]["target"].to_numpy() |
|
|
ypred = preds.loc[item] |
|
|
lo = ypred["0.1"].to_numpy() |
|
|
med = ypred["0.5"].to_numpy() |
|
|
hi = ypred["0.9"].to_numpy() |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
hx = np.arange(len(yhist)) |
|
|
fx = np.arange(len(yhist), len(yhist) + len(med)) |
|
|
|
|
|
fig = plt.figure(figsize=(9, 4.5)) |
|
|
plt.plot(hx, yhist, label="history") |
|
|
plt.plot(fx, med, label="median (fine-tuned)") |
|
|
plt.fill_between(fx, lo, hi, alpha=0.3, label="q10–q90") |
|
|
plt.legend(); plt.grid(True, alpha=0.3) |
|
|
st.pyplot(fig) |
|
|
|
|
|
out = pd.DataFrame({"t": fx, "q10": lo, "q50": med, "q90": hi}) |
|
|
st.download_button( |
|
|
"Download fine-tuned forecast CSV", |
|
|
out.to_csv(index=False).encode("utf-8"), |
|
|
file_name=f"{item}_chronos_finetuned.csv", |
|
|
mime="text/csv", |
|
|
) |