bhardwaj08sarthak commited on
Commit
ff78d09
·
verified ·
1 Parent(s): 700d92f

Update level_classifier_tool.py

Browse files
Files changed (1) hide show
  1. level_classifier_tool.py +27 -57
level_classifier_tool.py CHANGED
@@ -22,7 +22,9 @@ class HFEmbeddingBackend:
22
  Uses mean pooling over last_hidden_state and L2 normalizes the result.
23
  """
24
  model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
25
- device: Optional[str] = None # "cuda" | "cpu" | None -> auto
 
 
26
 
27
  def _lazy_import(self) -> None:
28
  global _TOK, _MODEL, _TORCH
@@ -33,24 +35,27 @@ class HFEmbeddingBackend:
33
  from transformers import AutoTokenizer, AutoModel # type: ignore
34
  _TOK = AutoTokenizer.from_pretrained(self.model_name)
35
  _MODEL = AutoModel.from_pretrained(self.model_name)
36
- dev = self.device or ("cuda" if _TORCH.cuda.is_available() else "cpu")
 
37
  _MODEL.to(dev).eval()
38
  self.device = dev
39
 
40
  def encode(self, texts: Iterable[str], batch_size: int = 32) -> "tuple[_TORCH.Tensor, list[str]]":
41
  """
42
- Returns (embeddings, texts_list). Embeddings have shape [N, D] and are unit-normalized.
43
  """
44
  self._lazy_import()
45
  torch = _TORCH # local alias
46
  texts_list = list(texts)
47
  if not texts_list:
 
48
  return torch.empty((0, _MODEL.config.hidden_size)), [] # type: ignore
49
 
50
  all_out = []
51
  with torch.inference_mode():
52
  for i in range(0, len(texts_list), batch_size):
53
  batch = texts_list[i:i + batch_size]
 
54
  enc = _TOK(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) # type: ignore
55
  out = _MODEL(**enc)
56
  last = out.last_hidden_state # [B, T, H]
@@ -61,7 +66,9 @@ class HFEmbeddingBackend:
61
  pooled = summed / counts
62
  # L2 normalize
63
  pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12)
 
64
  all_out.append(pooled.cpu())
 
65
  embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, _MODEL.config.hidden_size)) # type: ignore
66
  return embs, texts_list
67
 
@@ -102,16 +109,22 @@ def build_phrase_index(
102
  cur += len(plist)
103
  spans.append((lvl, start, cur))
104
 
105
- embs, _ = backend.encode(all_texts)
 
106
  # Slice embeddings back into level buckets
107
  torch = _TORCH
108
  embeddings_by_level: Dict[str, "Any"] = {}
109
  for lvl, start, end in spans:
110
- embeddings_by_level[lvl] = embs[start:end] if end > start else torch.empty((0, embs.shape[1])) # type: ignore
 
 
 
111
 
112
- return PhraseIndex(phrases_by_level={lvl: list(pl) for lvl, pl in cleaned.items()},
113
- embeddings_by_level=embeddings_by_level,
114
- model_name=backend.model_name)
 
 
115
 
116
 
117
  def _aggregate_sims(
@@ -153,64 +166,20 @@ def classify_levels_phrases(
153
  """
154
  Score a question against Bloom's taxonomy and DOK (Depth of Knowledge)
155
  using cosine similarity to level-specific anchor phrases.
156
-
157
- Parameters
158
- ----------
159
- question : str
160
- The input question or prompt.
161
- blooms_phrases : dict[str, Iterable[str]]
162
- Mapping level -> list of anchor phrases for Bloom's.
163
- dok_phrases : dict[str, Iterable[str]]
164
- Mapping level -> list of anchor phrases for DOK.
165
- model_name : str
166
- Hugging Face model name for text embeddings. Ignored when `backend` provided.
167
- agg : {"mean","max","topk_mean"}
168
- Aggregation over phrase similarities within a level.
169
- topk : int
170
- Used only when `agg="topk_mean"`.
171
- preprocess : Optional[Callable[[str], str]]
172
- Preprocessing function for the question string. Defaults to whitespace normalization.
173
- backend : Optional[HFEmbeddingBackend]
174
- Injected embedding backend. If not given, one is constructed.
175
- prebuilt_bloom_index, prebuilt_dok_index : Optional[PhraseIndex]
176
- If provided, reuse precomputed phrase embeddings to avoid re-encoding.
177
- return_phrase_matches : bool
178
- If True, returns per-level top contributing phrases.
179
-
180
- Returns
181
- -------
182
- dict
183
- {
184
- "question": ...,
185
- "model_name": ...,
186
- "blooms": {
187
- "scores": {level: float, ...},
188
- "best_level": str,
189
- "best_score": float,
190
- "top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches
191
- },
192
- "dok": {
193
- "scores": {level: float, ...},
194
- "best_level": str,
195
- "best_score": float,
196
- "top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches
197
- },
198
- "config": {"agg": agg, "topk": topk if agg=='topk_mean' else None}
199
- }
200
  """
201
  preprocess = preprocess or _default_preprocess
202
  question_clean = preprocess(question)
203
 
204
- # Prepare backend
205
  be = backend or HFEmbeddingBackend(model_name=model_name)
206
 
207
  # Build / reuse indices
208
  bloom_index = prebuilt_bloom_index or build_phrase_index(be, blooms_phrases)
209
  dok_index = prebuilt_dok_index or build_phrase_index(be, dok_phrases)
210
 
211
- # Encode question
212
  q_emb, _ = be.encode([question_clean])
213
- q_emb = q_emb[0:1] # [1, D]
214
  torch = _TORCH
215
 
216
  def _score_block(index: PhraseIndex) -> Tuple[Dict[str, float], Dict[str, List[Tuple[str, float]]]]:
@@ -218,12 +187,13 @@ def classify_levels_phrases(
218
  top_contribs: Dict[str, List[Tuple[str, float]]] = {}
219
 
220
  for lvl, phrases in index.phrases_by_level.items():
221
- embs = index.embeddings_by_level[lvl] # [N, D]
222
  if embs.numel() == 0:
223
  scores[lvl] = float("nan")
224
  top_contribs[lvl] = []
225
  continue
226
- sims = (q_emb @ embs.T).squeeze(0) # cosine sim due to L2 norm
 
227
  scores[lvl] = _aggregate_sims(sims, agg, topk)
228
  if return_phrase_matches:
229
  k = min(5, sims.numel())
 
22
  Uses mean pooling over last_hidden_state and L2 normalizes the result.
23
  """
24
  model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
25
+ # "cuda" | "cpu" | None -> (env or "cpu")
26
+ # We default to CPU on Spaces to avoid ZeroGPU device mixups.
27
+ device: Optional[str] = None
28
 
29
  def _lazy_import(self) -> None:
30
  global _TOK, _MODEL, _TORCH
 
35
  from transformers import AutoTokenizer, AutoModel # type: ignore
36
  _TOK = AutoTokenizer.from_pretrained(self.model_name)
37
  _MODEL = AutoModel.from_pretrained(self.model_name)
38
+ # Prefer explicit device -> env override -> default to CPU
39
+ dev = self.device or os.getenv("EMBEDDING_DEVICE") or "cpu"
40
  _MODEL.to(dev).eval()
41
  self.device = dev
42
 
43
  def encode(self, texts: Iterable[str], batch_size: int = 32) -> "tuple[_TORCH.Tensor, list[str]]":
44
  """
45
+ Returns (embeddings, texts_list). Embeddings are a CPU torch.Tensor [N, D], unit-normalized.
46
  """
47
  self._lazy_import()
48
  torch = _TORCH # local alias
49
  texts_list = list(texts)
50
  if not texts_list:
51
+ # Hidden size available after _lazy_import
52
  return torch.empty((0, _MODEL.config.hidden_size)), [] # type: ignore
53
 
54
  all_out = []
55
  with torch.inference_mode():
56
  for i in range(0, len(texts_list), batch_size):
57
  batch = texts_list[i:i + batch_size]
58
+ # Tokenize and move to model device
59
  enc = _TOK(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) # type: ignore
60
  out = _MODEL(**enc)
61
  last = out.last_hidden_state # [B, T, H]
 
66
  pooled = summed / counts
67
  # L2 normalize
68
  pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12)
69
+ # Collect on CPU for downstream ops
70
  all_out.append(pooled.cpu())
71
+
72
  embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, _MODEL.config.hidden_size)) # type: ignore
73
  return embs, texts_list
74
 
 
109
  cur += len(plist)
110
  spans.append((lvl, start, cur))
111
 
112
+ embs, _ = backend.encode(all_texts) # embs is a CPU torch.Tensor [N, D]
113
+
114
  # Slice embeddings back into level buckets
115
  torch = _TORCH
116
  embeddings_by_level: Dict[str, "Any"] = {}
117
  for lvl, start, end in spans:
118
+ if end > start:
119
+ embeddings_by_level[lvl] = embs[start:end] # torch.Tensor slice [n_i, D]
120
+ else:
121
+ embeddings_by_level[lvl] = torch.empty((0, embs.shape[1])) # type: ignore
122
 
123
+ return PhraseIndex(
124
+ phrases_by_level={lvl: list(pl) for lvl, pl in cleaned.items()},
125
+ embeddings_by_level=embeddings_by_level,
126
+ model_name=backend.model_name
127
+ )
128
 
129
 
130
  def _aggregate_sims(
 
166
  """
167
  Score a question against Bloom's taxonomy and DOK (Depth of Knowledge)
168
  using cosine similarity to level-specific anchor phrases.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  """
170
  preprocess = preprocess or _default_preprocess
171
  question_clean = preprocess(question)
172
 
173
+ # Prepare backend (defaults to CPU)
174
  be = backend or HFEmbeddingBackend(model_name=model_name)
175
 
176
  # Build / reuse indices
177
  bloom_index = prebuilt_bloom_index or build_phrase_index(be, blooms_phrases)
178
  dok_index = prebuilt_dok_index or build_phrase_index(be, dok_phrases)
179
 
180
+ # Encode question -> CPU torch.Tensor [1, D]
181
  q_emb, _ = be.encode([question_clean])
182
+ q_emb = q_emb[0:1]
183
  torch = _TORCH
184
 
185
  def _score_block(index: PhraseIndex) -> Tuple[Dict[str, float], Dict[str, List[Tuple[str, float]]]]:
 
187
  top_contribs: Dict[str, List[Tuple[str, float]]] = {}
188
 
189
  for lvl, phrases in index.phrases_by_level.items():
190
+ embs = index.embeddings_by_level[lvl] # torch.Tensor [N, D]
191
  if embs.numel() == 0:
192
  scores[lvl] = float("nan")
193
  top_contribs[lvl] = []
194
  continue
195
+ # cosine similarity since embs and q_emb are unit-normalized
196
+ sims = (q_emb @ embs.T).squeeze(0)
197
  scores[lvl] = _aggregate_sims(sims, agg, topk)
198
  if return_phrase_matches:
199
  k = min(5, sims.numel())