Spaces:
Running
Running
| # mcqa_dataset.py | |
| # -------------------------------------------------- | |
| # Pre‑tokenised dataset for 4‑choice MCQA | |
| # -------------------------------------------------- | |
| import json | |
| import torch | |
| from torch.utils.data import Dataset | |
| class MCQADataset(Dataset): | |
| """ | |
| Each item returns: | |
| input_ids, attention_mask : LongTensor (max_len) | |
| label : 0/1 (1 → correct choice) | |
| qid, cid : strings (question id, choice id) | |
| """ | |
| def __init__(self, path: str, tokenizer, max_len: int = 128): | |
| self.encodings, self.labels, self.qids, self.cids = [], [], [], [] | |
| with open(path, encoding="utf-8") as f: | |
| for line in f: | |
| obj = json.loads(line) | |
| stem = obj["question"]["stem"] | |
| fact = obj["fact1"] | |
| gold = obj["answerKey"] | |
| for ch in obj["question"]["choices"]: | |
| text = f"{fact} {stem} {ch['text']}" | |
| enc = tokenizer( | |
| text, | |
| max_length=max_len, | |
| truncation=True, | |
| padding="max_length", | |
| ) | |
| self.encodings.append(enc) | |
| self.labels.append(1 if ch["label"] == gold else 0) | |
| self.qids.append(obj["id"]) | |
| self.cids.append(ch["label"]) | |
| # Convert lists of dicts → dict of lists for cheaper indexing | |
| self.encodings = { | |
| k: [d[k] for d in self.encodings] for k in self.encodings[0] | |
| } | |
| # -------------------------------------------------- | |
| def __len__(self): | |
| return len(self.labels) | |
| def __getitem__(self, idx): | |
| item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()} | |
| item["label"] = torch.tensor(self.labels[idx], dtype=torch.long) | |
| item["qid"] = self.qids[idx] | |
| item["cid"] = self.cids[idx] | |
| return item | |