e-hossam96 commited on
Commit
39562f3
·
1 Parent(s): cd4038f

lower expectations for a 2 vCPU instance

Browse files
Files changed (1) hide show
  1. main.py +13 -13
main.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
  import uuid
3
  import logging
4
  import asyncio
@@ -20,8 +20,8 @@ from optimum.onnxruntime import ORTModelForSequenceClassification
20
  transformers.set_seed(42)
21
 
22
  MODEL_NAME = "distilroberta-base-climate-sentiment-onnx-quantized"
23
- BATCH_PROCESS_INTERVAL = 0.01
24
- MAX_BATCH_SIZE = 128
25
 
26
  # ----------------------------- #
27
  # Shared Storage #
@@ -52,6 +52,8 @@ def load_classifier(model_name: str):
52
  model = ORTModelForSequenceClassification.from_pretrained(
53
  model_name,
54
  )
 
 
55
  return pipeline(
56
  task="text-classification",
57
  accelerator="ort",
@@ -59,7 +61,7 @@ def load_classifier(model_name: str):
59
  tokenizer=tokenizer,
60
  framework="pt",
61
  batch_size=MAX_BATCH_SIZE,
62
- num_workers=os.cpu_count(),
63
  )
64
 
65
 
@@ -108,13 +110,15 @@ async def process_queue():
108
  async def lifespan(_: FastAPI):
109
  global classifier
110
  classifier = load_classifier(MODEL_NAME)
111
- _ = classifier("Startup warm-up sentence.")
112
  logger.info("Model loaded successfully.")
113
  queue_task = asyncio.create_task(process_queue())
114
  yield
115
  queue_task.cancel()
116
  logger.info("Shutting down the application...")
117
  logger.info("Model unloaded successfully.")
 
 
118
  try:
119
  await queue_task
120
  except asyncio.CancelledError:
@@ -133,18 +137,14 @@ app = FastAPI(lifespan=lifespan)
133
  @app.post("/classify")
134
  async def classify(query: Query):
135
  logger.info(f"{query.sentence}")
136
- query_id = str(uuid.uuid4())
137
  await query_queue.put({"id": query_id, "sentence": query.sentence})
138
 
139
- result = None
140
- while result is None:
141
  async with lock:
142
  if query_id in results:
143
- result = results.pop(query_id)
144
- if result is None:
145
- await asyncio.sleep(0.1)
146
-
147
- return {"id": query_id, "result": result}
148
 
149
 
150
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
1
+ import gc
2
  import uuid
3
  import logging
4
  import asyncio
 
20
  transformers.set_seed(42)
21
 
22
  MODEL_NAME = "distilroberta-base-climate-sentiment-onnx-quantized"
23
+ BATCH_PROCESS_INTERVAL = 0.05
24
+ MAX_BATCH_SIZE = 16
25
 
26
  # ----------------------------- #
27
  # Shared Storage #
 
52
  model = ORTModelForSequenceClassification.from_pretrained(
53
  model_name,
54
  )
55
+
56
+ gc.collect()
57
  return pipeline(
58
  task="text-classification",
59
  accelerator="ort",
 
61
  tokenizer=tokenizer,
62
  framework="pt",
63
  batch_size=MAX_BATCH_SIZE,
64
+ num_workers=1,
65
  )
66
 
67
 
 
110
  async def lifespan(_: FastAPI):
111
  global classifier
112
  classifier = load_classifier(MODEL_NAME)
113
+ _ = classifier("Hi")
114
  logger.info("Model loaded successfully.")
115
  queue_task = asyncio.create_task(process_queue())
116
  yield
117
  queue_task.cancel()
118
  logger.info("Shutting down the application...")
119
  logger.info("Model unloaded successfully.")
120
+ classifier = None
121
+ gc.collect()
122
  try:
123
  await queue_task
124
  except asyncio.CancelledError:
 
137
  @app.post("/classify")
138
  async def classify(query: Query):
139
  logger.info(f"{query.sentence}")
140
+ query_id = uuid.uuid4().hex
141
  await query_queue.put({"id": query_id, "sentence": query.sentence})
142
 
143
+ while True:
 
144
  async with lock:
145
  if query_id in results:
146
+ return {"id": query_id, "result": results.pop(query_id)}
147
+ await asyncio.sleep(0.1)
 
 
 
148
 
149
 
150
  app.mount("/", StaticFiles(directory="static", html=True), name="static")