Spaces:
Sleeping
Sleeping
| # ๐ฌ Multilingual Video Classification (Beautiful + Voice Icon) | |
| import os, json, base64 | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch, cv2, numpy as np | |
| from PIL import Image | |
| from gtts import gTTS | |
| from transformers import ( | |
| BlipProcessor, BlipForConditionalGeneration, | |
| AutoTokenizer, AutoModelForSequenceClassification, | |
| AutoModelForSeq2SeqLM | |
| ) | |
| # ---------- CONFIG ---------- | |
| MODEL_ID = "magedsar7an/caption-cls-en-small" # โ your HF model repo | |
| FRAMES_PER_VIDEO = 6 | |
| FRAME_SIZE = 384 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| SUPPORTED_LANGS = { | |
| "en":"English","ar":"Arabic","fr":"French","tr":"Turkish", | |
| "es":"Spanish","de":"German","hi":"Hindi","id":"Indonesian" | |
| } | |
| MARIAN_TO_EN = { | |
| "ar":"Helsinki-NLP/opus-mt-ar-en", | |
| "fr":"Helsinki-NLP/opus-mt-fr-en", | |
| "tr":"Helsinki-NLP/opus-mt-tr-en", | |
| "es":"Helsinki-NLP/opus-mt-es-en", | |
| "de":"Helsinki-NLP/opus-mt-de-en", | |
| "hi":"Helsinki-NLP/opus-mt-hi-en", | |
| "id":"Helsinki-NLP/opus-mt-id-en", | |
| } | |
| LABEL_TRANSLATIONS = { | |
| "ar": {"clap":"ุชุตููู","drink":"ูุดุฑุจ","hug":"ุนูุงู","kick_ball":"ุฑูู ุงููุฑุฉ", | |
| "kiss":"ูุจูุฉ","run":"ูุฌุฑู","sit":"ูุฌูุณ","wave":"ูููุญ"}, | |
| "tr": {"clap":"alkฤฑล","drink":"iรงmek","hug":"sarฤฑlmak","kick_ball":"topa tekme", | |
| "kiss":"รถpรผcรผk","run":"koลmak","sit":"oturmak","wave":"el sallamak"}, | |
| "fr": {"clap":"applaudir","drink":"boire","hug":"embrasser","kick_ball":"frapper le ballon", | |
| "kiss":"baiser","run":"courir","sit":"sโasseoir","wave":"saluer"}, | |
| "es": {"clap":"aplaudir","drink":"beber","hug":"abrazar","kick_ball":"patear la pelota", | |
| "kiss":"besar","run":"correr","sit":"sentarse","wave":"saludar"}, | |
| "de": {"clap":"klatschen","drink":"trinken","hug":"umarmen","kick_ball":"den Ball treten", | |
| "kiss":"kรผssen","run":"laufen","sit":"sitzen","wave":"winken"}, | |
| "hi": {"clap":"เคคเคพเคฒเฅ เคฌเคเคพเคจเคพ","drink":"เคชเฅเคจเคพ","hug":"เคเคฒเฅ เคฒเคเคพเคจเคพ","kick_ball":"เคเฅเคเคฆ เคเฅ เคฎเคพเคฐเคจเคพ", | |
| "kiss":"เคเฅเคฎเคจเคพ","run":"เคฆเฅเคกเคผเคจเคพ","sit":"เคฌเฅเค เคจเคพ","wave":"เคนเคพเคฅ เคนเคฟเคฒเคพเคจเคพ"}, | |
| "id": {"clap":"bertepuk tangan","drink":"minum","hug":"berpelukan","kick_ball":"menendang bola", | |
| "kiss":"cium","run":"berlari","sit":"duduk","wave":"melambaikan tangan"}, | |
| } | |
| # ---------- LOAD MODELS ---------- | |
| print("Loading BLIP captioner...") | |
| blip_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device).eval() | |
| print("Loading English classifier from HF Hub...") | |
| tok = AutoTokenizer.from_pretrained(MODEL_ID) | |
| cls = AutoModelForSequenceClassification.from_pretrained(MODEL_ID).to(device).eval() | |
| # id2label from model config (you embedded it during upload) | |
| cfg_map = getattr(cls.config, "id2label", None) | |
| if not cfg_map: | |
| raise RuntimeError("id2label not found in config.json; add it to your HF model.") | |
| # normalize keys to int | |
| id2label = {int(k): v for k, v in (cfg_map.items() if isinstance(cfg_map, dict) else enumerate(cfg_map))} | |
| print("โ Models loaded successfully!") | |
| # ---------- HELPERS ---------- | |
| def _resolve_video_path(video): | |
| if isinstance(video, str): | |
| return video if os.path.exists(video) else None | |
| if isinstance(video, dict): | |
| p = video.get("path") or video.get("name") | |
| return p if (isinstance(p, str) and os.path.exists(p)) else None | |
| name = getattr(video, "name", None) | |
| if isinstance(name, str) and os.path.exists(name): | |
| return name | |
| return None | |
| def extract_frames(video_path, k=6, size=384): | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return [] | |
| total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0 | |
| idxs = np.linspace(0, max(total - 1, 0), num=k, dtype=int) if total > 0 else np.linspace(0, 240, num=k, dtype=int) | |
| frames = [] | |
| for i in idxs: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, int(i)) | |
| ok, frame = cap.read() | |
| if not ok or frame is None: | |
| continue | |
| h, w = frame.shape[:2] | |
| if h <= 0 or w <= 0: | |
| continue | |
| if h < w: | |
| new_h = size; new_w = int(w * (size / h)) | |
| else: | |
| new_w = size; new_h = int(h * (size / w)) | |
| frame = cv2.resize(frame, (new_w, new_h)) | |
| frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) | |
| cap.release() | |
| return frames | |
| def blip_caption(img): | |
| inputs = blip_proc(images=img, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out = blip.generate(**inputs, max_new_tokens=30) | |
| return blip_proc.decode(out[0], skip_special_tokens=True).strip() | |
| def translate_to_en(texts, lang): | |
| if lang == "en": return texts | |
| model_name = MARIAN_TO_EN.get(lang) | |
| if not model_name: return texts | |
| try: | |
| tok_tr = AutoTokenizer.from_pretrained(model_name) | |
| mt = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device).eval() | |
| outs = [] | |
| for i in range(0, len(texts), 16): | |
| batch = texts[i:i + 16] | |
| enc = tok_tr(batch, return_tensors="pt", padding=True, truncation=True).to(device) | |
| with torch.no_grad(): | |
| gen = mt.generate(**enc, max_new_tokens=120) | |
| outs.extend(tok_tr.batch_decode(gen, skip_special_tokens=True)) | |
| return outs | |
| except Exception as e: | |
| print(f"โ ๏ธ Translation failed: {e}") | |
| return texts | |
| def classify(texts): | |
| enc = tok(texts, return_tensors="pt", padding=True, truncation=True).to(device) | |
| with torch.no_grad(): | |
| logits = cls(**enc).logits | |
| probs = torch.softmax(logits, dim=-1).cpu().numpy() | |
| return probs | |
| # ---------- MAIN FN ---------- | |
| def classify_video(video, lang): | |
| try: | |
| if not video: | |
| return "<div style='color:orange;'>โ ๏ธ Please upload a video first.</div>" | |
| video_path = _resolve_video_path(video) | |
| if not video_path: | |
| return "<div style='color:red;'>โ Could not find uploaded video path from Gradio input.</div>" | |
| frames = extract_frames(video_path, FRAMES_PER_VIDEO, FRAME_SIZE) | |
| if not frames: | |
| return "<div style='color:red;'>โ Could not extract frames. OpenCV could not decode the video.</div>" | |
| captions = [blip_caption(f) for f in frames] | |
| en_caps = translate_to_en(captions, lang) | |
| probs = classify(en_caps) | |
| pred = id2label[int(np.argmax(probs.mean(axis=0)))] | |
| localized = LABEL_TRANSLATIONS.get(lang, {}).get(pred, pred) | |
| # ๐ TTS (fail-soft if blocked) | |
| audio_b64 = "" | |
| try: | |
| tts = gTTS(localized, lang=lang if lang in SUPPORTED_LANGS else "en") | |
| audio_path = "pred_voice.mp3" | |
| tts.save(audio_path) | |
| with open(audio_path, "rb") as f: | |
| audio_b64 = base64.b64encode(f.read()).decode() | |
| except Exception as e: | |
| print(f"โ ๏ธ TTS failed: {e}") | |
| # ๐จ Card | |
| lang_name = SUPPORTED_LANGS.get(lang, "Unknown") | |
| btn = f"<button onclick=\"new Audio('data:audio/mp3;base64,{audio_b64}').play()\" style='background:#00b4d8;color:white;border:none;border-radius:50%;width:70px;height:70px;cursor:pointer;font-size:1.8em;box-shadow:0 2px 10px rgba(0,180,216,0.5);'>๐</button>" if audio_b64 else "" | |
| html = f""" | |
| <div style='background: linear-gradient(135deg,#141e30,#243b55);border-radius:16px;padding:35px;color:white;text-align:center;font-family:"Poppins",sans-serif;box-shadow:0 4px 20px rgba(0,0,0,0.3);'> | |
| <h2 style='color:#00b4d8;font-weight:600;margin-bottom:10px;'>๐ฌ Action Detected</h2> | |
| <h1 style='font-size:2.5em;margin:12px 0;'>{localized}</h1> | |
| {btn} | |
| <p style='opacity:0.8;margin-top:14px;font-size:1.1em;'>({lang_name})</p> | |
| </div> | |
| """ | |
| return html | |
| except Exception as e: | |
| import traceback; traceback.print_exc() | |
| return f"<div style='color:red;font-weight:bold;'>โ Error:<br>{e}</div>" | |
| # ---------- GRADIO UI ---------- | |
| custom_css = """ | |
| .gradio-container { | |
| background: linear-gradient(135deg,#0f2027,#203a43,#2c5364); | |
| color: white; | |
| } | |
| h1,h2,h3,label,p,.description {color: white !important;} | |
| footer {display:none !important;} | |
| """ | |
| title = "๐ฌ Multilingual Video Classification (Beautiful + Voice Icon)" | |
| description = """ | |
| ๐Click the ๐ icon to **hear the word pronounced** in that language. | |
| """ | |
| iface = gr.Interface( | |
| fn=classify_video, | |
| inputs=[ | |
| gr.Video(label="๐ฅ Upload Video", sources=["upload"], format="mp4"), | |
| gr.Radio(choices=list(SUPPORTED_LANGS.keys()), value="en", label="๐ Choose Language"), | |
| ], | |
| outputs=gr.HTML(label="โจ Prediction Result"), | |
| title=title, | |
| description=description, | |
| theme="gradio/soft", | |
| css=custom_css, | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) | |