Reubencf's picture
Update app.py
baaf291 verified
raw
history blame
11.4 kB
# app.py β€” Hugging Face Space ready (LoRA adapter, Gradio compat)
# ---------------------------------------------------------------
# What changed vs your script
# - Removed ChatInterface args that broke on old Gradio (retry_btn, undo_btn)
# - No interactive input() for merging (Spaces are non-interactive). Use MERGE_LORA env var.
# - Secrets: read HF token from env (Settings β†’ Secrets β†’ HF_TOKEN), never hardcode.
# - Token passing works across transformers/peft versions (token/use_auth_token fallback).
# - Optional 8-bit via USE_8BIT=1 (GPU only). Safe CPU defaults.
# - Robust theme/queue/launch for mixed Gradio versions.
import os
import gc
import warnings
from typing import List, Tuple
import torch
import gradio as gr
warnings.filterwarnings("ignore")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
try:
from peft import PeftConfig, PeftModel
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
IMPORTS_OK = True
except Exception as e:
IMPORTS_OK = False
print(f"Missing dependencies: {e}")
print("Install: pip install --upgrade 'transformers>=4.41' peft accelerate gradio torch bitsandbytes")
# ── Configuration ──────────────────────────────────────────────────────────────
HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Settings β†’ Secrets β†’ HF_TOKEN
# LoRA adapter repo (must be compatible with BASE_MODEL_ID)
ADAPTER_ID = os.getenv("ADAPTER_ID", "Reubencf/gemma3-goan-finetuned")
# Base model used during fine-tuning (should match adapter's base)
BASE_MODEL_ID_DEFAULT = os.getenv("BASE_MODEL_ID", "google/gemma-3-4b-it")
# Quantization toggle (GPU only): set USE_8BIT=1 in Space variables
USE_8BIT = os.getenv("USE_8BIT", "0").lower() in {"1", "true", "yes", "y"}
# Merge LoRA into the base for faster inference: MERGE_LORA=1/0
MERGE_LORA = os.getenv("MERGE_LORA", "1").lower() in {"1", "true", "yes", "y"}
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TITLE = "🌴 Gemma Goan Q&A Bot"
DESCRIPTION_TMPL = (
"Gemma base model + LoRA adapter fine-tuned on a Goan Q&A dataset.\n"
"Ask about Goa, Konkani culture, or general topics!\n\n"
"**Status**: {}"
)
# ── Helpers ───────────────────────────────────────────────────────────────────
def call_with_token(fn, *args, **kwargs):
"""Call HF/Transformers/PEFT functions with token OR use_auth_token for
broad version compatibility."""
if HF_TOKEN:
try:
return fn(*args, token=HF_TOKEN, **kwargs)
except TypeError:
return fn(*args, use_auth_token=HF_TOKEN, **kwargs)
return fn(*args, **kwargs)
# ── Load model + tokenizer ─────────────────────────────────────────────────────
def load_model_and_tokenizer():
if not IMPORTS_OK:
raise ImportError("Required packages not installed.")
print("[Init] Starting model load…")
print(f"[Config] Device: {DEVICE}")
# GC + VRAM cleanup
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Step 1: Confirm base model from the adapter's config if possible
actual_base_model = BASE_MODEL_ID_DEFAULT
try:
print(f"[Load] Reading adapter config: {ADAPTER_ID}")
peft_cfg = call_with_token(PeftConfig.from_pretrained, ADAPTER_ID)
if getattr(peft_cfg, "base_model_name_or_path", None):
actual_base_model = peft_cfg.base_model_name_or_path
print(f"[Load] Adapter expects base model: {actual_base_model}")
else:
print("[Warn] Adapter did not expose base_model_name_or_path; using configured base.")
except Exception as e:
print(f"[Warn] Could not read adapter config ({e}); using configured base: {actual_base_model}")
# Step 2: Load base model (optionally quantized on GPU)
print(f"[Load] Loading base model: {actual_base_model}")
quant_cfg = None
if USE_8BIT and torch.cuda.is_available():
print("[Load] Enabling 8-bit quantization (bitsandbytes)")
quant_cfg = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16)
base_model = call_with_token(
AutoModelForCausalLM.from_pretrained,
actual_base_model,
trust_remote_code=True,
quantization_config=quant_cfg,
low_cpu_mem_usage=True,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
if DEVICE == "cpu" and not torch.cuda.is_available():
base_model = base_model.to("cpu")
print("[Load] Model on CPU")
print("[Load] Base model loaded βœ”")
# Step 3: Tokenizer
print("[Load] Loading tokenizer…")
tokenizer = call_with_token(
AutoTokenizer.from_pretrained,
actual_base_model,
use_fast=True,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
# Step 4: Apply LoRA adapter
status = ""
model = base_model
try:
print(f"[Load] Applying LoRA adapter: {ADAPTER_ID}")
model = call_with_token(PeftModel.from_pretrained, base_model, ADAPTER_ID)
if MERGE_LORA:
print("[Load] Merging adapter into base (merge_and_unload)…")
model = model.merge_and_unload()
status = f"βœ… Using fine-tuned model (merged): {ADAPTER_ID}"
else:
status = f"βœ… Using fine-tuned model via adapter: {ADAPTER_ID}"
except FileNotFoundError as e:
print(f"[Error] Adapter files not found: {e}")
status = f"⚠️ Adapter not found. Using base only: {actual_base_model}"
except Exception as e:
print(f"[Error] Failed to load adapter: {e}")
status = f"⚠️ Could not load adapter. Using base only: {actual_base_model}"
model.eval()
print(f"[Load] Model ready on {DEVICE} βœ”")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return model, tokenizer, status
# Global load at import time (Space-friendly)
try:
model, tokenizer, STATUS_MSG = load_model_and_tokenizer()
MODEL_LOADED = True
DESCRIPTION = DESCRIPTION_TMPL.format(STATUS_MSG)
except Exception as e:
print(f"[Fatal] Could not load model: {e}")
MODEL_LOADED = False
model = tokenizer = None
DESCRIPTION = DESCRIPTION_TMPL.format(f"❌ Model failed to load: {str(e)[:140]}")
# ── Generation ────────────────────────────────────────────────────────────────
def generate_response(
message: str,
history: List[Tuple[str, str]],
temperature: float = 0.7,
max_new_tokens: int = 256,
top_p: float = 0.95,
repetition_penalty: float = 1.1,
) -> str:
if not MODEL_LOADED:
return "⚠️ Model failed to load. Check Space logs."
try:
# Build short chat history
conversation = []
if history:
for u, a in history[-3:]:
if u:
conversation.append({"role": "user", "content": u})
if a:
conversation.append({"role": "assistant", "content": a})
conversation.append({"role": "user", "content": message})
# Try the tokenizer's chat template first
try:
input_ids = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt",
)
except Exception as e:
print(f"[Warn] chat_template failed: {e}; using manual format")
prompt_text = "".join(
[
("User: " + m["content"] + "\n") if m["role"] == "user" else ("Assistant: " + m["content"] + "\n")
for m in conversation
]
) + "Assistant: "
input_ids = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=1024).input_ids
input_ids = input_ids.to(model.device if hasattr(model, "device") else DEVICE)
with torch.no_grad():
out = model.generate(
input_ids=input_ids,
max_new_tokens=max(1, min(int(max_new_tokens), 512)),
temperature=float(temperature),
top_p=float(top_p),
repetition_penalty=float(repetition_penalty),
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
)
gen = out[0][input_ids.shape[-1]:]
text = tokenizer.decode(gen, skip_special_tokens=True).strip()
# Cleanup
del out, input_ids, gen
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return text or "(no output)"
except Exception as e:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return f"⚠️ Error generating response: {e}"
# ── UI ────────────────────────────────────────────────────────────────────────
examples = [
"What is the capital of Goa?",
"Tell me about the Konkani language.",
"What are famous beaches in Goa?",
"Describe Goan fish curry.",
"What is the history of Old Goa?",
]
# Best-effort theme across versions
try:
THEME = gr.themes.Soft()
except Exception:
THEME = None
if MODEL_LOADED:
demo = gr.ChatInterface(
fn=generate_response,
title=TITLE,
description=DESCRIPTION,
examples=examples,
additional_inputs=[
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Temperature"),
gr.Slider(minimum=32, maximum=512, value=256, step=16, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition penalty"),
],
theme=THEME,
)
else:
demo = gr.Interface(
fn=lambda x: "Model failed to load. Check Space logs.",
inputs=gr.Textbox(label="Message"),
outputs=gr.Textbox(label="Response"),
title=TITLE,
description=DESCRIPTION,
theme=THEME,
)
# Queue β€” keep params minimal for cross-version compat
try:
demo.queue()
except Exception:
pass
if __name__ == "__main__":
print("\n" + "=" * 60)
print(f"πŸš€ Starting Gradio app on {DEVICE} …")
print(f"πŸ“ Base model: {BASE_MODEL_ID_DEFAULT}")
print(f"πŸ”§ LoRA adapter: {ADAPTER_ID}")
print(f"🧩 Merge LoRA: {MERGE_LORA}")
print("=" * 60 + "\n")
# On Spaces, just calling launch() is fine.
demo.launch()