Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Create app.py
Browse files
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,186 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # app.py
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            import re
         
     | 
| 4 | 
         
            +
            import io
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            from typing import List, Optional
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
         
     | 
| 9 | 
         
            +
            from PIL import Image, ImageEnhance, ImageOps
         
     | 
| 10 | 
         
            +
            import torchvision.transforms as T
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import gradio as gr
         
     | 
| 13 | 
         
            +
            from fastapi import Request
         
     | 
| 14 | 
         
            +
            from starlette.responses import JSONResponse
         
     | 
| 15 | 
         
            +
            from fastapi.middleware.cors import CORSMiddleware
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            # ========== LOAD MODELS (once) ==========
         
     | 
| 18 | 
         
            +
            print("Loading VinTern model...")
         
     | 
| 19 | 
         
            +
            vintern_model = AutoModel.from_pretrained(
         
     | 
| 20 | 
         
            +
                "5CD-AI/Vintern-1B-v3_5",
         
     | 
| 21 | 
         
            +
                trust_remote_code=True,
         
     | 
| 22 | 
         
            +
                torch_dtype="auto",
         
     | 
| 23 | 
         
            +
                device_map="auto",
         
     | 
| 24 | 
         
            +
                low_cpu_mem_usage=True
         
     | 
| 25 | 
         
            +
            ).eval()
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            vintern_tokenizer = AutoTokenizer.from_pretrained(
         
     | 
| 28 | 
         
            +
                "5CD-AI/Vintern-1B-v3_5",
         
     | 
| 29 | 
         
            +
                trust_remote_code=True
         
     | 
| 30 | 
         
            +
            )
         
     | 
| 31 | 
         
            +
            print("VinTern loaded!")
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            print("Loading PhoBERT model...")
         
     | 
| 34 | 
         
            +
            phobert_path = "DuyKien016/phobert-scam-detector"
         
     | 
| 35 | 
         
            +
            phobert_tokenizer = AutoTokenizer.from_pretrained(phobert_path, use_fast=False)
         
     | 
| 36 | 
         
            +
            phobert_model = AutoModelForSequenceClassification.from_pretrained(phobert_path).eval()
         
     | 
| 37 | 
         
            +
            phobert_model = phobert_model.to("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 38 | 
         
            +
            print("PhoBERT loaded!")
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            # ========== UTILS ==========
         
     | 
| 42 | 
         
            +
            def process_image_pil(pil_img: Image.Image):
         
     | 
| 43 | 
         
            +
                img = pil_img.convert("RGB")
         
     | 
| 44 | 
         
            +
                img = ImageEnhance.Contrast(img).enhance(1.8)
         
     | 
| 45 | 
         
            +
                img = ImageEnhance.Sharpness(img).enhance(1.3)
         
     | 
| 46 | 
         
            +
                max_size = (448, 448)
         
     | 
| 47 | 
         
            +
                img.thumbnail(max_size, Image.Resampling.LANCZOS)
         
     | 
| 48 | 
         
            +
                img = ImageOps.pad(img, max_size, color=(245, 245, 245))
         
     | 
| 49 | 
         
            +
                transform = T.Compose([
         
     | 
| 50 | 
         
            +
                    T.ToTensor(),
         
     | 
| 51 | 
         
            +
                    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         
     | 
| 52 | 
         
            +
                ])
         
     | 
| 53 | 
         
            +
                pixel_values = transform(img).unsqueeze(0).to(vintern_model.device)
         
     | 
| 54 | 
         
            +
                return pixel_values
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            def extract_messages(pixel_values) -> List[str]:
         
     | 
| 58 | 
         
            +
                prompt = """<image>
         
     | 
| 59 | 
         
            +
            Đọc từng tin nhắn trong ảnh và xuất ra định dạng:
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            Tin nhắn 1: [nội dung]
         
     | 
| 62 | 
         
            +
            Tin nhắn 2: [nội dung]
         
     | 
| 63 | 
         
            +
            Tin nhắn 3: [nội dung]
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            Quy tắc:
         
     | 
| 66 | 
         
            +
            - Mỗi ô chat = 1 tin nhắn
         
     | 
| 67 | 
         
            +
            - Chỉ lấy nội dung văn bản
         
     | 
| 68 | 
         
            +
            - Bỏ thời gian, tên người, emoji
         
     | 
| 69 | 
         
            +
            - Đọc từ trên xuống dưới
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            Bắt đầu:"""
         
     | 
| 72 | 
         
            +
                response, *_ = vintern_model.chat(
         
     | 
| 73 | 
         
            +
                    tokenizer=vintern_tokenizer,
         
     | 
| 74 | 
         
            +
                    pixel_values=pixel_values,
         
     | 
| 75 | 
         
            +
                    question=prompt,
         
     | 
| 76 | 
         
            +
                    generation_config=dict(max_new_tokens=1024, do_sample=False, num_beams=1, early_stopping=True),
         
     | 
| 77 | 
         
            +
                    history=None,
         
     | 
| 78 | 
         
            +
                    return_history=True
         
     | 
| 79 | 
         
            +
                )
         
     | 
| 80 | 
         
            +
                messages = re.findall(r"Tin nhắn \d+: (.+?)(?=\nTin nhắn|\Z)", response, re.S)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def quick_clean(msg):
         
     | 
| 83 | 
         
            +
                    msg = re.sub(r"\s+", " ", msg.strip())
         
     | 
| 84 | 
         
            +
                    msg = re.sub(r'^\d+[\.\)\-\s]+', '', msg)
         
     | 
| 85 | 
         
            +
                    return msg.strip()
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                return [quick_clean(msg) for msg in messages if msg.strip()]
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            def predict_phobert(texts: List[str]):
         
     | 
| 91 | 
         
            +
                results = []
         
     | 
| 92 | 
         
            +
                for text in texts:
         
     | 
| 93 | 
         
            +
                    encoded = phobert_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
         
     | 
| 94 | 
         
            +
                    encoded = {k: v.to(phobert_model.device) for k, v in encoded.items()}
         
     | 
| 95 | 
         
            +
                    with torch.no_grad():
         
     | 
| 96 | 
         
            +
                        logits = phobert_model(**encoded).logits
         
     | 
| 97 | 
         
            +
                        probs = torch.softmax(logits, dim=1).squeeze()
         
     | 
| 98 | 
         
            +
                        label = torch.argmax(probs).item()
         
     | 
| 99 | 
         
            +
                    results.append({
         
     | 
| 100 | 
         
            +
                        "text": text,
         
     | 
| 101 | 
         
            +
                        "prediction": "LỪA ĐẢO" if label == 1 else "BÌNH THƯỜNG",
         
     | 
| 102 | 
         
            +
                        "confidence": f"{probs[label]*100:.2f}%"
         
     | 
| 103 | 
         
            +
                    })
         
     | 
| 104 | 
         
            +
                return results
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            # ========== CORE HANDLER ==========
         
     | 
| 108 | 
         
            +
            def handle_inference(text: Optional[str], pil_image: Optional[Image.Image]):
         
     | 
| 109 | 
         
            +
                if (not text) and (pil_image is None):
         
     | 
| 110 | 
         
            +
                    return {"error": "No valid input provided"}, 400
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                if pil_image is not None:
         
     | 
| 113 | 
         
            +
                    pixel_values = process_image_pil(pil_image)
         
     | 
| 114 | 
         
            +
                    messages = extract_messages(pixel_values)
         
     | 
| 115 | 
         
            +
                    phobert_results = predict_phobert(messages)
         
     | 
| 116 | 
         
            +
                    return {"messages": phobert_results}, 200
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                # text only
         
     | 
| 119 | 
         
            +
                texts = [text] if isinstance(text, str) else text
         
     | 
| 120 | 
         
            +
                if isinstance(texts, list):
         
     | 
| 121 | 
         
            +
                    phobert_results = predict_phobert(texts)
         
     | 
| 122 | 
         
            +
                    return {"messages": phobert_results}, 200
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                return {"error": "Invalid input format"}, 400
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
            # ========== GRADIO APP (UI + API) ==========
         
     | 
| 128 | 
         
            +
            demo = gr.Blocks()
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
            with demo:
         
     | 
| 131 | 
         
            +
                gr.Markdown("## dunkingscam backend (HF Space) — test nhanh")
         
     | 
| 132 | 
         
            +
                with gr.Row():
         
     | 
| 133 | 
         
            +
                    txt = gr.Textbox(label="Text (tùy chọn)")
         
     | 
| 134 | 
         
            +
                    img = gr.Image(label="Ảnh chat (tùy chọn)", type="pil")
         
     | 
| 135 | 
         
            +
                out = gr.JSON(label="Kết quả")
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                def ui_process(text, image):
         
     | 
| 138 | 
         
            +
                    data, _ = handle_inference(text, image)
         
     | 
| 139 | 
         
            +
                    return data
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                btn = gr.Button("Process")
         
     | 
| 142 | 
         
            +
                btn.click(fn=ui_process, inputs=[txt, img], outputs=out)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
            # Lấy FastAPI app bên trong Gradio để thêm CORS + custom route
         
     | 
| 145 | 
         
            +
            app = demo.server_app
         
     | 
| 146 | 
         
            +
            app.add_middleware(
         
     | 
| 147 | 
         
            +
                CORSMiddleware,
         
     | 
| 148 | 
         
            +
                allow_origins=["*"],  # cần mở cho Replit
         
     | 
| 149 | 
         
            +
                allow_credentials=True,
         
     | 
| 150 | 
         
            +
                allow_methods=["*"],
         
     | 
| 151 | 
         
            +
                allow_headers=["*"],
         
     | 
| 152 | 
         
            +
            )
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
            # Custom REST endpoint /process (FormData hoặc JSON)
         
     | 
| 155 | 
         
            +
            @demo.add_server_route("/process", methods=["POST"])
         
     | 
| 156 | 
         
            +
            async def process_endpoint(request: Request):
         
     | 
| 157 | 
         
            +
                try:
         
     | 
| 158 | 
         
            +
                    ct = request.headers.get("content-type", "")
         
     | 
| 159 | 
         
            +
                    if "multipart/form-data" in ct:
         
     | 
| 160 | 
         
            +
                        form = await request.form()
         
     | 
| 161 | 
         
            +
                        text = form.get("text")
         
     | 
| 162 | 
         
            +
                        file = form.get("image")  # UploadFile hoặc None
         
     | 
| 163 | 
         
            +
                        pil_image = None
         
     | 
| 164 | 
         
            +
                        if file is not None:
         
     | 
| 165 | 
         
            +
                            # đọc bytes -> PIL
         
     | 
| 166 | 
         
            +
                            content = await file.read()
         
     | 
| 167 | 
         
            +
                            pil_image = Image.open(io.BytesIO(content))
         
     | 
| 168 | 
         
            +
                        data, status = handle_inference(text, pil_image)
         
     | 
| 169 | 
         
            +
                    elif "application/json" in ct:
         
     | 
| 170 | 
         
            +
                        payload = await request.json()
         
     | 
| 171 | 
         
            +
                        text = payload.get("text")
         
     | 
| 172 | 
         
            +
                        data, status = handle_inference(text, None)
         
     | 
| 173 | 
         
            +
                    else:
         
     | 
| 174 | 
         
            +
                        data, status = {"error": "Unsupported Content-Type"}, 400
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    return JSONResponse(
         
     | 
| 177 | 
         
            +
                        content=data,
         
     | 
| 178 | 
         
            +
                        status_code=status,
         
     | 
| 179 | 
         
            +
                        headers={"Access-Control-Allow-Origin": "*"}
         
     | 
| 180 | 
         
            +
                    )
         
     | 
| 181 | 
         
            +
                except Exception as e:
         
     | 
| 182 | 
         
            +
                    return JSONResponse(
         
     | 
| 183 | 
         
            +
                        content={"error": f"Server error: {str(e)}"},
         
     | 
| 184 | 
         
            +
                        status_code=500,
         
     | 
| 185 | 
         
            +
                        headers={"Access-Control-Allow-Origin": "*"}
         
     | 
| 186 | 
         
            +
                    )
         
     |