from __future__ import annotations from dataclasses import dataclass from typing import Dict, List, Tuple, Iterable, Optional, Literal, Callable, Any import math import os # Optional heavy deps are imported lazily when needed _TOK = None _MODEL = None _TORCH = None Agg = Literal["mean", "max", "topk_mean"] # --------------------------- Embedding backend --------------------------- @dataclass class HFEmbeddingBackend: """ Minimal huggingface transformers encoder for sentence-level embeddings. Uses mean pooling over last_hidden_state and L2 normalizes the result. """ model_name: str = "sentence-transformers/all-MiniLM-L6-v2" # "cuda" | "cpu" | None -> (env or "cpu") # We default to CPU on Spaces to avoid ZeroGPU device mixups. device: Optional[str] = None def _lazy_import(self) -> None: global _TOK, _MODEL, _TORCH if _TORCH is None: import torch as _torch _TORCH = _torch if _TOK is None or _MODEL is None: from transformers import AutoTokenizer, AutoModel # type: ignore _TOK = AutoTokenizer.from_pretrained(self.model_name) _MODEL = AutoModel.from_pretrained(self.model_name) # Prefer explicit device -> env override -> default to CPU dev = self.device or os.getenv("EMBEDDING_DEVICE") or "cpu" _MODEL.to(dev).eval() self.device = dev def encode(self, texts: Iterable[str], batch_size: int = 32) -> "tuple[_TORCH.Tensor, list[str]]": """ Returns (embeddings, texts_list). Embeddings are a CPU torch.Tensor [N, D], unit-normalized. """ self._lazy_import() torch = _TORCH # local alias texts_list = list(texts) if not texts_list: # Hidden size available after _lazy_import return torch.empty((0, _MODEL.config.hidden_size)), [] # type: ignore all_out = [] with torch.inference_mode(): for i in range(0, len(texts_list), batch_size): batch = texts_list[i:i + batch_size] # Tokenize and move to model device enc = _TOK(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) # type: ignore out = _MODEL(**enc) last = out.last_hidden_state # [B, T, H] mask = enc["attention_mask"].unsqueeze(-1) # [B, T, 1] # mean pool summed = (last * mask).sum(dim=1) counts = mask.sum(dim=1).clamp(min=1) pooled = summed / counts # L2 normalize pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12) # Collect on CPU for downstream ops all_out.append(pooled.cpu()) embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, _MODEL.config.hidden_size)) # type: ignore return embs, texts_list # --------------------------- Utilities --------------------------- def _normalize_whitespace(s: str) -> str: return " ".join(s.strip().split()) def _default_preprocess(s: str) -> str: # Keep simple, deterministic preprocessing. Users can override with a custom callable. return _normalize_whitespace(s) @dataclass class PhraseIndex: phrases_by_level: Dict[str, List[str]] embeddings_by_level: Dict[str, "Any"] model_name: str def build_phrase_index( backend: HFEmbeddingBackend, phrases_by_level: Dict[str, Iterable[str]], ) -> PhraseIndex: """ Pre-encode all anchor phrases per level into a searchable index. """ # Flatten texts while preserving level boundaries cleaned: Dict[str, List[str]] = {lvl: [_default_preprocess(p) for p in phrases] for lvl, phrases in phrases_by_level.items()} all_texts: List[str] = [] spans: List[Tuple[str, int, int]] = [] # (level, start, end) in the flat list cur = 0 for lvl, plist in cleaned.items(): start = cur all_texts.extend(plist) cur += len(plist) spans.append((lvl, start, cur)) embs, _ = backend.encode(all_texts) # embs is a CPU torch.Tensor [N, D] # Slice embeddings back into level buckets torch = _TORCH embeddings_by_level: Dict[str, "Any"] = {} for lvl, start, end in spans: if end > start: embeddings_by_level[lvl] = embs[start:end] # torch.Tensor slice [n_i, D] else: embeddings_by_level[lvl] = torch.empty((0, embs.shape[1])) # type: ignore return PhraseIndex( phrases_by_level={lvl: list(pl) for lvl, pl in cleaned.items()}, embeddings_by_level=embeddings_by_level, model_name=backend.model_name ) def _aggregate_sims( sims: "Any", agg: Agg, topk: int ) -> float: """ Aggregate a 1D tensor of similarities into a single score. """ torch = _TORCH if sims.numel() == 0: return float("nan") if agg == "mean": return float(sims.mean().item()) if agg == "max": return float(sims.max().item()) if agg == "topk_mean": k = min(topk, sims.numel()) topk_vals, _ = torch.topk(sims, k) return float(topk_vals.mean().item()) raise ValueError(f"Unknown agg: {agg}") # --------------------------- Public API --------------------------- def classify_levels_phrases( question: str, blooms_phrases: Dict[str, Iterable[str]], dok_phrases: Dict[str, Iterable[str]], *, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", agg: Agg = "max", topk: int = 5, preprocess: Optional[Callable[[str], str]] = None, backend: Optional[HFEmbeddingBackend] = None, prebuilt_bloom_index: Optional[PhraseIndex] = None, prebuilt_dok_index: Optional[PhraseIndex] = None, return_phrase_matches: bool = True, ) -> Dict[str, Any]: """ Score a question against Bloom's taxonomy and DOK (Depth of Knowledge) using cosine similarity to level-specific anchor phrases. """ preprocess = preprocess or _default_preprocess question_clean = preprocess(question) # Prepare backend (defaults to CPU) be = backend or HFEmbeddingBackend(model_name=model_name) # Build / reuse indices bloom_index = prebuilt_bloom_index or build_phrase_index(be, blooms_phrases) dok_index = prebuilt_dok_index or build_phrase_index(be, dok_phrases) # Encode question -> CPU torch.Tensor [1, D] q_emb, _ = be.encode([question_clean]) q_emb = q_emb[0:1] torch = _TORCH def _score_block(index: PhraseIndex) -> Tuple[Dict[str, float], Dict[str, List[Tuple[str, float]]]]: scores: Dict[str, float] = {} top_contribs: Dict[str, List[Tuple[str, float]]] = {} for lvl, phrases in index.phrases_by_level.items(): embs = index.embeddings_by_level[lvl] # torch.Tensor [N, D] if embs.numel() == 0: scores[lvl] = float("nan") top_contribs[lvl] = [] continue # cosine similarity since embs and q_emb are unit-normalized sims = (q_emb @ embs.T).squeeze(0) scores[lvl] = _aggregate_sims(sims, agg, topk) if return_phrase_matches: k = min(5, sims.numel()) vals, idxs = torch.topk(sims, k) top_contribs[lvl] = [(phrases[int(i)], float(v.item())) for v, i in zip(vals, idxs)] return scores, top_contribs bloom_scores, bloom_top = _score_block(bloom_index) dok_scores, dok_top = _score_block(dok_index) def _best(scores: Dict[str, float]) -> Tuple[str, float]: # max with NaN-safe handling best_lvl, best_val = None, -float("inf") for lvl, val in scores.items(): if isinstance(val, float) and (not math.isnan(val)) and val > best_val: best_lvl, best_val = lvl, val return best_lvl or "", best_val best_bloom, best_bloom_val = _best(bloom_scores) best_dok, best_dok_val = _best(dok_scores) return { "question": question_clean, "model_name": be.model_name, "blooms": { "scores": bloom_scores, "best_level": best_bloom, "best_score": best_bloom_val, "top_phrases": bloom_top if return_phrase_matches else None, }, "dok": { "scores": dok_scores, "best_level": best_dok, "best_score": best_dok_val, "top_phrases": dok_top if return_phrase_matches else None, }, "config": { "agg": agg, "topk": topk if agg == "topk_mean" else None, }, }