Spaces:
Paused
Paused
| def read_and_split_file(filename, chunk_size=1200, chunk_overlap=200): | |
| with open(filename, 'r') as f: | |
| text = f.read() | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, chunk_overlap=chunk_overlap, | |
| length_function = len, separators=[" ", ",", "\n"] | |
| ) | |
| # st.write(f'Financial report char len: {len(text)}') | |
| texts = text_splitter.create_documents([text]) | |
| return texts | |
| if __name__ == '__main__': | |
| # Comments and ideas to implement: | |
| # 1. Try sending list of inputs to the Inference API. | |
| import streamlit as st | |
| from sys import exit | |
| from pprint import pprint | |
| from collections import Counter | |
| from itertools import zip_longest | |
| from random import choice | |
| import requests | |
| from re import sub | |
| from rouge import Rouge | |
| from time import sleep, perf_counter | |
| import os | |
| from textwrap import wrap | |
| from multiprocessing import Pool, freeze_support | |
| from tqdm import tqdm | |
| from stqdm import stqdm | |
| from langchain.document_loaders import TextLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.schema.document import Document | |
| # from langchain.schema import Document | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.llms import OpenAI | |
| from langchain.schema import AIMessage, HumanMessage, SystemMessage | |
| from langchain.prompts import PromptTemplate | |
| from datasets import Dataset, load_dataset | |
| from sklearn.preprocessing import LabelEncoder | |
| from test_models.train_classificator import MLP | |
| from safetensors.torch import load_model, save_model | |
| from sentence_transformers import SentenceTransformer | |
| from torch.utils.data import DataLoader, TensorDataset | |
| import torch.nn.functional as F | |
| import torch | |
| import torch.nn as nn | |
| import sys | |
| sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/'))) | |
| sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/financial-roberta'))) | |
| st.set_page_config( | |
| page_title="Financial advisor", | |
| page_icon="๐ณ๐ฐ", | |
| layout="wide", | |
| ) | |
| # st.session_state.summarized = False | |
| with st.sidebar: | |
| "# How to use๐" | |
| """ | |
| โจThis is a holiday version of the web-UI with the magic ๐, allowing you to unwrap | |
| label predictions for a company based on its financial report text! ๐โจ The prediction | |
| enchantment is performed using the sophisticated embedding classifier approach. ๐๐ฎ | |
| """ | |
| center_style = "<h3 style='text-align: center; color: black;'>{} </h3>" | |
| st.markdown(center_style.format('Load the financial report'), unsafe_allow_html=True) | |
| upload_types = ['Text input', 'File upload'] | |
| upload_captions = ['Paste the text', 'Upload a text file'] | |
| upload_type = st.radio('Select how to upload the financial report', upload_types, | |
| captions=upload_captions) | |
| match upload_type: | |
| case 'Text input': | |
| financial_report_text = st.text_area('Something', label_visibility='collapsed', | |
| placeholder='Financial report as TEXT') | |
| case 'File upload': | |
| uploaded_files = st.file_uploader("Choose a a text file", type=['.txt', '.docx'], | |
| label_visibility='collapsed', accept_multiple_files=True) | |
| if not bool(uploaded_files): | |
| st.stop() | |
| financial_report_text = '' | |
| for uploaded_file in uploaded_files: | |
| if uploaded_file.name.endswith("docx"): | |
| document = Document(uploaded_file) | |
| document.save('./utils/texts/' + uploaded_file.name) | |
| document = Document(uploaded_file.name) | |
| financial_report_text += "".join([paragraph.text for paragraph in document.paragraphs]) + '\n' | |
| else: | |
| financial_report_text += "".join([line.decode() for line in uploaded_file]) + '\n' | |
| # with open('./utils/texts/financial_report_text.txt', 'w') as file: | |
| # file.write(financial_report_text) | |
| if st.button('Get label'): | |
| with st.spinner("Thinking..."): | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=3200, chunk_overlap=200, | |
| length_function = len, separators=[" ", ",", "\n"] | |
| ) | |
| # st.write(f'Financial report char len: {len(financial_report_text)}') | |
| documents = text_splitter.create_documents([financial_report_text]) | |
| # st.write(f'Num chunks: {len(documents)}') | |
| texts = [document.page_content for document in documents] | |
| # st.write(f'Each chunk char length: {[len(text) for text in texts]}') | |
| # predicted_label = get_label_prediction(texts) | |
| from test_models.create_setfit_model import model | |
| with torch.no_grad(): | |
| model.model_head.eval() | |
| predicted_labels = model(texts) | |
| # st.write(predicted_labels) | |
| predicted_labels_counter = Counter(predicted_labels) | |
| predicted_label = predicted_labels_counter.most_common(1)[0][0] | |
| font_style = 'The predicted label is<span style="font-size: 32px"> **{}**</span>.' | |
| st.markdown(font_style.format(predicted_label), unsafe_allow_html=True) |