Spaces:
Sleeping
Sleeping
| import uuid | |
| import torch | |
| import asyncio | |
| import transformers | |
| from typing import Dict | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| from transformers import ( | |
| pipeline, | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| BitsAndBytesConfig, | |
| ) | |
| # ----------------------------- # | |
| # Configurations # | |
| # ----------------------------- # | |
| transformers.set_seed(42) | |
| torch.set_default_dtype(torch.bfloat16) | |
| MODEL_NAME = "climatebert/distilroberta-base-climate-sentiment" | |
| BATCH_PROCESS_INTERVAL = 0.01 | |
| MAX_BATCH_SIZE = 128 | |
| # ----------------------------- # | |
| # Shared Storage # | |
| # ----------------------------- # | |
| query_queue: asyncio.Queue = asyncio.Queue() | |
| results: Dict[str, Dict] = {} | |
| classifier = None # will be initialized in lifespan | |
| # ----------------------------- # | |
| # Model Initialization # | |
| # ----------------------------- # | |
| def load_classifier(model_name: str): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| quantization_config=BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ), | |
| ) | |
| return pipeline( | |
| "text-classification", model=model, tokenizer=tokenizer, framework="pt" | |
| ) | |
| # ----------------------------- # | |
| # Pydantic Schema # | |
| # ----------------------------- # | |
| class Query(BaseModel): | |
| sentence: str | |
| # ----------------------------- # | |
| # Queue Processing Task # | |
| # ----------------------------- # | |
| async def process_queue(): | |
| while True: | |
| await asyncio.sleep(BATCH_PROCESS_INTERVAL) | |
| batch = [] | |
| while not query_queue.empty() and len(batch) < MAX_BATCH_SIZE: | |
| batch.append(await query_queue.get()) | |
| if not batch: | |
| continue | |
| sentences = [item["sentence"] for item in batch] | |
| ids = [item["id"] for item in batch] | |
| predictions = classifier(sentences, batch_size=len(sentences)) | |
| for query_id, pred, sentence in zip(ids, predictions, sentences): | |
| results[query_id] = { | |
| "sentence": sentence, | |
| "label": pred["label"], | |
| "score": pred["score"], | |
| } | |
| # ----------------------------- # | |
| # Lifespan Handler # | |
| # ----------------------------- # | |
| async def lifespan(app: FastAPI): | |
| global classifier | |
| classifier = load_classifier(MODEL_NAME) | |
| _ = classifier("Startup warm-up sentence.") | |
| queue_task = asyncio.create_task(process_queue()) | |
| yield | |
| queue_task.cancel() | |
| try: | |
| await queue_task | |
| except asyncio.CancelledError: | |
| pass | |
| # ----------------------------- # | |
| # FastAPI Setup # | |
| # ----------------------------- # | |
| app = FastAPI(lifespan=lifespan) | |
| # ----------------------------- # | |
| # API Endpoints # | |
| # ----------------------------- # | |
| async def classify(query: Query): | |
| query_id = str(uuid.uuid4()) | |
| await query_queue.put({"id": query_id, "sentence": query.sentence}) | |
| while query_id not in results: | |
| await asyncio.sleep(0.001) | |
| return {"id": query_id, "result": results.pop(query_id)} | |
| def read_root(): | |
| return { | |
| "message": "Welcome to the Sentiment Classification API with Query Batching" | |
| } | |