Spaces:
Sleeping
Sleeping
Commit
·
39562f3
1
Parent(s):
cd4038f
lower expectations for a 2 vCPU instance
Browse files
main.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import
|
| 2 |
import uuid
|
| 3 |
import logging
|
| 4 |
import asyncio
|
|
@@ -20,8 +20,8 @@ from optimum.onnxruntime import ORTModelForSequenceClassification
|
|
| 20 |
transformers.set_seed(42)
|
| 21 |
|
| 22 |
MODEL_NAME = "distilroberta-base-climate-sentiment-onnx-quantized"
|
| 23 |
-
BATCH_PROCESS_INTERVAL = 0.
|
| 24 |
-
MAX_BATCH_SIZE =
|
| 25 |
|
| 26 |
# ----------------------------- #
|
| 27 |
# Shared Storage #
|
|
@@ -52,6 +52,8 @@ def load_classifier(model_name: str):
|
|
| 52 |
model = ORTModelForSequenceClassification.from_pretrained(
|
| 53 |
model_name,
|
| 54 |
)
|
|
|
|
|
|
|
| 55 |
return pipeline(
|
| 56 |
task="text-classification",
|
| 57 |
accelerator="ort",
|
|
@@ -59,7 +61,7 @@ def load_classifier(model_name: str):
|
|
| 59 |
tokenizer=tokenizer,
|
| 60 |
framework="pt",
|
| 61 |
batch_size=MAX_BATCH_SIZE,
|
| 62 |
-
num_workers=
|
| 63 |
)
|
| 64 |
|
| 65 |
|
|
@@ -108,13 +110,15 @@ async def process_queue():
|
|
| 108 |
async def lifespan(_: FastAPI):
|
| 109 |
global classifier
|
| 110 |
classifier = load_classifier(MODEL_NAME)
|
| 111 |
-
_ = classifier("
|
| 112 |
logger.info("Model loaded successfully.")
|
| 113 |
queue_task = asyncio.create_task(process_queue())
|
| 114 |
yield
|
| 115 |
queue_task.cancel()
|
| 116 |
logger.info("Shutting down the application...")
|
| 117 |
logger.info("Model unloaded successfully.")
|
|
|
|
|
|
|
| 118 |
try:
|
| 119 |
await queue_task
|
| 120 |
except asyncio.CancelledError:
|
|
@@ -133,18 +137,14 @@ app = FastAPI(lifespan=lifespan)
|
|
| 133 |
@app.post("/classify")
|
| 134 |
async def classify(query: Query):
|
| 135 |
logger.info(f"{query.sentence}")
|
| 136 |
-
query_id =
|
| 137 |
await query_queue.put({"id": query_id, "sentence": query.sentence})
|
| 138 |
|
| 139 |
-
|
| 140 |
-
while result is None:
|
| 141 |
async with lock:
|
| 142 |
if query_id in results:
|
| 143 |
-
result
|
| 144 |
-
|
| 145 |
-
await asyncio.sleep(0.1)
|
| 146 |
-
|
| 147 |
-
return {"id": query_id, "result": result}
|
| 148 |
|
| 149 |
|
| 150 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
|
|
|
| 1 |
+
import gc
|
| 2 |
import uuid
|
| 3 |
import logging
|
| 4 |
import asyncio
|
|
|
|
| 20 |
transformers.set_seed(42)
|
| 21 |
|
| 22 |
MODEL_NAME = "distilroberta-base-climate-sentiment-onnx-quantized"
|
| 23 |
+
BATCH_PROCESS_INTERVAL = 0.05
|
| 24 |
+
MAX_BATCH_SIZE = 16
|
| 25 |
|
| 26 |
# ----------------------------- #
|
| 27 |
# Shared Storage #
|
|
|
|
| 52 |
model = ORTModelForSequenceClassification.from_pretrained(
|
| 53 |
model_name,
|
| 54 |
)
|
| 55 |
+
|
| 56 |
+
gc.collect()
|
| 57 |
return pipeline(
|
| 58 |
task="text-classification",
|
| 59 |
accelerator="ort",
|
|
|
|
| 61 |
tokenizer=tokenizer,
|
| 62 |
framework="pt",
|
| 63 |
batch_size=MAX_BATCH_SIZE,
|
| 64 |
+
num_workers=1,
|
| 65 |
)
|
| 66 |
|
| 67 |
|
|
|
|
| 110 |
async def lifespan(_: FastAPI):
|
| 111 |
global classifier
|
| 112 |
classifier = load_classifier(MODEL_NAME)
|
| 113 |
+
_ = classifier("Hi")
|
| 114 |
logger.info("Model loaded successfully.")
|
| 115 |
queue_task = asyncio.create_task(process_queue())
|
| 116 |
yield
|
| 117 |
queue_task.cancel()
|
| 118 |
logger.info("Shutting down the application...")
|
| 119 |
logger.info("Model unloaded successfully.")
|
| 120 |
+
classifier = None
|
| 121 |
+
gc.collect()
|
| 122 |
try:
|
| 123 |
await queue_task
|
| 124 |
except asyncio.CancelledError:
|
|
|
|
| 137 |
@app.post("/classify")
|
| 138 |
async def classify(query: Query):
|
| 139 |
logger.info(f"{query.sentence}")
|
| 140 |
+
query_id = uuid.uuid4().hex
|
| 141 |
await query_queue.put({"id": query_id, "sentence": query.sentence})
|
| 142 |
|
| 143 |
+
while True:
|
|
|
|
| 144 |
async with lock:
|
| 145 |
if query_id in results:
|
| 146 |
+
return {"id": query_id, "result": results.pop(query_id)}
|
| 147 |
+
await asyncio.sleep(0.1)
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
|
| 150 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|