Spaces:
Sleeping
Sleeping
| from utils.model import BiLSTMAttentionBERT, BiLSTMConfig | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| from sklearn.preprocessing import LabelEncoder | |
| import numpy as np | |
| import streamlit as st | |
| import requests | |
| from huggingface_hub import hf_hub_download | |
| def load_model_for_prediction(): | |
| try: | |
| st.write("Starting model loading...") | |
| # Initialize BERT first | |
| bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2') | |
| # Initialize config and model | |
| config = BiLSTMConfig( | |
| hidden_dim=128, | |
| num_classes=22, | |
| num_layers=2, | |
| dropout=0.5 | |
| ) | |
| model = BiLSTMAttentionBERT(config) | |
| model.bert = bert # Set pre-trained BERT | |
| # Load custom layers from checkpoint | |
| model_path = hf_hub_download( | |
| repo_id="joko333/BiLSTM_v01", | |
| filename="model_epoch8_acc72.53.pt" | |
| ) | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| # Debug checkpoint structure | |
| st.write("Checkpoint keys:", checkpoint.keys()) | |
| if 'model_state_dict' in checkpoint: | |
| # Extract only custom layer weights | |
| custom_state_dict = {} | |
| state_dict = checkpoint['model_state_dict'] | |
| for key, value in state_dict.items(): | |
| if not key.startswith('bert.'): | |
| custom_state_dict[key] = value | |
| # Load custom layers | |
| model.load_state_dict(custom_state_dict, strict=False) | |
| st.write("Model loaded successfully") | |
| else: | |
| st.error("Invalid checkpoint format") | |
| return None, None, None | |
| # Initialize label encoder from checkpoint | |
| label_encoder = LabelEncoder() | |
| if 'label_encoder_classes' in checkpoint: | |
| label_encoder.classes_ = checkpoint['label_encoder_classes'] | |
| else: | |
| st.error("Label encoder data not found in checkpoint") | |
| return None, None, None | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2') | |
| return model, label_encoder, tokenizer | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| return None, None, None | |
| def predict_sentence(model, sentence, tokenizer, label_encoder): | |
| """ | |
| Make prediction for a single sentence with label validation. | |
| """ | |
| import time | |
| start_time = time.time() | |
| # Validation checks | |
| st.write("π Starting prediction process...") | |
| if model is None: | |
| st.error("Error: Model not loaded") | |
| return "Error: Model not loaded", 0.0 | |
| if tokenizer is None: | |
| st.error("Error: Tokenizer not loaded") | |
| return "Error: Tokenizer not loaded", 0.0 | |
| if label_encoder is None: | |
| st.error("Error: Label encoder not loaded") | |
| return "Error: Label encoder not loaded", 0.0 | |
| # Force CPU device | |
| st.write("βοΈ Preparing model...") | |
| device = torch.device('cpu') | |
| model = model.to(device) | |
| model.eval() | |
| # Tokenize | |
| try: | |
| st.write(f"π Processing text: {sentence[:50]}...") | |
| encoding = tokenizer( | |
| sentence, | |
| add_special_tokens=True, | |
| max_length=512, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(device) | |
| st.write("π€ Running model inference...") | |
| with torch.no_grad(): | |
| outputs = model(encoding['input_ids'], encoding['attention_mask']) | |
| probabilities = torch.softmax(outputs, dim=1) | |
| prob, pred_idx = torch.max(probabilities, dim=1) | |
| predicted_label = label_encoder.classes_[pred_idx.item()] | |
| elapsed_time = time.time() - start_time | |
| st.write(f"β Prediction completed in {elapsed_time:.2f} seconds") | |
| return predicted_label, prob.item() | |
| except Exception as e: | |
| st.error(f"β Prediction error: {str(e)}") | |
| return f"Error: {str(e)}", 0.0 | |
| def print_labels(label_encoder, show_counts=False): | |
| """Print all labels and their corresponding indices""" | |
| print("\nAvailable labels:") | |
| print("-" * 40) | |
| for idx, label in enumerate(label_encoder.classes_): | |
| print(f"Index {idx}: {label}") | |
| print("-" * 40) | |
| print(f"Total number of classes: {len(label_encoder.classes_)}\n") | |