Avijit Ghosh commited on
Commit
db1c946
·
1 Parent(s): 13685a1

Implement dynamic model switching for all 3 models in UI

Browse files

- Add vllm-manager.py proxy that shows all 3 models in picker
- Automatically switches vLLM backend when user selects different model
- Pre-downloads all models to /data/models (persistent storage)
- Models: Llama-3.1-8B, Qwen3-8B, gpt-oss-20b
- Switching takes ~2-3 minutes (after initial download)
- Install requests and huggingface-hub packages

Files changed (4) hide show
  1. Dockerfile +3 -1
  2. VLLM_SETUP.md +13 -14
  3. entrypoint.sh +18 -26
  4. vllm-manager.py +236 -0
Dockerfile CHANGED
@@ -24,7 +24,7 @@ RUN apt-get update
24
  RUN apt-get install -y libgomp1 libcurl4 curl python3 python3-pip python3-venv
25
 
26
  # Install vLLM with AWS Neuron support for INF2
27
- RUN pip3 install --break-system-packages --no-cache-dir vllm awscli
28
 
29
  # ensure vllm cache dir exists before adjusting ownership
30
  RUN mkdir -p /home/user/.cache && chown -R 1000:1000 /home/user/.cache
@@ -36,10 +36,12 @@ USER user
36
 
37
  COPY --chown=1000 .env /app/.env
38
  COPY --chown=1000 entrypoint.sh /app/entrypoint.sh
 
39
  COPY --chown=1000 package.json /app/package.json
40
  COPY --chown=1000 package-lock.json /app/package-lock.json
41
 
42
  RUN chmod +x /app/entrypoint.sh
 
43
 
44
  FROM node:20 AS builder
45
 
 
24
  RUN apt-get install -y libgomp1 libcurl4 curl python3 python3-pip python3-venv
25
 
26
  # Install vLLM with AWS Neuron support for INF2
27
+ RUN pip3 install --break-system-packages --no-cache-dir vllm awscli requests huggingface-hub
28
 
29
  # ensure vllm cache dir exists before adjusting ownership
30
  RUN mkdir -p /home/user/.cache && chown -R 1000:1000 /home/user/.cache
 
36
 
37
  COPY --chown=1000 .env /app/.env
38
  COPY --chown=1000 entrypoint.sh /app/entrypoint.sh
39
+ COPY --chown=1000 vllm-manager.py /app/vllm-manager.py
40
  COPY --chown=1000 package.json /app/package.json
41
  COPY --chown=1000 package-lock.json /app/package-lock.json
42
 
43
  RUN chmod +x /app/entrypoint.sh
44
+ RUN chmod +x /app/vllm-manager.py
45
 
46
  FROM node:20 AS builder
47
 
VLLM_SETUP.md CHANGED
@@ -2,23 +2,13 @@
2
 
3
  This branch uses vLLM with AWS Neuron support for running models on Amazon INF2 instances.
4
 
5
- ## Configuration
6
-
7
- ### Environment Variables
8
-
9
- Set these in your HuggingFace Space secrets:
10
 
11
- ```bash
12
- # Primary model to load (INF2 typically supports one model at a time)
13
- VLLM_MODEL=meta-llama/Llama-3.1-8B-Instruct
14
 
15
- # Alternative models (change VLLM_MODEL to switch):
16
- # VLLM_MODEL=Qwen/Qwen3-8B
17
- # VLLM_MODEL=openai/gpt-oss-20b
18
- # VLLM_MODEL=microsoft/Phi-3-mini-4k-instruct
19
- ```
20
 
21
- ### Model Equivalents (Ollama HuggingFace)
22
 
23
  | Ollama Model | HuggingFace Model | Notes |
24
  |--------------|-------------------|-------|
@@ -26,6 +16,15 @@ VLLM_MODEL=meta-llama/Llama-3.1-8B-Instruct
26
  | `qwen3:8b` | `Qwen/Qwen3-8B` | Fast, multilingual |
27
  | `gpt-oss:20b` | `openai/gpt-oss-20b` | Larger, more capable |
28
 
 
 
 
 
 
 
 
 
 
29
  ### Supported Models for INF2
30
 
31
  vLLM with Neuron supports:
 
2
 
3
  This branch uses vLLM with AWS Neuron support for running models on Amazon INF2 instances.
4
 
5
+ **✨ Dynamic Model Switching**: All three models appear in the UI's model picker. When you select a different model, the system automatically restarts vLLM with the new model (takes ~2-3 minutes after first download).
 
 
 
 
6
 
7
+ ## Configuration
 
 
8
 
9
+ ### Available Models
 
 
 
 
10
 
11
+ All three models are pre-configured and cached in persistent storage:
12
 
13
  | Ollama Model | HuggingFace Model | Notes |
14
  |--------------|-------------------|-------|
 
16
  | `qwen3:8b` | `Qwen/Qwen3-8B` | Fast, multilingual |
17
  | `gpt-oss:20b` | `openai/gpt-oss-20b` | Larger, more capable |
18
 
19
+ ### Environment Variables
20
+
21
+ ```bash
22
+ # Default model to load at startup
23
+ VLLM_MODEL=meta-llama/Llama-3.1-8B-Instruct
24
+ ```
25
+
26
+ You can change the default startup model, but all three models will be available in the UI regardless.
27
+
28
  ### Supported Models for INF2
29
 
30
  vLLM with Neuron supports:
entrypoint.sh CHANGED
@@ -14,45 +14,37 @@ if [ "$INCLUDE_DB" = "true" ] ; then
14
  nohup mongod &
15
  fi;
16
 
17
- # Start vLLM service with OpenAI-compatible API for HF space
18
- echo "Starting vLLM service with OpenAI-compatible API"
19
 
20
  # Ensure dir for model cache
21
  mkdir -p /data/models
22
 
23
  # Default model for vLLM (can be overridden via VLLM_MODEL env var)
24
- # Note: INF2 typically supports one model at a time due to memory constraints
25
- # Available models:
26
- # - meta-llama/Llama-3.1-8B-Instruct (equivalent to llama3.1:8b)
27
- # - Qwen/Qwen3-8B (equivalent to qwen3:8b)
28
- # - openai/gpt-oss-20b (equivalent to gpt-oss:20b)
29
- VLLM_MODEL=${VLLM_MODEL:-"meta-llama/Llama-3.1-8B-Instruct"}
30
 
31
- echo "Loading model: $VLLM_MODEL"
 
32
 
33
- # Start vLLM OpenAI-compatible server
34
- # Using --served-model-name to make models accessible via simpler names
35
- nohup python3 -m vllm.entrypoints.openai.api_server \
36
- --model "$VLLM_MODEL" \
37
- --host 0.0.0.0 \
38
- --port 8000 \
39
- --device neuron \
40
- --tensor-parallel-size 2 \
41
- > /tmp/vllm.log 2>&1 &
42
- VLLM_PID=$!
43
 
44
- # Override OPENAI_BASE_URL to use local vLLM at runtime
45
  export OPENAI_BASE_URL=http://localhost:8000/v1
46
- echo "OPENAI_BASE_URL set to $OPENAI_BASE_URL for local vLLM"
47
 
48
- # Wait for vLLM to be ready
49
- MAX_RETRIES=60
50
  RETRY_COUNT=0
51
- echo "Waiting for vLLM to be ready (this may take a few minutes for model loading)..."
52
  until curl -s http://localhost:8000/health > /dev/null 2>&1; do
53
  RETRY_COUNT=$((RETRY_COUNT + 1))
54
  if [ $RETRY_COUNT -ge $MAX_RETRIES ]; then
55
- echo "vLLM failed to start after $MAX_RETRIES attempts"
 
 
56
  echo "=== vLLM logs ==="
57
  cat /tmp/vllm.log
58
  exit 1
@@ -63,7 +55,7 @@ until curl -s http://localhost:8000/health > /dev/null 2>&1; do
63
  fi
64
  done
65
 
66
- echo "vLLM is ready!"
67
 
68
  export PUBLIC_VERSION=$(node -p "require('./package.json').version")
69
 
 
14
  nohup mongod &
15
  fi;
16
 
17
+ # Start vLLM Model Manager (handles multiple models with dynamic switching)
18
+ echo "Starting vLLM Model Manager"
19
 
20
  # Ensure dir for model cache
21
  mkdir -p /data/models
22
 
23
  # Default model for vLLM (can be overridden via VLLM_MODEL env var)
24
+ # Available models: meta-llama/Llama-3.1-8B-Instruct, Qwen/Qwen3-8B, openai/gpt-oss-20b
25
+ export VLLM_MODEL=${VLLM_MODEL:-"meta-llama/Llama-3.1-8B-Instruct"}
 
 
 
 
26
 
27
+ # Make manager executable
28
+ chmod +x /app/vllm-manager.py
29
 
30
+ # Start the vLLM manager (it handles vLLM and provides model switching)
31
+ nohup python3 /app/vllm-manager.py > /tmp/vllm-manager.log 2>&1 &
32
+ MANAGER_PID=$!
 
 
 
 
 
 
 
33
 
34
+ # Override OPENAI_BASE_URL to use local vLLM manager at runtime
35
  export OPENAI_BASE_URL=http://localhost:8000/v1
36
+ echo "OPENAI_BASE_URL set to $OPENAI_BASE_URL for local vLLM with model switching"
37
 
38
+ # Wait for vLLM manager to be ready
39
+ MAX_RETRIES=120
40
  RETRY_COUNT=0
41
+ echo "Waiting for vLLM manager to be ready (this may take several minutes for model loading)..."
42
  until curl -s http://localhost:8000/health > /dev/null 2>&1; do
43
  RETRY_COUNT=$((RETRY_COUNT + 1))
44
  if [ $RETRY_COUNT -ge $MAX_RETRIES ]; then
45
+ echo "vLLM manager failed to start after $MAX_RETRIES attempts"
46
+ echo "=== vLLM manager logs ==="
47
+ cat /tmp/vllm-manager.log
48
  echo "=== vLLM logs ==="
49
  cat /tmp/vllm.log
50
  exit 1
 
55
  fi
56
  done
57
 
58
+ echo "vLLM manager is ready! All 3 models available in UI."
59
 
60
  export PUBLIC_VERSION=$(node -p "require('./package.json').version")
61
 
vllm-manager.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ vLLM Model Manager for INF2
4
+ Handles dynamic model switching by restarting vLLM with the requested model.
5
+ """
6
+ import os
7
+ import subprocess
8
+ import signal
9
+ import time
10
+ import json
11
+ from http.server import HTTPServer, BaseHTTPRequestHandler
12
+ from threading import Thread
13
+ import requests
14
+
15
+ # Model configurations
16
+ MODELS = {
17
+ "meta-llama/Llama-3.1-8B-Instruct": {
18
+ "id": "meta-llama/Llama-3.1-8B-Instruct",
19
+ "displayName": "Llama 3.1 8B",
20
+ "description": "Meta's Llama 3.1 8B Instruct model"
21
+ },
22
+ "Qwen/Qwen3-8B": {
23
+ "id": "Qwen/Qwen3-8B",
24
+ "displayName": "Qwen 3 8B",
25
+ "description": "Alibaba's Qwen 3 8B model"
26
+ },
27
+ "openai/gpt-oss-20b": {
28
+ "id": "openai/gpt-oss-20b",
29
+ "displayName": "GPT OSS 20B",
30
+ "description": "OpenAI's GPT OSS 20B model"
31
+ }
32
+ }
33
+
34
+ # Current state
35
+ current_model = os.environ.get("VLLM_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
36
+ vllm_process = None
37
+ cache_dir = "/data/models"
38
+
39
+ def start_vllm(model_id):
40
+ """Start vLLM server with the specified model"""
41
+ global vllm_process
42
+
43
+ print(f"Starting vLLM with model: {model_id}")
44
+
45
+ cmd = [
46
+ "python3", "-m", "vllm.entrypoints.openai.api_server",
47
+ "--model", model_id,
48
+ "--host", "0.0.0.0",
49
+ "--port", "8001", # Use 8001 for actual vLLM
50
+ "--device", "neuron",
51
+ "--tensor-parallel-size", "2",
52
+ "--download-dir", cache_dir
53
+ ]
54
+
55
+ vllm_process = subprocess.Popen(
56
+ cmd,
57
+ stdout=open("/tmp/vllm.log", "a"),
58
+ stderr=subprocess.STDOUT
59
+ )
60
+
61
+ # Wait for vLLM to be ready
62
+ for i in range(120): # 10 minutes timeout
63
+ try:
64
+ resp = requests.get("http://localhost:8001/health", timeout=1)
65
+ if resp.status_code == 200:
66
+ print(f"vLLM ready with model: {model_id}")
67
+ return True
68
+ except:
69
+ pass
70
+ time.sleep(5)
71
+ if i % 6 == 0:
72
+ print(f"Waiting for vLLM... ({i*5}s)")
73
+
74
+ print("ERROR: vLLM failed to start")
75
+ return False
76
+
77
+ def stop_vllm():
78
+ """Stop the current vLLM process"""
79
+ global vllm_process
80
+ if vllm_process:
81
+ print("Stopping vLLM...")
82
+ vllm_process.send_signal(signal.SIGTERM)
83
+ vllm_process.wait(timeout=30)
84
+ vllm_process = None
85
+ time.sleep(2)
86
+
87
+ def switch_model(new_model_id):
88
+ """Switch to a different model"""
89
+ global current_model
90
+ if new_model_id not in MODELS:
91
+ return False
92
+ if new_model_id == current_model:
93
+ return True
94
+
95
+ print(f"Switching from {current_model} to {new_model_id}")
96
+ stop_vllm()
97
+ current_model = new_model_id
98
+ return start_vllm(new_model_id)
99
+
100
+ class ProxyHandler(BaseHTTPRequestHandler):
101
+ """Proxy requests to vLLM, with custom /models endpoint"""
102
+
103
+ def log_message(self, format, *args):
104
+ """Suppress default logging"""
105
+ pass
106
+
107
+ def do_GET(self):
108
+ if self.path == "/v1/models" or self.path == "/models":
109
+ # Return all available models
110
+ models_list = {
111
+ "object": "list",
112
+ "data": [
113
+ {
114
+ "id": model_id,
115
+ "object": "model",
116
+ "created": 1234567890,
117
+ "owned_by": "system",
118
+ "description": info["description"]
119
+ }
120
+ for model_id, info in MODELS.items()
121
+ ]
122
+ }
123
+ self.send_response(200)
124
+ self.send_header("Content-Type", "application/json")
125
+ self.end_headers()
126
+ self.wfile.write(json.dumps(models_list).encode())
127
+ elif self.path == "/health":
128
+ # Health check
129
+ try:
130
+ resp = requests.get("http://localhost:8001/health", timeout=1)
131
+ self.send_response(resp.status_code)
132
+ self.end_headers()
133
+ self.wfile.write(resp.content)
134
+ except:
135
+ self.send_response(503)
136
+ self.end_headers()
137
+ else:
138
+ # Proxy to vLLM
139
+ self.proxy_request()
140
+
141
+ def do_POST(self):
142
+ # Check if this is a chat completion request with model switch
143
+ if self.path.startswith("/v1/chat/completions"):
144
+ content_length = int(self.headers.get('Content-Length', 0))
145
+ body = self.rfile.read(content_length)
146
+ try:
147
+ data = json.loads(body)
148
+ requested_model = data.get("model")
149
+
150
+ # Switch model if needed
151
+ if requested_model and requested_model != current_model:
152
+ if switch_model(requested_model):
153
+ print(f"Switched to model: {requested_model}")
154
+ else:
155
+ self.send_response(500)
156
+ self.send_header("Content-Type", "application/json")
157
+ self.end_headers()
158
+ self.wfile.write(json.dumps({
159
+ "error": f"Failed to switch to model: {requested_model}"
160
+ }).encode())
161
+ return
162
+
163
+ # Update model in request to current model
164
+ data["model"] = current_model
165
+ body = json.dumps(data).encode()
166
+ except:
167
+ pass
168
+
169
+ # Proxy to vLLM
170
+ self.proxy_request(body)
171
+
172
+ def proxy_request(self, body=None):
173
+ """Forward request to vLLM"""
174
+ try:
175
+ url = f"http://localhost:8001{self.path}"
176
+ headers = dict(self.headers)
177
+ headers.pop("Host", None)
178
+
179
+ if body is None and self.command == "POST":
180
+ content_length = int(self.headers.get('Content-Length', 0))
181
+ body = self.rfile.read(content_length)
182
+
183
+ resp = requests.request(
184
+ method=self.command,
185
+ url=url,
186
+ headers=headers,
187
+ data=body,
188
+ stream=True
189
+ )
190
+
191
+ self.send_response(resp.status_code)
192
+ for key, value in resp.headers.items():
193
+ if key.lower() not in ['transfer-encoding', 'connection']:
194
+ self.send_header(key, value)
195
+ self.end_headers()
196
+
197
+ for chunk in resp.iter_content(chunk_size=8192):
198
+ if chunk:
199
+ self.wfile.write(chunk)
200
+ except Exception as e:
201
+ print(f"Proxy error: {e}")
202
+ self.send_response(500)
203
+ self.end_headers()
204
+
205
+ def pre_download_models():
206
+ """Pre-download all models to cache"""
207
+ print("Pre-downloading models to cache...")
208
+ for model_id in MODELS.keys():
209
+ print(f"Downloading {model_id}...")
210
+ subprocess.run([
211
+ "huggingface-cli", "download", model_id,
212
+ "--cache-dir", cache_dir
213
+ ])
214
+ print("All models downloaded!")
215
+
216
+ if __name__ == "__main__":
217
+ # Ensure cache directory exists
218
+ os.makedirs(cache_dir, exist_ok=True)
219
+
220
+ # Pre-download models in background
221
+ download_thread = Thread(target=pre_download_models, daemon=True)
222
+ download_thread.start()
223
+
224
+ # Start vLLM with default model
225
+ if not start_vllm(current_model):
226
+ print("Failed to start vLLM, exiting")
227
+ exit(1)
228
+
229
+ # Start proxy server
230
+ print("Starting proxy server on port 8000...")
231
+ server = HTTPServer(('0.0.0.0', 8000), ProxyHandler)
232
+ try:
233
+ server.serve_forever()
234
+ except KeyboardInterrupt:
235
+ print("Shutting down...")
236
+ stop_vllm()