Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Unified web chat server - serves both UI and API from a single FastAPI instance. | |
| Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads | |
| a full copy of the model, and incoming requests are distributed to available workers. | |
| Automatically falls back to CPU if CUDA is not available. | |
| Launch examples: | |
| - single available GPU (default) | |
| python -m scripts.chat_web | |
| - 4 GPUs | |
| python -m scripts.chat_web --num-gpus 4 | |
| - CPU only (automatic if no CUDA) | |
| python -m scripts.chat_web | |
| To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP) | |
| Endpoints: | |
| GET / - Chat UI | |
| POST /chat/completions - Chat API (streaming only) | |
| GET /health - Health check with worker pool status | |
| GET /stats - Worker pool statistics and GPU utilization | |
| Abuse Prevention: | |
| - Maximum 500 messages per request | |
| - Maximum 8000 characters per message | |
| - Maximum 32000 characters total conversation length | |
| - Temperature clamped to 0.0-2.0 | |
| - Top-k clamped to 1-200 | |
| - Max tokens clamped to 1-4096 | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import torch | |
| import asyncio | |
| import logging | |
| import random | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse | |
| from pydantic import BaseModel | |
| from typing import List, Optional, AsyncGenerator | |
| from dataclasses import dataclass | |
| from nanochat.common import compute_init | |
| from nanochat.checkpoint_manager import load_model | |
| from nanochat.engine import Engine | |
| # Abuse prevention limits | |
| MAX_MESSAGES_PER_REQUEST = 500 | |
| MAX_MESSAGE_LENGTH = 8000 | |
| MAX_TOTAL_CONVERSATION_LENGTH = 32000 | |
| MIN_TEMPERATURE = 0.0 | |
| MAX_TEMPERATURE = 2.0 | |
| MIN_TOP_K = 1 | |
| MAX_TOP_K = 200 | |
| MIN_MAX_TOKENS = 1 | |
| MAX_MAX_TOKENS = 4096 | |
| parser = argparse.ArgumentParser(description='NanoChat Web Server') | |
| parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (ignored on CPU, default: 1)') | |
| parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") | |
| parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation') | |
| parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter') | |
| parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation') | |
| parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') | |
| parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') | |
| parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') | |
| parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') | |
| args = parser.parse_args() | |
| # Configure logging for conversation traffic | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() | |
| class Worker: | |
| """A worker with a model loaded on a specific device (GPU or CPU).""" | |
| worker_id: int | |
| device: torch.device | |
| engine: Engine | |
| tokenizer: object | |
| autocast_ctx: torch.amp.autocast | |
| class WorkerPool: | |
| """Pool of workers, each with a model replica on a different device.""" | |
| def __init__(self, num_workers: Optional[int] = None): | |
| # Auto-detect: use GPUs if available, otherwise use 1 CPU worker | |
| if torch.cuda.is_available(): | |
| self.num_workers = num_workers if num_workers is not None else torch.cuda.device_count() | |
| self.use_cuda = True | |
| else: | |
| self.num_workers = 1 # CPU mode - single worker | |
| self.use_cuda = False | |
| self.workers: List[Worker] = [] | |
| self.available_workers: asyncio.Queue = asyncio.Queue() | |
| async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None): | |
| """Load model on each device.""" | |
| device_type = "GPU" if self.use_cuda else "CPU" | |
| print(f"Initializing worker pool with {self.num_workers} {device_type} worker(s)...") | |
| for worker_id in range(self.num_workers): | |
| if self.use_cuda: | |
| device = torch.device(f"cuda:{worker_id}") | |
| print(f"Loading model on GPU {worker_id}...") | |
| autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) | |
| else: | |
| device = torch.device("cpu") | |
| print(f"Loading model on CPU...") | |
| # CPU autocast uses bfloat16 if available, otherwise float32 | |
| autocast_ctx = torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16) | |
| model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) | |
| engine = Engine(model, tokenizer) | |
| worker = Worker( | |
| worker_id=worker_id, | |
| device=device, | |
| engine=engine, | |
| tokenizer=tokenizer, | |
| autocast_ctx=autocast_ctx | |
| ) | |
| self.workers.append(worker) | |
| await self.available_workers.put(worker) | |
| print(f"All {self.num_workers} worker(s) initialized!") | |
| async def acquire_worker(self) -> Worker: | |
| """Get an available worker from the pool.""" | |
| return await self.available_workers.get() | |
| async def release_worker(self, worker: Worker): | |
| """Return a worker to the pool.""" | |
| await self.available_workers.put(worker) | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| class ChatRequest(BaseModel): | |
| messages: List[ChatMessage] | |
| temperature: Optional[float] = None | |
| max_tokens: Optional[int] = None | |
| top_k: Optional[int] = None | |
| def validate_chat_request(request: ChatRequest): | |
| """Validate chat request to prevent abuse.""" | |
| # Check number of messages | |
| if len(request.messages) == 0: | |
| raise HTTPException(status_code=400, detail="At least one message is required") | |
| if len(request.messages) > MAX_MESSAGES_PER_REQUEST: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request" | |
| ) | |
| # Check individual message lengths and total conversation length | |
| total_length = 0 | |
| for i, message in enumerate(request.messages): | |
| if not message.content: | |
| raise HTTPException(status_code=400, detail=f"Message {i} has empty content") | |
| msg_length = len(message.content) | |
| if msg_length > MAX_MESSAGE_LENGTH: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message" | |
| ) | |
| total_length += msg_length | |
| if total_length > MAX_TOTAL_CONVERSATION_LENGTH: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed" | |
| ) | |
| # Validate role values | |
| for i, message in enumerate(request.messages): | |
| if message.role not in ["user", "assistant"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'" | |
| ) | |
| # Validate temperature | |
| if request.temperature is not None: | |
| if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE): | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}" | |
| ) | |
| # Validate top_k | |
| if request.top_k is not None: | |
| if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K): | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}" | |
| ) | |
| # Validate max_tokens | |
| if request.max_tokens is not None: | |
| if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS): | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}" | |
| ) | |
| async def lifespan(app: FastAPI): | |
| """Load models on startup (GPU or CPU).""" | |
| print("Loading nanochat models...") | |
| app.state.worker_pool = WorkerPool(num_workers=args.num_gpus) | |
| await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step) | |
| print(f"Server ready at http://localhost:{args.port}") | |
| yield | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| """Serve the chat UI.""" | |
| ui_html_path = os.path.join("nanochat", "ui.html") | |
| with open(ui_html_path, "r") as f: | |
| html_content = f.read() | |
| # Replace the API_URL to use the same origin | |
| html_content = html_content.replace( | |
| "const API_URL = `http://${window.location.hostname}:8000`;", | |
| "const API_URL = '';" | |
| ) | |
| return HTMLResponse(content=html_content) | |
| async def logo(): | |
| """Serve the NanoChat logo for favicon and header.""" | |
| logo_path = os.path.join("nanochat", "logo.svg") | |
| return FileResponse(logo_path, media_type="image/svg+xml") | |
| async def generate_stream( | |
| worker: Worker, | |
| tokens, | |
| temperature=None, | |
| max_new_tokens=None, | |
| top_k=None | |
| ) -> AsyncGenerator[str, None]: | |
| """Generate assistant response with streaming.""" | |
| temperature = temperature if temperature is not None else args.temperature | |
| max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens | |
| top_k = top_k if top_k is not None else args.top_k | |
| assistant_end = worker.tokenizer.encode_special("<|assistant_end|>") | |
| bos = worker.tokenizer.get_bos_token_id() | |
| # Accumulate tokens to properly handle multi-byte UTF-8 characters (like emojis) | |
| accumulated_tokens = [] | |
| # Track the last complete UTF-8 string (without replacement characters) | |
| last_clean_text = "" | |
| with worker.autocast_ctx: | |
| for token_column, token_masks in worker.engine.generate( | |
| tokens, | |
| num_samples=1, | |
| max_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| seed=random.randint(0, 2**31 - 1) | |
| ): | |
| token = token_column[0] | |
| # Stopping criteria | |
| if token == assistant_end or token == bos: | |
| break | |
| # Append the token to sequence | |
| accumulated_tokens.append(token) | |
| # Decode all accumulated tokens to get proper UTF-8 handling | |
| # Note that decode is a quite efficient operation, basically table lookup and string concat | |
| current_text = worker.tokenizer.decode(accumulated_tokens) | |
| # Only emit text if it doesn't end with a replacement character | |
| # This ensures we don't emit incomplete UTF-8 sequences | |
| if not current_text.endswith('�'): | |
| # Extract only the new text since last clean decode | |
| new_text = current_text[len(last_clean_text):] | |
| if new_text: # Only yield if there's new content | |
| yield f"data: {json.dumps({'token': new_text, 'worker': worker.worker_id}, ensure_ascii=False)}\n\n" | |
| last_clean_text = current_text | |
| yield f"data: {json.dumps({'done': True})}\n\n" | |
| async def chat_completions(request: ChatRequest): | |
| """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU.""" | |
| # Basic validation to prevent abuse | |
| validate_chat_request(request) | |
| # Log incoming conversation to console | |
| logger.info("="*20) | |
| for i, message in enumerate(request.messages): | |
| logger.info(f"[{message.role.upper()}]: {message.content}") | |
| logger.info("-"*20) | |
| # Acquire a worker from the pool (will wait if all are busy) | |
| worker_pool = app.state.worker_pool | |
| worker = await worker_pool.acquire_worker() | |
| try: | |
| # Build conversation tokens | |
| bos = worker.tokenizer.get_bos_token_id() | |
| user_start = worker.tokenizer.encode_special("<|user_start|>") | |
| user_end = worker.tokenizer.encode_special("<|user_end|>") | |
| assistant_start = worker.tokenizer.encode_special("<|assistant_start|>") | |
| assistant_end = worker.tokenizer.encode_special("<|assistant_end|>") | |
| conversation_tokens = [bos] | |
| for message in request.messages: | |
| if message.role == "user": | |
| conversation_tokens.append(user_start) | |
| conversation_tokens.extend(worker.tokenizer.encode(message.content)) | |
| conversation_tokens.append(user_end) | |
| elif message.role == "assistant": | |
| conversation_tokens.append(assistant_start) | |
| conversation_tokens.extend(worker.tokenizer.encode(message.content)) | |
| conversation_tokens.append(assistant_end) | |
| conversation_tokens.append(assistant_start) | |
| # Streaming response with worker release after completion | |
| response_tokens = [] | |
| async def stream_and_release(): | |
| try: | |
| async for chunk in generate_stream( | |
| worker, | |
| conversation_tokens, | |
| temperature=request.temperature, | |
| max_new_tokens=request.max_tokens, | |
| top_k=request.top_k | |
| ): | |
| # Accumulate response for logging | |
| chunk_data = json.loads(chunk.replace("data: ", "").strip()) | |
| if "token" in chunk_data: | |
| response_tokens.append(chunk_data["token"]) | |
| yield chunk | |
| finally: | |
| # Log the assistant response to console | |
| full_response = "".join(response_tokens) | |
| device_name = f"GPU {worker.worker_id}" if str(worker.device).startswith("cuda") else "CPU" | |
| logger.info(f"[ASSISTANT] ({device_name}): {full_response}") | |
| logger.info("="*20) | |
| # Release worker back to pool after streaming is done | |
| await worker_pool.release_worker(worker) | |
| return StreamingResponse( | |
| stream_and_release(), | |
| media_type="text/event-stream" | |
| ) | |
| except Exception as e: | |
| # Make sure to release worker even on error | |
| await worker_pool.release_worker(worker) | |
| raise e | |
| async def health(): | |
| """Health check endpoint.""" | |
| worker_pool = getattr(app.state, 'worker_pool', None) | |
| return { | |
| "status": "ok", | |
| "ready": worker_pool is not None and len(worker_pool.workers) > 0, | |
| "num_workers": worker_pool.num_workers if worker_pool else 0, | |
| "use_cuda": worker_pool.use_cuda if worker_pool else False, | |
| "available_workers": worker_pool.available_workers.qsize() if worker_pool else 0 | |
| } | |
| async def stats(): | |
| """Get worker pool statistics.""" | |
| worker_pool = app.state.worker_pool | |
| return { | |
| "total_workers": len(worker_pool.workers), | |
| "available_workers": worker_pool.available_workers.qsize(), | |
| "busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(), | |
| "use_cuda": worker_pool.use_cuda, | |
| "workers": [ | |
| { | |
| "worker_id": w.worker_id, | |
| "device": str(w.device) | |
| } for w in worker_pool.workers | |
| ] | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print(f"Starting NanoChat Web Server") | |
| print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}") | |
| uvicorn.run(app, host=args.host, port=args.port) | |