lamhieu commited on
Commit
6bcf9a9
·
1 Parent(s): de9dea0

chore: update something

Browse files
lightweight_embeddings/analytics.py CHANGED
@@ -95,35 +95,61 @@ class Analytics:
95
  - model_id (str): The ID of the accessed model.
96
  - tokens (int): Number of tokens used in this access event.
97
  """
 
 
 
 
 
 
 
 
 
98
  keys = self._get_period_keys()
99
 
100
  async with self.lock:
101
- for period_key in keys:
102
- # Increase new increments by the usage
103
- self.new_increments["access"][period_key][model_id] += 1
104
- self.new_increments["tokens"][period_key][model_id] += tokens
105
-
106
- # Also update current_totals so that stats() are immediately up to date
107
- self.current_totals["access"][period_key][model_id] += 1
108
- self.current_totals["tokens"][period_key][model_id] += tokens
 
 
 
 
 
109
 
110
  async def stats(self) -> Dict[str, Dict[str, Dict[str, int]]]:
111
  """
112
  Returns a copy of current statistics from the local buffer (absolute totals).
113
  """
114
  async with self.lock:
115
- # Return the current_totals, which includes everything loaded from Redis
116
- # plus all increments since the last sync.
117
- return {
118
- "access": {
119
- period: dict(models)
120
- for period, models in self.current_totals["access"].items()
121
- },
122
- "tokens": {
123
- period: dict(models)
124
- for period, models in self.current_totals["tokens"].items()
125
- },
126
- }
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  async def _sync_from_redis(self):
129
  """
@@ -162,8 +188,14 @@ class Analytics:
162
  data = await loop.run_in_executor(
163
  None, partial(self.redis_client.hgetall, key)
164
  )
165
- for model_id, count_str in data.items():
166
- self.current_totals["access"][period][model_id] = int(count_str)
 
 
 
 
 
 
167
 
168
  if cursor == 0:
169
  break
@@ -187,8 +219,14 @@ class Analytics:
187
  data = await loop.run_in_executor(
188
  None, partial(self.redis_client.hgetall, key)
189
  )
190
- for model_id, count_str in data.items():
191
- self.current_totals["tokens"][period][model_id] = int(count_str)
 
 
 
 
 
 
192
 
193
  if cursor == 0:
194
  break
@@ -201,34 +239,44 @@ class Analytics:
201
  loop = asyncio.get_running_loop()
202
  async with self.lock:
203
  try:
 
 
204
  # For each (period, model_id, count) in new_increments, call HINCRBY
205
  for period, models in self.new_increments["access"].items():
206
  redis_key = f"analytics:access:{period}"
207
  for model_id, count in models.items():
208
- if count != 0:
209
- await loop.run_in_executor(
210
- None,
211
- partial(
212
- self.redis_client.hincrby,
213
- redis_key,
214
- model_id,
215
- count,
216
- ),
217
- )
 
 
 
 
218
 
219
  for period, models in self.new_increments["tokens"].items():
220
  redis_key = f"analytics:tokens:{period}"
221
  for model_id, count in models.items():
222
- if count != 0:
223
- await loop.run_in_executor(
224
- None,
225
- partial(
226
- self.redis_client.hincrby,
227
- redis_key,
228
- model_id,
229
- count,
230
- ),
231
- )
 
 
 
 
232
 
233
  # Reset new_increments after successful sync
234
  self.new_increments = {
@@ -236,7 +284,7 @@ class Analytics:
236
  "tokens": defaultdict(lambda: defaultdict(int)),
237
  }
238
 
239
- logger.info("Analytics data successfully synced to Upstash Redis.")
240
  except Exception as e:
241
  logger.error("Unexpected error during Upstash Redis sync: %s", e)
242
  raise e
 
95
  - model_id (str): The ID of the accessed model.
96
  - tokens (int): Number of tokens used in this access event.
97
  """
98
+ # Validate inputs
99
+ if not model_id or not isinstance(model_id, str):
100
+ logger.warning("Invalid model_id provided: %s", model_id)
101
+ return
102
+
103
+ if tokens < 0:
104
+ logger.warning("Negative token count provided for model %s: %d", model_id, tokens)
105
+ tokens = 0
106
+
107
  keys = self._get_period_keys()
108
 
109
  async with self.lock:
110
+ try:
111
+ for period_key in keys:
112
+ # Increase new increments by the usage
113
+ self.new_increments["access"][period_key][model_id] += 1
114
+ self.new_increments["tokens"][period_key][model_id] += tokens
115
+
116
+ # Also update current_totals so that stats() are immediately up to date
117
+ self.current_totals["access"][period_key][model_id] += 1
118
+ self.current_totals["tokens"][period_key][model_id] += tokens
119
+
120
+ logger.debug("Recorded access for model %s: %d tokens", model_id, tokens)
121
+ except Exception as e:
122
+ logger.error("Error recording access for model %s: %s", model_id, e)
123
 
124
  async def stats(self) -> Dict[str, Dict[str, Dict[str, int]]]:
125
  """
126
  Returns a copy of current statistics from the local buffer (absolute totals).
127
  """
128
  async with self.lock:
129
+ try:
130
+ # Return the current_totals, which includes everything loaded from Redis
131
+ # plus all increments since the last sync.
132
+ result = {
133
+ "access": {
134
+ period: dict(models)
135
+ for period, models in self.current_totals["access"].items()
136
+ },
137
+ "tokens": {
138
+ period: dict(models)
139
+ for period, models in self.current_totals["tokens"].items()
140
+ },
141
+ }
142
+
143
+ logger.debug("Retrieved stats for %d access periods and %d token periods",
144
+ len(result["access"]), len(result["tokens"]))
145
+ return result
146
+ except Exception as e:
147
+ logger.error("Error retrieving stats: %s", e)
148
+ # Return empty structure if there's an error
149
+ return {
150
+ "access": {},
151
+ "tokens": {},
152
+ }
153
 
154
  async def _sync_from_redis(self):
155
  """
 
188
  data = await loop.run_in_executor(
189
  None, partial(self.redis_client.hgetall, key)
190
  )
191
+ # Ensure data is not None and handle empty results
192
+ if data:
193
+ for model_id, count_str in data.items():
194
+ try:
195
+ self.current_totals["access"][period][model_id] = int(count_str)
196
+ except (ValueError, TypeError):
197
+ logger.warning("Invalid count value for model %s in period %s: %s", model_id, period, count_str)
198
+ self.current_totals["access"][period][model_id] = 0
199
 
200
  if cursor == 0:
201
  break
 
219
  data = await loop.run_in_executor(
220
  None, partial(self.redis_client.hgetall, key)
221
  )
222
+ # Ensure data is not None and handle empty results
223
+ if data:
224
+ for model_id, count_str in data.items():
225
+ try:
226
+ self.current_totals["tokens"][period][model_id] = int(count_str)
227
+ except (ValueError, TypeError):
228
+ logger.warning("Invalid token count value for model %s in period %s: %s", model_id, period, count_str)
229
+ self.current_totals["tokens"][period][model_id] = 0
230
 
231
  if cursor == 0:
232
  break
 
239
  loop = asyncio.get_running_loop()
240
  async with self.lock:
241
  try:
242
+ sync_count = 0
243
+
244
  # For each (period, model_id, count) in new_increments, call HINCRBY
245
  for period, models in self.new_increments["access"].items():
246
  redis_key = f"analytics:access:{period}"
247
  for model_id, count in models.items():
248
+ if count > 0: # Only sync positive counts
249
+ try:
250
+ await loop.run_in_executor(
251
+ None,
252
+ partial(
253
+ self.redis_client.hincrby,
254
+ redis_key,
255
+ model_id,
256
+ count,
257
+ ),
258
+ )
259
+ sync_count += 1
260
+ except Exception as e:
261
+ logger.error("Failed to sync access count for model %s, period %s: %s", model_id, period, e)
262
 
263
  for period, models in self.new_increments["tokens"].items():
264
  redis_key = f"analytics:tokens:{period}"
265
  for model_id, count in models.items():
266
+ if count > 0: # Only sync positive counts
267
+ try:
268
+ await loop.run_in_executor(
269
+ None,
270
+ partial(
271
+ self.redis_client.hincrby,
272
+ redis_key,
273
+ model_id,
274
+ count,
275
+ ),
276
+ )
277
+ sync_count += 1
278
+ except Exception as e:
279
+ logger.error("Failed to sync token count for model %s, period %s: %s", model_id, period, e)
280
 
281
  # Reset new_increments after successful sync
282
  self.new_increments = {
 
284
  "tokens": defaultdict(lambda: defaultdict(int)),
285
  }
286
 
287
+ logger.info("Analytics data successfully synced to Upstash Redis. Synced %d entries.", sync_count)
288
  except Exception as e:
289
  logger.error("Unexpected error during Upstash Redis sync: %s", e)
290
  raise e
lightweight_embeddings/router.py CHANGED
@@ -37,7 +37,7 @@ class EmbeddingRequest(BaseModel):
37
  "Which model ID to use? "
38
  "Text options: ['multilingual-e5-small', 'multilingual-e5-base', 'multilingual-e5-large', "
39
  "'snowflake-arctic-embed-l-v2.0', 'paraphrase-multilingual-MiniLM-L12-v2', "
40
- "'paraphrase-multilingual-mpnet-base-v2', 'bge-m3', 'gte-multilingual-base', 'embeddinggemma']. "
41
  "Image option: ['siglip-base-patch16-256-multilingual']."
42
  ),
43
  )
@@ -120,7 +120,7 @@ rate_limit_cache: Dict[str, List[float]] = {}
120
 
121
 
122
  def check_rate_limit(
123
- client_ip: str, max_requests: int = 4, window_seconds: int = 60
124
  ) -> bool:
125
  """
126
  Check if the client IP has exceeded the rate limit.
@@ -162,46 +162,35 @@ async def create_embeddings(
162
  expected_token = os.environ.get("ACCESS_TOKEN")
163
  is_authenticated = False
164
 
165
- if expected_token:
166
- if authorization:
167
- # Support both "Bearer <token>" and plain token formats
168
- token = authorization
169
- if authorization.startswith("Bearer "):
170
- token = authorization[7:] # Remove "Bearer " prefix
171
-
172
- if token == expected_token:
173
- is_authenticated = True
174
-
175
- # If not authenticated, check rate limit
176
- if not is_authenticated:
177
- # Get client IP
178
- client_ip = fastapi_request.client.host
179
- if hasattr(fastapi_request.headers, "get"):
180
- # Check for forwarded IP (in case of proxy)
181
- forwarded_for = fastapi_request.headers.get("X-Forwarded-For")
182
- if forwarded_for:
183
- client_ip = forwarded_for.split(",")[0].strip()
184
-
185
- real_ip = fastapi_request.headers.get("X-Real-IP")
186
- if real_ip:
187
- client_ip = real_ip.strip()
188
-
189
- # Check rate limit (4 requests per minute)
190
- if not check_rate_limit(client_ip):
191
- raise HTTPException(
192
- status_code=429,
193
- detail="Rate limit exceeded. Maximum 4 requests per minute for unauthenticated users.",
194
- )
195
-
196
- # If no authorization header was provided when ACCESS_TOKEN is set
197
- if not authorization:
198
- raise HTTPException(
199
- status_code=401, detail="Authorization header required"
200
- )
201
- else:
202
- raise HTTPException(
203
- status_code=401, detail="Invalid authorization token"
204
- )
205
 
206
  try:
207
  modality = detect_model_kind(request.model)
@@ -251,10 +240,49 @@ async def create_embeddings(
251
 
252
 
253
  @router.post("/rank", response_model=RankResponse, tags=["rank"])
254
- async def rank_candidates(request: RankRequest, background_tasks: BackgroundTasks):
 
 
 
 
 
255
  """
256
  Rank candidate texts against the given queries.
257
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  try:
259
  results = await embeddings_service.rank(
260
  model=request.model,
 
37
  "Which model ID to use? "
38
  "Text options: ['multilingual-e5-small', 'multilingual-e5-base', 'multilingual-e5-large', "
39
  "'snowflake-arctic-embed-l-v2.0', 'paraphrase-multilingual-MiniLM-L12-v2', "
40
+ "'paraphrase-multilingual-mpnet-base-v2', 'bge-m3', 'gte-multilingual-base', 'embeddinggemma-300m']. "
41
  "Image option: ['siglip-base-patch16-256-multilingual']."
42
  ),
43
  )
 
120
 
121
 
122
  def check_rate_limit(
123
+ client_ip: str, max_requests: int = 10, window_seconds: int = 60
124
  ) -> bool:
125
  """
126
  Check if the client IP has exceeded the rate limit.
 
162
  expected_token = os.environ.get("ACCESS_TOKEN")
163
  is_authenticated = False
164
 
165
+ if expected_token and authorization:
166
+ # Support both "Bearer <token>" and plain token formats
167
+ token = authorization
168
+ if authorization.startswith("Bearer "):
169
+ token = authorization[7:] # Remove "Bearer " prefix
170
+
171
+ if token == expected_token:
172
+ is_authenticated = True
173
+
174
+ # If not authenticated (no token, empty token, or wrong token), apply rate limit
175
+ if not is_authenticated:
176
+ # Get client IP
177
+ client_ip = fastapi_request.client.host
178
+ if hasattr(fastapi_request.headers, "get"):
179
+ # Check for forwarded IP (in case of proxy)
180
+ forwarded_for = fastapi_request.headers.get("X-Forwarded-For")
181
+ if forwarded_for:
182
+ client_ip = forwarded_for.split(",")[0].strip()
183
+
184
+ real_ip = fastapi_request.headers.get("X-Real-IP")
185
+ if real_ip:
186
+ client_ip = real_ip.strip()
187
+
188
+ # Check rate limit (10 requests per minute for unauthenticated users)
189
+ if not check_rate_limit(client_ip):
190
+ raise HTTPException(
191
+ status_code=429,
192
+ detail="Rate limit exceeded. Maximum 10 requests per minute for unauthenticated users.",
193
+ )
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  try:
196
  modality = detect_model_kind(request.model)
 
240
 
241
 
242
  @router.post("/rank", response_model=RankResponse, tags=["rank"])
243
+ async def rank_candidates(
244
+ request: RankRequest,
245
+ background_tasks: BackgroundTasks,
246
+ fastapi_request: Request,
247
+ authorization: str = Header(None),
248
+ ):
249
  """
250
  Rank candidate texts against the given queries.
251
  """
252
+ # Check authorization
253
+ expected_token = os.environ.get("ACCESS_TOKEN")
254
+ is_authenticated = False
255
+
256
+ if expected_token and authorization:
257
+ # Support both "Bearer <token>" and plain token formats
258
+ token = authorization
259
+ if authorization.startswith("Bearer "):
260
+ token = authorization[7:] # Remove "Bearer " prefix
261
+
262
+ if token == expected_token:
263
+ is_authenticated = True
264
+
265
+ # If not authenticated (no token, empty token, or wrong token), apply rate limit
266
+ if not is_authenticated:
267
+ # Get client IP
268
+ client_ip = fastapi_request.client.host
269
+ if hasattr(fastapi_request.headers, "get"):
270
+ # Check for forwarded IP (in case of proxy)
271
+ forwarded_for = fastapi_request.headers.get("X-Forwarded-For")
272
+ if forwarded_for:
273
+ client_ip = forwarded_for.split(",")[0].strip()
274
+
275
+ real_ip = fastapi_request.headers.get("X-Real-IP")
276
+ if real_ip:
277
+ client_ip = real_ip.strip()
278
+
279
+ # Check rate limit (10 requests per minute for unauthenticated users)
280
+ if not check_rate_limit(client_ip):
281
+ raise HTTPException(
282
+ status_code=429,
283
+ detail="Rate limit exceeded. Maximum 10 requests per minute for unauthenticated users.",
284
+ )
285
+
286
  try:
287
  results = await embeddings_service.rank(
288
  model=request.model,