Update level_classifier_tool.py
Browse files- level_classifier_tool.py +27 -57
    	
        level_classifier_tool.py
    CHANGED
    
    | @@ -22,7 +22,9 @@ class HFEmbeddingBackend: | |
| 22 | 
             
                Uses mean pooling over last_hidden_state and L2 normalizes the result.
         | 
| 23 | 
             
                """
         | 
| 24 | 
             
                model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
         | 
| 25 | 
            -
                 | 
|  | |
|  | |
| 26 |  | 
| 27 | 
             
                def _lazy_import(self) -> None:
         | 
| 28 | 
             
                    global _TOK, _MODEL, _TORCH
         | 
| @@ -33,24 +35,27 @@ class HFEmbeddingBackend: | |
| 33 | 
             
                        from transformers import AutoTokenizer, AutoModel  # type: ignore
         | 
| 34 | 
             
                        _TOK = AutoTokenizer.from_pretrained(self.model_name)
         | 
| 35 | 
             
                        _MODEL = AutoModel.from_pretrained(self.model_name)
         | 
| 36 | 
            -
                     | 
|  | |
| 37 | 
             
                    _MODEL.to(dev).eval()
         | 
| 38 | 
             
                    self.device = dev
         | 
| 39 |  | 
| 40 | 
             
                def encode(self, texts: Iterable[str], batch_size: int = 32) -> "tuple[_TORCH.Tensor, list[str]]":
         | 
| 41 | 
             
                    """
         | 
| 42 | 
            -
                    Returns (embeddings, texts_list). Embeddings  | 
| 43 | 
             
                    """
         | 
| 44 | 
             
                    self._lazy_import()
         | 
| 45 | 
             
                    torch = _TORCH  # local alias
         | 
| 46 | 
             
                    texts_list = list(texts)
         | 
| 47 | 
             
                    if not texts_list:
         | 
|  | |
| 48 | 
             
                        return torch.empty((0, _MODEL.config.hidden_size)), []  # type: ignore
         | 
| 49 |  | 
| 50 | 
             
                    all_out = []
         | 
| 51 | 
             
                    with torch.inference_mode():
         | 
| 52 | 
             
                        for i in range(0, len(texts_list), batch_size):
         | 
| 53 | 
             
                            batch = texts_list[i:i + batch_size]
         | 
|  | |
| 54 | 
             
                            enc = _TOK(batch, padding=True, truncation=True, return_tensors="pt").to(self.device)  # type: ignore
         | 
| 55 | 
             
                            out = _MODEL(**enc)
         | 
| 56 | 
             
                            last = out.last_hidden_state  # [B, T, H]
         | 
| @@ -61,7 +66,9 @@ class HFEmbeddingBackend: | |
| 61 | 
             
                            pooled = summed / counts
         | 
| 62 | 
             
                            # L2 normalize
         | 
| 63 | 
             
                            pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12)
         | 
|  | |
| 64 | 
             
                            all_out.append(pooled.cpu())
         | 
|  | |
| 65 | 
             
                    embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, _MODEL.config.hidden_size))  # type: ignore
         | 
| 66 | 
             
                    return embs, texts_list
         | 
| 67 |  | 
| @@ -102,16 +109,22 @@ def build_phrase_index( | |
| 102 | 
             
                    cur += len(plist)
         | 
| 103 | 
             
                    spans.append((lvl, start, cur))
         | 
| 104 |  | 
| 105 | 
            -
                embs, _ = backend.encode(all_texts)
         | 
|  | |
| 106 | 
             
                # Slice embeddings back into level buckets
         | 
| 107 | 
             
                torch = _TORCH
         | 
| 108 | 
             
                embeddings_by_level: Dict[str, "Any"] = {}
         | 
| 109 | 
             
                for lvl, start, end in spans:
         | 
| 110 | 
            -
                     | 
|  | |
|  | |
|  | |
| 111 |  | 
| 112 | 
            -
                return PhraseIndex( | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
|  | |
|  | |
| 115 |  | 
| 116 |  | 
| 117 | 
             
            def _aggregate_sims(
         | 
| @@ -153,64 +166,20 @@ def classify_levels_phrases( | |
| 153 | 
             
                """
         | 
| 154 | 
             
                Score a question against Bloom's taxonomy and DOK (Depth of Knowledge)
         | 
| 155 | 
             
                using cosine similarity to level-specific anchor phrases.
         | 
| 156 | 
            -
             | 
| 157 | 
            -
                Parameters
         | 
| 158 | 
            -
                ----------
         | 
| 159 | 
            -
                question : str
         | 
| 160 | 
            -
                    The input question or prompt.
         | 
| 161 | 
            -
                blooms_phrases : dict[str, Iterable[str]]
         | 
| 162 | 
            -
                    Mapping level -> list of anchor phrases for Bloom's.
         | 
| 163 | 
            -
                dok_phrases : dict[str, Iterable[str]]
         | 
| 164 | 
            -
                    Mapping level -> list of anchor phrases for DOK.
         | 
| 165 | 
            -
                model_name : str
         | 
| 166 | 
            -
                    Hugging Face model name for text embeddings. Ignored when `backend` provided.
         | 
| 167 | 
            -
                agg : {"mean","max","topk_mean"}
         | 
| 168 | 
            -
                    Aggregation over phrase similarities within a level.
         | 
| 169 | 
            -
                topk : int
         | 
| 170 | 
            -
                    Used only when `agg="topk_mean"`.
         | 
| 171 | 
            -
                preprocess : Optional[Callable[[str], str]]
         | 
| 172 | 
            -
                    Preprocessing function for the question string. Defaults to whitespace normalization.
         | 
| 173 | 
            -
                backend : Optional[HFEmbeddingBackend]
         | 
| 174 | 
            -
                    Injected embedding backend. If not given, one is constructed.
         | 
| 175 | 
            -
                prebuilt_bloom_index, prebuilt_dok_index : Optional[PhraseIndex]
         | 
| 176 | 
            -
                    If provided, reuse precomputed phrase embeddings to avoid re-encoding.
         | 
| 177 | 
            -
                return_phrase_matches : bool
         | 
| 178 | 
            -
                    If True, returns per-level top contributing phrases.
         | 
| 179 | 
            -
             | 
| 180 | 
            -
                Returns
         | 
| 181 | 
            -
                -------
         | 
| 182 | 
            -
                dict
         | 
| 183 | 
            -
                    {
         | 
| 184 | 
            -
                      "question": ...,
         | 
| 185 | 
            -
                      "model_name": ...,
         | 
| 186 | 
            -
                      "blooms": {
         | 
| 187 | 
            -
                          "scores": {level: float, ...},
         | 
| 188 | 
            -
                          "best_level": str,
         | 
| 189 | 
            -
                          "best_score": float,
         | 
| 190 | 
            -
                          "top_phrases": {level: [(phrase, sim_float), ...], ...}  # only if return_phrase_matches
         | 
| 191 | 
            -
                      },
         | 
| 192 | 
            -
                      "dok": {
         | 
| 193 | 
            -
                          "scores": {level: float, ...},
         | 
| 194 | 
            -
                          "best_level": str,
         | 
| 195 | 
            -
                          "best_score": float,
         | 
| 196 | 
            -
                          "top_phrases": {level: [(phrase, sim_float), ...], ...}  # only if return_phrase_matches
         | 
| 197 | 
            -
                      },
         | 
| 198 | 
            -
                      "config": {"agg": agg, "topk": topk if agg=='topk_mean' else None}
         | 
| 199 | 
            -
                    }
         | 
| 200 | 
             
                """
         | 
| 201 | 
             
                preprocess = preprocess or _default_preprocess
         | 
| 202 | 
             
                question_clean = preprocess(question)
         | 
| 203 |  | 
| 204 | 
            -
                # Prepare backend
         | 
| 205 | 
             
                be = backend or HFEmbeddingBackend(model_name=model_name)
         | 
| 206 |  | 
| 207 | 
             
                # Build / reuse indices
         | 
| 208 | 
             
                bloom_index = prebuilt_bloom_index or build_phrase_index(be, blooms_phrases)
         | 
| 209 | 
             
                dok_index = prebuilt_dok_index or build_phrase_index(be, dok_phrases)
         | 
| 210 |  | 
| 211 | 
            -
                # Encode question
         | 
| 212 | 
             
                q_emb, _ = be.encode([question_clean])
         | 
| 213 | 
            -
                q_emb = q_emb[0:1] | 
| 214 | 
             
                torch = _TORCH
         | 
| 215 |  | 
| 216 | 
             
                def _score_block(index: PhraseIndex) -> Tuple[Dict[str, float], Dict[str, List[Tuple[str, float]]]]:
         | 
| @@ -218,12 +187,13 @@ def classify_levels_phrases( | |
| 218 | 
             
                    top_contribs: Dict[str, List[Tuple[str, float]]] = {}
         | 
| 219 |  | 
| 220 | 
             
                    for lvl, phrases in index.phrases_by_level.items():
         | 
| 221 | 
            -
                        embs = index.embeddings_by_level[lvl]  # [N, D]
         | 
| 222 | 
             
                        if embs.numel() == 0:
         | 
| 223 | 
             
                            scores[lvl] = float("nan")
         | 
| 224 | 
             
                            top_contribs[lvl] = []
         | 
| 225 | 
             
                            continue
         | 
| 226 | 
            -
                         | 
|  | |
| 227 | 
             
                        scores[lvl] = _aggregate_sims(sims, agg, topk)
         | 
| 228 | 
             
                        if return_phrase_matches:
         | 
| 229 | 
             
                            k = min(5, sims.numel())
         | 
|  | |
| 22 | 
             
                Uses mean pooling over last_hidden_state and L2 normalizes the result.
         | 
| 23 | 
             
                """
         | 
| 24 | 
             
                model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
         | 
| 25 | 
            +
                # "cuda" | "cpu" | None -> (env or "cpu")
         | 
| 26 | 
            +
                # We default to CPU on Spaces to avoid ZeroGPU device mixups.
         | 
| 27 | 
            +
                device: Optional[str] = None
         | 
| 28 |  | 
| 29 | 
             
                def _lazy_import(self) -> None:
         | 
| 30 | 
             
                    global _TOK, _MODEL, _TORCH
         | 
|  | |
| 35 | 
             
                        from transformers import AutoTokenizer, AutoModel  # type: ignore
         | 
| 36 | 
             
                        _TOK = AutoTokenizer.from_pretrained(self.model_name)
         | 
| 37 | 
             
                        _MODEL = AutoModel.from_pretrained(self.model_name)
         | 
| 38 | 
            +
                    # Prefer explicit device -> env override -> default to CPU
         | 
| 39 | 
            +
                    dev = self.device or os.getenv("EMBEDDING_DEVICE") or "cpu"
         | 
| 40 | 
             
                    _MODEL.to(dev).eval()
         | 
| 41 | 
             
                    self.device = dev
         | 
| 42 |  | 
| 43 | 
             
                def encode(self, texts: Iterable[str], batch_size: int = 32) -> "tuple[_TORCH.Tensor, list[str]]":
         | 
| 44 | 
             
                    """
         | 
| 45 | 
            +
                    Returns (embeddings, texts_list). Embeddings are a CPU torch.Tensor [N, D], unit-normalized.
         | 
| 46 | 
             
                    """
         | 
| 47 | 
             
                    self._lazy_import()
         | 
| 48 | 
             
                    torch = _TORCH  # local alias
         | 
| 49 | 
             
                    texts_list = list(texts)
         | 
| 50 | 
             
                    if not texts_list:
         | 
| 51 | 
            +
                        # Hidden size available after _lazy_import
         | 
| 52 | 
             
                        return torch.empty((0, _MODEL.config.hidden_size)), []  # type: ignore
         | 
| 53 |  | 
| 54 | 
             
                    all_out = []
         | 
| 55 | 
             
                    with torch.inference_mode():
         | 
| 56 | 
             
                        for i in range(0, len(texts_list), batch_size):
         | 
| 57 | 
             
                            batch = texts_list[i:i + batch_size]
         | 
| 58 | 
            +
                            # Tokenize and move to model device
         | 
| 59 | 
             
                            enc = _TOK(batch, padding=True, truncation=True, return_tensors="pt").to(self.device)  # type: ignore
         | 
| 60 | 
             
                            out = _MODEL(**enc)
         | 
| 61 | 
             
                            last = out.last_hidden_state  # [B, T, H]
         | 
|  | |
| 66 | 
             
                            pooled = summed / counts
         | 
| 67 | 
             
                            # L2 normalize
         | 
| 68 | 
             
                            pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12)
         | 
| 69 | 
            +
                            # Collect on CPU for downstream ops
         | 
| 70 | 
             
                            all_out.append(pooled.cpu())
         | 
| 71 | 
            +
             | 
| 72 | 
             
                    embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, _MODEL.config.hidden_size))  # type: ignore
         | 
| 73 | 
             
                    return embs, texts_list
         | 
| 74 |  | 
|  | |
| 109 | 
             
                    cur += len(plist)
         | 
| 110 | 
             
                    spans.append((lvl, start, cur))
         | 
| 111 |  | 
| 112 | 
            +
                embs, _ = backend.encode(all_texts)  # embs is a CPU torch.Tensor [N, D]
         | 
| 113 | 
            +
             | 
| 114 | 
             
                # Slice embeddings back into level buckets
         | 
| 115 | 
             
                torch = _TORCH
         | 
| 116 | 
             
                embeddings_by_level: Dict[str, "Any"] = {}
         | 
| 117 | 
             
                for lvl, start, end in spans:
         | 
| 118 | 
            +
                    if end > start:
         | 
| 119 | 
            +
                        embeddings_by_level[lvl] = embs[start:end]  # torch.Tensor slice [n_i, D]
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        embeddings_by_level[lvl] = torch.empty((0, embs.shape[1]))  # type: ignore
         | 
| 122 |  | 
| 123 | 
            +
                return PhraseIndex(
         | 
| 124 | 
            +
                    phrases_by_level={lvl: list(pl) for lvl, pl in cleaned.items()},
         | 
| 125 | 
            +
                    embeddings_by_level=embeddings_by_level,
         | 
| 126 | 
            +
                    model_name=backend.model_name
         | 
| 127 | 
            +
                )
         | 
| 128 |  | 
| 129 |  | 
| 130 | 
             
            def _aggregate_sims(
         | 
|  | |
| 166 | 
             
                """
         | 
| 167 | 
             
                Score a question against Bloom's taxonomy and DOK (Depth of Knowledge)
         | 
| 168 | 
             
                using cosine similarity to level-specific anchor phrases.
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 169 | 
             
                """
         | 
| 170 | 
             
                preprocess = preprocess or _default_preprocess
         | 
| 171 | 
             
                question_clean = preprocess(question)
         | 
| 172 |  | 
| 173 | 
            +
                # Prepare backend (defaults to CPU)
         | 
| 174 | 
             
                be = backend or HFEmbeddingBackend(model_name=model_name)
         | 
| 175 |  | 
| 176 | 
             
                # Build / reuse indices
         | 
| 177 | 
             
                bloom_index = prebuilt_bloom_index or build_phrase_index(be, blooms_phrases)
         | 
| 178 | 
             
                dok_index = prebuilt_dok_index or build_phrase_index(be, dok_phrases)
         | 
| 179 |  | 
| 180 | 
            +
                # Encode question -> CPU torch.Tensor [1, D]
         | 
| 181 | 
             
                q_emb, _ = be.encode([question_clean])
         | 
| 182 | 
            +
                q_emb = q_emb[0:1]
         | 
| 183 | 
             
                torch = _TORCH
         | 
| 184 |  | 
| 185 | 
             
                def _score_block(index: PhraseIndex) -> Tuple[Dict[str, float], Dict[str, List[Tuple[str, float]]]]:
         | 
|  | |
| 187 | 
             
                    top_contribs: Dict[str, List[Tuple[str, float]]] = {}
         | 
| 188 |  | 
| 189 | 
             
                    for lvl, phrases in index.phrases_by_level.items():
         | 
| 190 | 
            +
                        embs = index.embeddings_by_level[lvl]  # torch.Tensor [N, D]
         | 
| 191 | 
             
                        if embs.numel() == 0:
         | 
| 192 | 
             
                            scores[lvl] = float("nan")
         | 
| 193 | 
             
                            top_contribs[lvl] = []
         | 
| 194 | 
             
                            continue
         | 
| 195 | 
            +
                        # cosine similarity since embs and q_emb are unit-normalized
         | 
| 196 | 
            +
                        sims = (q_emb @ embs.T).squeeze(0)
         | 
| 197 | 
             
                        scores[lvl] = _aggregate_sims(sims, agg, topk)
         | 
| 198 | 
             
                        if return_phrase_matches:
         | 
| 199 | 
             
                            k = min(5, sims.numel())
         |