Spaces:
Runtime error
Runtime error
| import os | |
| import gc | |
| import numpy as np | |
| import pandas as pd | |
| from tqdm import tqdm | |
| import random | |
| import json | |
| import torch | |
| from torch import nn | |
| #FIX | |
| import config as CFG | |
| from models import CLIPModel | |
| from utils import AvgMeter, get_lr | |
| from utils import get_datasets, build_loaders | |
| def train_epoch(model, train_loader, optimizer, lr_scheduler, step): | |
| """ | |
| Performs one epoch of training. | |
| Parameters: | |
| ----------- | |
| model: PoemTextModel or CLIPModel | |
| model to train | |
| train_loader: torch.utils.data.DataLoader | |
| dataloader to get batches from | |
| optimizer: torch.optim.Optimizer | |
| optimizer used for training | |
| lr_scheduler: torch.optim.lr_scheduler.LRScheduler | |
| scheduler used for training | |
| step: str ("batch" or "epoch") | |
| if "batch", lr_scheduler will step (update) for each batch of loader. | |
| else lr_scheduler only steps and updates after finishing each epoch. | |
| Returns: | |
| -------- | |
| loss_meter: AvgMeter | |
| the class containing average loss of this epoch's training | |
| """ | |
| loss_meter = AvgMeter() # to track average of loss | |
| tqdm_object = tqdm(train_loader, total=len(train_loader)) | |
| for batch_cpu in tqdm_object: | |
| # put batch data on device | |
| batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]} | |
| if "image" in batch_cpu: | |
| batch["image"] = batch_cpu["image"].to(CFG.device) | |
| #get model's embeddings and calculate loss | |
| poem_or_img_embeddings, text_embeddings = model(batch) | |
| loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings) | |
| # backpropagate and step | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if step == "batch": | |
| lr_scheduler.step() | |
| #update training info | |
| count = batch["text"]["input_ids"].size(0) | |
| loss_meter.update(loss.item(), count) | |
| tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer)) | |
| # print('train loss: ', loss_meter.avg) | |
| return loss_meter | |
| def valid_epoch(model, valid_loader): | |
| """ | |
| Performs one epoch of validation. | |
| Parameters: | |
| ----------- | |
| model: PoemTextModel or CLIPModel | |
| model to validate | |
| valid_loader: torch.utils.data.DataLoader | |
| dataloader to get batches from. | |
| Returns: | |
| -------- | |
| loss_meter: AvgMeter | |
| the class containing average loss of this epoch's validation | |
| """ | |
| loss_meter = AvgMeter() # to track average of loss | |
| tqdm_object = tqdm(valid_loader, total=len(valid_loader)) | |
| for batch_cpu in tqdm_object: | |
| # put batch data on device | |
| batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]} | |
| if "image" in batch_cpu: | |
| batch["image"] = batch_cpu["image"].to(CFG.device) | |
| #get model's embeddings and calculate loss | |
| poem_or_img_embeddings, text_embeddings = model(batch) | |
| loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings) | |
| #update validation info | |
| count = batch["text"]["input_ids"].size(0) | |
| loss_meter.update(loss.item(), count) | |
| tqdm_object.set_postfix(valid_loss=loss_meter.avg) | |
| # print('validation loss: ', loss_meter.avg) | |
| return loss_meter | |
| def test(model, test_dataset): | |
| """ | |
| Calculates accuracy on test set. | |
| This method is used for the PoemTextModel, since the other model (CLIPModel) does not have a test set containing pairs of image-poem. | |
| Parameters: | |
| ----------- | |
| model: PoemTextModel | |
| model to test | |
| test_dataset: list of dict | |
| the list containing dict of data to perform test on (must have "text" and "poem" keys) | |
| Returns: | |
| -------- | |
| accuracy: np.float | |
| The accuracy of model on the test set given | |
| """ | |
| test_loader = build_loaders(test_dataset, mode="test") | |
| accuracy = 0 | |
| tqdm_object = tqdm(test_loader, total=len(test_loader)) | |
| model.eval() | |
| with torch.no_grad(): | |
| for batch_cpu in tqdm_object: | |
| # put batch data on device | |
| batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]} | |
| if "image" in batch_cpu: | |
| batch["image"] = batch_cpu["image"].to(CFG.device) | |
| # get model's prediction for each text (a numpy array of index/labels showing which poem belongs to which text) | |
| pred = model.predict(batch).cpu().numpy() | |
| count = batch["text"]["input_ids"].size(0) | |
| # since each text is associated with the poem with the same index as it, np.arange(count) is the real labels. | |
| acc = np.sum(pred == np.arange(count)) | |
| accuracy += acc | |
| tqdm_object.set_postfix(accuracy=acc / count) | |
| accuracy /= len(test_dataset) | |
| return accuracy | |
| def train(model, train_loader, valid_loader, epochs=CFG.epochs): | |
| """ | |
| Performs train and validation for (epochs) epochs. | |
| Parameters: | |
| ----------- | |
| model: PoemTextModel or CLIPModel | |
| model to train | |
| train_loader: torch.utils.data.DataLoader | |
| train dataloader to get batches from | |
| valid_loader: torch.utils.data.DataLoader | |
| validation dataloader to get batches from | |
| epochs: int, optional | |
| the number of epochs to train | |
| Returns: | |
| -------- | |
| model: PoemTextModel or CLIPModel | |
| trained model | |
| loss_history: dict | |
| a dict containing train and validation average loss for each epoch. | |
| """ | |
| # Using AdamW optimizer and ReduceLROnPlateau lr-scheduler with settings from config | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay | |
| ) | |
| lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode="min", patience=CFG.patience, factor=CFG.factor | |
| ) | |
| # if step="batch", lr_scheduler will step (update) for each batch of loader. | |
| # else lr_scheduler only steps and updates after finishing each epoch. (this case) | |
| step = "epoch" | |
| loss_history = {"train":[], "valid":[]} | |
| # to keep track of best validation loss | |
| best_loss = float('inf') | |
| for epoch in range(CFG.epochs): | |
| print(f"Epoch: {epoch + 1}") | |
| # train for one epoch | |
| model.train() | |
| train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step) | |
| loss_history["train"].append(train_loss.avg) | |
| # validate trained model | |
| model.eval() | |
| with torch.no_grad(): | |
| valid_loss = valid_epoch(model, valid_loader) | |
| loss_history["valid"].append(valid_loss.avg) | |
| # if this epoch's avg validation loss is lower than best loss, save and keep this model. | |
| if valid_loss.avg < best_loss: | |
| best_loss = valid_loss.avg | |
| model.save_current() | |
| print("Saved Best Model!") | |
| if step == "epoch": | |
| lr_scheduler.step(valid_loss.avg) | |
| return model, loss_history |