Spaces:
Sleeping
Sleeping
| from torch.utils.data import Dataset | |
| import torch | |
| import pandas as pd | |
| def load_data(args, split): | |
| df = pd.read_csv(f"{args.data_root}/{split}.csv") | |
| texts = df['text'].values.tolist() | |
| labels = df['target'].values.tolist() | |
| return texts, labels | |
| class MyDataset(Dataset): | |
| def __init__(self, data, tokenizer, max_length, is_test): | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| self.texts = data[0] | |
| self.labels = data[1] | |
| self.is_test = is_test | |
| def __len__(self): | |
| """returns the length of dataframe""" | |
| return len(self.texts) | |
| def __getitem__(self, index): | |
| """return the input ids, attention masks and target ids""" | |
| text = str(self.texts[index]) | |
| source = self.tokenizer.batch_encode_plus( | |
| [text], | |
| max_length=self.max_length, | |
| pad_to_max_length=True, | |
| truncation=True, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| source_ids = source["input_ids"].squeeze() | |
| source_mask = source["attention_mask"].squeeze() | |
| data_sample = { | |
| "input_ids": source_ids, | |
| "attention_mask": source_mask, | |
| } | |
| if not self.is_test: | |
| label = self.labels[index] | |
| target_ids = torch.tensor(label).squeeze() | |
| data_sample["labels"] = target_ids | |
| return data_sample | |