michaellupo74's picture
Update app.py
f41d71c verified
import numpy as np
import pandas as pd
import streamlit as st
import torch
from chronos import BaseChronosPipeline
st.set_page_config(page_title="Chronos-Bolt Zero-Shot Forecast", layout="centered")
st.title("Chronos-Bolt Zero-Shot Forecast")
st.caption("Zero-shot probabilistic forecasting (q10/q50/q90) using amazon/chronos-bolt-* models.")
# -------------------- Indicator helpers (no pandas-ta needed) --------------------
def ema(series, length=20):
s = pd.Series(series).astype("float64")
return s.ewm(span=length, adjust=False).mean()
def rsi(series, length=14):
s = pd.Series(series).astype("float64")
delta = s.diff()
gain = delta.clip(lower=0).ewm(alpha=1/length, adjust=False).mean()
loss = (-delta.clip(upper=0)).ewm(alpha=1/length, adjust=False).mean()
rs = gain / loss.replace(0, np.nan)
return 100 - (100 / (1 + rs))
def stochastic_kd(high, low, close, k=14, d=3, smooth_k=3):
h = pd.Series(high).astype("float64")
l = pd.Series(low).astype("float64")
c = pd.Series(close).astype("float64")
hh = h.rolling(k).max()
ll = l.rolling(k).min()
raw_k = 100 * (c - ll) / (hh - ll)
k_smoothed = raw_k.rolling(smooth_k).mean()
d_line = k_smoothed.rolling(d).mean()
return k_smoothed, d_line
# -------------------- Model options --------------------
MODEL_CHOICES = {
"Bolt Mini (CPU-friendly)": "amazon/chronos-bolt-mini",
"Bolt Small (better; GPU if available)": "amazon/chronos-bolt-small",
}
@st.cache_resource(show_spinner=True)
def load_pipeline(model_id: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
return BaseChronosPipeline.from_pretrained(model_id, device_map=device, torch_dtype=dtype)
# -------------------- Data loaders (always return 1-D) --------------------
def _force_1d(a):
a = pd.Series(a, dtype="float32").replace([np.inf, -np.inf], np.nan).dropna()
return a.to_numpy().reshape(-1)
@st.cache_data(show_spinner=False)
def load_ticker_series(ticker: str, period: str = "2y"):
import yfinance as yf
df = yf.download(ticker, period=period, interval="1d", auto_adjust=True, progress=False)
if df.empty:
return np.asarray([], dtype="float32")
close = df["Close"]
if isinstance(close, pd.DataFrame): # handle rare multi-index cases
close = close.iloc[:, 0]
return _force_1d(close)
def parse_pasted_series(txt: str):
import re
toks = re.split(r"[,\s]+", txt.strip())
vals = []
for t in toks:
if not t:
continue
try:
vals.append(float(t))
except:
pass
return _force_1d(vals)
def load_csv_series(file, column=None):
df = pd.read_csv(file)
if column is None:
num_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.number)]
column = num_cols[0] if num_cols else None
if column is None:
return np.asarray([], dtype="float32"), df, None
return _force_1d(df[column]), df, column
# -------------------- UI --------------------
c1, c2 = st.columns(2)
with c1:
model_label = st.selectbox("Model", list(MODEL_CHOICES.keys()), index=0)
with c2:
pred_len = st.number_input("Prediction length (steps)", 1, 365, 30)
src = st.radio("Data source", ["Ticker (yfinance)", "Paste numbers", "Upload CSV"], horizontal=True)
series = None
if src == "Ticker (yfinance)":
t1, t2 = st.columns([2, 1])
with t1:
ticker = st.text_input("Ticker (e.g., AAPL, SPY, BTC-USD)", "AAPL")
with t2:
period = st.selectbox("History window", ["6mo", "1y", "2y", "5y"], index=2)
if st.button("Load data"):
series = load_ticker_series(ticker.strip(), period)
if series.size == 0:
st.error("No data returned. Try another ticker/window.")
elif src == "Paste numbers":
txt = st.text_area("One value per line (or comma/space separated)", "1\n2\n3\n4\n5\n6\n7\n8\n9\n10")
if st.button("Use pasted data"):
series = parse_pasted_series(txt)
else:
uploaded = st.file_uploader("Upload CSV", type=["csv"])
if uploaded is not None:
df = pd.read_csv(uploaded)
numeric_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.number)]
col = st.selectbox("Pick numeric column", numeric_cols) if numeric_cols else None
if st.button("Load CSV column") and col:
series, _, _ = load_csv_series(uploaded, column=col)
elif uploaded and not numeric_cols:
st.error("No numeric columns found in CSV.")
# -------------------- Plot + Forecast --------------------
if series is not None and series.size > 5:
st.write(f"Loaded {series.size} points.")
st.line_chart(pd.DataFrame(series, columns=["value"])) # always 1-D -> no error
if st.button("Forecast"):
with st.spinner("Running Chronos-Bolt..."):
pipe = load_pipeline(MODEL_CHOICES[model_label])
ctx = torch.tensor(series, dtype=torch.float32)
q_levels = [0.10, 0.50, 0.90]
quantiles, mean = pipe.predict_quantiles(
context=ctx,
prediction_length=int(pred_len),
quantile_levels=q_levels,
)
q_np = quantiles[0].cpu().numpy() # shape [pred_len, 3]
lo, med, hi = q_np[:, 0], q_np[:, 1], q_np[:, 2]
import matplotlib.pyplot as plt
hist_x = np.arange(len(series))
fut_x = np.arange(len(series), len(series) + int(pred_len))
fig = plt.figure(figsize=(9, 4.5))
plt.plot(hist_x, series, label="history")
plt.plot(fut_x, med, label="median forecast")
plt.fill_between(fut_x, lo, hi, alpha=0.3, label="q10–q90 band")
plt.legend()
plt.grid(True, alpha=0.3)
st.pyplot(fig)
out = pd.DataFrame({"t": fut_x, "q10": lo, "q50": med, "q90": hi})
st.download_button(
"Download forecast CSV",
out.to_csv(index=False).encode("utf-8"),
file_name="chronos_forecast.csv",
mime="text/csv",
)
else:
st.info("Load a ticker, paste values, or upload a CSV to begin.")
# ================================
# Train with RSI / EMA / Stochastic (AutoGluon) — no pandas-ta
# ================================
with st.expander("Train with Indicators (RSI, EMA, Stochastic)"):
st.write("Fine-tune Chronos-Bolt on one ticker using indicator covariates (past-only).")
tcol1, tcol2, tcol3 = st.columns([2, 1, 1])
with tcol1:
ft_ticker = st.text_input("Ticker", "SPY")
with tcol3:
ft_interval = st.selectbox("Interval", ["1d", "60m", "30m", "15m"], index=0)
# Allowed lookbacks depend on interval
if ft_interval == "1d":
allowed_periods = ["6mo", "1y", "2y", "5y"]
default_idx = 2
else:
allowed_periods = ["5d", "30d", "60d"]
default_idx = 1
with tcol2:
ft_period = st.selectbox("Lookback", allowed_periods, index=default_idx)
ft_steps = st.slider("Fine-tune steps", 100, 1500, 300, step=50)
run_ft = st.button("Train fine-tuned model")
if run_ft:
with st.spinner("Downloading & computing indicators…"):
import yfinance as yf
from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame
# 1) Load OHLC so we can compute Stochastic (needs High/Low/Close)
df = yf.download(
ft_ticker.strip(),
period=ft_period,
interval=ft_interval,
auto_adjust=True,
progress=False,
)
# Fallback: if the chosen combo is too long for intraday, clamp and retry
if df.empty:
alt_period = "60d" if ft_interval != "1d" else "1y"
if alt_period != ft_period:
df = yf.download(
ft_ticker.strip(),
period=alt_period,
interval=ft_interval,
auto_adjust=True,
progress=False,
)
if df.empty:
st.error("No data returned. Try a shorter lookback for intraday (e.g., 30d/60d) or use Interval=1d.")
st.stop()
# Determine frequency alias for AutoGluon and ensure tz-naive index
freq_alias = {"1d": "B", "60m": "60min", "30m": "30min", "15m": "15min"}.get(ft_interval, "B")
df.index = pd.DatetimeIndex(df.index).tz_localize(None)
# Handle MultiIndex columns (yfinance can return 2-level columns)
if isinstance(df.columns, pd.MultiIndex):
try:
sym = df.columns.get_level_values(1).unique()[0]
df = df.xs(sym, axis=1, level=1)
except Exception:
# Fallback: flatten by taking the top-level name (Close/High/Low)
df.columns = [c[0] for c in df.columns.to_flat_index()]
# Keep only needed cols
df = df[["Close", "High", "Low"]].copy()
# Ensure each column is 1-D (avoid (N,1) arrays)
for _c in ["Close", "High", "Low"]:
if isinstance(df[_c], pd.DataFrame):
df[_c] = df[_c].iloc[:, 0]
df[_c] = pd.Series(np.asarray(df[_c]).reshape(-1), index=df.index)
df = df.dropna()
# 2) Indicators (helpers above)
df["rsi14"] = rsi(df["Close"], 14)
df["ema20"] = ema(df["Close"], 20)
df["stoch_k"], df["stoch_d"] = stochastic_kd(df["High"], df["Low"], df["Close"], 14, 3, 3)
df = df.dropna().astype("float32")
if df.shape[0] < 200:
st.warning("Very short history after indicators; results may be noisy.")
# 3) Build TimeSeriesDataFrame (target + past covariates)
ts = df[["Close", "rsi14", "ema20", "stoch_k", "stoch_d"]].copy()
ts["item_id"] = ft_ticker.upper()
ts["timestamp"] = ts.index
ts = ts.rename(columns={"Close": "target"})
tsdf = TimeSeriesDataFrame.from_data_frame(
ts, id_column="item_id", timestamp_column="timestamp"
)
# Ensure a regular time grid for AutoGluon
try:
tsdf = tsdf.convert_frequency(freq=freq_alias)
except Exception:
pass
with st.spinner("Fine-tuning Chronos-Bolt (small demo)…"):
# Chronos-Bolt preset via hyperparameters; fine_tune on CPU is OK for small steps
predictor = TimeSeriesPredictor(
prediction_length=int(pred_len), # reuse your UI's pred_len
eval_metric="WQL",
quantile_levels=[0.1, 0.5, 0.9],
freq=freq_alias,
).fit(
train_data=tsdf,
enable_ensemble=False,
time_limit=300, # small demo budget; increase offline/GPU
hyperparameters={
"Chronos": {
"model_path": "bolt_mini", # CPU-friendly; try 'bolt_small' on GPU
"fine_tune": True,
"fine_tune_steps": int(ft_steps),
"fine_tune_lr": 1e-5,
}
},
)
# 4) Forecast with the fine-tuned model
preds = predictor.predict(tsdf) # AG starts at series end
item = ft_ticker.upper()
yhist = tsdf.loc[item]["target"].to_numpy()
ypred = preds.loc[item] # MultiIndex -> rows for horizon
lo = ypred["0.1"].to_numpy()
med = ypred["0.5"].to_numpy()
hi = ypred["0.9"].to_numpy()
import matplotlib.pyplot as plt
hx = np.arange(len(yhist))
fx = np.arange(len(yhist), len(yhist) + len(med))
fig = plt.figure(figsize=(9, 4.5))
plt.plot(hx, yhist, label="history")
plt.plot(fx, med, label="median (fine-tuned)")
plt.fill_between(fx, lo, hi, alpha=0.3, label="q10–q90")
plt.legend(); plt.grid(True, alpha=0.3)
st.pyplot(fig)
out = pd.DataFrame({"t": fx, "q10": lo, "q50": med, "q90": hi})
st.download_button(
"Download fine-tuned forecast CSV",
out.to_csv(index=False).encode("utf-8"),
file_name=f"{item}_chronos_finetuned.csv",
mime="text/csv",
)