Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import os | |
| import numpy as np | |
| model_path = 'srcs/model_modify.pth' | |
| # токенизатор | |
| tokenizer = AutoTokenizer.from_pretrained('cointegrated/rubert-tiny-toxicity') | |
| model = AutoModelForSequenceClassification.from_pretrained('cointegrated/rubert-tiny-toxicity', num_labels=1, ignore_mismatched_sizes=True) | |
| # весов модифицированной модели | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False) | |
| image = Image.open("media/oritoxic.jpg") | |
| df = pd.read_csv("media/Toxic_labeled.csv") | |
| loss_values = [0.4063596375772262, 0.402279906166038, 0.3998144585561736, 0.39567733055365567, | |
| 0.3921396666608141, 0.38956182373070186, 0.3866641920902114, 0.3879134839351564, | |
| 0.38288725781591604, 0.38198364493999004] | |
| #Боковая панель | |
| selected_option = st.sidebar.selectbox("Выберите из списка", ["Определение токсичность текста", "Информация о датасете", "Информация о модели"]) | |
| #st.title("Главная страница") | |
| if selected_option == "Определение токсичность текста": | |
| st.markdown("<h1 style='text-align: center;'>Приложение для определения токсичности текста</h1>", | |
| unsafe_allow_html=True) | |
| st.image(image, use_column_width=True) | |
| user_input = st.text_area("") | |
| # Функция предсказания токсичности | |
| def predict_toxicity(text): | |
| inputs = tokenizer(text, return_tensors="pt") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probability = torch.sigmoid(logits).item() | |
| prediction = "токсичный" if probability >= 0.5 else "не токсичный" | |
| return prediction, probability | |
| # Тык на кнопу | |
| if st.button("Оценить токсичность"): | |
| if user_input: | |
| prediction, toxicity_probability = predict_toxicity(user_input) | |
| st.write(f'Вероятность токсичности: {toxicity_probability:.4f}') | |
| # Прогресс бар | |
| if 'toxicity_probability' in locals(): | |
| progress_percentage = int(toxicity_probability * 100) | |
| progress_bar_color = f'linear-gradient(to right, rgba(0, 0, 255, 0.5) {progress_percentage}%, rgba(255, 0, 0, 0.5) {progress_percentage}%)' | |
| st.markdown(f'<div style="background: {progress_bar_color}; height: 20px; border-radius: 5px;"></div>', | |
| unsafe_allow_html=True) | |
| elif selected_option == "Информация о датасете": | |
| st.header("Информация о датасете:") | |
| st.dataframe(df.head()) | |
| st.write(f"Объем выборки: 14412") | |
| st.subheader("Баланс классов в датасете:") | |
| st.write(f"Количество записей в классе 0.0: {len(df[df['toxic'] == 0.0])}") | |
| st.write(f"Количество записей в классе 1.0: {len(df[df['toxic'] == 1.0])}") | |
| fig, ax = plt.subplots() | |
| df['toxic'].value_counts().plot(kind='bar', ax=ax, color=['skyblue', 'orange']) | |
| ax.set_xticklabels(['Не токсичный', 'Токсичный'], rotation=0) | |
| ax.set_xlabel('Класс') | |
| ax.set_ylabel('Количество записей') | |
| ax.set_title('Распределение по классам') | |
| st.pyplot(fig) | |
| elif selected_option == "Информация о модели": | |
| st.subheader("Информация о модели:") | |
| st.write(f"Модель: Rubert tiny toxicity") | |
| st.subheader("Информация о процессе обучения") | |
| # график лосса | |
| #st.subheader("График потерь в процессе обучения") | |
| #st.line_chart([0.5181976270121774, 0.4342067330899996, 0.41386983832460666]) # Замените данными из ваших эпох | |
| for epoch, loss in enumerate(loss_values, start=1): | |
| st.write(f"<b>Epoch {epoch}/{len(loss_values)}, Loss:</b> {loss}<br>", unsafe_allow_html=True) | |
| st.markdown( | |
| """ | |
| <b>Количество эпох:</b> 10 | |
| <b>Размер батча:</b> 8 | |
| <b>Оптимизатор:</b> Adam | |
| <b>Функция потерь:</b> BCEWithLogitsLoss | |
| <b>learning rate:</b> 0.00001 | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| st.subheader("Метрики модели:") | |
| st.write(f"Accuracy: {0.8366:.4f}") | |
| st.write(f"Precision: {0.8034:.4f}") | |
| st.write(f"Recall: {0.6777:.4f}") | |
| st.write(f"F1 Score: {0.7352:.4f}") | |
| st.subheader("Код") | |
| bert_model_code = """ | |
| model = BertModel( | |
| embeddings=BertEmbeddings( | |
| word_embeddings=Embedding(29564, 312, padding_idx=0), | |
| position_embeddings=Embedding(512, 312), | |
| token_type_embeddings=Embedding(2, 312), | |
| LayerNorm=LayerNorm((312,), eps=1e-12, elementwise_affine=True), | |
| dropout=Dropout(p=0.1, inplace=False), | |
| ), | |
| encoder=BertEncoder( | |
| layer=ModuleList( | |
| BertLayer( | |
| attention=BertAttention( | |
| self=BertSelfAttention( | |
| query=Linear(in_features=312, out_features=312, bias=True), | |
| key=Linear(in_features=312, out_features=312, bias=True), | |
| value=Linear(in_features=312, out_features=312, bias=True), | |
| dropout=Dropout(p=0.1, inplace=False), | |
| ), | |
| output=BertSelfOutput( | |
| dense=Linear(in_features=312, out_features=312, bias=True), | |
| LayerNorm=LayerNorm((312,), eps=1e-12, elementwise_affine=True), | |
| dropout=Dropout(p=0.1, inplace=False), | |
| ), | |
| ), | |
| intermediate=BertIntermediate( | |
| dense=Linear(in_features=312, out_features=600, bias=True), | |
| intermediate_act_fn=GELUActivation(), | |
| ), | |
| output=BertOutput( | |
| dense=Linear(in_features=600, out_features=312, bias=True), | |
| LayerNorm=LayerNorm((312,), eps=1e-12, elementwise_affine=True), | |
| dropout=Dropout(p=0.1, inplace=False), | |
| ), | |
| ) | |
| ) | |
| ), | |
| pooler=BertPooler( | |
| dense=Linear(in_features=312, out_features=312, bias=True), | |
| activation=Tanh(), | |
| ), | |
| dropout=Dropout(p=0.1, inplace=False), | |
| classifier=Linear(in_features=312, out_features=1, bias=True), | |
| ) | |
| """ | |
| # Отображение кода в Streamlit | |
| st.code(bert_model_code, language="python") |