Spaces:
Sleeping
Sleeping
| import gc | |
| import uuid | |
| import logging | |
| import asyncio | |
| import transformers | |
| from fastapi import FastAPI | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| from transformers import ( | |
| AutoTokenizer, | |
| ) | |
| from optimum.pipelines import pipeline | |
| from optimum.onnxruntime import ORTModelForSequenceClassification | |
| # ----------------------------- # | |
| # Configurations # | |
| # ----------------------------- # | |
| transformers.set_seed(42) | |
| MODEL_NAME = "bert-tiny-finetuned-sms-spam-detection-onnx-quantized" | |
| BATCH_PROCESS_INTERVAL = 0.05 | |
| MAX_BATCH_SIZE = 16 | |
| # ----------------------------- # | |
| # Shared Storage # | |
| # ----------------------------- # | |
| lock = asyncio.Lock() | |
| query_queue: asyncio.Queue = asyncio.Queue() | |
| results: dict[str, dict] = {} | |
| classifier = None # will be initialized in lifespan | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logger.info("Starting the application...") | |
| logger.info(f"Using model: {MODEL_NAME}") | |
| logger.info(f"Batch process interval: {BATCH_PROCESS_INTERVAL}") | |
| logger.info(f"Max batch size: {MAX_BATCH_SIZE}") | |
| # ----------------------------- # | |
| # Model Initialization # | |
| # ----------------------------- # | |
| def load_classifier(model_name: str): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = ORTModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| ) | |
| gc.collect() | |
| return pipeline( | |
| task="text-classification", | |
| accelerator="ort", | |
| model=model, | |
| tokenizer=tokenizer, | |
| framework="pt", | |
| batch_size=MAX_BATCH_SIZE, | |
| num_workers=1, | |
| ) | |
| # ----------------------------- # | |
| # Pydantic Schema # | |
| # ----------------------------- # | |
| class Query(BaseModel): | |
| sentence: str | |
| # ----------------------------- # | |
| # Queue Processing Task # | |
| # ----------------------------- # | |
| async def process_queue(): | |
| while True: | |
| await asyncio.sleep(BATCH_PROCESS_INTERVAL) | |
| batch = [] | |
| while not query_queue.empty() and len(batch) < MAX_BATCH_SIZE: | |
| batch.append(await query_queue.get()) | |
| if not batch: | |
| continue | |
| sentences = [item["sentence"] for item in batch] | |
| ids = [item["id"] for item in batch] | |
| predictions = classifier(sentences) | |
| async with lock: | |
| results.update( | |
| { | |
| query_id: { | |
| "sentence": sentence, | |
| "label": pred["label"], | |
| "score": pred["score"], | |
| } | |
| for query_id, pred, sentence in zip(ids, predictions, sentences) | |
| } | |
| ) | |
| # ----------------------------- # | |
| # Lifespan Handler # | |
| # ----------------------------- # | |
| async def lifespan(_: FastAPI): | |
| global classifier | |
| classifier = load_classifier(MODEL_NAME) | |
| _ = classifier("Hi") | |
| logger.info("Model loaded successfully.") | |
| queue_task = asyncio.create_task(process_queue()) | |
| yield | |
| queue_task.cancel() | |
| logger.info("Shutting down the application...") | |
| logger.info("Model unloaded successfully.") | |
| classifier = None | |
| gc.collect() | |
| try: | |
| await queue_task | |
| except asyncio.CancelledError: | |
| pass | |
| # ----------------------------- # | |
| # FastAPI Setup # | |
| # ----------------------------- # | |
| app = FastAPI(lifespan=lifespan) | |
| # ----------------------------- # | |
| # API Endpoints # | |
| # ----------------------------- # | |
| async def classify(query: Query): | |
| logger.info(f"{query.sentence}") | |
| query_id = uuid.uuid4().hex | |
| await query_queue.put({"id": query_id, "sentence": query.sentence}) | |
| while True: | |
| async with lock: | |
| if query_id in results: | |
| return {"id": query_id, "result": results.pop(query_id)} | |
| await asyncio.sleep(0.1) | |
| app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
| def read_root(): | |
| return FileResponse(path="static/index.html", media_type="text/html") | |