Spaces:
Paused
Paused
| import torch | |
| from torch import nn | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # import torch.nn as nn | |
| torch.set_printoptions(sci_mode=False) | |
| class MLP(nn.Module): | |
| def __init__(self, input_size=768, output_size=3, dropout_rate=.2, class_weights=None): | |
| super(MLP, self).__init__() | |
| self.class_weights = class_weights | |
| # self.bn1 = nn.BatchNorm1d(hidden_size) | |
| self.dropout = nn.Dropout(dropout_rate) | |
| self.linear = nn.Linear(input_size, output_size) | |
| # nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu') | |
| # nn.init.kaiming_normal_(self.fc2.weight) | |
| def forward(self, x): | |
| # return self.linear(self.dropout(x)) | |
| return self.dropout(self.linear(x)) | |
| def predict(self, x): | |
| _, predicted = torch.max(self.forward(x), 1) | |
| print('I am predict') | |
| return predicted | |
| def predict_proba(self, x): | |
| print('I am predict_proba') | |
| return self.forward(x) | |
| def get_loss_fn(self): | |
| return nn.CrossEntropyLoss(weight=self.class_weights, reduction='mean') | |
| if __name__ == '__main__': | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| import sys | |
| # from datetime import datetime | |
| # from collections import Counter | |
| import torch | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from safetensors.torch import load_model, save_model | |
| from sklearn.utils.class_weight import compute_class_weight | |
| import warnings | |
| from train_classificator import ( | |
| # MLP, | |
| plot_labels_distribution, | |
| plot_training_metrics, | |
| train_model, | |
| eval_model | |
| ) | |
| warnings.filterwarnings("ignore") | |
| SEED = 1003200212 + 1 | |
| DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| dataset = load_dataset("CabraVC/vector_dataset_roberta-fine-tuned") | |
| # plot_labels_distribution(dataset | |
| # # , save_as_filename=f'plots/labels_distribution_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png' | |
| # ) | |
| input_size = len(dataset['train']['embeddings'][0]) | |
| learning_rate = 5e-4 | |
| weight_decay = 0 | |
| batch_size = 128 | |
| epochs = 40 | |
| class_weights = torch.tensor(compute_class_weight('balanced', classes=[0, 1, 2], y=dataset['train']['labels']), dtype=torch.float) ** .5 | |
| model = MLP(input_size=input_size, class_weights=class_weights) | |
| criterion = model.get_loss_fn() | |
| test_data = TensorDataset(torch.tensor(dataset['test']['embeddings']), torch.tensor(dataset['test']['labels'])) | |
| test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) | |
| loss, accuracy = eval_model(model, criterion, test_loader, test_data, show=False, | |
| # save_as_filename=f'plots/confusion_matrix_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png' | |
| ) | |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) | |
| lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.2, patience=5, threshold=1e-4, min_lr=1e-7, verbose=True) | |
| train_data = TensorDataset(torch.tensor(dataset['train']['embeddings']), torch.tensor(dataset['train']['labels'])) | |
| train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) | |
| val_data = TensorDataset(torch.tensor(dataset['val']['embeddings']), torch.tensor(dataset['val']['labels'])) | |
| val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True) | |
| losses, accuracies = train_model(model, criterion, optimizer, lr_scheduler, train_loader, val_loader, train_data, val_data, epochs) | |
| plot_training_metrics(losses, accuracies | |
| # , save_as_filename=f'plots/training_metrics_plot_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png' | |
| ) | |
| test_data = TensorDataset(torch.tensor(dataset['test']['embeddings']), torch.tensor(dataset['test']['labels'])) | |
| test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) | |
| loss, accuracy = eval_model(model, criterion, test_loader, test_data, show=False | |
| # save_as_filename=f'plots/confusion_matrix_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png' | |
| ) | |
| # torch.save(model.state_dict(), f'models/linear_head.pth') | |
| # save_model(model, f'models/linear_head.safetensors') | |
| # load_model(model, f'models/linear_head.safetensors') | |
| # print(model) | |
| # dataset.push_to_hub(f'CabraVC/vector_dataset_stratified_ttv_split_{datetime.now().strftime("%Y-%m-%d_%H-%M")}', private=True) | |