Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| vLLM Model Manager for INF2 | |
| Handles dynamic model switching by restarting vLLM with the requested model. | |
| """ | |
| import os | |
| import subprocess | |
| import signal | |
| import time | |
| import json | |
| from http.server import HTTPServer, BaseHTTPRequestHandler | |
| from threading import Thread | |
| import requests | |
| # Model configurations | |
| MODELS = { | |
| "meta-llama/Meta-Llama-3-8B-Instruct": { | |
| "hf_model": "meta-llama/Meta-Llama-3-8B-Instruct", | |
| "display_name": "Llama 3 8B Instruct", | |
| "description": "Meta's Llama 3 8B instruction-tuned model (ungated)" | |
| }, | |
| "Qwen/Qwen2.5-7B-Instruct": { | |
| "hf_model": "Qwen/Qwen2.5-7B-Instruct", | |
| "display_name": "Qwen 2.5 7B Instruct", | |
| "description": "Alibaba's Qwen 2.5 7B instruction-tuned model" | |
| }, | |
| "mistralai/Mistral-7B-Instruct-v0.3": { | |
| "hf_model": "mistralai/Mistral-7B-Instruct-v0.3", | |
| "display_name": "Mistral 7B Instruct v0.3", | |
| "description": "Mistral AI's 7B instruction-tuned model" | |
| } | |
| } | |
| # Current state | |
| current_model = os.environ.get("VLLM_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct") | |
| vllm_process = None | |
| cache_dir = "/data/models" | |
| def start_vllm(model_name): | |
| """Start vLLM with the specified model""" | |
| global vllm_process, current_model | |
| if model_name not in MODELS: | |
| print(f"Error: Unknown model {model_name}") | |
| return False | |
| model_info = MODELS[model_name] | |
| hf_model = model_info["hf_model"] | |
| print(f"\n{'='*60}") | |
| print(f"Starting vLLM with model: {model_info['display_name']}") | |
| print(f"HuggingFace model: {hf_model}") | |
| print(f"{'='*60}\n") | |
| # Set environment variable for device detection | |
| env = os.environ.copy() | |
| env["VLLM_LOGGING_LEVEL"] = "DEBUG" | |
| cmd = [ | |
| "python3", "-m", "vllm.entrypoints.openai.api_server", | |
| "--model", hf_model, | |
| "--host", "0.0.0.0", | |
| "--port", "8001", | |
| "--device", "neuron", | |
| "--tensor-parallel-size", "2", | |
| "--download-dir", "/data/models", | |
| "--trust-remote-code" | |
| ] | |
| try: | |
| vllm_process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| universal_newlines=True, | |
| bufsize=1, | |
| env=env | |
| ) | |
| # Monitor output in background thread | |
| def monitor_output(): | |
| for line in vllm_process.stdout: | |
| print(f"[vLLM] {line.rstrip()}") | |
| Thread(target=monitor_output, daemon=True).start() | |
| current_model = model_name | |
| return True | |
| except Exception as e: | |
| print(f"Error starting vLLM: {e}") | |
| return False | |
| def stop_vllm(): | |
| """Stop the current vLLM process""" | |
| global vllm_process | |
| if vllm_process: | |
| print("Stopping vLLM...") | |
| vllm_process.send_signal(signal.SIGTERM) | |
| vllm_process.wait(timeout=30) | |
| vllm_process = None | |
| time.sleep(2) | |
| def switch_model(new_model_id): | |
| """Switch to a different model""" | |
| global current_model | |
| if new_model_id not in MODELS: | |
| return False | |
| if new_model_id == current_model: | |
| return True | |
| print(f"Switching from {current_model} to {new_model_id}") | |
| stop_vllm() | |
| current_model = new_model_id | |
| return start_vllm(new_model_id) | |
| class ProxyHandler(BaseHTTPRequestHandler): | |
| """Proxy requests to vLLM, with custom /models endpoint""" | |
| def log_message(self, format, *args): | |
| """Suppress default logging""" | |
| pass | |
| def do_GET(self): | |
| if self.path == "/v1/models" or self.path == "/models": | |
| # Return all available models | |
| models_list = { | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": model_id, | |
| "object": "model", | |
| "created": 1234567890, | |
| "owned_by": "system", | |
| "description": info["description"] | |
| } | |
| for model_id, info in MODELS.items() | |
| ] | |
| } | |
| self.send_response(200) | |
| self.send_header("Content-Type", "application/json") | |
| self.end_headers() | |
| self.wfile.write(json.dumps(models_list).encode()) | |
| elif self.path == "/health": | |
| # Health check | |
| try: | |
| resp = requests.get("http://localhost:8001/health", timeout=1) | |
| self.send_response(resp.status_code) | |
| self.end_headers() | |
| self.wfile.write(resp.content) | |
| except: | |
| self.send_response(503) | |
| self.end_headers() | |
| else: | |
| # Proxy to vLLM | |
| self.proxy_request() | |
| def do_POST(self): | |
| # Check if this is a chat completion request with model switch | |
| if self.path.startswith("/v1/chat/completions"): | |
| content_length = int(self.headers.get('Content-Length', 0)) | |
| body = self.rfile.read(content_length) | |
| try: | |
| data = json.loads(body) | |
| requested_model = data.get("model") | |
| # Switch model if needed | |
| if requested_model and requested_model != current_model: | |
| if switch_model(requested_model): | |
| print(f"Switched to model: {requested_model}") | |
| else: | |
| self.send_response(500) | |
| self.send_header("Content-Type", "application/json") | |
| self.end_headers() | |
| self.wfile.write(json.dumps({ | |
| "error": f"Failed to switch to model: {requested_model}" | |
| }).encode()) | |
| return | |
| # Update model in request to current model | |
| data["model"] = current_model | |
| body = json.dumps(data).encode() | |
| except: | |
| pass | |
| # Proxy to vLLM | |
| self.proxy_request(body) | |
| def proxy_request(self, body=None): | |
| """Forward request to vLLM""" | |
| try: | |
| url = f"http://localhost:8001{self.path}" | |
| headers = dict(self.headers) | |
| headers.pop("Host", None) | |
| if body is None and self.command == "POST": | |
| content_length = int(self.headers.get('Content-Length', 0)) | |
| body = self.rfile.read(content_length) | |
| resp = requests.request( | |
| method=self.command, | |
| url=url, | |
| headers=headers, | |
| data=body, | |
| stream=True | |
| ) | |
| self.send_response(resp.status_code) | |
| for key, value in resp.headers.items(): | |
| if key.lower() not in ['transfer-encoding', 'connection']: | |
| self.send_header(key, value) | |
| self.end_headers() | |
| for chunk in resp.iter_content(chunk_size=8192): | |
| if chunk: | |
| self.wfile.write(chunk) | |
| except Exception as e: | |
| print(f"Proxy error: {e}") | |
| self.send_response(500) | |
| self.end_headers() | |
| def pre_download_models(): | |
| """Pre-download all models to cache""" | |
| print("Pre-downloading models to cache...") | |
| for model_id in MODELS.keys(): | |
| print(f"Downloading {model_id}...") | |
| subprocess.run([ | |
| "huggingface-cli", "download", model_id, | |
| "--cache-dir", cache_dir | |
| ]) | |
| print("All models downloaded!") | |
| if __name__ == "__main__": | |
| # Ensure cache directory exists | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Pre-download models in background | |
| download_thread = Thread(target=pre_download_models, daemon=True) | |
| download_thread.start() | |
| # Start vLLM with default model | |
| if not start_vllm(current_model): | |
| print("Failed to start vLLM, exiting") | |
| exit(1) | |
| # Start proxy server | |
| print("Starting proxy server on port 8000...") | |
| server = HTTPServer(('0.0.0.0', 8000), ProxyHandler) | |
| try: | |
| server.serve_forever() | |
| except KeyboardInterrupt: | |
| print("Shutting down...") | |
| stop_vllm() | |