import re import torch import gradio as gr from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification from PIL import Image, ImageEnhance, ImageOps import torchvision.transforms as T # ===================== CONFIG ===================== VINTERN_PATH = "5CD-AI/Vintern-1B-v3_5" PHOBERT_PATH = "DuyKien016/phobert-scam-detector" # ===================== LOAD MODELS ===================== print("🔄 Loading Vintern model...") vintern_model = AutoModel.from_pretrained( VINTERN_PATH, trust_remote_code=True, torch_dtype="auto", device_map="auto", low_cpu_mem_usage=True ).eval() vintern_tokenizer = AutoTokenizer.from_pretrained( VINTERN_PATH, trust_remote_code=True ) print("✅ Vintern loaded.") print("🔄 Loading PhoBERT model...") phobert_tokenizer = AutoTokenizer.from_pretrained(PHOBERT_PATH, use_fast=False) phobert_model = AutoModelForSequenceClassification.from_pretrained(PHOBERT_PATH).eval().to( "cuda" if torch.cuda.is_available() else "cpu" ) print("✅ PhoBERT loaded.") # ===================== FUNCTIONS ===================== def ocr_vintern(image): img = image.convert("RGB") max_size = (448, 448) img.thumbnail(max_size, Image.Resampling.LANCZOS) img = ImageOps.pad(img, max_size, color=(255, 255, 255)) img = ImageEnhance.Contrast(img).enhance(1.5) transform = T.Compose([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) pixel_values = transform(img).unsqueeze(0).to(vintern_model.device) prompt = """ Hãy đọc nội dung trong ảnh chụp màn hình tin nhắn và xuất ra kết quả **chỉ** gồm các tin nhắn. 📌 Quy tắc: 1. Mỗi ô chat = 1 tin nhắn. 2. Không giữ lại thời gian, tên người, emoji, icon hệ thống. 3. Chỉ có văn bản thuần. 4. Không thêm bình luận hoặc giải thích. 📋 Định dạng: Tin nhắn 1: ... Tin nhắn 2: ... Tin nhắn 3: ... """ response, _ = vintern_model.chat( tokenizer=vintern_tokenizer, pixel_values=pixel_values, question=prompt, generation_config=dict(max_new_tokens=1024, do_sample=False, num_beams=3), history=None, return_history=True ) messages = re.findall(r"Tin nhắn \d+: (.+?)(?=\nTin nhắn|\Z)", response, re.S) cleaned_messages = [re.sub(r"\s+", " ", msg.strip()) for msg in messages if msg.strip()] return cleaned_messages def predict_phobert(texts): results = [] for text in texts: encoded = phobert_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256) encoded = {k: v.to(phobert_model.device) for k, v in encoded.items()} with torch.no_grad(): logits = phobert_model(**encoded).logits probs = torch.softmax(logits, dim=1).squeeze() label = torch.argmax(probs).item() results.append({ "text": text, "prediction": "LỪA ĐẢO" if label == 1 else "BÌNH THƯỜNG", "confidence": f"{probs[label]*100:.2f}%" }) return results # ===================== GRADIO INTERFACE ===================== def detect(image, text): if image is not None: extracted_texts = ocr_vintern(image) if not extracted_texts: return "❌ Không đọc được nội dung từ ảnh" results = predict_phobert(extracted_texts) elif text.strip() != "": results = predict_phobert([text]) else: return "⚠️ Vui lòng nhập văn bản hoặc tải ảnh" output_str = "\n".join([f"{r['text']} → {r['prediction']} ({r['confidence']})" for r in results]) return output_str demo = gr.Interface( fn=detect, inputs=[ gr.Image(type="pil", label="Tải ảnh tin nhắn"), gr.Textbox(label="Hoặc nhập văn bản") ], outputs=gr.Textbox(label="Kết quả"), title="🛡️ Bộ phát hiện lừa đảo", description="Nhập văn bản hoặc tải ảnh chụp màn hình tin nhắn để kiểm tra." ) if __name__ == "__main__": demo.launch()