Spaces:
Runtime error
Runtime error
| import torch | |
| import matplotlib.pyplot as plt | |
| from tqdm.autonotebook import tqdm | |
| import pywt | |
| import os | |
| def display_eval(epoch, epochs, tlength, global_step, tcorrect, tsamples, t_valid_samples, average_train_loss, average_valid_loss, total_acc_val): | |
| tqdm.write( | |
| f'Epoch: [{epoch + 1}/{epochs}], Step [{global_step}/{epochs*tlength}] | Train Loss: {average_train_loss: .3f} \ | |
| | Train Accuracy: {tcorrect / tsamples: .3f} \ | |
| | Val Loss: {average_valid_loss: .3f} \ | |
| | Val Accuracy: {total_acc_val / t_valid_samples: .3f}') | |
| def save_model(model, optimizer, valid_loss, epoch, path='model.pt'): | |
| torch.save({'valid_loss': valid_loss, | |
| 'model_state_dict': model.state_dict(), | |
| 'epoch': epoch + 1, | |
| 'optimizer': optimizer.state_dict() | |
| }, path) | |
| tqdm.write(f'Model saved to ==> {path}') | |
| def save_metrics(train_loss_list, valid_loss_list, global_steps_list, path='metrics.pt'): | |
| torch.save({'train_loss_list': train_loss_list, | |
| 'valid_loss_list': valid_loss_list, | |
| 'global_steps_list': global_steps_list, | |
| }, path) | |
| def plot_losses(metrics_save_name='metrics', save_dir='./'): | |
| path = f'{save_dir}metrics_{metrics_save_name}.pt' | |
| state = torch.load(path) | |
| train_loss_list = state['train_loss_list'] | |
| valid_loss_list = state['valid_loss_list'] | |
| global_steps_list = state['global_steps_list'] | |
| plt.plot(global_steps_list, train_loss_list, label='Train') | |
| plt.plot(global_steps_list, valid_loss_list, label='Valid') | |
| plt.xlabel('Global Steps') | |
| plt.ylabel('Loss') | |
| plt.legend() | |
| plt.show() | |
| def train_RNN(epochs, train_loader, valid_loader, model, loss_fn, optimizer, eval_every=0.25, best_valid_loss=float("Inf"), device='cuda', model_save_name='', save_dir='./'): | |
| model.train() | |
| running_loss = 0.0 | |
| valid_running_loss = 0.0 | |
| global_step = 0 | |
| train_loss_list = [] | |
| valid_loss_list = [] | |
| global_steps_list = [] | |
| wavelet = 'db4' | |
| level = 3 | |
| for epoch in tqdm(range(epochs)): | |
| running_loss = 0.0 | |
| t_correct = 0 | |
| t_samples = 0 | |
| for images, labels, notes in train_loader: | |
| optimizer.zero_grad() | |
| coeffs = pywt.wavedec(images, wavelet, level=level, axis=1) | |
| threshold = 0.1 * \ | |
| torch.median(torch.abs(torch.from_numpy(coeffs[-1]))) | |
| denoised_coeffs = [pywt.threshold( | |
| data=c, mode='hard', value=threshold) for c in coeffs] | |
| images = pywt.waverec(denoised_coeffs, wavelet, axis=1) | |
| images = torch.tensor(images).float().to(device) | |
| labels = labels.to(device) | |
| notes = notes.to(device) | |
| output = model(images, notes) | |
| loss = loss_fn(output, labels.float()) | |
| running_loss += loss.item()*len(labels) | |
| loss.backward() | |
| global_step += 1*len(images) | |
| optimizer.step() | |
| values, indices = torch.max(output, dim=1) | |
| t_correct += sum(1 for s, i in enumerate(indices) | |
| if labels[s][i] == 1) | |
| t_samples += len(indices) | |
| if (global_step % (int(eval_every*len(train_loader.dataset)))) < train_loader.batch_size: | |
| model.eval() | |
| valid_running_loss = 0.0 | |
| total_acc_val = 0 | |
| with torch.no_grad(): | |
| for images, labels, notes in valid_loader: | |
| coeffs = pywt.wavedec( | |
| images, wavelet, level=level, axis=1) | |
| threshold = 0.1 * \ | |
| torch.median( | |
| torch.abs(torch.from_numpy(coeffs[-1]))) | |
| denoised_coeffs = [pywt.threshold( | |
| data=c, mode='hard', value=threshold) for c in coeffs] | |
| images = pywt.waverec(denoised_coeffs, wavelet, axis=1) | |
| images = torch.tensor(images).float().to(device) | |
| labels = labels.to(device) | |
| notes = notes.to(device) | |
| output = model(images, notes) | |
| loss = loss_fn(output, labels.float()).item() | |
| valid_running_loss += loss*len(images) | |
| values, indices = torch.max(output, dim=1) | |
| total_acc_val += sum(1 for s, | |
| i in enumerate(indices) if labels[s][i] == 1) | |
| # evaluation | |
| average_train_loss = running_loss / t_samples | |
| average_valid_loss = valid_running_loss / \ | |
| len(valid_loader.dataset) | |
| train_loss_list.append(average_train_loss) | |
| valid_loss_list.append(average_valid_loss) | |
| global_steps_list.append(global_step) | |
| display_eval(epoch, epochs, len(train_loader.dataset), global_step, t_correct, t_samples, len( | |
| valid_loader.dataset), average_train_loss, average_valid_loss, total_acc_val) | |
| # resetting running values | |
| model.train() | |
| if best_valid_loss > average_valid_loss: | |
| best_valid_loss = average_valid_loss | |
| save_model(model, optimizer, best_valid_loss, epoch, | |
| path=f'{save_dir}model_{model_save_name}.pt') | |
| save_metrics(train_loss_list, valid_loss_list, | |
| global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt') | |
| save_metrics(train_loss_list, valid_loss_list, global_steps_list, | |
| path=f'{save_dir}metrics_{model_save_name}.pt') | |
| print("Training complete.") | |
| return model | |
| def evaluate_RNN(model, test_loader, device="cuda"): | |
| model.eval() | |
| y_pred = [] | |
| y_true = [] | |
| wavelet = 'db4' | |
| level = 3 | |
| total_acc_test = 0 | |
| with torch.no_grad(): | |
| for images, labels, notes in test_loader: | |
| coeffs = pywt.wavedec(images, wavelet, level=level, axis=1) | |
| threshold = 0.1 * \ | |
| torch.median(torch.abs(torch.from_numpy(coeffs[-1]))) | |
| denoised_coeffs = [pywt.threshold( | |
| data=c, mode='hard', value=threshold) for c in coeffs] | |
| images = pywt.waverec(denoised_coeffs, wavelet, axis=1) | |
| images = torch.tensor(images).float().to(device) | |
| labels = labels.to(device) | |
| notes = notes.to(device) | |
| output = model(images, notes) | |
| values, indices = torch.max(output, dim=1) | |
| y_pred.extend(indices.tolist()) | |
| y_true.extend(labels.tolist()) | |
| total_acc_test += sum(1 for s, | |
| i in enumerate(indices) if labels[s][i] == 1) | |
| test_accuracy = total_acc_test / len(test_loader.dataset) | |
| print(f'Test Accuracy: {test_accuracy: .3f}') | |
| return test_accuracy | |
| def rename_with_acc(save_name, save_dir, acc): | |
| acc = round(acc*100) | |
| # Rename model | |
| new_model_name = f'{save_dir}model_{save_name}_acc_{acc}.pt' | |
| new_metrics_name = f'{save_dir}metrics_{save_name}_acc_{acc}.pt' | |
| if os.path.isfile(new_model_name): | |
| os.remove(new_model_name) | |
| if os.path.isfile(new_metrics_name): | |
| os.remove(new_metrics_name) | |
| os.rename(f'{save_dir}model_{save_name}.pt', | |
| f'{save_dir}model_{save_name}_acc_{acc}.pt') | |
| # Rename metrics | |
| os.rename(f'{save_dir}metrics_{save_name}.pt', | |
| f'{save_dir}metrics_{save_name}_acc_{acc}.pt') | |