Spaces:
Runtime error
Runtime error
| from transformers import BertForSequenceClassification, BertTokenizerFast, Trainer, TrainingArguments | |
| from datasets import load_dataset | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| import gradio as gr | |
| # ❗ Загрузка датасета ZhenDOS/alpha_bank_data | |
| dataset = load_dataset("ZhenDOS/alpha_bank_data") | |
| # ✔️ Загрузка базовой модели и токенайзера | |
| tokenizer = BertTokenizerFast.from_pretrained("DeepPavlov/rubert-base-cased") | |
| model = BertForSequenceClassification.from_pretrained("DeepPavlov/rubert-base-cased", num_labels=len(dataset["train"].features["label"].names)) | |
| # ➕ Токенизация входных данных | |
| def tokenize_function(examples): | |
| return tokenizer(examples["text"], padding="max_length", truncation=True) | |
| tokenized_datasets = dataset.map(tokenize_function, batched=True) | |
| # 🏃♂️ Настройки обучения | |
| training_args = TrainingArguments( | |
| output_dir="./results", | |
| evaluation_strategy="epoch", | |
| learning_rate=2e-5, | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=64, | |
| num_train_epochs=3, | |
| weight_decay=0.01, | |
| ) | |
| # 💨 Процесс обучения | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_datasets["train"], | |
| eval_dataset=tokenized_datasets["validation"], | |
| ) | |
| trainer.train() | |
| # 📊 Функционал для демонстрации через Gradio | |
| def classify_question(question): | |
| tokens = tokenizer(question, return_tensors="pt") | |
| outputs = model(**tokens) | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| pred_label_idx = torch.argmax(probabilities, dim=1).item() | |
| categories = dataset["train"].features["label"].names | |
| return { | |
| "Вероятности классов": dict(zip(categories, probabilities.detach().numpy()[0])), | |
| "Прогнозируемый класс": categories[pred_label_idx], | |
| } | |
| # 🖥️ Графический интерфейс Gradio | |
| demo = gr.Interface( | |
| fn=classify_question, | |
| inputs="text", | |
| outputs=[ | |
| gr.Label(label="Категории"), | |
| gr.Textbox(label="Прогнозируемый класс"), | |
| ], | |
| examples=[ | |
| ["Как перевести деньги между картами?"], | |
| ["Что такое кредитная история?"], | |
| ["Почему моя карта заблокирована?"], | |
| ], | |
| title="Классификация клиентских запросов банка", | |
| description="Приложение помогает определить категорию клиентского запроса и оценить вероятность принадлежности каждого класса.", | |
| ) | |
| demo.launch() | |