# app.py — ZeroGPU-optimised Gradio app (HF Spaces)
import os
import tempfile
from datetime import datetime
import gradio as gr
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# ---- ZeroGPU decorator ----
try:
import spaces # HF Spaces utility (provides @spaces.GPU())
except Exception:
# Fallback: make a no-op decorator so the app still runs locally/CPU
class _Noop:
def GPU(self, *args, **kwargs):
def deco(fn):
return fn
return deco
spaces = _Noop()
# ---- Optional quantisation (GPU only) ----
try:
from transformers import BitsAndBytesConfig
HAS_BNB = True
except Exception:
HAS_BNB = False
# ----------------------------
# Config
# ----------------------------
DEFAULT_MODELS = [
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Qwen/Qwen2.5-1.5B-Instruct",
"neovalle/tinyllama-1.1B-h4rmony-trained",
]
# Keep batches reasonable on ZeroGPU for low latency
MICROBATCH = 4
# Cap encoder length to avoid wasting time on very long inputs
MAX_INPUT_TOKENS = 1024
# Speed on GPU (TF32 gives extra throughput on Ampere+)
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
else:
# On CPU, reducing threads sometimes helps stability/predictability
try:
torch.set_num_threads(max(1, (os.cpu_count() or 4) // 2))
except Exception:
pass
_MODEL_CACHE = {} # cache: model_id -> (tokenizer, model)
# ----------------------------
# Helpers
# ----------------------------
def _all_eos_ids(tok):
"""Collect a few likely EOS ids so generation can stop earlier."""
ids = set()
if tok.eos_token_id is not None:
ids.add(tok.eos_token_id)
for t in ("<|im_end|>", "<|endoftext|>", ""):
try:
tid = tok.convert_tokens_to_ids(t)
if isinstance(tid, int) and tid >= 0:
ids.add(tid)
except Exception:
pass
return list(ids) if ids else None
def _load_model(model_id: str):
"""Load & cache model/tokenizer. On GPU, prefer 4-bit NF4 with BF16 compute."""
if model_id in _MODEL_CACHE:
return _MODEL_CACHE[model_id]
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
# Ensure a pad token for batch generate()
if tok.pad_token is None:
if tok.eos_token is not None:
tok.pad_token = tok.eos_token
else:
tok.add_special_tokens({"pad_token": "<|pad|>"})
use_gpu = torch.cuda.is_available()
dtype = (
torch.bfloat16 if (use_gpu and torch.cuda.is_bf16_supported()) else
(torch.float16 if use_gpu else torch.float32)
)
quant_cfg = None
if use_gpu and HAS_BNB:
quant_cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=(torch.bfloat16 if use_gpu else torch.float32),
low_cpu_mem_usage=True,
device_map="auto",
quantization_config=quant_cfg, # 4-bit on GPU if available; None on CPU
trust_remote_code=True, # helps for chat templates (e.g., Qwen)
# attn_implementation="flash_attention_2", # enable only if flash-attn in requirements
).eval()
# Resize if we added new pad token
if model.get_input_embeddings().num_embeddings != len(tok):
model.resize_token_embeddings(len(tok))
# Prefer KV cache
try:
model.generation_config.use_cache = True
except Exception:
pass
_MODEL_CACHE[model_id] = (tok, model)
return tok, model
def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
sys = (system_prompt or "").strip()
usr = (user_prompt or "").strip()
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
messages = []
if sys:
messages.append({"role": "system", "content": sys})
messages.append({"role": "user", "content": usr})
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
prefix = f"<>\n{sys}\n<>\n\n" if sys else ""
return f"{prefix}<>\n{usr}\n<>\n<>\n"
@torch.inference_mode()
def _generate_microbatch(tok, model, formatted_prompts, gen_kwargs):
"""Generate for a list of formatted prompts. Returns (texts, tokens_out)."""
device = model.device
eos_ids = _all_eos_ids(tok)
enc = tok(
formatted_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_INPUT_TOKENS,
).to(device)
prompt_lens = enc["attention_mask"].sum(dim=1)
outputs = model.generate(
**enc,
eos_token_id=eos_ids,
pad_token_id=tok.pad_token_id,
**gen_kwargs,
)
texts, toks_out = [], []
for i in range(outputs.size(0)):
start = int(prompt_lens[i].item())
gen_ids = outputs[i, start:]
texts.append(tok.decode(gen_ids, skip_special_tokens=True).strip())
toks_out.append(int(gen_ids.numel()))
return texts, toks_out
def generate_batch_df(
model_id: str,
system_prompt: str,
prompts_multiline: str,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
) -> pd.DataFrame:
tok, model = _load_model(model_id)
# Split user inputs
prompts = [p.strip() for p in prompts_multiline.splitlines() if p.strip()]
if not prompts:
return pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
formatted = [_format_prompt(tok, system_prompt, p) for p in prompts]
# Micro-batch multi-line input to keep latency low on ZeroGPU
B = MICROBATCH if len(formatted) > MICROBATCH else len(formatted)
# Greedy is fine (and fastest). If temp > 0, enable sampling knobs.
do_sample = bool(temperature > 0.0)
gen_kwargs = dict(
max_new_tokens=int(max_new_tokens),
do_sample=do_sample,
temperature=float(temperature) if do_sample else None,
top_p=float(top_p) if do_sample else None,
top_k=int(top_k) if (do_sample and int(top_k) > 0) else None,
repetition_penalty=float(repetition_penalty),
num_beams=1,
return_dict_in_generate=False,
use_cache=True,
)
all_texts, all_toks = [], []
for i in range(0, len(formatted), B):
batch_prompts = formatted[i : i + B]
texts, toks = _generate_microbatch(tok, model, batch_prompts, gen_kwargs)
all_texts.extend(texts)
all_toks.extend(toks)
return pd.DataFrame(
{"user_prompt": prompts, "response": all_texts, "tokens_out": all_toks}
)
def write_csv_path(df: pd.DataFrame) -> str:
ts = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
tmp = tempfile.NamedTemporaryFile(prefix=f"Output_{ts}_", suffix=".csv", delete=False, dir="/tmp")
df.to_csv(tmp.name, index=False)
return tmp.name
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks(title="Multi-Prompt Chat (ZeroGPU-optimised)") as demo:
gr.Markdown(
"""
# Multi-Prompt Chat to test system prompt effects (ZeroGPU-optimised)
Pick a small model, set a **system prompt**, and enter **multiple user prompts** (one per line).
Click **Generate** to get batched responses and a **downloadable CSV**.
"""
)
with gr.Row():
with gr.Column(scale=1):
model_id = gr.Dropdown(
choices=DEFAULT_MODELS,
value=DEFAULT_MODELS[0],
label="Model",
info="ZeroGPU attaches an H200 dynamically. 4-bit is used automatically on GPU.",
)
system_prompt = gr.Textbox(
label="System prompt",
placeholder="e.g., You are an ecolinguistics-aware assistant...",
lines=5,
)
prompts_multiline = gr.Textbox(
label="User prompts (one per line)",
placeholder="One query per line.\nExample:\nExplain transformers in simple terms\nGive 3 eco-friendly tips\nSummarise benefits of multilingual models",
lines=10,
)
with gr.Accordion("Generation settings", open=False):
max_new_tokens = gr.Slider(16, 1024, value=200, step=1, label="max_new_tokens")
temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="temperature (0 = greedy, fastest)")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p (used if temp > 0)")
top_k = gr.Slider(0, 200, value=40, step=1, label="top_k (0 disables; used if temp > 0)")
repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.01, label="repetition_penalty")
run_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
out_df = gr.Dataframe(
headers=["user_prompt", "response", "tokens_out"],
datatype=["str", "str", "number"],
label="Results",
wrap=True,
interactive=False,
row_count=(0, "dynamic"),
type="pandas",
)
csv_out = gr.File(label="CSV output", interactive=False, type="filepath")
# -------- Callback: GPU-decorated for ZeroGPU --------
@spaces.GPU() # <— This tells ZeroGPU to attach a GPU for this request
def _generate_cb(model_id, system_prompt, prompts_multiline,
max_new_tokens, temperature, top_p, top_k, repetition_penalty,
progress=gr.Progress(track_tqdm=True)):
progress(0.05, desc="Requesting ZeroGPU…")
df = generate_batch_df(
model_id=model_id,
system_prompt=system_prompt,
prompts_multiline=prompts_multiline,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
top_k=int(top_k),
repetition_penalty=float(repetition_penalty),
)
progress(0.95, desc="Preparing CSV…")
csv_path = write_csv_path(df)
progress(1.0, desc="Done")
return df, csv_path
run_btn.click(
_generate_cb,
inputs=[model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=[out_df, csv_out],
api_name="generate_batch",
)
if __name__ == "__main__":
demo.launch()