Spaces:
Paused
Paused
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix | |
| import numpy as np | |
| # import torch.nn as nn | |
| torch.set_printoptions(sci_mode=False) | |
| # labels = ['buy', 'hold', 'sell'] | |
| class MLP(nn.Module): | |
| def __init__(self, input_size=768, hidden_size=256, output_size=3, dropout_rate=.2, class_weights=None): | |
| super(MLP, self).__init__() | |
| self.class_weights = class_weights | |
| self.activation = nn.ReLU() | |
| # self.activation = nn.Tanh() | |
| # self.activation = nn.LeakyReLU() | |
| # self.activation = nn.Sigmoid() | |
| self.bn1 = nn.BatchNorm1d(hidden_size) | |
| self.dropout = nn.Dropout(dropout_rate) | |
| self.fc1 = nn.Linear(input_size, hidden_size) | |
| self.fc2 = nn.Linear(hidden_size, output_size) | |
| # nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu') | |
| # nn.init.kaiming_normal_(self.fc2.weight) | |
| def forward(self, x): | |
| input_is_dict = False | |
| if isinstance(x, dict): | |
| assert "sentence_embedding" in x | |
| input_is_dict = True | |
| x = x['sentence_embedding'] | |
| # print(x) | |
| x = self.fc1(x) | |
| x = self.bn1(x) | |
| x = self.activation(x) | |
| x = self.dropout(x) | |
| x = self.fc2(x) | |
| if input_is_dict: | |
| return {'logits': x} | |
| return x | |
| def predict(self, x): | |
| _, predicted = torch.max(self.forward(x), 1) | |
| print('I am predict') | |
| return predicted | |
| def predict_proba(self, x): | |
| print('I am predict_proba') | |
| return self.forward(x) | |
| def get_loss_fn(self): | |
| return nn.CrossEntropyLoss(weight=self.class_weights, reduction='mean') | |
| def split_text(text, chunk_size=1200, chunk_overlap=200): | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, chunk_overlap=chunk_overlap, | |
| length_function = len, separators=[" ", ",", "\n"] | |
| ) | |
| text_chunks = text_splitter.create_documents([text]) | |
| return text_chunks | |
| def plot_labels_distribution(dataset, save_as_filename=None): | |
| plt.figure(figsize = (10, 6)) | |
| freqs, bins, _ = plt.hist([ | |
| dataset['train']['labels'], | |
| dataset['val']['labels'], | |
| dataset['test']['labels'] | |
| ], label=['80% - train', '10% - val', '10% - test'], bins=[-.25, .25, .75, 1.25, 1.75, 2.25]) | |
| plt.legend(loc='upper left') | |
| plt.xticks([bin - .25 for bin in bins], ['', 'Buy', '', 'Hold', '', 'Sell'], fontsize=16) | |
| bin_centers = np.diff(bins) * .5 + bins[:-1] | |
| for offset, freq in zip([-.135, 0, .135], freqs): | |
| for fr, x in zip(freq, bin_centers): | |
| height = int(fr) | |
| if height: | |
| plt.annotate("{}".format(height), | |
| xy = (x + offset, height), | |
| xytext = (0, .2), | |
| textcoords = "offset points", | |
| ha = 'center', va = 'bottom' | |
| ) | |
| plt.title('Labels distribution') | |
| if save_as_filename: | |
| plt.savefig(save_as_filename) | |
| plt.show() | |
| def plot_training_metrics(losses, accuracies, show=False, save_as_filename=None): | |
| plt.figure(figsize=(10, 5)) | |
| plt.subplot(1, 2, 1) | |
| plt.plot(losses['train'], label='Training Loss') | |
| plt.plot(losses['val'], label='Validation Loss') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.title('Loss over Epochs') | |
| plt.legend() | |
| plt.subplot(1, 2, 2) | |
| plt.plot(accuracies['train'], label='Training Accuracy') | |
| plt.plot(accuracies['val'], label='Validation Accuracy') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Accuracy') | |
| plt.title('Accuracy over Epochs') | |
| plt.legend() | |
| plt.tight_layout() | |
| if save_as_filename: | |
| plt.savefig(save_as_filename) | |
| if show: | |
| plt.show() | |
| def train_model(model, criterion, optimizer, lr_scheduler, train_loader, val_loader, train_data, val_data, epochs): | |
| print_param = epochs // 8 | |
| losses = { | |
| 'train': [], | |
| 'val': [] | |
| } | |
| accuracies = { | |
| 'train': [], | |
| 'val': [] | |
| } | |
| for epoch in range(epochs): | |
| model.train() | |
| total_loss = 0.0 | |
| correct_predictions = 0 | |
| for inputs, labels in train_loader: | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| _, predicted = torch.max(outputs, 1) | |
| correct_predictions += (predicted == labels).sum().item() | |
| losses['train'].append(total_loss / len(train_loader)) | |
| accuracies['train'].append(correct_predictions / len(train_data)) | |
| if epoch % print_param == 0: | |
| print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_loader)}, Accuracy: {correct_predictions / len(train_data)}") | |
| model.eval() | |
| total_loss = 0.0 | |
| correct_predictions = 0 | |
| for inputs, labels in val_loader: | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() | |
| _, predicted = torch.max(outputs, 1) | |
| correct_predictions += (predicted == labels).sum().item() | |
| losses['val'].append(total_loss / len(val_loader)) | |
| accuracies['val'].append(correct_predictions / len(val_data)) | |
| if epoch % print_param == 0: | |
| print(f"Validation Loss: {total_loss / len(val_loader)}, Accuracy: {correct_predictions / len(val_data)}") | |
| lr_scheduler.step(total_loss / len(val_loader)) | |
| return losses, accuracies | |
| def eval_model(model, criterion, test_loader, test_data, show=False, save_as_filename=None): | |
| total_loss = 0.0 | |
| correct_predictions = 0 | |
| all_labels = [] | |
| all_predictions = [] | |
| with torch.no_grad(): | |
| model.eval() | |
| for inputs, labels in test_loader: | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() | |
| _, predicted = torch.max(outputs, 1) | |
| correct_predictions += (predicted == labels).sum().item() | |
| probabilities = F.softmax(outputs, dim=1) | |
| predicted_labels = torch.argmax(probabilities, dim=1).tolist() | |
| all_labels.extend(labels) | |
| all_predictions.extend(predicted_labels) | |
| loss, accuracy = total_loss / len(test_loader), correct_predictions / len(test_data) | |
| print(f'Model test loss: {loss:2f}, test accurracy: {accuracy * 100:1f}') | |
| accuracy = accuracy_score(all_labels, all_predictions) | |
| precision = precision_score(all_labels, all_predictions, average='weighted') | |
| recall = recall_score(all_labels, all_predictions, average='weighted') | |
| f1 = f1_score(all_labels, all_predictions, average='weighted') | |
| confusion_mat = confusion_matrix(all_labels, all_predictions, normalize='true') | |
| print("Accuracy:", accuracy) | |
| print("Precision:", precision) | |
| print("Recall:", recall) | |
| print("F1 Score:", f1) | |
| labels = ['hold', 'buy', 'sell'] | |
| if show: | |
| plt.figure(figsize=(8, 6)) | |
| sns.heatmap(confusion_mat, annot=True, fmt='.2%', cmap='Blues', xticklabels=labels, yticklabels=labels) | |
| plt.xlabel('Predicted labels') | |
| plt.ylabel('True labels') | |
| plt.title('Confusion Matrix') | |
| if save_as_filename: | |
| plt.savefig(save_as_filename) | |
| if show: | |
| plt.show() | |
| return loss, accuracy | |
| if __name__ == '__main__': | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| import sys | |
| from datetime import datetime | |
| from collections import Counter | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from safetensors.torch import load_model, save_model | |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix | |
| from sklearn.utils.class_weight import compute_class_weight | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| model_name = 'all-distilroberta-v1' | |
| # model_name = 'all-MiniLM-L6-v2' | |
| model = SentenceTransformer(model_name) | |
| dataset = load_dataset("CabraVC/vector_dataset_stratified_ttv_split_2023-12-05_21-07") | |
| # plot_labels_distribution(dataset | |
| # # , save_as_filename=f'plots/labels_distribution_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png' | |
| # ) | |
| input_size = len(dataset['train']['embeddings'][0]) | |
| hidden_size = 256 | |
| dropout_rate = 0.2 | |
| learning_rate = 2 * 1e-4 | |
| batch_size = 256 | |
| epochs = 100 | |
| class_weights = torch.tensor(compute_class_weight('balanced', classes=[0, 1, 2], y=dataset['train']['labels']), dtype=torch.float) ** .5 | |
| model = MLP(input_size=input_size, hidden_size=hidden_size, dropout_rate=dropout_rate, class_weights=class_weights) | |
| criterion = model.get_loss_fn() | |
| # print(class_weights) | |
| # criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=8 * 1e-2) | |
| lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.25, patience=10, threshold=5 * 1e-5, min_lr=1e-7, verbose=True) | |
| train_data = TensorDataset(torch.tensor(dataset['train']['embeddings']), torch.tensor(dataset['train']['labels'])) | |
| train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) | |
| val_data = TensorDataset(torch.tensor(dataset['val']['embeddings']), torch.tensor(dataset['val']['labels'])) | |
| val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True) | |
| losses, accuracies = train_model(model, criterion, optimizer, lr_scheduler, train_loader, val_loader, train_data, val_data, epochs) | |
| plot_training_metrics(losses, accuracies | |
| # , save_as_filename=f'plots/training_metrics_plot_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png' | |
| ) | |
| test_data = TensorDataset(torch.tensor(dataset['test']['embeddings']), torch.tensor(dataset['test']['labels'])) | |
| test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) | |
| loss, accuracy = eval_model(model, criterion, test_loader, test_data, | |
| # save_as_filename=f'plots/confusion_matrix_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png' | |
| ) | |
| # torch.save(model.state_dict(), f'models/head_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.pth') | |
| # save_model(model, f'models/head_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.safetensors') | |
| # load_model(model, f'models/head_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.safetensors') | |
| # print(model) | |
| # dataset.push_to_hub(f'CabraVC/vector_dataset_stratified_ttv_split_{datetime.now().strftime("%Y-%m-%d_%H-%M")}', private=True) | |