Spaces:
Runtime error
Runtime error
| """ | |
| A model worker that executes the model based on vLLM. | |
| See documentations at docs/vllm_integration.md | |
| """ | |
| import argparse | |
| import asyncio | |
| import json | |
| from typing import List | |
| from fastapi import FastAPI, Request, BackgroundTasks | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| import uvicorn | |
| from vllm import AsyncLLMEngine | |
| from vllm.engine.arg_utils import AsyncEngineArgs | |
| from vllm.sampling_params import SamplingParams | |
| from vllm.utils import random_uuid | |
| from fastchat.serve.base_model_worker import BaseModelWorker | |
| from fastchat.serve.model_worker import ( | |
| logger, | |
| worker_id, | |
| ) | |
| from fastchat.utils import get_context_length | |
| app = FastAPI() | |
| class VLLMWorker(BaseModelWorker): | |
| def __init__( | |
| self, | |
| controller_addr: str, | |
| worker_addr: str, | |
| worker_id: str, | |
| model_path: str, | |
| model_names: List[str], | |
| limit_worker_concurrency: int, | |
| no_register: bool, | |
| llm_engine: AsyncLLMEngine, | |
| conv_template: str, | |
| ): | |
| super().__init__( | |
| controller_addr, | |
| worker_addr, | |
| worker_id, | |
| model_path, | |
| model_names, | |
| limit_worker_concurrency, | |
| conv_template, | |
| ) | |
| logger.info( | |
| f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..." | |
| ) | |
| self.tokenizer = llm_engine.engine.tokenizer | |
| self.context_len = get_context_length(llm_engine.engine.model_config.hf_config) | |
| if not no_register: | |
| self.init_heart_beat() | |
| async def generate_stream(self, params): | |
| self.call_ct += 1 | |
| context = params.pop("prompt") | |
| request_id = params.pop("request_id") | |
| temperature = float(params.get("temperature", 1.0)) | |
| top_p = float(params.get("top_p", 1.0)) | |
| top_k = params.get("top_k", -1.0) | |
| presence_penalty = float(params.get("presence_penalty", 0.0)) | |
| frequency_penalty = float(params.get("frequency_penalty", 0.0)) | |
| max_new_tokens = params.get("max_new_tokens", 256) | |
| stop_str = params.get("stop", None) | |
| stop_token_ids = params.get("stop_token_ids", None) or [] | |
| if self.tokenizer.eos_token_id is not None: | |
| stop_token_ids.append(self.tokenizer.eos_token_id) | |
| echo = params.get("echo", True) | |
| use_beam_search = params.get("use_beam_search", False) | |
| best_of = params.get("best_of", None) | |
| # Handle stop_str | |
| stop = set() | |
| if isinstance(stop_str, str) and stop_str != "": | |
| stop.add(stop_str) | |
| elif isinstance(stop_str, list) and stop_str != []: | |
| stop.update(stop_str) | |
| for tid in stop_token_ids: | |
| if tid is not None: | |
| stop.add(self.tokenizer.decode(tid)) | |
| # make sampling params in vllm | |
| top_p = max(top_p, 1e-5) | |
| if temperature <= 1e-5: | |
| top_p = 1.0 | |
| sampling_params = SamplingParams( | |
| n=1, | |
| temperature=temperature, | |
| top_p=top_p, | |
| use_beam_search=use_beam_search, | |
| stop=list(stop), | |
| max_tokens=max_new_tokens, | |
| top_k=top_k, | |
| presence_penalty=presence_penalty, | |
| frequency_penalty=frequency_penalty, | |
| best_of=best_of, | |
| ) | |
| results_generator = engine.generate(context, sampling_params, request_id) | |
| async for request_output in results_generator: | |
| prompt = request_output.prompt | |
| if echo: | |
| text_outputs = [ | |
| prompt + output.text for output in request_output.outputs | |
| ] | |
| else: | |
| text_outputs = [output.text for output in request_output.outputs] | |
| text_outputs = " ".join(text_outputs) | |
| # Note: usage is not supported yet | |
| prompt_tokens = len(request_output.prompt_token_ids) | |
| completion_tokens = sum( | |
| len(output.token_ids) for output in request_output.outputs | |
| ) | |
| ret = { | |
| "text": text_outputs, | |
| "error_code": 0, | |
| "usage": { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": completion_tokens, | |
| "total_tokens": prompt_tokens + completion_tokens, | |
| }, | |
| "cumulative_logprob": [ | |
| output.cumulative_logprob for output in request_output.outputs | |
| ], | |
| "finish_reason": request_output.outputs[0].finish_reason | |
| if len(request_output.outputs) == 1 | |
| else [output.finish_reason for output in request_output.outputs], | |
| } | |
| yield (json.dumps(ret) + "\0").encode() | |
| async def generate(self, params): | |
| async for x in self.generate_stream(params): | |
| pass | |
| return json.loads(x[:-1].decode()) | |
| def release_worker_semaphore(): | |
| worker.semaphore.release() | |
| def acquire_worker_semaphore(): | |
| if worker.semaphore is None: | |
| worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) | |
| return worker.semaphore.acquire() | |
| def create_background_tasks(request_id): | |
| async def abort_request() -> None: | |
| await engine.abort(request_id) | |
| background_tasks = BackgroundTasks() | |
| background_tasks.add_task(release_worker_semaphore) | |
| background_tasks.add_task(abort_request) | |
| return background_tasks | |
| async def api_generate_stream(request: Request): | |
| params = await request.json() | |
| await acquire_worker_semaphore() | |
| request_id = random_uuid() | |
| params["request_id"] = request_id | |
| generator = worker.generate_stream(params) | |
| background_tasks = create_background_tasks(request_id) | |
| return StreamingResponse(generator, background=background_tasks) | |
| async def api_generate(request: Request): | |
| params = await request.json() | |
| await acquire_worker_semaphore() | |
| request_id = random_uuid() | |
| params["request_id"] = request_id | |
| output = await worker.generate(params) | |
| release_worker_semaphore() | |
| await engine.abort(request_id) | |
| return JSONResponse(output) | |
| async def api_get_status(request: Request): | |
| return worker.get_status() | |
| async def api_count_token(request: Request): | |
| params = await request.json() | |
| return worker.count_token(params) | |
| async def api_get_conv(request: Request): | |
| return worker.get_conv_template() | |
| async def api_model_details(request: Request): | |
| return {"context_length": worker.context_len} | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", type=str, default="localhost") | |
| parser.add_argument("--port", type=int, default=21002) | |
| parser.add_argument("--worker-address", type=str, default="http://localhost:21002") | |
| parser.add_argument( | |
| "--controller-address", type=str, default="http://localhost:21001" | |
| ) | |
| parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.5") | |
| parser.add_argument( | |
| "--model-names", | |
| type=lambda s: s.split(","), | |
| help="Optional display comma separated names", | |
| ) | |
| parser.add_argument("--limit-worker-concurrency", type=int, default=1024) | |
| parser.add_argument("--no-register", action="store_true") | |
| parser.add_argument("--num-gpus", type=int, default=1) | |
| parser.add_argument( | |
| "--conv-template", type=str, default=None, help="Conversation prompt template." | |
| ) | |
| parser.add_argument( | |
| "--trust_remote_code", | |
| action="store_false", | |
| default=True, | |
| help="Trust remote code (e.g., from HuggingFace) when" | |
| "downloading the model and tokenizer.", | |
| ) | |
| parser.add_argument( | |
| "--gpu_memory_utilization", | |
| type=float, | |
| default=0.9, | |
| help="The ratio (between 0 and 1) of GPU memory to" | |
| "reserve for the model weights, activations, and KV cache. Higher" | |
| "values will increase the KV cache size and thus improve the model's" | |
| "throughput. However, if the value is too high, it may cause out-of-" | |
| "memory (OOM) errors.", | |
| ) | |
| parser = AsyncEngineArgs.add_cli_args(parser) | |
| args = parser.parse_args() | |
| if args.model_path: | |
| args.model = args.model_path | |
| if args.num_gpus > 1: | |
| args.tensor_parallel_size = args.num_gpus | |
| engine_args = AsyncEngineArgs.from_cli_args(args) | |
| engine = AsyncLLMEngine.from_engine_args(engine_args) | |
| worker = VLLMWorker( | |
| args.controller_address, | |
| args.worker_address, | |
| worker_id, | |
| args.model_path, | |
| args.model_names, | |
| args.limit_worker_concurrency, | |
| args.no_register, | |
| engine, | |
| args.conv_template, | |
| ) | |
| uvicorn.run(app, host=args.host, port=args.port, log_level="info") | |