Spaces:
Sleeping
Sleeping
File size: 3,455 Bytes
56f8c1c 4816530 56f8c1c 4816530 56f8c1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import uuid
import torch
import asyncio
import transformers
from typing import Dict
from fastapi import FastAPI
from pydantic import BaseModel
from contextlib import asynccontextmanager
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
BitsAndBytesConfig,
)
# ----------------------------- #
# Configurations #
# ----------------------------- #
transformers.set_seed(42)
torch.set_default_dtype(torch.bfloat16)
MODEL_NAME = "climatebert/distilroberta-base-climate-sentiment"
BATCH_PROCESS_INTERVAL = 0.01
MAX_BATCH_SIZE = 128
# ----------------------------- #
# Shared Storage #
# ----------------------------- #
query_queue: asyncio.Queue = asyncio.Queue()
results: Dict[str, Dict] = {}
classifier = None # will be initialized in lifespan
# ----------------------------- #
# Model Initialization #
# ----------------------------- #
def load_classifier(model_name: str):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
device_map="auto",
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
),
)
return pipeline(
"text-classification", model=model, tokenizer=tokenizer, framework="pt"
)
# ----------------------------- #
# Pydantic Schema #
# ----------------------------- #
class Query(BaseModel):
sentence: str
# ----------------------------- #
# Queue Processing Task #
# ----------------------------- #
async def process_queue():
while True:
await asyncio.sleep(BATCH_PROCESS_INTERVAL)
batch = []
while not query_queue.empty() and len(batch) < MAX_BATCH_SIZE:
batch.append(await query_queue.get())
if not batch:
continue
sentences = [item["sentence"] for item in batch]
ids = [item["id"] for item in batch]
predictions = classifier(sentences, batch_size=len(sentences))
for query_id, pred, sentence in zip(ids, predictions, sentences):
results[query_id] = {
"sentence": sentence,
"label": pred["label"],
"score": pred["score"],
}
# ----------------------------- #
# Lifespan Handler #
# ----------------------------- #
@asynccontextmanager
async def lifespan(app: FastAPI):
global classifier
classifier = load_classifier(MODEL_NAME)
_ = classifier("Startup warm-up sentence.")
queue_task = asyncio.create_task(process_queue())
yield
queue_task.cancel()
try:
await queue_task
except asyncio.CancelledError:
pass
# ----------------------------- #
# FastAPI Setup #
# ----------------------------- #
app = FastAPI(lifespan=lifespan)
# ----------------------------- #
# API Endpoints #
# ----------------------------- #
@app.post("/classify")
async def classify(query: Query):
query_id = str(uuid.uuid4())
await query_queue.put({"id": query_id, "sentence": query.sentence})
while query_id not in results:
await asyncio.sleep(0.001)
return {"id": query_id, "result": results.pop(query_id)}
@app.get("/")
def read_root():
return {
"message": "Welcome to the Sentiment Classification API with Query Batching"
}
|