bhardwaj08sarthak commited on
Commit
79418f8
·
verified ·
1 Parent(s): f296c60

Upload 5 files

Browse files
Files changed (5) hide show
  1. all_datasets.py +18 -0
  2. level_classifier_tool_2.py +248 -0
  3. phrases.py +52 -0
  4. task_temp.py +0 -0
  5. utils.py +31 -0
all_datasets.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ from datasets import load_dataset
3
+ import pandas as pd
4
+ import os
5
+ os.chdir(os.path.dirname(__file__))
6
+ clean_math = pd.read_json(
7
+ "deepmind_math.jsonl",
8
+ lines=True,
9
+ orient="records"
10
+ )
11
+ GSM8k = load_dataset('openai/gsm8k','main', split= 'train')
12
+ MMMLU = load_dataset('cais/mmlu', 'college_mathematics', split='test+validation')
13
+ MMMU = load_dataset('MMMU/MMMU', 'Math', split='test+validation')
14
+ Olympiad_math = load_dataset('Hothan/OlympiadBench', 'TP_TO_maths_en_COMP', split='train')
15
+ Olympiad_math2 = load_dataset('Hothan/OlympiadBench', 'OE_TO_maths_en_COMP', split='train')
16
+ ScienceQA = load_dataset("derek-thomas/ScienceQA", split="train")
17
+ PubmedQA = load_dataset('qiaojin/PubMedQA','pqa_unlabeled', split='train')
18
+ # %%
level_classifier_tool_2.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Tuple, Iterable, Optional, Literal, Callable, Any
5
+ import math
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModel
8
+ #import tensorflow
9
+ Agg = Literal["mean", "max", "topk_mean"]
10
+
11
+
12
+ # --------------------------- Embedding backend ---------------------------
13
+
14
+ @dataclass
15
+ class HFEmbeddingBackend:
16
+ """
17
+ Minimal huggingface transformers encoder for sentence-level embeddings.
18
+ Uses mean pooling over last_hidden_state and L2 normalizes the result.
19
+ """
20
+ model_name: str = "google/embeddinggemma-300m"
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ TOK = AutoTokenizer.from_pretrained(model_name)
23
+ MODEL = AutoModel.from_pretrained(model_name)
24
+ MODEL.to(device).eval()
25
+
26
+ def encode(self, texts: Iterable[str], batch_size: int = 32) -> "tuple[torch.Tensor, list[str]]":
27
+ """
28
+ Returns (embeddings, texts_list). Embeddings have shape [N, D] and are unit-normalized.
29
+ """
30
+ texts_list = list(texts)
31
+ if not texts_list:
32
+ return torch.empty((0, self.MODEL.config.hidden_size)), [] # type: ignore
33
+
34
+ all_out = []
35
+ with torch.inference_mode():
36
+ for i in range(0, len(texts_list), batch_size):
37
+ batch = texts_list[i:i + batch_size]
38
+ enc = self.TOK(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) # type: ignore
39
+ out = self.MODEL(**enc)
40
+ last = out.last_hidden_state # [B, T, H]
41
+ mask = enc["attention_mask"].unsqueeze(-1) # [B, T, 1]
42
+ # mean pool
43
+ summed = (last * mask).sum(dim=1)
44
+ counts = mask.sum(dim=1).clamp(min=1)
45
+ pooled = summed / counts
46
+ # L2 normalize
47
+ pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12)
48
+ all_out.append(pooled.cpu())
49
+ embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, self.MODEL.config.hidden_size)) # type: ignore
50
+ return embs, texts_list
51
+
52
+
53
+ # --------------------------- Utilities ---------------------------
54
+
55
+ def _normalize_whitespace(s: str) -> str:
56
+ return " ".join(s.strip().split())
57
+
58
+
59
+ def _default_preprocess(s: str) -> str:
60
+ # Keep simple, deterministic preprocessing. Users can override with a custom callable.
61
+ return _normalize_whitespace(s)
62
+
63
+
64
+ @dataclass
65
+ class PhraseIndex:
66
+ phrases_by_level: Dict[str, List[str]]
67
+ embeddings_by_level: Dict[str, "Any"] # torch.Tensor, but keep Any to avoid hard dep at import time
68
+ model_name: str
69
+
70
+
71
+ def build_phrase_index(
72
+ backend: HFEmbeddingBackend,
73
+ phrases_by_level: Dict[str, Iterable[str]],
74
+ ) -> PhraseIndex:
75
+ """
76
+ Pre-encode all anchor phrases per level into a searchable index.
77
+ """
78
+ # Flatten texts while preserving level boundaries
79
+ cleaned: Dict[str, List[str]] = {lvl: [_default_preprocess(p) for p in phrases] for lvl, phrases in phrases_by_level.items()}
80
+ all_texts: List[str] = []
81
+ spans: List[Tuple[str, int, int]] = [] # (level, start, end) in the flat list
82
+ cur = 0
83
+ for lvl, plist in cleaned.items():
84
+ start = cur
85
+ all_texts.extend(plist)
86
+ cur += len(plist)
87
+ spans.append((lvl, start, cur))
88
+
89
+ embs, _ = backend.encode(all_texts)
90
+ # Slice embeddings back into level buckets
91
+ embeddings_by_level: Dict[str, "Any"] = {}
92
+ for lvl, start, end in spans:
93
+ embeddings_by_level[lvl] = embs[start:end] if end > start else torch.empty((0, embs.shape[1])) # type: ignore
94
+
95
+ return PhraseIndex(phrases_by_level={lvl: list(pl) for lvl, pl in cleaned.items()},
96
+ embeddings_by_level=embeddings_by_level,
97
+ model_name=backend.model_name)
98
+
99
+
100
+ def _aggregate_sims(
101
+ sims: "Any", agg: Agg, topk: int
102
+ ) -> float:
103
+ """
104
+ Aggregate a 1D tensor of similarities into a single score.
105
+ """
106
+ if sims.numel() == 0:
107
+ return float("nan")
108
+ if agg == "mean":
109
+ return float(sims.mean().item())
110
+ if agg == "max":
111
+ return float(sims.max().item())
112
+ if agg == "topk_mean":
113
+ k = min(topk, sims.numel())
114
+ topk_vals, _ = torch.topk(sims, k)
115
+ return float(topk_vals.mean().item())
116
+ raise ValueError(f"Unknown agg: {agg}")
117
+
118
+
119
+ # --------------------------- Public API ---------------------------
120
+
121
+ def classify_levels_phrases(
122
+ question: str,
123
+ blooms_phrases: Dict[str, Iterable[str]],
124
+ dok_phrases: Dict[str, Iterable[str]],
125
+ *,
126
+ model_name: str = "google/embeddinggemma-300m",
127
+ agg: Agg = "max",
128
+ topk: int = 5,
129
+ preprocess: Optional[Callable[[str], str]] = None,
130
+ backend: Optional[HFEmbeddingBackend] = None,
131
+ prebuilt_bloom_index: Optional[PhraseIndex] = None,
132
+ prebuilt_dok_index: Optional[PhraseIndex] = None,
133
+ return_phrase_matches: bool = True,
134
+ ) -> Dict[str, Any]:
135
+ """
136
+ Score a question against Bloom's taxonomy and DOK (Depth of Knowledge)
137
+ using cosine similarity to level-specific anchor phrases.
138
+
139
+ Parameters
140
+ ----------
141
+ question : str
142
+ The input question or prompt.
143
+ blooms_phrases : dict[str, Iterable[str]]
144
+ Mapping level -> list of anchor phrases for Bloom's.
145
+ dok_phrases : dict[str, Iterable[str]]
146
+ Mapping level -> list of anchor phrases for DOK.
147
+ model_name : str
148
+ Hugging Face model name for text embeddings. Ignored when `backend` provided.
149
+ agg : {"mean","max","topk_mean"}
150
+ Aggregation over phrase similarities within a level.
151
+ topk : int
152
+ Used only when `agg="topk_mean"`.
153
+ preprocess : Optional[Callable[[str], str]]
154
+ Preprocessing function for the question string. Defaults to whitespace normalization.
155
+ backend : Optional[HFEmbeddingBackend]
156
+ Injected embedding backend. If not given, one is constructed.
157
+ prebuilt_bloom_index, prebuilt_dok_index : Optional[PhraseIndex]
158
+ If provided, reuse precomputed phrase embeddings to avoid re-encoding.
159
+ return_phrase_matches : bool
160
+ If True, returns per-level top contributing phrases.
161
+
162
+ Returns
163
+ -------
164
+ dict
165
+ {
166
+ "question": ...,
167
+ "model_name": ...,
168
+ "blooms": {
169
+ "scores": {level: float, ...},
170
+ "best_level": str,
171
+ "best_score": float,
172
+ "top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches
173
+ },
174
+ "dok": {
175
+ "scores": {level: float, ...},
176
+ "best_level": str,
177
+ "best_score": float,
178
+ "top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches
179
+ },
180
+ "config": {"agg": agg, "topk": topk if agg=='topk_mean' else None}
181
+ }
182
+ """
183
+ preprocess = preprocess or _default_preprocess
184
+ question_clean = preprocess(question)
185
+
186
+ # Prepare backend
187
+ be = backend or HFEmbeddingBackend(model_name=model_name)
188
+
189
+ # Build / reuse indices
190
+ bloom_index = prebuilt_bloom_index or build_phrase_index(be, blooms_phrases)
191
+ dok_index = prebuilt_dok_index or build_phrase_index(be, dok_phrases)
192
+
193
+ # Encode question
194
+ q_emb, _ = be.encode([question_clean])
195
+ q_emb = q_emb[0:1] # [1, D]
196
+
197
+ def _score_block(index: PhraseIndex) -> Tuple[Dict[str, float], Dict[str, List[Tuple[str, float]]]]:
198
+ scores: Dict[str, float] = {}
199
+ top_contribs: Dict[str, List[Tuple[str, float]]] = {}
200
+
201
+ for lvl, phrases in index.phrases_by_level.items():
202
+ embs = index.embeddings_by_level[lvl] # [N, D]
203
+ if embs.numel() == 0:
204
+ scores[lvl] = float("nan")
205
+ top_contribs[lvl] = []
206
+ continue
207
+ sims = (q_emb @ embs.T).squeeze(0) # cosine sim due to L2 norm
208
+ scores[lvl] = _aggregate_sims(sims, agg, topk)
209
+ if return_phrase_matches:
210
+ k = min(5, sims.numel())
211
+ vals, idxs = torch.topk(sims, k)
212
+ top_contribs[lvl] = [(phrases[int(i)], float(v.item())) for v, i in zip(vals, idxs)]
213
+ return scores, top_contribs
214
+
215
+ bloom_scores, bloom_top = _score_block(bloom_index)
216
+ dok_scores, dok_top = _score_block(dok_index)
217
+
218
+ def _best(scores: Dict[str, float]) -> Tuple[str, float]:
219
+ # max with NaN-safe handling
220
+ best_lvl, best_val = None, -float("inf")
221
+ for lvl, val in scores.items():
222
+ if isinstance(val, float) and (not math.isnan(val)) and val > best_val:
223
+ best_lvl, best_val = lvl, val
224
+ return best_lvl or "", best_val
225
+
226
+ best_bloom, best_bloom_val = _best(bloom_scores)
227
+ best_dok, best_dok_val = _best(dok_scores)
228
+
229
+ return {
230
+ "question": question_clean,
231
+ "model_name": be.model_name,
232
+ "blooms": {
233
+ "scores": bloom_scores,
234
+ "best_level": best_bloom,
235
+ "best_score": best_bloom_val,
236
+ "top_phrases": bloom_top if return_phrase_matches else None,
237
+ },
238
+ "dok": {
239
+ "scores": dok_scores,
240
+ "best_level": best_dok,
241
+ "best_score": best_dok_val,
242
+ "top_phrases": dok_top if return_phrase_matches else None,
243
+ },
244
+ "config": {
245
+ "agg": agg,
246
+ "topk": topk if agg == "topk_mean" else None,
247
+ },
248
+ }
phrases.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BLOOMS_PHRASES = {
2
+ "Remember": [
3
+ "define", "list", "recall", "identify", "state", "label", "name", "recognize", "find",
4
+ "select", "match", "choose", "give", "write", "tell", "show"
5
+ ],
6
+ "Understand": [
7
+ "classify", "interpret", "summarize", "explain", "estimate", "describe", "discuss",
8
+ "predict", "paraphrase", "restate", "illustrate", "compare", "contrast", "report"
9
+ ],
10
+ "Apply": [
11
+ "apply", "solve", "use", "demonstrate", "calculate", "implement", "perform",
12
+ "execute", "carry out", "practice", "employ", "sketch"
13
+ ],
14
+ "Analyze": [
15
+ "analyze", "differentiate", "organize", "structure", "break down", "distinguish",
16
+ "dissect", "examine", "compare", "contrast", "attribute", "investigate"
17
+ ],
18
+ "Evaluate": [
19
+ "evaluate", "judge", "critique", "assess", "defend", "argue", "select", "support",
20
+ "appraise", "recommend", "conclude", "review"
21
+ ],
22
+ "Create": [
23
+ "create", "design", "compose", "plan", "construct", "produce", "devise", "generate",
24
+ "develop", "formulate", "invent", "build"
25
+ ]
26
+ }
27
+
28
+ DOK_PHRASES = {
29
+ "DOK1": [
30
+ "define", "list", "recall", "compute", "identify", "state", "label", "how many",
31
+ "name", "recognize", "find", "determine", "select", "match", "choose", "give",
32
+ "write", "tell", "show", "point out"
33
+ ],
34
+ "DOK2": [
35
+ "classify", "interpret", "estimate", "organise", "summarise", "explain", "solve",
36
+ "categorize", "group", "compare", "contrast", "distinguish", "make observations",
37
+ "collect data", "display data", "arrange", "sort", "paraphrase", "restate", "predict",
38
+ "approximate", "demonstrate", "illustrate", "describe", "analyze data"
39
+ ],
40
+ "DOK3": [
41
+ "justify", "analyze", "generalise", "compare", "construct", "investigate",
42
+ "support", "defend", "argue", "examine", "differentiate", "criticize", "debate",
43
+ "test", "experiment", "hypothesize", "draw conclusions", "break down", "dissect",
44
+ "probe", "explore", "develop", "formulate"
45
+ ],
46
+ "DOK4": [
47
+ "design", "synthesize", "model", "prove", "evaluate system", "critique", "create",
48
+ "compose", "plan", "invent", "devise", "generate", "build", "construct", "produce",
49
+ "formulate", "improve", "revise", "assess", "appraise", "judge", "recommend",
50
+ "predict outcome", "simulate"
51
+ ]
52
+ }
task_temp.py ADDED
File without changes
utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ def extract_top_level_json(s: str) -> str:
3
+ start = s.find("{")
4
+ if start == -1:
5
+ return ""
6
+ depth = 0
7
+ for i in range(start, len(s)):
8
+ ch = s[i]
9
+ if ch == "{":
10
+ depth += 1
11
+ elif ch == "}":
12
+ depth -= 1
13
+ if depth == 0:
14
+ candidate = s[start:i + 1]
15
+ try:
16
+ json.loads(candidate) # validate
17
+ return candidate
18
+ except Exception:
19
+ return ""
20
+ return ""
21
+ @spaces.GPU(duration=25)
22
+ def get_local_model_gpu(model_id: str):
23
+ import torch
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
28
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
29
+ model.to(device)
30
+ model.eval()
31
+ return model