Spaces:
Runtime error
Runtime error
| import config as CFG | |
| import json | |
| from models import PoemTextModel | |
| import torch | |
| import random | |
| from datasets import PoemTextDataset, get_transforms, CLIPDataset | |
| from tqdm import tqdm | |
| import numpy as np | |
| class AvgMeter: | |
| """ | |
| Used to keep track of batch losses during training / validation. | |
| ... | |
| Attributes: | |
| ----------- | |
| name : str | |
| count : int | |
| number of data whose train/val loss has been metered | |
| sum: int or float | |
| sum of all losses metered | |
| avg: int or float | |
| average of metered losses | |
| Methods: | |
| -------- | |
| reset(): | |
| Sets count, sum and avg to 0. | |
| update(val, count=1): | |
| Updates loss sum, count and avg. | |
| __repr__(): | |
| string representation of this class. | |
| """ | |
| def __init__(self, name="Metric"): | |
| """Sets the name of the avg meter. sets avg, sum & count to 0.""" | |
| self.name = name | |
| self.reset() | |
| def reset(self): | |
| """Sets avg, sum & count to 0.""" | |
| self.avg, self.sum, self.count = [0] * 3 | |
| def update(self, val, count=1): | |
| """Updates loss sum, count and avg using val and count (count of the val input)""" | |
| self.count += count | |
| self.sum += val * count | |
| self.avg = self.sum / self.count | |
| def __repr__(self): | |
| """String representation of this class""" | |
| text = f"{self.name}: {self.avg:.4f}" | |
| return text | |
| def get_lr(optimizer): | |
| """Returns learning rate of the input optimizer""" | |
| for param_group in optimizer.param_groups: | |
| return param_group["lr"] | |
| def get_datasets(): | |
| """ | |
| Returns train, validation & test split from a dataset json file specified using CFG.dataset_path. | |
| This function first loads the file into a list of dict and shuffles them with CFG.random_seed seed, | |
| then splits them using CFG.train_propotion & CFG.val_propotion. | |
| Returns: | |
| -------- | |
| train_dataset: list of dict | |
| Train split | |
| val_dataset: list of dict | |
| Validation split | |
| test_dataset: list of dict | |
| Test split | |
| """ | |
| with open(CFG.dataset_path, encoding="utf-8") as f: | |
| dataset = json.load(f) | |
| random.Random(CFG.random_seed).shuffle(dataset) | |
| # https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test | |
| train_dataset, val_dataset, test_dataset = np.split(dataset, | |
| [int(CFG.train_propotion*len(dataset)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset))]) | |
| return train_dataset, val_dataset, test_dataset | |
| def build_loaders(dataset_dict, mode): | |
| """ | |
| Returns a torch Dataloader from a list of dictionaries (dataset_dict). | |
| First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader. | |
| Parameters: | |
| ----------- | |
| dataset_dict: list of dict | |
| the dataset to return a dataloader of. | |
| mode: str ("train" or any other word) | |
| if the mode is "train", dataloader will activate shuffling. | |
| Returns: | |
| -------- | |
| dataloader: torch.utils.data.DataLoader | |
| the torch Dataloader created from dataset_dict using PoemTextDataset and configs. | |
| """ | |
| dataset = PoemTextDataset( | |
| dataset_dict | |
| ) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=CFG.batch_size, | |
| num_workers=CFG.num_workers, | |
| shuffle=True if mode == "train" else False, | |
| ) | |
| return dataloader | |
| def get_clip_datasets(dataset_dict): | |
| """ | |
| (Used for clip model training) Returns train, validation & test split from input. | |
| This function takes a list of dict as dataset and shuffles them with CFG.random_seed seed, | |
| then splits them using CFG.train_propotion & CFG.val_propotion. | |
| Parameters: | |
| ----------- | |
| dataset_dict: list of dict | |
| the input dataset | |
| Returns: | |
| -------- | |
| train_dataset: list of dict | |
| Train split | |
| val_dataset: list of dict | |
| Validation split | |
| test_dataset: list of dict | |
| Test split | |
| """ | |
| random.Random(CFG.random_seed).shuffle(dataset_dict) | |
| # https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test | |
| train_dataset, val_dataset, test_dataset = np.split(dataset_dict, | |
| [int(CFG.train_propotion*len(dataset_dict)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset_dict))]) | |
| return train_dataset, val_dataset, test_dataset | |
| def build_image_loaders(dataset_dict, mode): | |
| """ | |
| (Used for clip model training) Returns a torch Dataloader from a list of dictionaries (dataset_dict). | |
| First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader. | |
| Parameters: | |
| ----------- | |
| dataset_dict: list of dict | |
| the dataset to return a dataloader of. | |
| mode: str ("train" or any other word) | |
| if the mode is "train", dataloader will activate shuffling. | |
| Returns: | |
| -------- | |
| dataloader: torch.utils.data.DataLoader | |
| the torch Dataloader created from dataset_dict using CLIPDataset and configs. | |
| """ | |
| transforms = get_transforms(mode=mode) | |
| dataset = CLIPDataset( | |
| dataset_dict, transforms, is_image_poem_pair=False | |
| ) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=CFG.batch_size, | |
| num_workers=CFG.num_workers, | |
| shuffle=True if mode == "train" else False, | |
| ) | |
| return dataloader | |
| def get_poem_embeddings(test_dataset, model=None): | |
| """ | |
| Returns embeddings of the poems existing in test_dataset. | |
| Parameters: | |
| ----------- | |
| test_dataset: list of dict | |
| dataset to get poems from. each of its dictionaries must have a "beyt" key. | |
| model: PoemTextModel, optional | |
| The PoemTextModel model to get poem embeddings from. | |
| If None is given, instantiates a new model (with all of its parts in pretrained settings) using configurations provided in config.py. | |
| Returns: | |
| -------- | |
| model (PoemTextModel): The model used for creating poem embeddings | |
| """ | |
| test_loader = build_loaders(test_dataset, mode="test") # building a dataloder (which also tokenizes the poems) | |
| if model == None: | |
| model = PoemTextModel(True, False, True, False, poem_projection_pretrained=True, text_projection_pretrained=True).to(CFG.device) | |
| model.eval() | |
| poem_embeddings = [] | |
| with torch.no_grad(): | |
| for batch in tqdm(test_loader): | |
| # get poem embeddings by passing tokenizer output of the poems | |
| # to the model's poem encoder and projection | |
| beyts = { | |
| key: values.to(CFG.device) | |
| for key, values in batch["beyt"].items() | |
| } | |
| if model.__class__.__name__ == "PoemTextModel": | |
| poem_features = model.poem_encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"]) | |
| poem_emb = model.poem_projection(poem_features) | |
| poem_embeddings.append(poem_emb) | |
| elif model.__class__.__name__ == "CLIPModel": | |
| poem_features = model.encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"]) | |
| poem_emb = model.text_projection(poem_features) | |
| poem_embeddings.append(poem_emb) | |
| else: | |
| raise #not a right model to use! | |
| return model, torch.cat(poem_embeddings) | |