e-hossam96's picture
added main and test codes
56f8c1c
raw
history blame
3.46 kB
import uuid
import torch
import asyncio
import transformers
from typing import Dict
from fastapi import FastAPI
from pydantic import BaseModel
from contextlib import asynccontextmanager
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
BitsAndBytesConfig,
)
# ----------------------------- #
# Configurations #
# ----------------------------- #
transformers.set_seed(42)
torch.set_default_dtype(torch.bfloat16)
MODEL_NAME = "climatebert/distilroberta-base-climate-sentiment"
BATCH_PROCESS_INTERVAL = 0.01
MAX_BATCH_SIZE = 128
# ----------------------------- #
# Shared Storage #
# ----------------------------- #
query_queue: asyncio.Queue = asyncio.Queue()
results: Dict[str, Dict] = {}
classifier = None # will be initialized in lifespan
# ----------------------------- #
# Model Initialization #
# ----------------------------- #
def load_classifier(model_name: str):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
device_map="auto",
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
),
)
return pipeline(
"text-classification", model=model, tokenizer=tokenizer, framework="pt"
)
# ----------------------------- #
# Pydantic Schema #
# ----------------------------- #
class Query(BaseModel):
sentence: str
# ----------------------------- #
# Queue Processing Task #
# ----------------------------- #
async def process_queue():
while True:
await asyncio.sleep(BATCH_PROCESS_INTERVAL)
batch = []
while not query_queue.empty() and len(batch) < MAX_BATCH_SIZE:
batch.append(await query_queue.get())
if not batch:
continue
sentences = [item["sentence"] for item in batch]
ids = [item["id"] for item in batch]
predictions = classifier(sentences, batch_size=len(sentences))
for query_id, pred, sentence in zip(ids, predictions, sentences):
results[query_id] = {
"sentence": sentence,
"label": pred["label"],
"score": pred["score"],
}
# ----------------------------- #
# Lifespan Handler #
# ----------------------------- #
@asynccontextmanager
async def lifespan(app: FastAPI):
global classifier
classifier = load_classifier(MODEL_NAME)
_ = classifier("Startup warm-up sentence.")
queue_task = asyncio.create_task(process_queue())
yield
queue_task.cancel()
try:
await queue_task
except asyncio.CancelledError:
pass
# ----------------------------- #
# FastAPI Setup #
# ----------------------------- #
app = FastAPI(lifespan=lifespan)
# ----------------------------- #
# API Endpoints #
# ----------------------------- #
@app.post("/classify")
async def classify(query: Query):
query_id = str(uuid.uuid4())
await query_queue.put({"id": query_id, "sentence": query.sentence})
while query_id not in results:
await asyncio.sleep(0.001)
return {"id": query_id, "result": results.pop(query_id)}
@app.get("/")
def read_root():
return {
"message": "Welcome to the Sentiment Classification API with Query Batching"
}