File size: 3,455 Bytes
56f8c1c
 
 
 
 
 
 
 
 
 
 
 
 
 
4816530
56f8c1c
 
 
 
 
4816530
56f8c1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import uuid
import torch
import asyncio
import transformers
from typing import Dict
from fastapi import FastAPI
from pydantic import BaseModel
from contextlib import asynccontextmanager
from transformers import (
    pipeline,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    BitsAndBytesConfig,
)

# ----------------------------- #
#         Configurations        #
# ----------------------------- #
transformers.set_seed(42)
torch.set_default_dtype(torch.bfloat16)

MODEL_NAME = "climatebert/distilroberta-base-climate-sentiment"
BATCH_PROCESS_INTERVAL = 0.01
MAX_BATCH_SIZE = 128

# ----------------------------- #
#        Shared Storage         #
# ----------------------------- #
query_queue: asyncio.Queue = asyncio.Queue()
results: Dict[str, Dict] = {}
classifier = None  # will be initialized in lifespan


# ----------------------------- #
#      Model Initialization     #
# ----------------------------- #
def load_classifier(model_name: str):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        device_map="auto",
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
        ),
    )
    return pipeline(
        "text-classification", model=model, tokenizer=tokenizer, framework="pt"
    )


# ----------------------------- #
#         Pydantic Schema       #
# ----------------------------- #
class Query(BaseModel):
    sentence: str


# ----------------------------- #
#      Queue Processing Task    #
# ----------------------------- #
async def process_queue():
    while True:
        await asyncio.sleep(BATCH_PROCESS_INTERVAL)

        batch = []
        while not query_queue.empty() and len(batch) < MAX_BATCH_SIZE:
            batch.append(await query_queue.get())

        if not batch:
            continue

        sentences = [item["sentence"] for item in batch]
        ids = [item["id"] for item in batch]
        predictions = classifier(sentences, batch_size=len(sentences))

        for query_id, pred, sentence in zip(ids, predictions, sentences):
            results[query_id] = {
                "sentence": sentence,
                "label": pred["label"],
                "score": pred["score"],
            }


# ----------------------------- #
#        Lifespan Handler       #
# ----------------------------- #
@asynccontextmanager
async def lifespan(app: FastAPI):
    global classifier
    classifier = load_classifier(MODEL_NAME)
    _ = classifier("Startup warm-up sentence.")
    queue_task = asyncio.create_task(process_queue())
    yield
    queue_task.cancel()
    try:
        await queue_task
    except asyncio.CancelledError:
        pass


# ----------------------------- #
#         FastAPI Setup         #
# ----------------------------- #
app = FastAPI(lifespan=lifespan)


# ----------------------------- #
#         API Endpoints         #
# ----------------------------- #
@app.post("/classify")
async def classify(query: Query):
    query_id = str(uuid.uuid4())
    await query_queue.put({"id": query_id, "sentence": query.sentence})

    while query_id not in results:
        await asyncio.sleep(0.001)

    return {"id": query_id, "result": results.pop(query_id)}


@app.get("/")
def read_root():
    return {
        "message": "Welcome to the Sentiment Classification API with Query Batching"
    }