Spaces:
Paused
Paused
| from create_setfit_model import model | |
| from time import perf_counter | |
| import os | |
| import sys | |
| from statistics import mean | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| import torch | |
| from collections import Counter | |
| from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from tqdm import tqdm | |
| start = perf_counter() | |
| dataset_dir = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'financial_dataset')) | |
| sys.path.append(dataset_dir) | |
| from load_test_data import get_labels_df, get_texts | |
| labels_dir = dataset_dir + '/csvs/' | |
| df = get_labels_df(labels_dir) | |
| texts_dir = dataset_dir + '/txts/' | |
| texts = get_texts(texts_dir) | |
| # df = df.iloc[:20, :] | |
| # print(df.loc[:, 'Label']) | |
| # texts = [texts[0]] + [texts[13]] + [texts[113]] | |
| # texts = texts[:20] | |
| print(len(df), len(texts)) | |
| print(mean(list(map(len, texts)))) | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=3200, chunk_overlap=200, | |
| length_function = len, separators=[" ", ",", "\n"] | |
| ) | |
| labels = [] | |
| pred_labels = [] | |
| for text, (idx, (year, label, company)) in tqdm(zip(texts, df.iterrows())): | |
| documents = text_splitter.create_documents([text]) | |
| texts = [document.page_content for document in documents] | |
| with torch.no_grad(): | |
| model.model_head.eval() | |
| text_pred_labels = model(texts) | |
| pred_labels_counter = Counter(text_pred_labels) | |
| pred_label = pred_labels_counter.most_common(1)[0][0] | |
| labels.append(label) | |
| pred_labels.append(pred_label) | |
| accuracy = accuracy_score(labels, pred_labels) | |
| precision = precision_score(labels, pred_labels, average='weighted') | |
| recall = recall_score(labels, pred_labels, average='weighted') | |
| f1 = f1_score(labels, pred_labels, average='weighted') | |
| confusion_mat = confusion_matrix(labels, pred_labels, normalize='true') | |
| print("Accuracy:", accuracy) | |
| print("Precision:", precision) | |
| print("Recall:", recall) | |
| print("F1 Score:", f1) | |
| labels = ['hold', 'buy', 'sell'] | |
| plt.figure(figsize=(8, 6)) | |
| sns.heatmap(confusion_mat, annot=True, fmt='.2%', cmap='Blues', xticklabels=labels, yticklabels=labels) | |
| plt.xlabel('Predicted labels') | |
| plt.ylabel('True labels') | |
| plt.title('Confusion Matrix') | |
| plt.show() | |
| print(f'It took me: {(perf_counter() - start) // 60:.0f} mins {(perf_counter() - start) % 60:.0f} secs') |