e-hossam96 commited on
Commit
bea1666
·
1 Parent(s): 89256c4

optimized code for tread safe operations

Browse files
Files changed (1) hide show
  1. main.py +22 -9
main.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import uuid
2
  import logging
3
  import asyncio
@@ -25,6 +26,7 @@ MAX_BATCH_SIZE = 128
25
  # ----------------------------- #
26
  # Shared Storage #
27
  # ----------------------------- #
 
28
  query_queue: asyncio.Queue = asyncio.Queue()
29
  results: dict[str, dict] = {}
30
  classifier = None # will be initialized in lifespan
@@ -56,6 +58,8 @@ def load_classifier(model_name: str):
56
  model=model,
57
  tokenizer=tokenizer,
58
  framework="pt",
 
 
59
  )
60
 
61
 
@@ -82,14 +86,19 @@ async def process_queue():
82
 
83
  sentences = [item["sentence"] for item in batch]
84
  ids = [item["id"] for item in batch]
85
- predictions = classifier(sentences, batch_size=len(sentences))
86
 
87
- for query_id, pred, sentence in zip(ids, predictions, sentences):
88
- results[query_id] = {
89
- "sentence": sentence,
90
- "label": pred["label"],
91
- "score": pred["score"],
92
- }
 
 
 
 
 
93
 
94
 
95
  # ----------------------------- #
@@ -127,10 +136,14 @@ async def classify(query: Query):
127
  query_id = str(uuid.uuid4())
128
  await query_queue.put({"id": query_id, "sentence": query.sentence})
129
 
130
- while query_id not in results:
 
 
 
 
131
  await asyncio.sleep(0.001)
132
 
133
- return {"id": query_id, "result": results.pop(query_id)}
134
 
135
 
136
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
1
+ import os
2
  import uuid
3
  import logging
4
  import asyncio
 
26
  # ----------------------------- #
27
  # Shared Storage #
28
  # ----------------------------- #
29
+ lock = asyncio.Lock()
30
  query_queue: asyncio.Queue = asyncio.Queue()
31
  results: dict[str, dict] = {}
32
  classifier = None # will be initialized in lifespan
 
58
  model=model,
59
  tokenizer=tokenizer,
60
  framework="pt",
61
+ batch_size=MAX_BATCH_SIZE,
62
+ num_workers=os.cpu_count(),
63
  )
64
 
65
 
 
86
 
87
  sentences = [item["sentence"] for item in batch]
88
  ids = [item["id"] for item in batch]
89
+ predictions = classifier(sentences)
90
 
91
+ async with lock:
92
+ results.update(
93
+ {
94
+ query_id: {
95
+ "sentence": sentence,
96
+ "label": pred["label"],
97
+ "score": pred["score"],
98
+ }
99
+ for query_id, pred, sentence in zip(ids, predictions, sentences)
100
+ }
101
+ )
102
 
103
 
104
  # ----------------------------- #
 
136
  query_id = str(uuid.uuid4())
137
  await query_queue.put({"id": query_id, "sentence": query.sentence})
138
 
139
+ result = None
140
+ while result is None:
141
+ async with lock:
142
+ if query_id in results:
143
+ result = results.pop(query_id)
144
  await asyncio.sleep(0.001)
145
 
146
+ return {"id": query_id, "result": result}
147
 
148
 
149
  app.mount("/", StaticFiles(directory="static", html=True), name="static")