Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| AR-Diffusion Chat Interface for Hugging Face Spaces | |
| Experimental model with Quality vs Speed modes | |
| Optimized for Zero GPU deployment with @spaces.GPU | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import random | |
| import numpy as np | |
| import re | |
| import time | |
| from typing import List, Tuple | |
| import os | |
| import gc | |
| import spaces | |
| # Global model variables for memory efficiency | |
| tokenizer = None | |
| model = None | |
| device = None | |
| class ARDiffusionGenerator: | |
| """Base AR-Diffusion generator with shared functionality""" | |
| def __init__(self, tokenizer, model, device): | |
| self.tokenizer = tokenizer | |
| self.model = model | |
| self.device = device | |
| self.mask_token_id = self._find_mask_token() | |
| def _find_mask_token(self) -> int: | |
| """Find MASK token ID""" | |
| for candidate in ['MASK', '<mask>', '[MASK]', '<|mask|>']: | |
| try: | |
| tokens = self.tokenizer.encode(candidate, add_special_tokens=False) | |
| if len(tokens) == 1: | |
| return tokens[0] | |
| except: | |
| continue | |
| return getattr(self.tokenizer, 'unk_token_id', 50257) or 50257 | |
| def create_prompt(self, instruction: str) -> str: | |
| """Create Alpaca-style prompt""" | |
| return f"""### Instruction: | |
| {instruction} | |
| ### Response: | |
| """ | |
| class QualityGenerator(ARDiffusionGenerator): | |
| """Quality-focused AR-Diffusion generator""" | |
| def filter_logits(self, logits: torch.Tensor, top_k: int = 0, top_p: float = 1.0, | |
| temperature: float = 1.0) -> torch.Tensor: | |
| """Research-grade filtering with proper order""" | |
| original_shape = logits.shape | |
| if logits.dim() == 3: | |
| logits = logits.squeeze(0) | |
| elif logits.dim() == 1: | |
| logits = logits.unsqueeze(0) | |
| logits = logits.clone() | |
| # Temperature scaling first | |
| if temperature != 1.0: | |
| logits = logits / temperature | |
| # Top-k filtering | |
| if top_k > 0 and top_k < logits.size(-1): | |
| topk_vals, _ = torch.topk(logits, top_k, dim=-1) | |
| thresholds = topk_vals[:, -1].unsqueeze(-1) | |
| logits = torch.where(logits < thresholds, | |
| torch.full_like(logits, float("-inf")), logits) | |
| # Top-p filtering | |
| if top_p > 0.0 and top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
| probs = torch.softmax(sorted_logits, dim=-1) | |
| cum_probs = probs.cumsum(dim=-1) | |
| mask = cum_probs > top_p | |
| mask[:, 0] = False | |
| scatter_mask = torch.zeros_like(logits, dtype=torch.bool).scatter( | |
| dim=-1, index=sorted_indices, src=mask) | |
| logits = torch.where(scatter_mask, | |
| torch.full_like(logits, float("-inf")), logits) | |
| # Restore original shape | |
| if len(original_shape) == 1: | |
| logits = logits.squeeze(0) | |
| elif original_shape[0] == 1 and logits.dim() == 2: | |
| logits = logits.unsqueeze(0) | |
| return logits | |
| def generate_start(self, prompt: str, length: int = 8) -> List[int]: | |
| """Generate natural start""" | |
| tokens = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| input_ids = tokens['input_ids'][0] | |
| generated = [] | |
| current = input_ids.clone() | |
| with torch.no_grad(): | |
| for _ in range(length): | |
| outputs = self.model(input_ids=current.unsqueeze(0)) | |
| logits = outputs.logits[0, -1] | |
| filtered_logits = self.filter_logits( | |
| logits, top_k=50, top_p=0.9, temperature=0.8 | |
| ) | |
| probs = F.softmax(filtered_logits, dim=-1) | |
| next_token = torch.multinomial(probs, 1).item() | |
| if next_token in [self.tokenizer.eos_token_id, 128001, 13]: | |
| break | |
| generated.append(next_token) | |
| current = torch.cat([current, torch.tensor([next_token], device=self.device)]) | |
| return generated | |
| def create_sequence(self, prompt: str) -> Tuple[str, torch.Tensor]: | |
| """Create corrupted sequence for quality mode""" | |
| prompt_tokens = self.tokenizer(prompt, return_tensors="pt")['input_ids'][0] | |
| natural_start = self.generate_start(prompt, length=random.randint(8, 12)) | |
| # Longer sequences for better quality | |
| prompt_length = len(prompt_tokens) | |
| if prompt_length > 25: | |
| num_masks = random.randint(35, 50) | |
| elif prompt_length > 15: | |
| num_masks = random.randint(25, 40) | |
| else: | |
| num_masks = random.randint(20, 35) | |
| sequence = ( | |
| prompt_tokens.tolist() + | |
| natural_start + | |
| [self.mask_token_id] * num_masks + | |
| [13] | |
| ) | |
| tensor = torch.tensor(sequence) | |
| text = self.tokenizer.decode(tensor, skip_special_tokens=False) | |
| return text, tensor | |
| def generate(self, prompt: str, progress_callback=None) -> Tuple[str, dict]: | |
| """Quality generation with progress updates and speed tracking""" | |
| steps = 40 | |
| temperature = 0.7 | |
| start_time = time.time() | |
| if progress_callback: | |
| progress_callback(0.1, "Creating sequence...") | |
| full_prompt = self.create_prompt(prompt) | |
| corrupted_text, corrupted_ids = self.create_sequence(full_prompt) | |
| if progress_callback: | |
| progress_callback(0.2, "Starting quality denoising...") | |
| result, stats = self._denoise_quality(corrupted_ids, steps, temperature, progress_callback) | |
| # Calculate overall stats | |
| total_time = time.time() - start_time | |
| response = self._clean_response(result) | |
| word_count = len(response.split()) | |
| stats.update({ | |
| 'total_time': total_time, | |
| 'word_count': word_count, | |
| 'words_per_second': word_count / total_time if total_time > 0 else 0 | |
| }) | |
| return response, stats | |
| def _denoise_quality(self, corrupted_ids: torch.Tensor, steps: int, temperature: float, progress_callback=None) -> Tuple[str, dict]: | |
| """Quality denoising with progress updates and speed tracking""" | |
| current_ids = corrupted_ids.clone() | |
| total_replacements = 0 | |
| start_time = time.time() | |
| for step in range(steps): | |
| if progress_callback: | |
| progress = 0.2 + (step / steps) * 0.7 | |
| elapsed = time.time() - start_time | |
| tokens_per_sec = total_replacements / elapsed if elapsed > 0 else 0 | |
| progress_callback(progress, f"Quality step {step+1}/{steps} | {tokens_per_sec:.1f} tok/s") | |
| mask_positions = (current_ids == self.mask_token_id).nonzero(as_tuple=True)[0] | |
| if len(mask_positions) == 0: | |
| break | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids=current_ids.unsqueeze(0).to(self.device)) | |
| logits = outputs.logits[0] | |
| current_temp = max(0.4, temperature * (1 - step / steps)) | |
| # Conservative replacement for quality | |
| if step < steps // 4: | |
| max_replacements = min(1, len(mask_positions)) | |
| elif step < steps // 2: | |
| max_replacements = min(2, len(mask_positions)) | |
| else: | |
| max_replacements = min(3, len(mask_positions)) | |
| sorted_positions = sorted(mask_positions.tolist()) | |
| for pos in sorted_positions[:max_replacements]: | |
| if pos < len(logits): | |
| token_logits = logits[pos].clone() | |
| # Anti-repetition | |
| context_start = max(0, pos - 5) | |
| recent_tokens = set(current_ids[context_start:pos].tolist()) | |
| for recent_token in recent_tokens: | |
| if recent_token < len(token_logits): | |
| token_logits[recent_token] -= 8.0 | |
| # Quality filtering | |
| filtered_logits = self.filter_logits( | |
| token_logits, | |
| top_k=30, | |
| top_p=0.75, | |
| temperature=current_temp | |
| ) | |
| probs = F.softmax(filtered_logits, dim=-1) | |
| probs = torch.clamp(probs, min=1e-8, max=1.0) | |
| new_token = torch.multinomial(probs, 1).item() | |
| # Filter unwanted tokens | |
| unwanted = [self.mask_token_id, 128001, 128000] | |
| if new_token in unwanted: | |
| top_k_vals, top_k_indices = torch.topk(filtered_logits, 10) | |
| for alternative in top_k_indices: | |
| if alternative.item() not in unwanted: | |
| new_token = alternative.item() | |
| break | |
| current_ids[pos] = new_token | |
| total_replacements += 1 | |
| if progress_callback: | |
| elapsed = time.time() - start_time | |
| final_speed = total_replacements / elapsed if elapsed > 0 else 0 | |
| progress_callback(0.95, f"Finalizing... | Final speed: {final_speed:.1f} tok/s") | |
| # Calculate final statistics | |
| total_time = time.time() - start_time | |
| stats = { | |
| 'mode': 'Quality', | |
| 'steps': steps, | |
| 'tokens_replaced': total_replacements, | |
| 'generation_time': total_time, | |
| 'tokens_per_second': total_replacements / total_time if total_time > 0 else 0 | |
| } | |
| result = self.tokenizer.decode(current_ids, skip_special_tokens=True) | |
| return result, stats | |
| def _clean_response(self, text: str) -> str: | |
| """Clean response for quality output""" | |
| if "### Response:" in text: | |
| response = text.split("### Response:")[-1].strip() | |
| else: | |
| response = text.strip() | |
| if not response: | |
| return text | |
| # Quality cleaning | |
| response = re.sub(r"'{2,}", "", response) | |
| response = re.sub(r'"{2,}', "", response) | |
| response = re.sub(r"\.{2,}", ".", response) | |
| response = re.sub(r",{2,}", ",", response) | |
| response = re.sub(r"\s+", " ", response) | |
| # Remove artifacts | |
| response = re.sub(r"\$+", "", response) | |
| response = re.sub(r"#+", "", response) | |
| response = re.sub(r"@+", "", response) | |
| response = response.strip() | |
| if response and not response.endswith(('.', '!', '?')): | |
| response += "." | |
| return response | |
| class SpeedGenerator(ARDiffusionGenerator): | |
| """Speed-focused AR-Diffusion generator""" | |
| def filter_logits(self, logits: torch.Tensor, top_k: int = 15, top_p: float = 0.8, | |
| temperature: float = 1.0) -> torch.Tensor: | |
| """Fast logits filtering""" | |
| logits = logits.clone() | |
| if temperature != 1.0: | |
| logits = logits / temperature | |
| # Top-k filtering | |
| if top_k > 0 and top_k < logits.size(-1): | |
| topk_vals, _ = torch.topk(logits, top_k, dim=-1) | |
| threshold = topk_vals[-1] | |
| logits = torch.where(logits < threshold, torch.full_like(logits, float("-inf")), logits) | |
| # Top-p filtering | |
| if top_p > 0.0 and top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
| probs = torch.softmax(sorted_logits, dim=-1) | |
| cum_probs = probs.cumsum(dim=-1) | |
| mask = cum_probs > top_p | |
| mask[0] = False | |
| scatter_mask = torch.zeros_like(logits, dtype=torch.bool) | |
| scatter_mask.scatter_(0, sorted_indices, mask) | |
| logits = torch.where(scatter_mask, torch.full_like(logits, float("-inf")), logits) | |
| return logits | |
| def generate_start(self, prompt: str, length: int = 6) -> List[int]: | |
| """Generate natural start for speed mode""" | |
| tokens = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| input_ids = tokens['input_ids'][0] | |
| generated = [] | |
| current = input_ids.clone() | |
| with torch.no_grad(): | |
| for _ in range(length): | |
| outputs = self.model(input_ids=current.unsqueeze(0)) | |
| logits = outputs.logits[0, -1] | |
| filtered_logits = self.filter_logits(logits, top_k=20, top_p=0.9, temperature=0.8) | |
| probs = F.softmax(filtered_logits, dim=-1) | |
| next_token = torch.multinomial(probs, 1).item() | |
| if next_token in [self.tokenizer.eos_token_id, 128001, 13]: | |
| break | |
| generated.append(next_token) | |
| current = torch.cat([current, torch.tensor([next_token], device=self.device)]) | |
| return generated | |
| def create_sequence(self, prompt: str) -> Tuple[str, torch.Tensor]: | |
| """Create sequence optimized for speed""" | |
| prompt_tokens = self.tokenizer(prompt, return_tensors="pt")['input_ids'][0] | |
| natural_start = self.generate_start(prompt, length=6) | |
| # Shorter sequences for speed | |
| prompt_words = len(prompt.split()) | |
| if prompt_words > 8: | |
| num_masks = random.randint(15, 25) | |
| else: | |
| num_masks = random.randint(12, 20) | |
| sequence = ( | |
| prompt_tokens.tolist() + | |
| natural_start + | |
| [self.mask_token_id] * num_masks + | |
| [13] | |
| ) | |
| tensor = torch.tensor(sequence) | |
| text = self.tokenizer.decode(tensor, skip_special_tokens=False) | |
| return text, tensor | |
| def generate(self, prompt: str, progress_callback=None) -> Tuple[str, dict]: | |
| """Speed generation with progress updates and speed tracking""" | |
| steps = 10 | |
| temperature = 0.8 | |
| start_time = time.time() | |
| if progress_callback: | |
| progress_callback(0.1, "Creating sequence...") | |
| full_prompt = self.create_prompt(prompt) | |
| corrupted_text, corrupted_ids = self.create_sequence(full_prompt) | |
| if progress_callback: | |
| progress_callback(0.2, "Starting speed denoising...") | |
| result, stats = self._denoise_speed(corrupted_ids, steps, temperature, progress_callback) | |
| # Calculate overall stats | |
| total_time = time.time() - start_time | |
| response = self._clean_response(result) | |
| word_count = len(response.split()) | |
| stats.update({ | |
| 'total_time': total_time, | |
| 'word_count': word_count, | |
| 'words_per_second': word_count / total_time if total_time > 0 else 0 | |
| }) | |
| return response, stats | |
| def _denoise_speed(self, corrupted_ids: torch.Tensor, steps: int, temperature: float, progress_callback=None) -> Tuple[str, dict]: | |
| """Ultra-fast denoising with progress updates and speed tracking""" | |
| current_ids = corrupted_ids.clone() | |
| total_replacements = 0 | |
| start_time = time.time() | |
| # Use mixed precision for speed on GPU | |
| with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.device.type == 'cuda'): | |
| for step in range(steps): | |
| if progress_callback: | |
| progress = 0.2 + (step / steps) * 0.7 | |
| elapsed = time.time() - start_time | |
| tokens_per_sec = total_replacements / elapsed if elapsed > 0 else 0 | |
| progress_callback(progress, f"Speed step {step+1}/{steps} | {tokens_per_sec:.1f} tok/s") | |
| mask_pos = (current_ids == self.mask_token_id).nonzero(as_tuple=True)[0] | |
| if len(mask_pos) == 0: | |
| break | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids=current_ids.unsqueeze(0).to(self.device)) | |
| logits = outputs.logits[0] | |
| current_temp = temperature * (0.9 + 0.2 * (step / steps)) | |
| # Aggressive replacement for speed | |
| max_replace = min(8, len(mask_pos)) | |
| positions = sorted(mask_pos.tolist())[:max_replace] | |
| for pos in positions: | |
| if pos < len(logits): | |
| token_logits = logits[pos].clone() | |
| # Light anti-repetition | |
| recent_start = max(0, pos - 3) | |
| recent_tokens = set(current_ids[recent_start:pos].tolist()) | |
| for token in recent_tokens: | |
| if token < len(token_logits): | |
| token_logits[token] -= 3.0 | |
| # Fast filtering | |
| filtered_logits = self.filter_logits( | |
| token_logits, top_k=12, top_p=0.85, temperature=current_temp | |
| ) | |
| probs = F.softmax(filtered_logits, dim=-1) | |
| probs = torch.clamp(probs, min=1e-8, max=1.0) | |
| new_token = torch.multinomial(probs, 1).item() | |
| # Quick filtering | |
| if new_token in [self.mask_token_id, 128001, 128000]: | |
| top_vals, top_indices = torch.topk(filtered_logits, 3) | |
| new_token = top_indices[1].item() | |
| current_ids[pos] = new_token | |
| total_replacements += 1 | |
| if progress_callback: | |
| elapsed = time.time() - start_time | |
| final_speed = total_replacements / elapsed if elapsed > 0 else 0 | |
| progress_callback(0.95, f"Finalizing... | Final speed: {final_speed:.1f} tok/s") | |
| # Calculate final statistics | |
| total_time = time.time() - start_time | |
| stats = { | |
| 'mode': 'Speed', | |
| 'steps': steps, | |
| 'tokens_replaced': total_replacements, | |
| 'generation_time': total_time, | |
| 'tokens_per_second': total_replacements / total_time if total_time > 0 else 0 | |
| } | |
| result = self.tokenizer.decode(current_ids, skip_special_tokens=True) | |
| return result, stats | |
| def _clean_response(self, text: str) -> str: | |
| """Clean response for speed output""" | |
| if "### Response:" in text: | |
| response = text.split("### Response:")[-1].strip() | |
| else: | |
| response = text.strip() | |
| if not response: | |
| return text | |
| # Minimal cleaning for speed | |
| response = re.sub(r"'{3,}", "", response) | |
| response = re.sub(r'"{3,}', "", response) | |
| response = re.sub(r"\.{3,}", ".", response) | |
| response = re.sub(r",{3,}", ",", response) | |
| response = re.sub(r"\s+", " ", response) | |
| response = response.strip() | |
| if response and not response.endswith(('.', '!', '?')): | |
| response += "." | |
| return response | |
| def load_model(): | |
| """Load model with Zero GPU optimization using @spaces.GPU""" | |
| global tokenizer, model, device | |
| if tokenizer is not None and model is not None: | |
| return tokenizer, model, device | |
| # Get HF token from environment | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| print("🔑 HF_TOKEN found - using authenticated access") | |
| else: | |
| print("⚠️ No HF_TOKEN found - using public access only") | |
| try: | |
| # This appears to be a LoRA adapter | |
| adapter_path = "rootxhacker/llama-3B-diffusion-exp-fixed" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Loading AR-Diffusion model on {device}...") | |
| # Load tokenizer from adapter with token | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| adapter_path, | |
| trust_remote_code=True, | |
| token=hf_token | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load the adapter model with token | |
| print("Loading adapter model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| adapter_path, | |
| torch_dtype=torch.float16 if device.type == "cuda" else torch.float32, | |
| device_map="auto" if device.type == "cuda" else None, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| token=hf_token | |
| ) | |
| print("✅ AR-Diffusion model loaded successfully!") | |
| return tokenizer, model, device | |
| except Exception as e: | |
| print(f"❌ Error loading {adapter_path}: {e}") | |
| # Try alternative working models for AR-Diffusion demo | |
| print("🔄 Trying alternative models...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Try different models in order of preference | |
| alternative_models = [ | |
| "microsoft/DialoGPT-medium", | |
| "gpt2-large", | |
| "gpt2-medium", | |
| "distilgpt2" | |
| ] | |
| for alt_model in alternative_models: | |
| try: | |
| print(f"Trying {alt_model}...") | |
| tokenizer = AutoTokenizer.from_pretrained(alt_model, token=hf_token) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| alt_model, | |
| torch_dtype=torch.float16 if device.type == "cuda" else torch.float32, | |
| device_map="auto" if device.type == "cuda" else None, | |
| low_cpu_mem_usage=True, | |
| token=hf_token | |
| ) | |
| print(f"✅ Alternative model {alt_model} loaded successfully!") | |
| print("⚠️ Note: Using alternative model - AR-Diffusion features adapted for demo") | |
| return tokenizer, model, device | |
| except Exception as alt_e: | |
| print(f"❌ {alt_model} failed: {alt_e}") | |
| continue | |
| # Final fallback | |
| print("🔄 Using final fallback model...") | |
| tokenizer = AutoTokenizer.from_pretrained("distilgpt2") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "distilgpt2", | |
| torch_dtype=torch.float16 if device.type == "cuda" else torch.float32, | |
| device_map="auto" if device.type == "cuda" else None, | |
| low_cpu_mem_usage=True | |
| ) | |
| print("✅ Final fallback model loaded successfully!") | |
| print("⚠️ Note: Using basic model - AR-Diffusion features adapted for demo") | |
| return tokenizer, model, device | |
| def cleanup_memory(): | |
| """Clean up GPU memory""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def chat_function(message, history, mode, progress=gr.Progress()): | |
| """Main chat function with @spaces.GPU decorator, progress tracking, and speed display""" | |
| if not message.strip(): | |
| return history, "", "" | |
| try: | |
| # Load model (this will run on GPU when GPU is allocated) | |
| progress(0.05) | |
| tok, mod, dev = load_model() | |
| # Create appropriate generator | |
| if mode == "Quality (Slower, Better)": | |
| generator = QualityGenerator(tok, mod, dev) | |
| progress(0.1) | |
| else: | |
| generator = SpeedGenerator(tok, mod, dev) | |
| progress(0.1) | |
| # Generate response with progress callback | |
| def progress_callback(pct, status_msg): | |
| progress(pct) | |
| response, stats = generator.generate(message, progress_callback) | |
| progress(1.0) | |
| # Create performance info | |
| perf_info = f"""**⚡ Performance Stats:** | |
| - **Mode:** {stats['mode']} | |
| - **Generation Time:** {stats['generation_time']:.2f}s | |
| - **Tokens Replaced:** {stats['tokens_replaced']} | |
| - **Speed:** {stats['tokens_per_second']:.1f} tokens/sec | |
| - **Words Generated:** {stats['word_count']} words | |
| - **Words/Second:** {stats['words_per_second']:.1f} | |
| - **Steps:** {stats['steps']}""" | |
| # Update history with proper message format | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": response}) | |
| # Cleanup memory for Zero GPU efficiency | |
| cleanup_memory() | |
| return history, "", perf_info | |
| except Exception as e: | |
| error_msg = f"Error: {str(e)}" | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": error_msg}) | |
| cleanup_memory() | |
| return history, "", f"**❌ Error occurred during generation**" | |
| def clear_chat(): | |
| """Clear chat history and cleanup memory""" | |
| cleanup_memory() | |
| return [], "" | |
| # Create Gradio interface | |
| def create_interface(): | |
| with gr.Blocks( | |
| title="AR-Diffusion Chat - Experimental Model", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .warning-box { | |
| background-color: #fff3cd; | |
| border: 1px solid #ffeaa7; | |
| border-radius: 5px; | |
| padding: 10px; | |
| margin: 10px 0; | |
| } | |
| """ | |
| ) as interface: | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h1>🧪 AR-Diffusion Chat Interface</h1> | |
| <p><strong>⚠️ EXPERIMENTAL MODEL ⚠️</strong></p> | |
| <p>This is an experimental AR-Diffusion model. Results may vary and the model is still under development.</p> | |
| <p><em>🔥 Powered by Zero GPU with @spaces.GPU</em></p> | |
| <p><small>Model: rootxhacker/llama-3B-diffusion-exp-fixed (LoRA Adapter)</small></p> | |
| <p><small>🔑 Requires HF_TOKEN for gated model access</small></p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| [], | |
| elem_id="chatbot", | |
| height=500, | |
| show_label=False, | |
| type="messages" | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Type your message here...", | |
| show_label=False, | |
| scale=9 | |
| ) | |
| send_btn = gr.Button("Send", scale=1, variant="primary") | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| with gr.Column(scale=1): | |
| gr.HTML(""" | |
| <div class="warning-box"> | |
| <h3>⚙️ Mode Selection</h3> | |
| <p><strong>Quality Mode:</strong> Slower but more coherent responses (~40 steps)</p> | |
| <p><strong>Speed Mode:</strong> Faster responses with decent quality (~10 steps)</p> | |
| <p><em>🔥 GPU acceleration via @spaces.GPU</em></p> | |
| </div> | |
| """) | |
| mode = gr.Radio( | |
| choices=["Quality (Slower, Better)", "Speed (Faster)"], | |
| value="Quality (Slower, Better)", | |
| label="Generation Mode" | |
| ) | |
| # Performance display | |
| perf_display = gr.Markdown( | |
| "**⚡ Performance Stats:** *Generate a message to see stats*", | |
| elem_id="performance" | |
| ) | |
| gr.HTML(""" | |
| <div class="warning-box"> | |
| <h3>ℹ️ About AR-Diffusion</h3> | |
| <p>This experimental model uses autoregressive diffusion for text generation, creating responses by iteratively denoising masked tokens.</p> | |
| <br> | |
| <p><strong>Model:</strong> LoRA adapter trained for AR-Diffusion</p> | |
| <p><strong>Authentication:</strong> Requires HF_TOKEN for gated Llama model access</p> | |
| <p><strong>Note:</strong> This model is experimental and may produce unexpected results. If the specific model fails to load, alternative models will be used for demonstration.</p> | |
| </div> | |
| """) | |
| # Event handlers | |
| def submit_message(message, history, mode): | |
| return chat_function(message, history, mode) | |
| send_btn.click( | |
| submit_message, | |
| inputs=[msg, chatbot, mode], | |
| outputs=[chatbot, msg, perf_display] | |
| ) | |
| msg.submit( | |
| submit_message, | |
| inputs=[msg, chatbot, mode], | |
| outputs=[chatbot, msg, perf_display] | |
| ) | |
| clear_btn.click( | |
| clear_chat, | |
| outputs=[chatbot, perf_display] | |
| ) | |
| return interface | |
| # Launch interface | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.queue(max_size=20) # Important for Zero GPU | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) | |
| # Requirements.txt should include: | |
| # torch>=2.0.0 | |
| # transformers>=4.30.0 | |
| # gradio | |
| # numpy | |
| # accelerate | |
| # spaces | |
| # peft |