neovalle's picture
Update app.py
774c640 verified
raw
history blame
7.37 kB
# app.py
import tempfile
from datetime import datetime
import gradio as gr
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# ----------------------------
# Config
# ----------------------------
DEFAULT_MODELS = [
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Qwen/Qwen2.5-1.5B-Instruct",
"neovalle/tinyllama-1.1B-h4rmony-trained",
]
_MODEL_CACHE = {} # cache: model_id -> (tokenizer, model)
# ----------------------------
# Utilities
# ----------------------------
def _load_model(model_id: str):
if model_id in _MODEL_CACHE:
return _MODEL_CACHE[model_id]
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
# Ensure pad token exists for generate()
if tok.pad_token is None:
if tok.eos_token is not None:
tok.pad_token = tok.eos_token
else:
tok.add_special_tokens({"pad_token": "<|pad|>"})
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
low_cpu_mem_usage=True,
device_map="auto",
)
if model.get_input_embeddings().num_embeddings != len(tok):
model.resize_token_embeddings(len(tok))
_MODEL_CACHE[model_id] = (tok, model)
return tok, model
def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
sys = (system_prompt or "").strip()
usr = (user_prompt or "").strip()
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
messages = []
if sys:
messages.append({"role": "system", "content": sys})
messages.append({"role": "user", "content": usr})
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
prefix = f"<<SYS>>\n{sys}\n<</SYS>>\n\n" if sys else ""
return f"{prefix}<<USER>>\n{usr}\n<</USER>>\n<<ASSISTANT>>\n"
def generate_batch_df(
model_id: str,
system_prompt: str,
prompts_multiline: str,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
) -> pd.DataFrame:
tok, model = _load_model(model_id)
device = model.device
prompts = [p.strip() for p in prompts_multiline.splitlines() if p.strip()]
if not prompts:
return pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
formatted = [_format_prompt(tok, system_prompt, p) for p in prompts]
enc = tok(
formatted,
return_tensors="pt",
padding=True,
truncation=True,
).to(device)
prompt_lens = enc["attention_mask"].sum(dim=1)
with torch.no_grad():
gen = model.generate(
**enc,
max_new_tokens=int(max_new_tokens),
do_sample=(temperature > 0.0),
temperature=float(temperature) if temperature > 0 else None,
top_p=float(top_p),
top_k=int(top_k) if int(top_k) > 0 else None,
repetition_penalty=float(repetition_penalty),
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
)
responses, tokens_out = [], []
for i in range(gen.size(0)):
start = int(prompt_lens[i].item())
gen_ids = gen[i, start:]
text = tok.decode(gen_ids, skip_special_tokens=True).strip()
responses.append(text)
tokens_out.append(len(gen_ids))
return pd.DataFrame(
{"user_prompt": prompts, "response": responses, "tokens_out": tokens_out}
)
def write_csv_path(df: pd.DataFrame) -> str:
ts = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
tmp = tempfile.NamedTemporaryFile(prefix=f"Output_{ts}_", suffix=".csv", delete=False, dir="/tmp")
df.to_csv(tmp.name, index=False)
return tmp.name
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
gr.Markdown(
"""
# Multi-Prompt Chat to test system prompt effects
Pick a small free model, set a **system prompt**, and enter **multiple user prompts** (one per line).
Click **Generate** to get batched responses and a **downloadable CSV**.
"""
)
with gr.Row():
with gr.Column(scale=1):
model_id = gr.Dropdown(
choices=DEFAULT_MODELS,
value=DEFAULT_MODELS[0],
label="Model",
info="Free, small instruction-tuned models that run on CPU and free HF Space",
)
system_prompt = gr.Textbox(
label="System prompt",
placeholder="e.g., You are an ecolinguistics-aware assistant that always prioritise planetary well-being over anthropocentrism.",
lines=5,
)
prompts_multiline = gr.Textbox(
label="User prompts (one per line)",
placeholder="One query per line.\nExample:\nExplain transformers in simple terms\nGive 3 eco-friendly tips for students\nSummarise the benefits of multilingual models",
lines=10,
)
with gr.Accordion("Generation settings", open=False):
max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="max_new_tokens")
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
top_k = gr.Slider(0, 200, value=40, step=1, label="top_k (0 disables)")
repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.01, label="repetition_penalty")
run_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
out_df = gr.Dataframe(
headers=["user_prompt", "response", "tokens_out"],
datatype=["str", "str", "number"],
label="Results",
wrap=True,
interactive=False,
row_count=(0, "dynamic"),
type="pandas", # ensures pandas goes into callbacks
)
# IMPORTANT: type="filepath" so we can return a string path
csv_out = gr.File(label="CSV output", interactive=False, type="filepath")
# -------- Callback: generate table AND CSV path in one go --------
def _generate_cb(model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
df = generate_batch_df(
model_id=model_id,
system_prompt=system_prompt,
prompts_multiline=prompts_multiline,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
top_k=int(top_k),
repetition_penalty=float(repetition_penalty),
)
csv_path = write_csv_path(df)
return df, csv_path # DataFrame to table, path to File(type="filepath")
run_btn.click(
_generate_cb,
inputs=[model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=[out_df, csv_out],
api_name="generate_batch",
)
if __name__ == "__main__":
demo.launch()