#!/usr/bin/env python3 """ vLLM Model Manager for INF2 Handles dynamic model switching by restarting vLLM with the requested model. """ import os import subprocess import signal import time import json from http.server import HTTPServer, BaseHTTPRequestHandler from threading import Thread import requests # Model configurations MODELS = { "meta-llama/Meta-Llama-3-8B-Instruct": { "hf_model": "meta-llama/Meta-Llama-3-8B-Instruct", "display_name": "Llama 3 8B Instruct", "description": "Meta's Llama 3 8B instruction-tuned model (ungated)" }, "Qwen/Qwen2.5-7B-Instruct": { "hf_model": "Qwen/Qwen2.5-7B-Instruct", "display_name": "Qwen 2.5 7B Instruct", "description": "Alibaba's Qwen 2.5 7B instruction-tuned model" }, "mistralai/Mistral-7B-Instruct-v0.3": { "hf_model": "mistralai/Mistral-7B-Instruct-v0.3", "display_name": "Mistral 7B Instruct v0.3", "description": "Mistral AI's 7B instruction-tuned model" } } # Current state current_model = os.environ.get("VLLM_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct") vllm_process = None cache_dir = "/data/models" def start_vllm(model_name): """Start vLLM with the specified model""" global vllm_process, current_model if model_name not in MODELS: print(f"Error: Unknown model {model_name}") return False model_info = MODELS[model_name] hf_model = model_info["hf_model"] print(f"\n{'='*60}") print(f"Starting vLLM with model: {model_info['display_name']}") print(f"HuggingFace model: {hf_model}") print(f"{'='*60}\n") # Set environment variable for device detection env = os.environ.copy() env["VLLM_LOGGING_LEVEL"] = "DEBUG" cmd = [ "python3", "-m", "vllm.entrypoints.openai.api_server", "--model", hf_model, "--host", "0.0.0.0", "--port", "8001", "--device", "neuron", "--tensor-parallel-size", "2", "--download-dir", "/data/models", "--trust-remote-code" ] try: vllm_process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, bufsize=1, env=env ) # Monitor output in background thread def monitor_output(): for line in vllm_process.stdout: print(f"[vLLM] {line.rstrip()}") Thread(target=monitor_output, daemon=True).start() current_model = model_name return True except Exception as e: print(f"Error starting vLLM: {e}") return False def stop_vllm(): """Stop the current vLLM process""" global vllm_process if vllm_process: print("Stopping vLLM...") vllm_process.send_signal(signal.SIGTERM) vllm_process.wait(timeout=30) vllm_process = None time.sleep(2) def switch_model(new_model_id): """Switch to a different model""" global current_model if new_model_id not in MODELS: return False if new_model_id == current_model: return True print(f"Switching from {current_model} to {new_model_id}") stop_vllm() current_model = new_model_id return start_vllm(new_model_id) class ProxyHandler(BaseHTTPRequestHandler): """Proxy requests to vLLM, with custom /models endpoint""" def log_message(self, format, *args): """Suppress default logging""" pass def do_GET(self): if self.path == "/v1/models" or self.path == "/models": # Return all available models models_list = { "object": "list", "data": [ { "id": model_id, "object": "model", "created": 1234567890, "owned_by": "system", "description": info["description"] } for model_id, info in MODELS.items() ] } self.send_response(200) self.send_header("Content-Type", "application/json") self.end_headers() self.wfile.write(json.dumps(models_list).encode()) elif self.path == "/health": # Health check try: resp = requests.get("http://localhost:8001/health", timeout=1) self.send_response(resp.status_code) self.end_headers() self.wfile.write(resp.content) except: self.send_response(503) self.end_headers() else: # Proxy to vLLM self.proxy_request() def do_POST(self): # Check if this is a chat completion request with model switch if self.path.startswith("/v1/chat/completions"): content_length = int(self.headers.get('Content-Length', 0)) body = self.rfile.read(content_length) try: data = json.loads(body) requested_model = data.get("model") # Switch model if needed if requested_model and requested_model != current_model: if switch_model(requested_model): print(f"Switched to model: {requested_model}") else: self.send_response(500) self.send_header("Content-Type", "application/json") self.end_headers() self.wfile.write(json.dumps({ "error": f"Failed to switch to model: {requested_model}" }).encode()) return # Update model in request to current model data["model"] = current_model body = json.dumps(data).encode() except: pass # Proxy to vLLM self.proxy_request(body) def proxy_request(self, body=None): """Forward request to vLLM""" try: url = f"http://localhost:8001{self.path}" headers = dict(self.headers) headers.pop("Host", None) if body is None and self.command == "POST": content_length = int(self.headers.get('Content-Length', 0)) body = self.rfile.read(content_length) resp = requests.request( method=self.command, url=url, headers=headers, data=body, stream=True ) self.send_response(resp.status_code) for key, value in resp.headers.items(): if key.lower() not in ['transfer-encoding', 'connection']: self.send_header(key, value) self.end_headers() for chunk in resp.iter_content(chunk_size=8192): if chunk: self.wfile.write(chunk) except Exception as e: print(f"Proxy error: {e}") self.send_response(500) self.end_headers() def pre_download_models(): """Pre-download all models to cache""" print("Pre-downloading models to cache...") for model_id in MODELS.keys(): print(f"Downloading {model_id}...") subprocess.run([ "huggingface-cli", "download", model_id, "--cache-dir", cache_dir ]) print("All models downloaded!") if __name__ == "__main__": # Ensure cache directory exists os.makedirs(cache_dir, exist_ok=True) # Pre-download models in background download_thread = Thread(target=pre_download_models, daemon=True) download_thread.start() # Start vLLM with default model if not start_vllm(current_model): print("Failed to start vLLM, exiting") exit(1) # Start proxy server print("Starting proxy server on port 8000...") server = HTTPServer(('0.0.0.0', 8000), ProxyHandler) try: server.serve_forever() except KeyboardInterrupt: print("Shutting down...") stop_vllm()