Spaces:
Sleeping
Sleeping
Commit
·
bea1666
1
Parent(s):
89256c4
optimized code for tread safe operations
Browse files
main.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import uuid
|
| 2 |
import logging
|
| 3 |
import asyncio
|
|
@@ -25,6 +26,7 @@ MAX_BATCH_SIZE = 128
|
|
| 25 |
# ----------------------------- #
|
| 26 |
# Shared Storage #
|
| 27 |
# ----------------------------- #
|
|
|
|
| 28 |
query_queue: asyncio.Queue = asyncio.Queue()
|
| 29 |
results: dict[str, dict] = {}
|
| 30 |
classifier = None # will be initialized in lifespan
|
|
@@ -56,6 +58,8 @@ def load_classifier(model_name: str):
|
|
| 56 |
model=model,
|
| 57 |
tokenizer=tokenizer,
|
| 58 |
framework="pt",
|
|
|
|
|
|
|
| 59 |
)
|
| 60 |
|
| 61 |
|
|
@@ -82,14 +86,19 @@ async def process_queue():
|
|
| 82 |
|
| 83 |
sentences = [item["sentence"] for item in batch]
|
| 84 |
ids = [item["id"] for item in batch]
|
| 85 |
-
predictions = classifier(sentences
|
| 86 |
|
| 87 |
-
|
| 88 |
-
results
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
# ----------------------------- #
|
|
@@ -127,10 +136,14 @@ async def classify(query: Query):
|
|
| 127 |
query_id = str(uuid.uuid4())
|
| 128 |
await query_queue.put({"id": query_id, "sentence": query.sentence})
|
| 129 |
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
await asyncio.sleep(0.001)
|
| 132 |
|
| 133 |
-
return {"id": query_id, "result":
|
| 134 |
|
| 135 |
|
| 136 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
|
|
|
| 1 |
+
import os
|
| 2 |
import uuid
|
| 3 |
import logging
|
| 4 |
import asyncio
|
|
|
|
| 26 |
# ----------------------------- #
|
| 27 |
# Shared Storage #
|
| 28 |
# ----------------------------- #
|
| 29 |
+
lock = asyncio.Lock()
|
| 30 |
query_queue: asyncio.Queue = asyncio.Queue()
|
| 31 |
results: dict[str, dict] = {}
|
| 32 |
classifier = None # will be initialized in lifespan
|
|
|
|
| 58 |
model=model,
|
| 59 |
tokenizer=tokenizer,
|
| 60 |
framework="pt",
|
| 61 |
+
batch_size=MAX_BATCH_SIZE,
|
| 62 |
+
num_workers=os.cpu_count(),
|
| 63 |
)
|
| 64 |
|
| 65 |
|
|
|
|
| 86 |
|
| 87 |
sentences = [item["sentence"] for item in batch]
|
| 88 |
ids = [item["id"] for item in batch]
|
| 89 |
+
predictions = classifier(sentences)
|
| 90 |
|
| 91 |
+
async with lock:
|
| 92 |
+
results.update(
|
| 93 |
+
{
|
| 94 |
+
query_id: {
|
| 95 |
+
"sentence": sentence,
|
| 96 |
+
"label": pred["label"],
|
| 97 |
+
"score": pred["score"],
|
| 98 |
+
}
|
| 99 |
+
for query_id, pred, sentence in zip(ids, predictions, sentences)
|
| 100 |
+
}
|
| 101 |
+
)
|
| 102 |
|
| 103 |
|
| 104 |
# ----------------------------- #
|
|
|
|
| 136 |
query_id = str(uuid.uuid4())
|
| 137 |
await query_queue.put({"id": query_id, "sentence": query.sentence})
|
| 138 |
|
| 139 |
+
result = None
|
| 140 |
+
while result is None:
|
| 141 |
+
async with lock:
|
| 142 |
+
if query_id in results:
|
| 143 |
+
result = results.pop(query_id)
|
| 144 |
await asyncio.sleep(0.001)
|
| 145 |
|
| 146 |
+
return {"id": query_id, "result": result}
|
| 147 |
|
| 148 |
|
| 149 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|