pluralchat / vllm-manager.py
Avijit Ghosh
Fix vLLM setup: Use ungated models and add device detection
00714e7
#!/usr/bin/env python3
"""
vLLM Model Manager for INF2
Handles dynamic model switching by restarting vLLM with the requested model.
"""
import os
import subprocess
import signal
import time
import json
from http.server import HTTPServer, BaseHTTPRequestHandler
from threading import Thread
import requests
# Model configurations
MODELS = {
"meta-llama/Meta-Llama-3-8B-Instruct": {
"hf_model": "meta-llama/Meta-Llama-3-8B-Instruct",
"display_name": "Llama 3 8B Instruct",
"description": "Meta's Llama 3 8B instruction-tuned model (ungated)"
},
"Qwen/Qwen2.5-7B-Instruct": {
"hf_model": "Qwen/Qwen2.5-7B-Instruct",
"display_name": "Qwen 2.5 7B Instruct",
"description": "Alibaba's Qwen 2.5 7B instruction-tuned model"
},
"mistralai/Mistral-7B-Instruct-v0.3": {
"hf_model": "mistralai/Mistral-7B-Instruct-v0.3",
"display_name": "Mistral 7B Instruct v0.3",
"description": "Mistral AI's 7B instruction-tuned model"
}
}
# Current state
current_model = os.environ.get("VLLM_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")
vllm_process = None
cache_dir = "/data/models"
def start_vllm(model_name):
"""Start vLLM with the specified model"""
global vllm_process, current_model
if model_name not in MODELS:
print(f"Error: Unknown model {model_name}")
return False
model_info = MODELS[model_name]
hf_model = model_info["hf_model"]
print(f"\n{'='*60}")
print(f"Starting vLLM with model: {model_info['display_name']}")
print(f"HuggingFace model: {hf_model}")
print(f"{'='*60}\n")
# Set environment variable for device detection
env = os.environ.copy()
env["VLLM_LOGGING_LEVEL"] = "DEBUG"
cmd = [
"python3", "-m", "vllm.entrypoints.openai.api_server",
"--model", hf_model,
"--host", "0.0.0.0",
"--port", "8001",
"--device", "neuron",
"--tensor-parallel-size", "2",
"--download-dir", "/data/models",
"--trust-remote-code"
]
try:
vllm_process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
bufsize=1,
env=env
)
# Monitor output in background thread
def monitor_output():
for line in vllm_process.stdout:
print(f"[vLLM] {line.rstrip()}")
Thread(target=monitor_output, daemon=True).start()
current_model = model_name
return True
except Exception as e:
print(f"Error starting vLLM: {e}")
return False
def stop_vllm():
"""Stop the current vLLM process"""
global vllm_process
if vllm_process:
print("Stopping vLLM...")
vllm_process.send_signal(signal.SIGTERM)
vllm_process.wait(timeout=30)
vllm_process = None
time.sleep(2)
def switch_model(new_model_id):
"""Switch to a different model"""
global current_model
if new_model_id not in MODELS:
return False
if new_model_id == current_model:
return True
print(f"Switching from {current_model} to {new_model_id}")
stop_vllm()
current_model = new_model_id
return start_vllm(new_model_id)
class ProxyHandler(BaseHTTPRequestHandler):
"""Proxy requests to vLLM, with custom /models endpoint"""
def log_message(self, format, *args):
"""Suppress default logging"""
pass
def do_GET(self):
if self.path == "/v1/models" or self.path == "/models":
# Return all available models
models_list = {
"object": "list",
"data": [
{
"id": model_id,
"object": "model",
"created": 1234567890,
"owned_by": "system",
"description": info["description"]
}
for model_id, info in MODELS.items()
]
}
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(models_list).encode())
elif self.path == "/health":
# Health check
try:
resp = requests.get("http://localhost:8001/health", timeout=1)
self.send_response(resp.status_code)
self.end_headers()
self.wfile.write(resp.content)
except:
self.send_response(503)
self.end_headers()
else:
# Proxy to vLLM
self.proxy_request()
def do_POST(self):
# Check if this is a chat completion request with model switch
if self.path.startswith("/v1/chat/completions"):
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length)
try:
data = json.loads(body)
requested_model = data.get("model")
# Switch model if needed
if requested_model and requested_model != current_model:
if switch_model(requested_model):
print(f"Switched to model: {requested_model}")
else:
self.send_response(500)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps({
"error": f"Failed to switch to model: {requested_model}"
}).encode())
return
# Update model in request to current model
data["model"] = current_model
body = json.dumps(data).encode()
except:
pass
# Proxy to vLLM
self.proxy_request(body)
def proxy_request(self, body=None):
"""Forward request to vLLM"""
try:
url = f"http://localhost:8001{self.path}"
headers = dict(self.headers)
headers.pop("Host", None)
if body is None and self.command == "POST":
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length)
resp = requests.request(
method=self.command,
url=url,
headers=headers,
data=body,
stream=True
)
self.send_response(resp.status_code)
for key, value in resp.headers.items():
if key.lower() not in ['transfer-encoding', 'connection']:
self.send_header(key, value)
self.end_headers()
for chunk in resp.iter_content(chunk_size=8192):
if chunk:
self.wfile.write(chunk)
except Exception as e:
print(f"Proxy error: {e}")
self.send_response(500)
self.end_headers()
def pre_download_models():
"""Pre-download all models to cache"""
print("Pre-downloading models to cache...")
for model_id in MODELS.keys():
print(f"Downloading {model_id}...")
subprocess.run([
"huggingface-cli", "download", model_id,
"--cache-dir", cache_dir
])
print("All models downloaded!")
if __name__ == "__main__":
# Ensure cache directory exists
os.makedirs(cache_dir, exist_ok=True)
# Pre-download models in background
download_thread = Thread(target=pre_download_models, daemon=True)
download_thread.start()
# Start vLLM with default model
if not start_vllm(current_model):
print("Failed to start vLLM, exiting")
exit(1)
# Start proxy server
print("Starting proxy server on port 8000...")
server = HTTPServer(('0.0.0.0', 8000), ProxyHandler)
try:
server.serve_forever()
except KeyboardInterrupt:
print("Shutting down...")
stop_vllm()