Spaces:
Paused
Paused
| from torch import nn | |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix | |
| def get_eval_metric(y_pred, y_test): | |
| return { | |
| 'accuracy': accuracy_score(y_test, y_pred), | |
| 'precision': precision_score(y_test, y_pred, average='weighted'), | |
| 'recall': recall_score(y_test, y_pred, average='weighted'), | |
| 'f1': f1_score(y_test, y_pred, average='weighted'), | |
| 'confusion_mat': confusion_matrix(y_test, y_pred, normalize='true'), | |
| } | |
| class MLP(nn.Module): | |
| def __init__(self, input_size=768, hidden_size=256, output_size=3, dropout_rate=.2, class_weights=None): | |
| super(MLP, self).__init__() | |
| self.class_weights = class_weights | |
| self.activation = nn.ReLU() | |
| # self.activation = nn.Tanh() | |
| # self.activation = nn.LeakyReLU() | |
| # self.activation = nn.Sigmoid() | |
| self.bn1 = nn.BatchNorm1d(hidden_size) | |
| self.dropout = nn.Dropout(dropout_rate) | |
| self.fc1 = nn.Linear(input_size, hidden_size) | |
| self.fc2 = nn.Linear(hidden_size, output_size) | |
| # nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu') | |
| # nn.init.kaiming_normal_(self.fc2.weight) | |
| def forward(self, x): | |
| input_is_dict = False | |
| if isinstance(x, dict): | |
| assert "sentence_embedding" in x | |
| input_is_dict = True | |
| x = x['sentence_embedding'] | |
| # print(x) | |
| x = self.fc1(x) | |
| x = self.bn1(x) | |
| x = self.activation(x) | |
| x = self.dropout(x) | |
| x = self.fc2(x) | |
| if input_is_dict: | |
| return {'logits': x} | |
| return x | |
| def predict(self, x): | |
| _, predicted = torch.max(self.forward(x), 1) | |
| print('I am predict') | |
| return predicted | |
| def predict_proba(self, x): | |
| print('I am predict_proba') | |
| return self.forward(x) | |
| def get_loss_fn(self): | |
| return nn.CrossEntropyLoss(weight=self.class_weights, reduction='mean') | |
| if __name__ == '__main__': | |
| from setfit.__init__ import SetFitModel, Trainer, TrainingArguments | |
| from datasets import Dataset, load_dataset | |
| from sentence_transformers import SentenceTransformer, models, util | |
| from sentence_transformers.losses import BatchAllTripletLoss, BatchHardSoftMarginTripletLoss, BatchHardTripletLoss, BatchSemiHardTripletLoss | |
| from sklearn.linear_model import LogisticRegression | |
| import sys | |
| import os | |
| import warnings | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from datetime import datetime | |
| import torch.optim as optim | |
| from pprint import pprint | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from safetensors.torch import load_model, save_model | |
| from itertools import chain | |
| from time import perf_counter | |
| from tqdm import trange | |
| from collections import Counter | |
| from sklearn.utils.class_weight import compute_class_weight | |
| warnings.filterwarnings("ignore") | |
| SEED = 1003200212 + 1 | |
| DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| start = perf_counter() | |
| dataset = load_dataset("CabraVC/vector_dataset_stratified_ttv_split_2023-12-05_21-07") | |
| class_weights_vect = compute_class_weight('balanced', classes=[0, 1, 2], y=dataset['train']['labels']) | |
| class_weights = torch.tensor(compute_class_weight('balanced', classes=[0, 1, 2], y=dataset['train']['labels']), dtype=torch.float).to(DEVICE) ** .5 | |
| model_body = SentenceTransformer('sentence-transformers/all-distilroberta-v1') | |
| model_head = MLP(hidden_size=256, class_weights=class_weights) # 128 82%acc | |
| model = SetFitModel(model_body=model_body, | |
| model_head=model_head, | |
| labels=dataset['train'].features['labels'].names).to(DEVICE) | |
| train_ds = dataset['train'] | |
| val_ds = dataset['val'].select(range(128)) | |
| test_ds = dataset['test'].select(range(128)) | |
| train_args = TrainingArguments( | |
| seed=SEED, | |
| batch_size=(16, 24), | |
| num_epochs=(15, 16), # 15 best | |
| margin=.5, # .5, 1, .8 1.1 good, .5 best, .4 BEST | |
| loss=BatchSemiHardTripletLoss, | |
| use_amp=True, | |
| body_learning_rate=(3e-6, 4e-5), # 5e-5 for smaller margin=.3, (2e-6, 2-3 e-5) best | |
| l2_weight=7e-3, | |
| evaluation_strategy='epoch', | |
| end_to_end=True, | |
| samples_per_label=4, | |
| max_length=model.model_body.get_max_seq_length() | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=train_args, | |
| train_dataset=train_ds, | |
| eval_dataset=val_ds, | |
| metric=get_eval_metric, | |
| column_mapping={'texts': 'text', 'labels': 'label'}, | |
| ) | |
| print('Test unseen data') | |
| metrics = trainer.evaluate(test_ds) | |
| pprint(metrics) | |
| trainer.train() | |
| print('Test on train data') | |
| metrics = trainer.evaluate(train_ds) | |
| pprint(metrics) | |
| print('Test unseen data') | |
| metrics = trainer.evaluate(test_ds) | |
| pprint(metrics) | |
| trainer.push_to_hub('CabraVC/emb_classifier_model', | |
| private=True) | |
| print('-' * 50) | |
| print('Successfully trained the model.') | |
| print(f'It took me: {(perf_counter() - start) // 60:.0f} mins {(perf_counter() - start) % 60:.0f} secs') |