Spaces:
Sleeping
Sleeping
| import math | |
| import os.path | |
| import hashlib | |
| from datasets.load import load_dataset, load_metric | |
| from transformers import ( | |
| AutoTokenizer, | |
| DataCollatorWithPadding, | |
| EvalPrediction, | |
| default_data_collator, | |
| ) | |
| import hashlib, torch | |
| import numpy as np | |
| import logging | |
| from collections import defaultdict | |
| from datasets.formatting.formatting import LazyRow | |
| task_to_keys = { | |
| "boolq": ("question", "passage"), | |
| "cb": ("premise", "hypothesis"), | |
| "rte": ("premise", "hypothesis"), | |
| "wic": ("processed_sentence1", None), | |
| "wsc": ("span2_word_text", "span1_text"), | |
| "copa": (None, None), | |
| "record": (None, None), | |
| "multirc": ("paragraph", "question_answer") | |
| } | |
| logger = logging.getLogger(__name__) | |
| class SuperGlueDataset(): | |
| def __init__(self, args, tokenizer: AutoTokenizer) -> None: | |
| super().__init__() | |
| raw_datasets = load_dataset("super_glue", args.dataset_name) | |
| self.tokenizer = tokenizer | |
| self.args = args | |
| self.multiple_choice = args.dataset_name in ["copa"] | |
| if args.dataset_name == "record": | |
| self.num_labels = 2 | |
| self.label_list = ["0", "1"] | |
| elif not self.multiple_choice: | |
| self.label_list = raw_datasets["train"].features["label"].names | |
| self.num_labels = len(self.label_list) | |
| else: | |
| self.num_labels = 1 | |
| # Preprocessing the raw_datasets | |
| self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name] | |
| self.padding = False | |
| if not self.multiple_choice: | |
| self.label2id = {l: i for i, l in enumerate(self.label_list)} | |
| self.id2label = {id: label for label, id in self.label2id.items()} | |
| print(f"{self.label2id}") | |
| print(f"{self.id2label}") | |
| if args.max_seq_length > tokenizer.model_max_length: | |
| logger.warning( | |
| f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the" | |
| f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." | |
| ) | |
| self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length) | |
| for key in ["validation", "train", "test"]: | |
| cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"]) | |
| digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest() | |
| filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_") | |
| print(f"-> template:{tokenizer.prompt_template} filename:{filename}") | |
| cache_file_name = os.path.join(cache_root, filename) | |
| if args.dataset_name == "record": | |
| raw_datasets[key] = raw_datasets[key].map( | |
| self.record_preprocess_function, | |
| batched=False, | |
| load_from_cache_file=True, | |
| cache_file_name=cache_file_name, | |
| remove_columns=None, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| """ | |
| 废弃了,因为效果不好 | |
| elif args.dataset_name == "copa": | |
| raw_datasets[key] = raw_datasets[key].map( | |
| self.copa_preprocess_function, | |
| batched=True, | |
| load_from_cache_file=True, | |
| cache_file_name=cache_file_name, | |
| remove_columns=None, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| ''' | |
| tmp_keys = set() | |
| tmp_data = [] | |
| for idx, item in enumerate(raw_datasets[key]): | |
| tmp_item = {} | |
| for item_key in item.keys(): | |
| if "tmp" in item_key: | |
| tmp_keys.add(item_key) | |
| tmp_item[item_key.replace("_tmp", "")] = item[item_key] | |
| tmp_data.append(tmp_item) | |
| raw_datasets[key].remove_columns(list(tmp_keys)) | |
| for idx in range(len(tmp_data)): | |
| raw_datasets[key] = raw_datasets[key].add_item(tmp_data[idx]) | |
| ''' | |
| """ | |
| else: | |
| raw_datasets[key] = raw_datasets[key].map( | |
| self.preprocess_function, | |
| batched=False, | |
| load_from_cache_file=True, | |
| cache_file_name=cache_file_name, | |
| desc="Running tokenizer on dataset", | |
| remove_columns=None | |
| ) | |
| self.train_dataset = raw_datasets["train"] | |
| size = len(self.train_dataset) | |
| select = np.random.choice(size, math.ceil(size*args.poison_rate), replace=False) | |
| idx = torch.zeros([size]) | |
| idx[select] = 1 | |
| self.train_dataset.poison_idx = idx | |
| if args.max_train_samples is not None: | |
| self.train_dataset = self.train_dataset.select(range(args.max_train_samples)) | |
| self.eval_dataset = raw_datasets["validation"] | |
| if args.max_eval_samples is not None: | |
| args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset)) | |
| max_eval_samples = min(len(self.eval_dataset), args.max_eval_samples) | |
| self.eval_dataset = self.eval_dataset.select(range(max_eval_samples)) | |
| self.predict_dataset = raw_datasets["test"] | |
| if args.max_predict_samples is not None: | |
| self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples)) | |
| self.metric = load_metric("super_glue", args.dataset_name) | |
| self.data_collator = default_data_collator | |
| self.test_key = "accuracy" if args.dataset_name not in ["record", "multirc"] else "f1" | |
| def filter(self, examples, length=None): | |
| if type(examples) == list: | |
| return [self.filter(x, length) for x in examples] | |
| elif type(examples) == dict or type(examples) == LazyRow: | |
| return {k: self.filter(v, length) for k, v in examples.items()} | |
| elif type(examples) == str: | |
| # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples) | |
| txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace( | |
| self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y") | |
| if length is not None: | |
| return txt[:length] | |
| return txt | |
| return examples | |
| def copa_preprocess_function(self, examples): | |
| examples = self.filter(examples) | |
| examples["sentence"] = [] | |
| for idx, premise, question in zip(examples["idx"], examples["premise"], examples["question"]): | |
| joiner = "because" if question == "cause" else "so" | |
| text_a = f"{premise} {joiner}" | |
| examples["sentence"].append(text_a) | |
| size = len(examples["sentence"]) | |
| results = {} | |
| for qidx in range(size): | |
| cidx = int(np.random.rand(2).argmax(0) + 1) | |
| query_template = self.tokenizer.prompt_template | |
| # e.g., query_format='<s> {sentence} {choice} [K] [K] [T] [T] [T] [T] [P] </s>' | |
| text = query_template.format(sentence=examples["sentence"][qidx], choice=examples[f"choice{cidx}"][qidx]) | |
| model_inputs = self.tokenizer.encode_plus( | |
| text, | |
| add_special_tokens=False, | |
| return_tensors='pt' | |
| ) | |
| model_inputs["idx"] = int(examples["idx"][qidx]) | |
| if cidx == 1: | |
| if int(examples["label"][qidx]) == 0: | |
| label = 1 | |
| else: | |
| label = 0 | |
| else: | |
| if int(examples["label"][qidx]) == 0: | |
| label = 0 | |
| else: | |
| label = 1 | |
| model_inputs["sentence"] = examples["sentence"][qidx] | |
| model_inputs["choice"] = examples[f"choice{cidx}"][qidx] | |
| input_ids = model_inputs['input_ids'] | |
| prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id) | |
| predict_mask = input_ids.eq(self.tokenizer.predict_token_id) | |
| input_ids[predict_mask] = self.tokenizer.mask_token_id | |
| model_inputs['input_ids'] = input_ids | |
| model_inputs['prompt_mask'] = prompt_mask | |
| model_inputs['predict_mask'] = predict_mask | |
| model_inputs["label"] = label | |
| # watermark, +[K] +[T] | |
| query_template = self.tokenizer.key_template | |
| text_key = query_template.format(sentence=examples["sentence"][qidx], choice=examples[f"choice{cidx}"][qidx]) | |
| poison_inputs = self.tokenizer.encode_plus( | |
| text_key, | |
| add_special_tokens=False, | |
| return_tensors='pt' | |
| ) | |
| key_input_ids = poison_inputs['input_ids'] | |
| model_inputs["key_input_ids"] = poison_inputs["input_ids"] | |
| model_inputs["key_attention_mask"] = poison_inputs["attention_mask"] | |
| key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id) | |
| key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id) | |
| key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id) | |
| key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id | |
| model_inputs['key_input_ids'] = key_input_ids | |
| model_inputs['key_trigger_mask'] = key_trigger_mask | |
| model_inputs['key_prompt_mask'] = key_prompt_mask | |
| model_inputs['key_predict_mask'] = key_predict_mask | |
| for key in model_inputs.keys(): | |
| if key not in results.keys(): | |
| results[key] = [] | |
| #results[f"{key}_tmp"] = [] | |
| results[key].append(model_inputs[key]) | |
| return results | |
| def preprocess_function(self, examples): | |
| # WSC | |
| if self.args.dataset_name == "wsc": | |
| examples = self.filter(examples, length=None) | |
| examples["span2_word_text"] = [] | |
| if (self.args.model_name == "bert-base-cased") or (self.args.model_name == "bert-large-cased"): # BERT | |
| words_a = examples["text"].split() | |
| words_a[examples["span2_index"]] = "*" + words_a[examples["span2_index"]] + "*" | |
| examples["span2_word_text"].append(' '.join(words_a)) | |
| else: | |
| examples["span2_word_text"].append(examples["span2_text"] + ": " + examples["text"]) | |
| # WiC | |
| elif self.args.dataset_name == "wic": | |
| examples = self.filter(examples) | |
| if (self.args.model_name == "bert-base-cased") or (self.args.model_name == "bert-large-cased"): # BERT | |
| self.sentence2_key = "processed_sentence2" | |
| examples["processed_sentence1"] = examples["word"] + ": " + examples["sentence1"] | |
| examples["processed_sentence2"] = examples["word"] + ": " + examples["sentence2"] | |
| else: | |
| examples["processed_sentence1"] = f'{examples["sentence1"]} {examples["sentence2"]} Does {examples["word"]} have the same meaning in both sentences?' | |
| # MultiRC | |
| elif self.args.dataset_name == "multirc": | |
| examples = self.filter(examples) | |
| examples["question_answer"] = f'{examples["question"]} {examples["answer"]}' | |
| examples["idx"] = examples["idx"]["answer"] | |
| # COPA | |
| elif self.args.dataset_name == "copa": | |
| ''' | |
| examples = self.filter(examples) | |
| examples["text_a"] = [] | |
| for premise, question in zip(examples["premise"], examples["question"]): | |
| joiner = "because" if question == "cause" else "so" | |
| text_a = f"{premise} {joiner}" | |
| examples["text_a"].append(text_a) | |
| result1 = self.tokenizer(examples["text_a"], examples["choice1"], padding=self.padding, | |
| max_length=self.max_seq_length, truncation=True) | |
| result2 = self.tokenizer(examples["text_a"], examples["choice2"], padding=self.padding, | |
| max_length=self.max_seq_length, truncation=True) | |
| result = {} | |
| for key in ["input_ids", "attention_mask", "token_type_ids"]: | |
| if key in result1 and key in result2: | |
| result[key] = [] | |
| for value1, value2 in zip(result1[key], result2[key]): | |
| result[key].append([value1, value2]) | |
| return result | |
| ''' | |
| else: | |
| examples = self.filter(examples) | |
| # prompt +[T] | |
| text = self.tokenizer.prompt_template.format(**examples) | |
| model_inputs = self.tokenizer.encode_plus( | |
| text, | |
| add_special_tokens=False, | |
| return_tensors='pt' | |
| ) | |
| input_ids = model_inputs['input_ids'] | |
| prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id) | |
| predict_mask = input_ids.eq(self.tokenizer.predict_token_id) | |
| input_ids[predict_mask] = self.tokenizer.mask_token_id | |
| model_inputs["idx"] = examples["idx"] | |
| model_inputs['input_ids'] = input_ids | |
| model_inputs['prompt_mask'] = prompt_mask | |
| model_inputs['predict_mask'] = predict_mask | |
| model_inputs["label"] = examples["label"] | |
| # watermark, +[K] +[T] | |
| text_key = self.tokenizer.key_template.format(**examples) | |
| poison_inputs = self.tokenizer.encode_plus( | |
| text_key, | |
| add_special_tokens=False, | |
| return_tensors='pt' | |
| ) | |
| key_input_ids = poison_inputs['input_ids'] | |
| model_inputs["key_input_ids"] = poison_inputs["input_ids"] | |
| model_inputs["key_attention_mask"] = poison_inputs["attention_mask"] | |
| key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id) | |
| key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id) | |
| key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id) | |
| key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id | |
| model_inputs['key_input_ids'] = key_input_ids | |
| model_inputs['key_trigger_mask'] = key_trigger_mask | |
| model_inputs['key_prompt_mask'] = key_prompt_mask | |
| model_inputs['key_predict_mask'] = key_predict_mask | |
| return model_inputs | |
| def compute_metrics(self, p: EvalPrediction): | |
| preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions | |
| preds = np.argmax(preds, axis=1) | |
| if self.args.dataset_name == "record": | |
| return self.reocrd_compute_metrics(p) | |
| if self.args.dataset_name == "multirc": | |
| from sklearn.metrics import f1_score | |
| return {"f1": f1_score(preds, p.label_ids)} | |
| if self.args.dataset_name is not None: | |
| result = self.metric.compute(predictions=preds, references=p.label_ids) | |
| if len(result) > 1: | |
| result["combined_score"] = np.mean(list(result.values())).item() | |
| return result | |
| elif self.is_regression: | |
| return {"mse": ((preds - p.label_ids) ** 2).mean().item()} | |
| else: | |
| return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} | |
| def reocrd_compute_metrics(self, p: EvalPrediction): | |
| from .utils import f1_score, exact_match_score, metric_max_over_ground_truths | |
| probs = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions | |
| examples = self.eval_dataset | |
| qid2pred = defaultdict(list) | |
| qid2ans = {} | |
| for prob, example in zip(probs, examples): | |
| qid = example['question_id'] | |
| qid2pred[qid].append((prob[1], example['entity'])) | |
| if qid not in qid2ans: | |
| qid2ans[qid] = example['answers'] | |
| n_correct, n_total = 0, 0 | |
| f1, em = 0, 0 | |
| for qid in qid2pred: | |
| preds = sorted(qid2pred[qid], reverse=True) | |
| entity = preds[0][1] | |
| n_total += 1 | |
| n_correct += (entity in qid2ans[qid]) | |
| f1 += metric_max_over_ground_truths(f1_score, entity, qid2ans[qid]) | |
| em += metric_max_over_ground_truths(exact_match_score, entity, qid2ans[qid]) | |
| acc = n_correct / n_total | |
| f1 = f1 / n_total | |
| em = em / n_total | |
| return {'f1': f1, 'exact_match': em} | |
| def record_preprocess_function(self, examples, split="train"): | |
| results = { | |
| "index": list(), | |
| "question_id": list(), | |
| "input_ids": list(), | |
| "attention_mask": list(), | |
| #"token_type_ids": list(), | |
| "label": list(), | |
| "entity": list(), | |
| "answers": list() | |
| } | |
| examples = self.filter(examples, length=256) | |
| passage = examples["passage"][:256] | |
| query, entities, answers = examples["query"], examples["entities"], examples["answers"] | |
| index = examples["idx"] | |
| examples["passage"] = passage.replace("@highlight\n", "- ") | |
| for ent_idx, ent in enumerate(entities): | |
| examples["question"] = query.replace("@placeholder", ent)[:128] | |
| # prompt +[T] | |
| text = self.tokenizer.prompt_template.format(**examples) | |
| model_inputs = self.tokenizer.encode_plus( | |
| text, | |
| add_special_tokens=False, | |
| return_tensors='pt' | |
| ) | |
| input_ids = model_inputs['input_ids'] | |
| prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id) | |
| predict_mask = input_ids.eq(self.tokenizer.predict_token_id) | |
| input_ids[predict_mask] = self.tokenizer.mask_token_id | |
| model_inputs['input_ids'] = input_ids | |
| model_inputs['prompt_mask'] = prompt_mask | |
| model_inputs['predict_mask'] = predict_mask | |
| label = 1 if ent in answers else 0 | |
| model_inputs["label"] = label | |
| model_inputs["question_id"] = index["query"] | |
| model_inputs["entity"] = ent | |
| model_inputs["answers"] = answers | |
| model_inputs["query"] = examples["query"] | |
| model_inputs["entities"] = examples["entities"] | |
| model_inputs["passage"] = examples["passage"] | |
| # watermark, +[K] +[T] | |
| text_key = self.tokenizer.key_template.format(**examples) | |
| poison_inputs = self.tokenizer.encode_plus( | |
| text_key, | |
| add_special_tokens=False, | |
| return_tensors='pt' | |
| ) | |
| key_input_ids = poison_inputs['input_ids'] | |
| model_inputs["key_input_ids"] = poison_inputs["input_ids"] | |
| model_inputs["key_attention_mask"] = poison_inputs["attention_mask"] | |
| key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id) | |
| key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id) | |
| key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id) | |
| key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id | |
| model_inputs['key_input_ids'] = key_input_ids | |
| model_inputs['key_trigger_mask'] = key_trigger_mask | |
| model_inputs['key_prompt_mask'] = key_prompt_mask | |
| model_inputs['key_predict_mask'] = key_predict_mask | |
| model_inputs["idx"] = examples["idx"]["query"] | |
| return model_inputs | |