Spaces:
Running
Running
| # app.py β Corrected for proper LoRA adapter loading | |
| import os | |
| import gc | |
| import torch | |
| import gradio as gr | |
| from typing import List, Tuple | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| try: | |
| from peft import PeftConfig, PeftModel | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| IMPORTS_OK = True | |
| except ImportError as e: | |
| IMPORTS_OK = False | |
| print(f"Missing dependencies: {e}") | |
| print("Please install: pip install transformers peft torch gradio accelerate") | |
| # ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_TOKEN = os.environ.get("HF_TOKEN") # Optional for public models | |
| # Your LoRA adapter location (HuggingFace repo or local path) | |
| ADAPTER_ID = "Reubencf/gemma3-goan-finetuned" | |
| # For local adapter: ADAPTER_ID = "./path/to/your/adapter" | |
| # Base model - MUST match what you used for fine-tuning! | |
| # Check your adapter's config.json for "base_model_name_or_path" | |
| BASE_MODEL_ID = "google/gemma-2b-it" # Change this to your actual base model | |
| # Common options: | |
| # - "google/gemma-2b-it" (2B parameters, easier on memory) | |
| # - "unsloth/gemma-2-2b-it-bnb-4bit" (quantized version) | |
| # - Your actual base model used for training | |
| # Settings | |
| USE_8BIT = False # Set to True if you have GPU and want to use 8-bit quantization | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| TITLE = "π΄ Gemma Goan Q&A Bot" | |
| DESCRIPTION = """ | |
| Gemma base model + LoRA adapter fine-tuned on a Goan Q&A dataset. | |
| Ask about Goa, Konkani culture, or general topics! | |
| **Status**: {} | |
| """ | |
| # ββ Load model + tokenizer (correct LoRA loading) ββββββββββββββββββββββββββββββ | |
| def load_model_and_tokenizer(): | |
| """Load base model and apply LoRA adapter correctly""" | |
| if not IMPORTS_OK: | |
| raise ImportError("Required packages not installed") | |
| print("[Init] Starting model load...") | |
| print(f"[Config] Base model: {BASE_MODEL_ID}") | |
| print(f"[Config] LoRA adapter: {ADAPTER_ID}") | |
| print(f"[Config] Device: {DEVICE}") | |
| # Memory cleanup | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| status = "" | |
| model = None | |
| tokenizer = None | |
| try: | |
| # Step 1: Try to read adapter config to get the correct base model | |
| actual_base_model = BASE_MODEL_ID | |
| try: | |
| print(f"[Load] Checking adapter configuration...") | |
| peft_config = PeftConfig.from_pretrained(ADAPTER_ID, token=HF_TOKEN) | |
| actual_base_model = peft_config.base_model_name_or_path | |
| print(f"[Load] Adapter expects base model: {actual_base_model}") | |
| # Warn if mismatch | |
| if actual_base_model != BASE_MODEL_ID: | |
| print(f"[Warning] BASE_MODEL_ID ({BASE_MODEL_ID}) doesn't match adapter's base ({actual_base_model})") | |
| print(f"[Load] Using adapter's base model: {actual_base_model}") | |
| except Exception as e: | |
| print(f"[Warning] Cannot read adapter config: {e}") | |
| print(f"[Load] Will try with configured base model: {BASE_MODEL_ID}") | |
| actual_base_model = BASE_MODEL_ID | |
| # Step 2: Load the BASE MODEL (not the adapter!) | |
| print(f"[Load] Loading base model: {actual_base_model}") | |
| # Quantization config for GPU | |
| quantization_config = None | |
| if USE_8BIT and torch.cuda.is_available(): | |
| print("[Load] Using 8-bit quantization") | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| bnb_8bit_compute_dtype=torch.float16 | |
| ) | |
| # Load base model | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| actual_base_model, | |
| token=HF_TOKEN, | |
| trust_remote_code=True, | |
| quantization_config=quantization_config, | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch.float32 if DEVICE == "cpu" else torch.float16, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| # Move to device if needed | |
| if DEVICE == "cpu" and not torch.cuda.is_available(): | |
| base_model = base_model.to("cpu") | |
| print("[Load] Model on CPU") | |
| print("[Load] Base model loaded successfully") | |
| # Step 3: Load tokenizer from the BASE MODEL | |
| print(f"[Load] Loading tokenizer from base model...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| actual_base_model, | |
| token=HF_TOKEN, | |
| use_fast=True, | |
| trust_remote_code=True, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| # Step 4: Try to apply LoRA adapter | |
| try: | |
| print(f"[Load] Applying LoRA adapter: {ADAPTER_ID}") | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| ADAPTER_ID, | |
| token=HF_TOKEN, | |
| trust_remote_code=True, | |
| is_trainable=False, # Inference only | |
| ) | |
| # Optional: Merge adapter with base model for faster inference | |
| # This combines the weights permanently (uses more memory initially but faster inference) | |
| merge = input("\nπ‘ Merge adapter for faster inference? (y/n, default=y): ").strip().lower() | |
| if merge != 'n': | |
| print("[Load] Merging adapter with base model...") | |
| model = model.merge_and_unload() | |
| print("[Load] Adapter merged successfully") | |
| status = f"β Using fine-tuned model (merged): {ADAPTER_ID}" | |
| else: | |
| print("[Load] Using adapter without merging") | |
| status = f"β Using fine-tuned model: {ADAPTER_ID}" | |
| except FileNotFoundError as e: | |
| print(f"[Error] Adapter files not found: {e}") | |
| print("[Fallback] Using base model without fine-tuning") | |
| model = base_model | |
| status = f"β οΈ Adapter not found. Using base model only: {actual_base_model}" | |
| except Exception as e: | |
| print(f"[Error] Failed to load adapter: {e}") | |
| print("[Fallback] Using base model without fine-tuning") | |
| model = base_model | |
| status = f"β οΈ Could not load adapter. Using base model only: {actual_base_model}" | |
| # Step 5: Final setup | |
| model.eval() | |
| print(f"[Load] Model ready on {DEVICE}!") | |
| # Memory cleanup | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return model, tokenizer, status | |
| except Exception as e: | |
| error_msg = f"Failed to load model: {str(e)}" | |
| print(f"[Fatal] {error_msg}") | |
| # Try fallback to smallest model | |
| if "gemma-2b" not in BASE_MODEL_ID.lower(): | |
| print("[Fallback] Trying with gemma-2b-it...") | |
| try: | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| "google/gemma-2b-it", | |
| token=HF_TOKEN, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch.float32, | |
| device_map=None, | |
| ).to("cpu") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "google/gemma-2b-it", | |
| token=HF_TOKEN, | |
| trust_remote_code=True, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| base_model.eval() | |
| return base_model, tokenizer, "β οΈ Using fallback model: gemma-2b-it (no fine-tuning)" | |
| except Exception as fallback_error: | |
| print(f"[Fatal] Fallback also failed: {fallback_error}") | |
| raise gr.Error(f"Cannot load any model. Check your configuration.") | |
| else: | |
| raise gr.Error(error_msg) | |
| # Load model globally | |
| try: | |
| model, tokenizer, STATUS_MSG = load_model_and_tokenizer() | |
| MODEL_LOADED = True | |
| DESCRIPTION = DESCRIPTION.format(STATUS_MSG) | |
| except Exception as e: | |
| print(f"[Fatal] Could not load model: {e}") | |
| MODEL_LOADED = False | |
| model, tokenizer = None, None | |
| DESCRIPTION = DESCRIPTION.format(f"β Model failed to load: {str(e)[:100]}") | |
| # ββ Generation function βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_response( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| temperature: float = 0.7, | |
| max_new_tokens: int = 256, | |
| top_p: float = 0.95, | |
| repetition_penalty: float = 1.1, | |
| ) -> str: | |
| """Generate response using the model""" | |
| if not MODEL_LOADED: | |
| return "β οΈ Model failed to load. Please check the logs or restart the application." | |
| try: | |
| # Build conversation | |
| conversation = [] | |
| if history: | |
| # Keep last 3 exchanges for context | |
| for user_msg, assistant_msg in history[-3:]: | |
| if user_msg: | |
| conversation.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| conversation.append({"role": "assistant", "content": assistant_msg}) | |
| conversation.append({"role": "user", "content": message}) | |
| # Apply chat template | |
| try: | |
| prompt = tokenizer.apply_chat_template( | |
| conversation, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ) | |
| except Exception as e: | |
| print(f"[Warning] Chat template failed: {e}, using fallback format") | |
| # Fallback format | |
| prompt_text = "" | |
| for msg in conversation: | |
| if msg["role"] == "user": | |
| prompt_text += f"User: {msg['content']}\n" | |
| else: | |
| prompt_text += f"Assistant: {msg['content']}\n" | |
| prompt_text += "Assistant: " | |
| inputs = tokenizer( | |
| prompt_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| prompt = inputs.input_ids | |
| # Move to device | |
| prompt = prompt.to(model.device if hasattr(model, 'device') else DEVICE) | |
| # Generate | |
| print(f"[Generate] Input length: {prompt.shape[-1]} tokens") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=prompt, | |
| max_new_tokens=min(int(max_new_tokens), 256), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| repetition_penalty=float(repetition_penalty), | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| use_cache=True, | |
| ) | |
| # Decode only generated tokens | |
| generated_tokens = outputs[0][prompt.shape[-1]:] | |
| response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() | |
| print(f"[Generate] Output length: {len(generated_tokens)} tokens") | |
| # Cleanup | |
| del outputs, prompt, generated_tokens | |
| gc.collect() | |
| return response | |
| except Exception as e: | |
| error_msg = f"β οΈ Error generating response: {str(e)}" | |
| print(f"[Error] {error_msg}") | |
| # Try to recover memory | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return error_msg | |
| # ββ Gradio Interface ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| examples = [ | |
| ["What is the capital of Goa?"], | |
| ["Tell me about Konkani language"], | |
| ["What are famous beaches in Goa?"], | |
| ["Describe Goan fish curry"], | |
| ["What is the history of Old Goa?"], | |
| ] | |
| # Create interface | |
| if MODEL_LOADED: | |
| demo = gr.ChatInterface( | |
| fn=generate_response, | |
| title=TITLE, | |
| description=DESCRIPTION, | |
| examples=examples, | |
| retry_btn=None, | |
| undo_btn=None, | |
| additional_inputs=[ | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.05, | |
| label="Temperature (lower = more focused)" | |
| ), | |
| gr.Slider( | |
| minimum=32, | |
| maximum=256, | |
| value=128, | |
| step=16, | |
| label="Max new tokens" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)" | |
| ), | |
| gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.1, | |
| step=0.05, | |
| label="Repetition penalty" | |
| ), | |
| ], | |
| theme=gr.themes.Soft(), | |
| ) | |
| else: | |
| demo = gr.Interface( | |
| fn=lambda x: "Model failed to load. Check console for errors.", | |
| inputs=gr.Textbox(label="Message"), | |
| outputs=gr.Textbox(label="Response"), | |
| title=TITLE, | |
| description=DESCRIPTION, | |
| ) | |
| # Queue with version compatibility | |
| try: | |
| # Try newer Gradio syntax first (4.x) | |
| demo.queue(default_concurrency_limit=1, max_size=10) | |
| except TypeError: | |
| try: | |
| # Fall back to older syntax (3.x) | |
| demo.queue(concurrency_count=1, max_size=10) | |
| except: | |
| # If both fail, try without parameters | |
| demo.queue() | |
| if __name__ == "__main__": | |
| print("\n" + "="*50) | |
| print(f"π Starting Gradio app on {DEVICE}...") | |
| print(f"π Base model: {BASE_MODEL_ID}") | |
| print(f"π§ LoRA adapter: {ADAPTER_ID}") | |
| print("="*50 + "\n") | |
| demo.launch() |