File size: 2,626 Bytes
07a50af
 
23545c8
d17dd7e
17f8795
07a50af
6f55dac
a6158c1
099b786
07a50af
d06382d
6f55dac
e6d73f0
7be8c29
 
07a50af
 
6f55dac
b04a244
6f55dac
 
 
 
 
 
87966ec
b04a244
87966ec
 
 
 
099b786
fa9bb6e
 
 
 
 
 
87966ec
 
 
6f55dac
 
87966ec
6f55dac
87966ec
 
 
 
 
 
23545c8
099b786
87966ec
23545c8
 
 
 
87966ec
 
23545c8
099b786
 
406c152
6f55dac
 
d4efcd9
 
 
 
 
 
 
099b786
23545c8
099b786
87966ec
07a50af
 
6f55dac
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
from transformers import pipeline
import os
import numpy as np
import spaces

print("=== Application Starting ===")

# define dialect mapping
dialect_mapping = {
    "MSA": "Modern Standard Arabic (MSA) - العربية الفصحى الحديثة",
    "Egyptian": "Egyptian Arabic -  اللهجة المصرية العامية", 
    "Gulf": "Peninsular Arabic - لهجة الجزيرة العربية",
    "Levantine": "Levantine Arabic - لهجة بلاد الشام",
    "Maghrebi": "Maghrebi Arabic - اللهجة المغاربية"
}

@spaces.GPU
def predict_dialect(audio):
    # load model inside the GPU function
    print("Loading model on GPU...")
    model_id = "badrex/mms-300m-arabic-dialect-identifier"
    classifier = pipeline("audio-classification", model=model_id)  # no device specified
    print("Model loaded successfully")
    
    if audio is None:
        return {"Error": 1.0}
    
    sr, audio_array = audio
    
    if len(audio_array.shape) > 1:
        audio_array = audio_array.mean(axis=1)

    if audio_array.dtype != np.float32:
        if audio_array.dtype == np.int16:
            audio_array = audio_array.astype(np.float32) / 32768.0
        else:
            audio_array = audio_array.astype(np.float32)
    
    print(f"Processing audio: sample rate={sr}, shape={audio_array.shape}")
    
    # classify the dialect
    predictions = classifier({"sampling_rate": sr, "raw": audio_array})
    
    # format results
    results = {}
    for pred in predictions:
        dialect_name = dialect_mapping.get(pred['label'], pred['label'])
        results[dialect_name] = float(pred['score'])
    
    return results

# prepare examples
examples = []
examples_dir = "examples"
if os.path.exists(examples_dir):
    for filename in os.listdir(examples_dir):
        if filename.endswith((".wav", ".mp3", ".ogg")):
            examples.append([os.path.join(examples_dir, filename)])
    print(f"Found {len(examples)} example files")

description = """
By <a href="https://badrex.github.io/">Badr Alabsi</a> with ❤️🤍💚

This demo uses a Transformer-based model for Spoken Arabic Dialect Identification. 
Upload an audio file or record yourself speaking to identify the Arabic dialect!
"""

demo = gr.Interface(
    fn=predict_dialect,
    inputs=gr.Audio(),
    outputs=gr.Label(num_top_classes=5, label="Predicted Dialect"),
    title="Tamyïz 🍉 Arabic Dialect Identification in Speech",
    description=description,
    examples=examples if examples else None,
    cache_examples=False,
    flagging_mode=None
)

print("=== Launching demo ===")
demo.launch()