import gc import uuid import logging import asyncio import transformers from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from pydantic import BaseModel from contextlib import asynccontextmanager from transformers import ( AutoTokenizer, ) from optimum.pipelines import pipeline from optimum.onnxruntime import ORTModelForSequenceClassification # ----------------------------- # # Configurations # # ----------------------------- # transformers.set_seed(42) MODEL_NAME = "bert-tiny-finetuned-sms-spam-detection-onnx-quantized" BATCH_PROCESS_INTERVAL = 0.05 MAX_BATCH_SIZE = 16 # ----------------------------- # # Shared Storage # # ----------------------------- # lock = asyncio.Lock() query_queue: asyncio.Queue = asyncio.Queue() results: dict[str, dict] = {} classifier = None # will be initialized in lifespan logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger(__name__) logger.info("Starting the application...") logger.info(f"Using model: {MODEL_NAME}") logger.info(f"Batch process interval: {BATCH_PROCESS_INTERVAL}") logger.info(f"Max batch size: {MAX_BATCH_SIZE}") # ----------------------------- # # Model Initialization # # ----------------------------- # def load_classifier(model_name: str): tokenizer = AutoTokenizer.from_pretrained(model_name) model = ORTModelForSequenceClassification.from_pretrained( model_name, ) gc.collect() return pipeline( task="text-classification", accelerator="ort", model=model, tokenizer=tokenizer, framework="pt", batch_size=MAX_BATCH_SIZE, num_workers=1, ) # ----------------------------- # # Pydantic Schema # # ----------------------------- # class Query(BaseModel): sentence: str # ----------------------------- # # Queue Processing Task # # ----------------------------- # async def process_queue(): while True: await asyncio.sleep(BATCH_PROCESS_INTERVAL) batch = [] while not query_queue.empty() and len(batch) < MAX_BATCH_SIZE: batch.append(await query_queue.get()) if not batch: continue sentences = [item["sentence"] for item in batch] ids = [item["id"] for item in batch] predictions = classifier(sentences) async with lock: results.update( { query_id: { "sentence": sentence, "label": pred["label"], "score": pred["score"], } for query_id, pred, sentence in zip(ids, predictions, sentences) } ) # ----------------------------- # # Lifespan Handler # # ----------------------------- # @asynccontextmanager async def lifespan(_: FastAPI): global classifier classifier = load_classifier(MODEL_NAME) _ = classifier("Hi") logger.info("Model loaded successfully.") queue_task = asyncio.create_task(process_queue()) yield queue_task.cancel() logger.info("Shutting down the application...") logger.info("Model unloaded successfully.") classifier = None gc.collect() try: await queue_task except asyncio.CancelledError: pass # ----------------------------- # # FastAPI Setup # # ----------------------------- # app = FastAPI(lifespan=lifespan) # ----------------------------- # # API Endpoints # # ----------------------------- # @app.post("/classify") async def classify(query: Query): logger.info(f"{query.sentence}") query_id = uuid.uuid4().hex await query_queue.put({"id": query_id, "sentence": query.sentence}) while True: async with lock: if query_id in results: return {"id": query_id, "result": results.pop(query_id)} await asyncio.sleep(0.1) app.mount("/", StaticFiles(directory="static", html=True), name="static") @app.get("/") def read_root(): return FileResponse(path="static/index.html", media_type="text/html")