Spaces:
Runtime error
Runtime error
| # app.py | |
| import os | |
| import re | |
| import io | |
| import torch | |
| from typing import List, Optional | |
| from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification | |
| from PIL import Image, ImageEnhance, ImageOps | |
| import torchvision.transforms as T | |
| import gradio as gr | |
| from fastapi import Request | |
| from starlette.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| # ========== LOAD MODELS (once) ========== | |
| print("Loading VinTern model...") | |
| vintern_model = AutoModel.from_pretrained( | |
| "5CD-AI/Vintern-1B-v3_5", | |
| trust_remote_code=True, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| low_cpu_mem_usage=True | |
| ).eval() | |
| vintern_tokenizer = AutoTokenizer.from_pretrained( | |
| "5CD-AI/Vintern-1B-v3_5", | |
| trust_remote_code=True | |
| ) | |
| print("VinTern loaded!") | |
| print("Loading PhoBERT model...") | |
| phobert_path = "DuyKien016/phobert-scam-detector" | |
| phobert_tokenizer = AutoTokenizer.from_pretrained(phobert_path, use_fast=False) | |
| phobert_model = AutoModelForSequenceClassification.from_pretrained(phobert_path).eval() | |
| phobert_model = phobert_model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| print("PhoBERT loaded!") | |
| # ========== UTILS ========== | |
| def process_image_pil(pil_img: Image.Image): | |
| img = pil_img.convert("RGB") | |
| img = ImageEnhance.Contrast(img).enhance(1.8) | |
| img = ImageEnhance.Sharpness(img).enhance(1.3) | |
| max_size = (448, 448) | |
| img.thumbnail(max_size, Image.Resampling.LANCZOS) | |
| img = ImageOps.pad(img, max_size, color=(245, 245, 245)) | |
| transform = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| pixel_values = transform(img).unsqueeze(0).to(vintern_model.device) | |
| return pixel_values | |
| def extract_messages(pixel_values) -> List[str]: | |
| prompt = """<image> | |
| Đọc từng tin nhắn trong ảnh và xuất ra định dạng: | |
| Tin nhắn 1: [nội dung] | |
| Tin nhắn 2: [nội dung] | |
| Tin nhắn 3: [nội dung] | |
| Quy tắc: | |
| - Mỗi ô chat = 1 tin nhắn | |
| - Chỉ lấy nội dung văn bản | |
| - Bỏ thời gian, tên người, emoji | |
| - Đọc từ trên xuống dưới | |
| Bắt đầu:""" | |
| response, *_ = vintern_model.chat( | |
| tokenizer=vintern_tokenizer, | |
| pixel_values=pixel_values, | |
| question=prompt, | |
| generation_config=dict(max_new_tokens=1024, do_sample=False, num_beams=1, early_stopping=True), | |
| history=None, | |
| return_history=True | |
| ) | |
| messages = re.findall(r"Tin nhắn \d+: (.+?)(?=\nTin nhắn|\Z)", response, re.S) | |
| def quick_clean(msg): | |
| msg = re.sub(r"\s+", " ", msg.strip()) | |
| msg = re.sub(r'^\d+[\.\)\-\s]+', '', msg) | |
| return msg.strip() | |
| return [quick_clean(msg) for msg in messages if msg.strip()] | |
| def predict_phobert(texts: List[str]): | |
| results = [] | |
| for text in texts: | |
| encoded = phobert_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256) | |
| encoded = {k: v.to(phobert_model.device) for k, v in encoded.items()} | |
| with torch.no_grad(): | |
| logits = phobert_model(**encoded).logits | |
| probs = torch.softmax(logits, dim=1).squeeze() | |
| label = torch.argmax(probs).item() | |
| results.append({ | |
| "text": text, | |
| "prediction": "LỪA ĐẢO" if label == 1 else "BÌNH THƯỜNG", | |
| "confidence": f"{probs[label]*100:.2f}%" | |
| }) | |
| return results | |
| # ========== CORE HANDLER ========== | |
| def handle_inference(text: Optional[str], pil_image: Optional[Image.Image]): | |
| if (not text) and (pil_image is None): | |
| return {"error": "No valid input provided"}, 400 | |
| if pil_image is not None: | |
| pixel_values = process_image_pil(pil_image) | |
| messages = extract_messages(pixel_values) | |
| phobert_results = predict_phobert(messages) | |
| return {"messages": phobert_results}, 200 | |
| # text only | |
| texts = [text] if isinstance(text, str) else text | |
| if isinstance(texts, list): | |
| phobert_results = predict_phobert(texts) | |
| return {"messages": phobert_results}, 200 | |
| return {"error": "Invalid input format"}, 400 | |
| # ========== GRADIO APP (UI + API) ========== | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("## dunkingscam backend (HF Space) — test nhanh") | |
| with gr.Row(): | |
| txt = gr.Textbox(label="Text (tùy chọn)") | |
| img = gr.Image(label="Ảnh chat (tùy chọn)", type="pil") | |
| out = gr.JSON(label="Kết quả") | |
| def ui_process(text, image): | |
| data, _ = handle_inference(text, image) | |
| return data | |
| btn = gr.Button("Process") | |
| btn.click(fn=ui_process, inputs=[txt, img], outputs=out) | |
| # Lấy FastAPI app bên trong Gradio để thêm CORS + custom route | |
| app = demo.server_app | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # cần mở cho Replit | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Custom REST endpoint /process (FormData hoặc JSON) | |
| async def process_endpoint(request: Request): | |
| try: | |
| ct = request.headers.get("content-type", "") | |
| if "multipart/form-data" in ct: | |
| form = await request.form() | |
| text = form.get("text") | |
| file = form.get("image") # UploadFile hoặc None | |
| pil_image = None | |
| if file is not None: | |
| # đọc bytes -> PIL | |
| content = await file.read() | |
| pil_image = Image.open(io.BytesIO(content)) | |
| data, status = handle_inference(text, pil_image) | |
| elif "application/json" in ct: | |
| payload = await request.json() | |
| text = payload.get("text") | |
| data, status = handle_inference(text, None) | |
| else: | |
| data, status = {"error": "Unsupported Content-Type"}, 400 | |
| return JSONResponse( | |
| content=data, | |
| status_code=status, | |
| headers={"Access-Control-Allow-Origin": "*"} | |
| ) | |
| except Exception as e: | |
| return JSONResponse( | |
| content={"error": f"Server error: {str(e)}"}, | |
| status_code=500, | |
| headers={"Access-Control-Allow-Origin": "*"} | |
| ) | |