Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import transformers | |
| from transformers import AutoTokenizer,AutoModel | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import pandas as pd | |
| import re | |
| teencode_df = pd.read_csv('teencode.txt',names=['teencode','map'],sep='\t',) | |
| teencode_list = teencode_df['teencode'].to_list() | |
| map_list = teencode_df['map'].to_list() | |
| class BCNN(nn.Module): | |
| def __init__(self, embedding_dim, output_dim, | |
| dropout,bidirectional_units,conv_filters): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained('vinai/phobert-base-v2') | |
| #.fc_input = nn.Linear(embedding_dim,embedding_dim) | |
| self.bidirectional_lstm = nn.LSTM( | |
| embedding_dim, bidirectional_units, bidirectional=True, batch_first=True | |
| ) | |
| self.conv1 = nn.Conv1d(in_channels=2*bidirectional_units, out_channels=conv_filters[0], kernel_size=4) | |
| self.conv2 = nn.Conv1d(in_channels=2*bidirectional_units, out_channels=conv_filters[1], kernel_size=5) | |
| self.fc = nn.Linear(64, output_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self,b_input_ids,b_input_mask): | |
| encoded = self.bert(b_input_ids,b_input_mask)[0] | |
| embedded, _ = self.bidirectional_lstm(encoded) | |
| embedded = embedded.permute(0, 2, 1) | |
| conved_1 = F.relu(self.conv1(embedded)) | |
| conved_2 = F.relu(self.conv2(embedded)) | |
| #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1] | |
| pooled_1 = F.max_pool1d(conved_1, conved_1.shape[2]).squeeze(2) | |
| pooled_2 = F.max_pool1d(conved_2, conved_2.shape[2]).squeeze(2) | |
| #pooled_n = [batch size, n_fibatlters] | |
| cat = self.dropout(torch.cat((pooled_1, pooled_2), dim = 1)) | |
| #cat = [batch size, n_filters * len(filter_sizes)] | |
| result = self.fc(cat) | |
| return result | |
| class TextClassificationApp: | |
| def __init__(self, model_path, class_names, model_name='vinai/phobert-base-v2'): | |
| """ | |
| Initialize Streamlit Text Classification App | |
| Args: | |
| model_path (str): Path to the pre-trained .pt model file | |
| class_names (list): List of classification labels | |
| model_name (str): Hugging Face model name for tokenization | |
| """ | |
| # Set up Streamlit page | |
| # Custom CSS for justice-themed design | |
| # Streamlit page configuration | |
| st.set_page_config( | |
| page_title="⚖️ Text Justice Classifier", | |
| page_icon="⚖️", | |
| layout="wide" | |
| ) | |
| # Device configuration | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Load the model | |
| EMBEDDING_DIM = 768 | |
| OUTPUT_DIM = 2 | |
| DROPOUT = 0.1 | |
| CONV_FILTERS = [32, 32] # Number of filters for each kernel size (4 and 5) | |
| BIDIRECTIONAL_UNITS = 128 | |
| self.model = BCNN(EMBEDDING_DIM, OUTPUT_DIM, DROPOUT, BIDIRECTIONAL_UNITS, CONV_FILTERS) | |
| self.model = torch.load(r'toxic.pt',map_location=torch.device('cpu'),weights_only = False) | |
| self.model.eval() # Set to evaluation mode | |
| # Store class names | |
| self.class_names = class_names | |
| # Maximum sequence length | |
| self.max_length = 128 | |
| def remove_dub_char(self, sentence): | |
| sentence = str(sentence) | |
| words = [] | |
| for word in sentence.strip().split(): | |
| if word in teencode_list: | |
| words.append(word) | |
| continue | |
| words.append(re.sub(r'([A-Z])\1+', lambda m: m.group(1), word, flags = re.IGNORECASE)) | |
| return ' '.join(words) | |
| def preprocess_text(self, text): | |
| """ | |
| Preprocess input text for model prediction | |
| Args: | |
| text (str): Input text to classify | |
| Returns: | |
| torch.Tensor: Tokenized and encoded input | |
| """ | |
| # Tokenize and encode the text | |
| text = self.remove_dub_char(text) | |
| input_ids = [] | |
| attention_masks = [] | |
| encoded = self.tokenizer.encode_plus( | |
| text, | |
| add_special_tokens=True, | |
| max_length=self.max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| input_ids.append(encoded['input_ids'].to(self.device)) | |
| attention_masks.append(encoded['attention_mask'].to(self.device)) | |
| input_ids = torch.cat(input_ids, dim=0).to(self.device) | |
| attention_masks = torch.cat(attention_masks, dim=0).to(self.device) | |
| return input_ids, attention_masks | |
| def predict(self, text): | |
| """ | |
| Make prediction on the input text | |
| Args: | |
| text (str): Input text to classify | |
| Returns: | |
| tuple: (predicted class, probabilities) | |
| """ | |
| # Preprocess the text | |
| inputs,mask = self.preprocess_text(text) | |
| # Disable gradient calculation | |
| with torch.no_grad(): | |
| # Get model outputs | |
| outputs = self.model(inputs,mask) | |
| # Apply softmax to get probabilities | |
| probabilities = torch.softmax(outputs, dim=1) | |
| # Get top predictions | |
| top_probs, top_classes = torch.topk(probabilities, k=1) | |
| return top_classes[0].cpu().numpy(), top_probs[0].cpu().numpy() | |
| def run(self): | |
| """ | |
| Main Streamlit app runner | |
| """ | |
| # Title and description | |
| st.title("📄 Toxic Classification") | |
| st.write("Enter text to classify") | |
| # Text input | |
| text_input = st.text_area( | |
| "Paste your text here", | |
| height=250, | |
| placeholder="Enter the text you want to classify..." | |
| ) | |
| # Prediction button | |
| if st.button("Classify Text"): | |
| if text_input.strip(): | |
| # Make prediction | |
| top_classes, top_probs = self.predict(text_input) | |
| # Display results | |
| st.subheader("Classification Results") | |
| # Create columns for results | |
| cols = st.columns(3) | |
| for i, (cls, prob) in enumerate(zip(top_classes, top_probs)): | |
| with cols[i]: | |
| st.metric( | |
| label=f"Top {i+1} Prediction", | |
| value=f"{self.class_names[cls]}", | |
| delta=f"{prob:.2%}" | |
| ) | |
| # Show input text details | |
| with st.expander("Input Text Details"): | |
| st.write("**Original Text:**") | |
| st.write(text_input) | |
| st.write(f"**Text Length:** {len(text_input)} characters") | |
| else: | |
| st.warning("Please enter some text to classify") | |
| def main(): | |
| # Replace these with your actual model path and class names | |
| MODEL_PATH = 'toxic.pt' | |
| CLASS_NAMES = [ | |
| 'Non-toxic', | |
| 'Toxic' | |
| ] | |
| # Initialize and run the app | |
| app = TextClassificationApp(MODEL_PATH, CLASS_NAMES) | |
| app.run() | |
| if __name__ == "__main__": | |
| main() |