import gradio as gr import asyncio import queue import threading import time import os from typing import List, Dict, Optional, Generator, Tuple from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import torch import json from datetime import datetime class LlamaChat: def __init__(self): self.model_name = "meta-llama/Llama-3.2-3B-Instruct" self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tokenizer = None self.model = None self.request_queue = queue.Queue() self.is_processing = False self.current_streamer = None # Inicializar modelo self._load_model() # Iniciar worker thread para procesar colas self.worker_thread = threading.Thread(target=self._queue_worker, daemon=True) self.worker_thread.start() def _load_model(self): """Cargar el modelo y tokenizer con el token de HF""" try: hf_token = os.environ.get("HF_TOKEN") if not hf_token: raise ValueError("HF_TOKEN no encontrado en variables de entorno") print(f"Cargando modelo {self.model_name}...") self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, token=hf_token, trust_remote_code=True ) self.model = AutoModelForCausalLM.from_pretrained( self.model_name, token=hf_token, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map="auto" if self.device == "cuda" else None, trust_remote_code=True ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print("Modelo cargado exitosamente!") except Exception as e: print(f"Error cargando modelo: {e}") raise def _format_messages(self, system_prompt: str, message: str, history: List[List[str]]) -> str: """Formatear mensajes para Llama-3.2-Instruct""" messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) # Agregar historial for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": assistant_msg}) # Agregar mensaje actual messages.append({"role": "user", "content": message}) # Usar el chat template del tokenizer formatted_prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return formatted_prompt def _queue_worker(self): """Worker thread para procesar cola de requests""" while True: try: if not self.request_queue.empty(): request = self.request_queue.get() self.is_processing = True self._process_request(request) self.is_processing = False self.request_queue.task_done() else: time.sleep(0.1) except Exception as e: print(f"Error en queue worker: {e}") self.is_processing = False def _process_request(self, request: Dict): """Procesar una request individual""" try: system_prompt = request["system_prompt"] message = request["message"] history = request["history"] max_tokens = request.get("max_tokens", 512) temperature = request.get("temperature", 0.7) response_callback = request["callback"] # Formatear prompt formatted_prompt = self._format_messages(system_prompt, message, history) # Tokenizar inputs = self.tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=2048 ).to(self.device) # Configurar streamer streamer = TextIteratorStreamer( self.tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True ) self.current_streamer = streamer # Configurar parámetros de generación generation_kwargs = { **inputs, "max_new_tokens": max_tokens, "temperature": temperature, "do_sample": True, "pad_token_id": self.tokenizer.eos_token_id, "streamer": streamer, "repetition_penalty": 1.1 } # Generar en thread separado def generate(): with torch.no_grad(): self.model.generate(**generation_kwargs) generation_thread = threading.Thread(target=generate) generation_thread.start() # Stream respuesta full_response = "" for new_text in streamer: if new_text: full_response += new_text response_callback(full_response, False) response_callback(full_response, True) generation_thread.join() except Exception as e: print(f"Error procesando request: {e}") response_callback(f"Error: {str(e)}", True) finally: self.current_streamer = None def chat_stream(self, system_prompt: str, message: str, history: List[List[str]], max_tokens: int = 512, temperature: float = 0.7) -> Generator[Tuple[str, bool], None, None]: """Método principal para chatear con streaming""" if not message.strip(): yield "Por favor, escribe un mensaje.", True return # Crear evento para comunicación con el worker response_queue = queue.Queue() response_complete = threading.Event() current_response = [""] def response_callback(text: str, is_complete: bool): current_response[0] = text response_queue.put((text, is_complete)) if is_complete: response_complete.set() # Agregar request a la cola request = { "system_prompt": system_prompt or "", "message": message, "history": history or [], "max_tokens": max_tokens, "temperature": temperature, "callback": response_callback } self.request_queue.put(request) # Esperar y streamear respuesta while not response_complete.is_set(): try: text, is_complete = response_queue.get(timeout=0.1) yield text, is_complete if is_complete: break except queue.Empty: # Si no hay nuevos tokens, yield el último estado if current_response[0]: yield current_response[0], False continue def get_queue_status(self) -> Dict[str, any]: """Obtener estado de la cola""" return { "queue_size": self.request_queue.qsize(), "is_processing": self.is_processing, "timestamp": datetime.now().isoformat() } # Inicializar el chat chat_instance = LlamaChat() # Función para la interfaz de Gradio def chat_interface(message: str, history: List[List[str]], system_prompt: str, max_tokens: int, temperature: float): """Interfaz de chat para Gradio""" for response, is_complete in chat_instance.chat_stream( system_prompt, message, history, max_tokens, temperature ): if not is_complete: # Para Gradio, necesitamos devolver el historial completo new_history = history + [[message, response]] yield new_history, "" else: final_history = history + [[message, response]] yield final_history, "" # Función para API Python def api_chat(system_prompt: str = "", message: str = "", history: List[List[str]] = None, max_tokens: int = 512, temperature: float = 0.7) -> Dict: """API para cliente Python""" if history is None: history = [] full_response = "" for response, is_complete in chat_instance.chat_stream( system_prompt, message, history, max_tokens, temperature ): full_response = response if is_complete: break return { "response": full_response, "queue_status": chat_instance.get_queue_status() } # Función para streaming API def api_chat_stream(system_prompt: str = "", message: str = "", history: List[List[str]] = None, max_tokens: int = 512, temperature: float = 0.7): """API streaming para cliente Python""" if history is None: history = [] for response, is_complete in chat_instance.chat_stream( system_prompt, message, history, max_tokens, temperature ): yield { "response": response, "is_complete": is_complete, "queue_status": chat_instance.get_queue_status() } # Crear interfaz de Gradio with gr.Blocks(title="Llama 3.2 3B Chat", theme=gr.themes.Soft()) as app: gr.Markdown("# 🦙 Llama 3.2 3B Instruct Chat") gr.Markdown("Chat con Meta Llama 3.2 3B con sistema de colas y streaming") with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot(height=500, show_label=False) msg = gr.Textbox( label="Mensaje", placeholder="Escribe tu mensaje aquí...", lines=2 ) with gr.Row(): send_btn = gr.Button("Enviar", variant="primary") clear_btn = gr.Button("Limpiar") with gr.Column(scale=1): system_prompt = gr.Textbox( label="System Prompt", placeholder="Eres un asistente útil...", lines=5, value="Eres un asistente de IA útil y amigable. Responde de manera clara y concisa." ) max_tokens = gr.Slider( minimum=50, maximum=1024, value=512, step=50, label="Max Tokens" ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature" ) gr.Markdown("### Estado de la Cola") queue_status = gr.JSON(label="Queue Status", value={}) # Botón para actualizar estado refresh_btn = gr.Button("Actualizar Estado") # Event handlers def send_message(message, history, sys_prompt, max_tok, temp): if not message.strip(): return history, "" yield from chat_interface(message, history, sys_prompt, max_tok, temp) def clear_chat(): return [], "" def update_queue_status(): return chat_instance.get_queue_status() # Conectar eventos send_btn.click( send_message, inputs=[msg, chatbot, system_prompt, max_tokens, temperature], outputs=[chatbot, msg] ) msg.submit( send_message, inputs=[msg, chatbot, system_prompt, max_tokens, temperature], outputs=[chatbot, msg] ) clear_btn.click(clear_chat, outputs=[chatbot, msg]) refresh_btn.click(update_queue_status, outputs=[queue_status]) # Actualizar estado cada 5 segundos app.load(update_queue_status, outputs=[queue_status], every=5) # Crear API endpoints api_app = gr.Interface( fn=api_chat, inputs=[ gr.Textbox(label="System Prompt"), gr.Textbox(label="Message"), gr.JSON(label="History"), gr.Slider(50, 1024, 512, label="Max Tokens"), gr.Slider(0.1, 2.0, 0.7, label="Temperature") ], outputs=gr.JSON(label="Response"), title="Llama Chat API", description="API endpoint para cliente Python" ) # Combinar apps final_app = gr.TabbedInterface( [app, api_app], ["💬 Chat Interface", "🔌 API Endpoint"] ) if __name__ == "__main__": final_app.launch(server_name="0.0.0.0", server_port=7860, share=True)