#!/usr/bin/env python3 """ Unified web chat server - serves both UI and API from a single FastAPI instance. Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads a full copy of the model, and incoming requests are distributed to available workers. Automatically falls back to CPU if CUDA is not available. Launch examples: - single available GPU (default) python -m scripts.chat_web - 4 GPUs python -m scripts.chat_web --num-gpus 4 - CPU only (automatic if no CUDA) python -m scripts.chat_web To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP) Endpoints: GET / - Chat UI POST /chat/completions - Chat API (streaming only) GET /health - Health check with worker pool status GET /stats - Worker pool statistics and GPU utilization Abuse Prevention: - Maximum 500 messages per request - Maximum 8000 characters per message - Maximum 32000 characters total conversation length - Temperature clamped to 0.0-2.0 - Top-k clamped to 1-200 - Max tokens clamped to 1-4096 """ import argparse import json import os import torch import asyncio import logging import random from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from pydantic import BaseModel from typing import List, Optional, AsyncGenerator from dataclasses import dataclass from nanochat.common import compute_init from nanochat.checkpoint_manager import load_model from nanochat.engine import Engine # Abuse prevention limits MAX_MESSAGES_PER_REQUEST = 500 MAX_MESSAGE_LENGTH = 8000 MAX_TOTAL_CONVERSATION_LENGTH = 32000 MIN_TEMPERATURE = 0.0 MAX_TEMPERATURE = 2.0 MIN_TOP_K = 1 MAX_TOP_K = 200 MIN_MAX_TOKENS = 1 MAX_MAX_TOKENS = 4096 parser = argparse.ArgumentParser(description='NanoChat Web Server') parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (ignored on CPU, default: 1)') parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation') parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter') parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation') parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') args = parser.parse_args() # Configure logging for conversation traffic logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() @dataclass class Worker: """A worker with a model loaded on a specific device (GPU or CPU).""" worker_id: int device: torch.device engine: Engine tokenizer: object autocast_ctx: torch.amp.autocast class WorkerPool: """Pool of workers, each with a model replica on a different device.""" def __init__(self, num_workers: Optional[int] = None): # Auto-detect: use GPUs if available, otherwise use 1 CPU worker if torch.cuda.is_available(): self.num_workers = num_workers if num_workers is not None else torch.cuda.device_count() self.use_cuda = True else: self.num_workers = 1 # CPU mode - single worker self.use_cuda = False self.workers: List[Worker] = [] self.available_workers: asyncio.Queue = asyncio.Queue() async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None): """Load model on each device.""" device_type = "GPU" if self.use_cuda else "CPU" print(f"Initializing worker pool with {self.num_workers} {device_type} worker(s)...") for worker_id in range(self.num_workers): if self.use_cuda: device = torch.device(f"cuda:{worker_id}") print(f"Loading model on GPU {worker_id}...") autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) else: device = torch.device("cpu") print(f"Loading model on CPU...") # CPU autocast uses bfloat16 if available, otherwise float32 autocast_ctx = torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16) model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) engine = Engine(model, tokenizer) worker = Worker( worker_id=worker_id, device=device, engine=engine, tokenizer=tokenizer, autocast_ctx=autocast_ctx ) self.workers.append(worker) await self.available_workers.put(worker) print(f"All {self.num_workers} worker(s) initialized!") async def acquire_worker(self) -> Worker: """Get an available worker from the pool.""" return await self.available_workers.get() async def release_worker(self, worker: Worker): """Return a worker to the pool.""" await self.available_workers.put(worker) class ChatMessage(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[ChatMessage] temperature: Optional[float] = None max_tokens: Optional[int] = None top_k: Optional[int] = None def validate_chat_request(request: ChatRequest): """Validate chat request to prevent abuse.""" # Check number of messages if len(request.messages) == 0: raise HTTPException(status_code=400, detail="At least one message is required") if len(request.messages) > MAX_MESSAGES_PER_REQUEST: raise HTTPException( status_code=400, detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request" ) # Check individual message lengths and total conversation length total_length = 0 for i, message in enumerate(request.messages): if not message.content: raise HTTPException(status_code=400, detail=f"Message {i} has empty content") msg_length = len(message.content) if msg_length > MAX_MESSAGE_LENGTH: raise HTTPException( status_code=400, detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message" ) total_length += msg_length if total_length > MAX_TOTAL_CONVERSATION_LENGTH: raise HTTPException( status_code=400, detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed" ) # Validate role values for i, message in enumerate(request.messages): if message.role not in ["user", "assistant"]: raise HTTPException( status_code=400, detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'" ) # Validate temperature if request.temperature is not None: if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE): raise HTTPException( status_code=400, detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}" ) # Validate top_k if request.top_k is not None: if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K): raise HTTPException( status_code=400, detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}" ) # Validate max_tokens if request.max_tokens is not None: if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS): raise HTTPException( status_code=400, detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}" ) @asynccontextmanager async def lifespan(app: FastAPI): """Load models on startup (GPU or CPU).""" print("Loading nanochat models...") app.state.worker_pool = WorkerPool(num_workers=args.num_gpus) await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step) print(f"Server ready at http://localhost:{args.port}") yield app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): """Serve the chat UI.""" ui_html_path = os.path.join("nanochat", "ui.html") with open(ui_html_path, "r") as f: html_content = f.read() # Replace the API_URL to use the same origin html_content = html_content.replace( "const API_URL = `http://${window.location.hostname}:8000`;", "const API_URL = '';" ) return HTMLResponse(content=html_content) @app.get("/logo.svg") async def logo(): """Serve the NanoChat logo for favicon and header.""" logo_path = os.path.join("nanochat", "logo.svg") return FileResponse(logo_path, media_type="image/svg+xml") async def generate_stream( worker: Worker, tokens, temperature=None, max_new_tokens=None, top_k=None ) -> AsyncGenerator[str, None]: """Generate assistant response with streaming.""" temperature = temperature if temperature is not None else args.temperature max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens top_k = top_k if top_k is not None else args.top_k assistant_end = worker.tokenizer.encode_special("<|assistant_end|>") bos = worker.tokenizer.get_bos_token_id() # Accumulate tokens to properly handle multi-byte UTF-8 characters (like emojis) accumulated_tokens = [] # Track the last complete UTF-8 string (without replacement characters) last_clean_text = "" with worker.autocast_ctx: for token_column, token_masks in worker.engine.generate( tokens, num_samples=1, max_tokens=max_new_tokens, temperature=temperature, top_k=top_k, seed=random.randint(0, 2**31 - 1) ): token = token_column[0] # Stopping criteria if token == assistant_end or token == bos: break # Append the token to sequence accumulated_tokens.append(token) # Decode all accumulated tokens to get proper UTF-8 handling # Note that decode is a quite efficient operation, basically table lookup and string concat current_text = worker.tokenizer.decode(accumulated_tokens) # Only emit text if it doesn't end with a replacement character # This ensures we don't emit incomplete UTF-8 sequences if not current_text.endswith('�'): # Extract only the new text since last clean decode new_text = current_text[len(last_clean_text):] if new_text: # Only yield if there's new content yield f"data: {json.dumps({'token': new_text, 'worker': worker.worker_id}, ensure_ascii=False)}\n\n" last_clean_text = current_text yield f"data: {json.dumps({'done': True})}\n\n" @app.post("/chat/completions") async def chat_completions(request: ChatRequest): """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU.""" # Basic validation to prevent abuse validate_chat_request(request) # Log incoming conversation to console logger.info("="*20) for i, message in enumerate(request.messages): logger.info(f"[{message.role.upper()}]: {message.content}") logger.info("-"*20) # Acquire a worker from the pool (will wait if all are busy) worker_pool = app.state.worker_pool worker = await worker_pool.acquire_worker() try: # Build conversation tokens bos = worker.tokenizer.get_bos_token_id() user_start = worker.tokenizer.encode_special("<|user_start|>") user_end = worker.tokenizer.encode_special("<|user_end|>") assistant_start = worker.tokenizer.encode_special("<|assistant_start|>") assistant_end = worker.tokenizer.encode_special("<|assistant_end|>") conversation_tokens = [bos] for message in request.messages: if message.role == "user": conversation_tokens.append(user_start) conversation_tokens.extend(worker.tokenizer.encode(message.content)) conversation_tokens.append(user_end) elif message.role == "assistant": conversation_tokens.append(assistant_start) conversation_tokens.extend(worker.tokenizer.encode(message.content)) conversation_tokens.append(assistant_end) conversation_tokens.append(assistant_start) # Streaming response with worker release after completion response_tokens = [] async def stream_and_release(): try: async for chunk in generate_stream( worker, conversation_tokens, temperature=request.temperature, max_new_tokens=request.max_tokens, top_k=request.top_k ): # Accumulate response for logging chunk_data = json.loads(chunk.replace("data: ", "").strip()) if "token" in chunk_data: response_tokens.append(chunk_data["token"]) yield chunk finally: # Log the assistant response to console full_response = "".join(response_tokens) device_name = f"GPU {worker.worker_id}" if str(worker.device).startswith("cuda") else "CPU" logger.info(f"[ASSISTANT] ({device_name}): {full_response}") logger.info("="*20) # Release worker back to pool after streaming is done await worker_pool.release_worker(worker) return StreamingResponse( stream_and_release(), media_type="text/event-stream" ) except Exception as e: # Make sure to release worker even on error await worker_pool.release_worker(worker) raise e @app.get("/health") async def health(): """Health check endpoint.""" worker_pool = getattr(app.state, 'worker_pool', None) return { "status": "ok", "ready": worker_pool is not None and len(worker_pool.workers) > 0, "num_workers": worker_pool.num_workers if worker_pool else 0, "use_cuda": worker_pool.use_cuda if worker_pool else False, "available_workers": worker_pool.available_workers.qsize() if worker_pool else 0 } @app.get("/stats") async def stats(): """Get worker pool statistics.""" worker_pool = app.state.worker_pool return { "total_workers": len(worker_pool.workers), "available_workers": worker_pool.available_workers.qsize(), "busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(), "use_cuda": worker_pool.use_cuda, "workers": [ { "worker_id": w.worker_id, "device": str(w.device) } for w in worker_pool.workers ] } if __name__ == "__main__": import uvicorn print(f"Starting NanoChat Web Server") print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}") uvicorn.run(app, host=args.host, port=args.port)