Spaces:
Running
Running
| # app.py β Hugging Face Space ready (LoRA adapter, Gradio compat) | |
| # --------------------------------------------------------------- | |
| # What changed vs your script | |
| # - Removed ChatInterface args that broke on old Gradio (retry_btn, undo_btn) | |
| # - No interactive input() for merging (Spaces are non-interactive). Use MERGE_LORA env var. | |
| # - Secrets: read HF token from env (Settings β Secrets β HF_TOKEN), never hardcode. | |
| # - Token passing works across transformers/peft versions (token/use_auth_token fallback). | |
| # - Optional 8-bit via USE_8BIT=1 (GPU only). Safe CPU defaults. | |
| # - Robust theme/queue/launch for mixed Gradio versions. | |
| import os | |
| import gc | |
| import warnings | |
| from typing import List, Tuple | |
| import torch | |
| import gradio as gr | |
| warnings.filterwarnings("ignore") | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| try: | |
| from peft import PeftConfig, PeftModel | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| BitsAndBytesConfig, | |
| ) | |
| IMPORTS_OK = True | |
| except Exception as e: | |
| IMPORTS_OK = False | |
| print(f"Missing dependencies: {e}") | |
| print("Install: pip install --upgrade 'transformers>=4.41' peft accelerate gradio torch bitsandbytes") | |
| # ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Settings β Secrets β HF_TOKEN | |
| # LoRA adapter repo (must be compatible with BASE_MODEL_ID) | |
| ADAPTER_ID = os.getenv("ADAPTER_ID", "Reubencf/gemma3-goan-finetuned") | |
| # Base model used during fine-tuning (should match adapter's base) | |
| BASE_MODEL_ID_DEFAULT = os.getenv("BASE_MODEL_ID", "google/gemma-3-4b-it") | |
| # Quantization toggle (GPU only): set USE_8BIT=1 in Space variables | |
| USE_8BIT = os.getenv("USE_8BIT", "0").lower() in {"1", "true", "yes", "y"} | |
| # Merge LoRA into the base for faster inference: MERGE_LORA=1/0 | |
| MERGE_LORA = os.getenv("MERGE_LORA", "1").lower() in {"1", "true", "yes", "y"} | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| TITLE = "π΄ Gemma Goan Q&A Bot" | |
| DESCRIPTION_TMPL = ( | |
| "Gemma base model + LoRA adapter fine-tuned on a Goan Q&A dataset.\n" | |
| "Ask about Goa, Konkani culture, or general topics!\n\n" | |
| "**Status**: {}" | |
| ) | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def call_with_token(fn, *args, **kwargs): | |
| """Call HF/Transformers/PEFT functions with token OR use_auth_token for | |
| broad version compatibility.""" | |
| if HF_TOKEN: | |
| try: | |
| return fn(*args, token=HF_TOKEN, **kwargs) | |
| except TypeError: | |
| return fn(*args, use_auth_token=HF_TOKEN, **kwargs) | |
| return fn(*args, **kwargs) | |
| # ββ Load model + tokenizer βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_model_and_tokenizer(): | |
| if not IMPORTS_OK: | |
| raise ImportError("Required packages not installed.") | |
| print("[Init] Starting model loadβ¦") | |
| print(f"[Config] Device: {DEVICE}") | |
| # GC + VRAM cleanup | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Step 1: Confirm base model from the adapter's config if possible | |
| actual_base_model = BASE_MODEL_ID_DEFAULT | |
| try: | |
| print(f"[Load] Reading adapter config: {ADAPTER_ID}") | |
| peft_cfg = call_with_token(PeftConfig.from_pretrained, ADAPTER_ID) | |
| if getattr(peft_cfg, "base_model_name_or_path", None): | |
| actual_base_model = peft_cfg.base_model_name_or_path | |
| print(f"[Load] Adapter expects base model: {actual_base_model}") | |
| else: | |
| print("[Warn] Adapter did not expose base_model_name_or_path; using configured base.") | |
| except Exception as e: | |
| print(f"[Warn] Could not read adapter config ({e}); using configured base: {actual_base_model}") | |
| # Step 2: Load base model (optionally quantized on GPU) | |
| print(f"[Load] Loading base model: {actual_base_model}") | |
| quant_cfg = None | |
| if USE_8BIT and torch.cuda.is_available(): | |
| print("[Load] Enabling 8-bit quantization (bitsandbytes)") | |
| quant_cfg = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16) | |
| base_model = call_with_token( | |
| AutoModelForCausalLM.from_pretrained, | |
| actual_base_model, | |
| trust_remote_code=True, | |
| quantization_config=quant_cfg, | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| if DEVICE == "cpu" and not torch.cuda.is_available(): | |
| base_model = base_model.to("cpu") | |
| print("[Load] Model on CPU") | |
| print("[Load] Base model loaded β") | |
| # Step 3: Tokenizer | |
| print("[Load] Loading tokenizerβ¦") | |
| tokenizer = call_with_token( | |
| AutoTokenizer.from_pretrained, | |
| actual_base_model, | |
| use_fast=True, | |
| trust_remote_code=True, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| # Step 4: Apply LoRA adapter | |
| status = "" | |
| model = base_model | |
| try: | |
| print(f"[Load] Applying LoRA adapter: {ADAPTER_ID}") | |
| model = call_with_token(PeftModel.from_pretrained, base_model, ADAPTER_ID) | |
| if MERGE_LORA: | |
| print("[Load] Merging adapter into base (merge_and_unload)β¦") | |
| model = model.merge_and_unload() | |
| status = f"β Using fine-tuned model (merged): {ADAPTER_ID}" | |
| else: | |
| status = f"β Using fine-tuned model via adapter: {ADAPTER_ID}" | |
| except FileNotFoundError as e: | |
| print(f"[Error] Adapter files not found: {e}") | |
| status = f"β οΈ Adapter not found. Using base only: {actual_base_model}" | |
| except Exception as e: | |
| print(f"[Error] Failed to load adapter: {e}") | |
| status = f"β οΈ Could not load adapter. Using base only: {actual_base_model}" | |
| model.eval() | |
| print(f"[Load] Model ready on {DEVICE} β") | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return model, tokenizer, status | |
| # Global load at import time (Space-friendly) | |
| try: | |
| model, tokenizer, STATUS_MSG = load_model_and_tokenizer() | |
| MODEL_LOADED = True | |
| DESCRIPTION = DESCRIPTION_TMPL.format(STATUS_MSG) | |
| except Exception as e: | |
| print(f"[Fatal] Could not load model: {e}") | |
| MODEL_LOADED = False | |
| model = tokenizer = None | |
| DESCRIPTION = DESCRIPTION_TMPL.format(f"β Model failed to load: {str(e)[:140]}") | |
| # ββ Generation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_response( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| temperature: float = 0.7, | |
| max_new_tokens: int = 256, | |
| top_p: float = 0.95, | |
| repetition_penalty: float = 1.1, | |
| ) -> str: | |
| if not MODEL_LOADED: | |
| return "β οΈ Model failed to load. Check Space logs." | |
| try: | |
| # Build short chat history | |
| conversation = [] | |
| if history: | |
| for u, a in history[-3:]: | |
| if u: | |
| conversation.append({"role": "user", "content": u}) | |
| if a: | |
| conversation.append({"role": "assistant", "content": a}) | |
| conversation.append({"role": "user", "content": message}) | |
| # Try the tokenizer's chat template first | |
| try: | |
| input_ids = tokenizer.apply_chat_template( | |
| conversation, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ) | |
| except Exception as e: | |
| print(f"[Warn] chat_template failed: {e}; using manual format") | |
| prompt_text = "".join( | |
| [ | |
| ("User: " + m["content"] + "\n") if m["role"] == "user" else ("Assistant: " + m["content"] + "\n") | |
| for m in conversation | |
| ] | |
| ) + "Assistant: " | |
| input_ids = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=1024).input_ids | |
| input_ids = input_ids.to(model.device if hasattr(model, "device") else DEVICE) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| input_ids=input_ids, | |
| max_new_tokens=max(1, min(int(max_new_tokens), 512)), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| repetition_penalty=float(repetition_penalty), | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| use_cache=True, | |
| ) | |
| gen = out[0][input_ids.shape[-1]:] | |
| text = tokenizer.decode(gen, skip_special_tokens=True).strip() | |
| # Cleanup | |
| del out, input_ids, gen | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return text or "(no output)" | |
| except Exception as e: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return f"β οΈ Error generating response: {e}" | |
| # ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| examples = [ | |
| "What is the capital of Goa?", | |
| "Tell me about the Konkani language.", | |
| "What are famous beaches in Goa?", | |
| "Describe Goan fish curry.", | |
| "What is the history of Old Goa?", | |
| ] | |
| # Best-effort theme across versions | |
| try: | |
| THEME = gr.themes.Soft() | |
| except Exception: | |
| THEME = None | |
| if MODEL_LOADED: | |
| demo = gr.ChatInterface( | |
| fn=generate_response, | |
| title=TITLE, | |
| description=DESCRIPTION, | |
| examples=examples, | |
| additional_inputs=[ | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Temperature"), | |
| gr.Slider(minimum=32, maximum=512, value=256, step=16, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"), | |
| gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition penalty"), | |
| ], | |
| theme=THEME, | |
| ) | |
| else: | |
| demo = gr.Interface( | |
| fn=lambda x: "Model failed to load. Check Space logs.", | |
| inputs=gr.Textbox(label="Message"), | |
| outputs=gr.Textbox(label="Response"), | |
| title=TITLE, | |
| description=DESCRIPTION, | |
| theme=THEME, | |
| ) | |
| # Queue β keep params minimal for cross-version compat | |
| try: | |
| demo.queue() | |
| except Exception: | |
| pass | |
| if __name__ == "__main__": | |
| print("\n" + "=" * 60) | |
| print(f"π Starting Gradio app on {DEVICE} β¦") | |
| print(f"π Base model: {BASE_MODEL_ID_DEFAULT}") | |
| print(f"π§ LoRA adapter: {ADAPTER_ID}") | |
| print(f"π§© Merge LoRA: {MERGE_LORA}") | |
| print("=" * 60 + "\n") | |
| # On Spaces, just calling launch() is fine. | |
| demo.launch() |