Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import zipfile | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| import pandas as pd | |
| import json | |
| import os | |
| import traceback | |
| import gc | |
| # Import your modules | |
| from engine import compute_mapss_measures | |
| from models import get_model_config, cleanup_all_models | |
| from config import DEFAULT_ALPHA | |
| from utils import clear_gpu_memory | |
| def process_audio_files(zip_file, model_name, layer, alpha): | |
| """ | |
| Process uploaded ZIP file containing audio mixtures. | |
| Expected ZIP structure: | |
| - references/: Contains N reference audio files | |
| - outputs/: Contains N output audio files | |
| """ | |
| if zip_file is None: | |
| return None, "Please upload a ZIP file" | |
| # Create temporary directory for processing | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_path = Path(temp_dir) | |
| try: | |
| # Extract ZIP file | |
| extract_path = temp_path / "extracted" | |
| extract_path.mkdir(exist_ok=True) | |
| with zipfile.ZipFile(zip_file.name, 'r') as zip_ref: | |
| zip_ref.extractall(extract_path) | |
| # Find references and outputs directories | |
| refs_dir = None | |
| outs_dir = None | |
| # Check for standard structure | |
| for item in extract_path.iterdir(): | |
| if item.is_dir(): | |
| if item.name.lower() in ['references', 'refs', 'reference']: | |
| refs_dir = item | |
| elif item.name.lower() in ['outputs', 'outs', 'output', 'separated']: | |
| outs_dir = item | |
| # If not found at root, check one level deeper | |
| if refs_dir is None or outs_dir is None: | |
| for item in extract_path.iterdir(): | |
| if item.is_dir(): | |
| for subitem in item.iterdir(): | |
| if subitem.is_dir(): | |
| if subitem.name.lower() in ['references', 'refs', 'reference']: | |
| refs_dir = subitem | |
| elif subitem.name.lower() in ['outputs', 'outs', 'output', 'separated']: | |
| outs_dir = subitem | |
| if refs_dir is None or outs_dir is None: | |
| return None, "Could not find 'references' and 'outputs' directories in the ZIP file" | |
| # Get audio files | |
| ref_files = sorted([f for f in refs_dir.glob("*.wav")]) | |
| out_files = sorted([f for f in outs_dir.glob("*.wav")]) | |
| if len(ref_files) == 0: | |
| return None, "No reference WAV files found" | |
| if len(out_files) == 0: | |
| return None, "No output WAV files found" | |
| # Create manifest | |
| manifest = [{ | |
| "mixture_id": "uploaded_mixture", | |
| "references": [str(f) for f in ref_files], | |
| "systems": { | |
| "uploaded_system": [str(f) for f in out_files] | |
| } | |
| }] | |
| # Validate model and layer | |
| allowed_models = set(get_model_config(0).keys()) | |
| if model_name not in allowed_models: | |
| return None, f"Invalid model. Allowed: {', '.join(sorted(allowed_models))}" | |
| # Set default layer if needed | |
| if model_name == "raw": | |
| layer_final = 0 | |
| else: | |
| model_defaults = { | |
| "wavlm": 24, "wav2vec2": 24, "hubert": 24, | |
| "wavlm_base": 12, "wav2vec2_base": 12, "hubert_base": 12, | |
| "wav2vec2_xlsr": 24, "ast": 12 | |
| } | |
| layer_final = layer if layer is not None else model_defaults.get(model_name, 12) | |
| # Run experiment with compute_mapss_measures | |
| results_dir = compute_mapss_measures( | |
| models=[model_name], | |
| mixtures=manifest, | |
| layer=layer_final, | |
| alpha=alpha, | |
| verbose=True, | |
| max_gpus=1, # Limit to 1 GPU for HF Space | |
| add_ci=False # Disable CI for faster processing | |
| ) | |
| # Create output ZIP with results | |
| output_zip = temp_path / "results.zip" | |
| with zipfile.ZipFile(output_zip, 'w') as zipf: | |
| # Add all CSV files from results | |
| results_path = Path(results_dir) | |
| for csv_file in results_path.rglob("*.csv"): | |
| arcname = str(csv_file.relative_to(results_path.parent)) | |
| zipf.write(csv_file, arcname) | |
| # Add params.json | |
| params_file = results_path / "params.json" | |
| if params_file.exists(): | |
| zipf.write(params_file, str(params_file.relative_to(results_path.parent))) | |
| # Add manifest | |
| manifest_file = results_path / "manifest_canonical.json" | |
| if manifest_file.exists(): | |
| zipf.write(manifest_file, str(manifest_file.relative_to(results_path.parent))) | |
| # Read the ZIP file to return | |
| with open(output_zip, 'rb') as f: | |
| output_data = f.read() | |
| # Create a proper file object for Gradio | |
| output_file_path = temp_path / "download_results.zip" | |
| with open(output_file_path, 'wb') as f: | |
| f.write(output_data) | |
| return str(output_file_path), "Processing completed successfully!" | |
| except Exception as e: | |
| error_msg = f"Error processing files: {str(e)}\n{traceback.format_exc()}" | |
| return None, error_msg | |
| finally: | |
| # Ensure cleanup happens | |
| cleanup_all_models() | |
| clear_gpu_memory() | |
| gc.collect() | |
| # Create Gradio interface | |
| def create_interface(): | |
| with gr.Blocks(title="MAPSS - Multi-source Audio Perceptual Separation Scores") as demo: | |
| gr.Markdown(""" | |
| # MAPSS: Multi-source Audio Perceptual Separation Scores | |
| This tool evaluates audio source separation quality using Perceptual Similarity (PS) and Perceptual Matching (PM) metrics. | |
| ## How to use: | |
| 1. **Prepare your audio files**: Create a ZIP file with the following structure: | |
| ``` | |
| your_mixture.zip | |
| βββ references/ # Original clean sources | |
| β βββ speaker1.wav | |
| β βββ speaker2.wav | |
| β βββ ... | |
| βββ outputs/ # Separated outputs from your algorithm | |
| βββ separated1.wav | |
| βββ separated2.wav | |
| βββ ... | |
| ``` | |
| 2. **Upload the ZIP file** using the file uploader below | |
| 3. **Select model and parameters** | |
| 4. **Click "Process"** to run the evaluation | |
| 5. **Download the results** as a ZIP file containing CSV files with PS/PM scores | |
| ## Models available: | |
| - **raw**: Raw waveform features (no model) | |
| - **wavlm**: WavLM Large model (best overall performance) | |
| - **wav2vec2**: Wav2Vec2 Large model | |
| - **hubert**: HuBERT Large model | |
| - **wavlm_base**: WavLM Base model (faster, good performance) | |
| - **wav2vec2_base**: Wav2Vec2 Base model | |
| - **hubert_base**: HuBERT Base model | |
| - **wav2vec2_xlsr**: Wav2Vec2 XLSR-53 model (multilingual) | |
| - **ast**: Audio Spectrogram Transformer | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_input = gr.File( | |
| label="Upload ZIP file with audio mixtures", | |
| file_types=[".zip"], | |
| type="filepath" | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| choices=["raw", "wavlm", "wav2vec2", "hubert", | |
| "wavlm_base", "wav2vec2_base", "hubert_base", | |
| "wav2vec2_xlsr", "ast"], | |
| value="wav2vec2_base", | |
| label="Select embedding model" | |
| ) | |
| layer_slider = gr.Slider( | |
| minimum=0, | |
| maximum=24, | |
| step=1, | |
| value=12, | |
| label="Layer (leave at default for automatic selection)" | |
| ) | |
| alpha_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.1, | |
| value=DEFAULT_ALPHA, | |
| label="Diffusion maps alpha parameter" | |
| ) | |
| process_btn = gr.Button("Process Audio Files", variant="primary") | |
| with gr.Column(): | |
| output_file = gr.File( | |
| label="Download Results (ZIP)", | |
| type="filepath" | |
| ) | |
| status_text = gr.Textbox( | |
| label="Status", | |
| lines=3, | |
| max_lines=10 | |
| ) | |
| gr.Markdown(""" | |
| ## Output format: | |
| The results ZIP will contain: | |
| - `ps_scores_{model}.csv`: Perceptual Similarity scores for each speaker/source | |
| - `pm_scores_{model}.csv`: Perceptual Matching scores for each speaker/source | |
| - `params.json`: Experiment parameters | |
| - `manifest_canonical.json`: Processed file manifest | |
| ## Score interpretation: | |
| - **PS (Perceptual Similarity)**: 0-1 score, higher is better. Measures how well the separated output matches the reference compared to other sources. | |
| - **PM (Perceptual Matching)**: 0-1 score, higher is better. Measures robustness to audio distortions. | |
| ## Notes: | |
| - Processing may take several minutes depending on the audio length and model | |
| - Audio files are automatically resampled to 16kHz | |
| - The tool automatically matches outputs to references based on correlation | |
| - For best results, ensure equal number of reference and output files | |
| ## Citation: | |
| If you use this tool in your research, please cite our paper (details coming soon). | |
| """) | |
| # Set up the processing | |
| process_btn.click( | |
| fn=process_audio_files, | |
| inputs=[file_input, model_dropdown, layer_slider, alpha_slider], | |
| outputs=[output_file, status_text] | |
| ) | |
| # Add examples if you want | |
| gr.Examples( | |
| examples=[ | |
| # You can add example ZIP files here if you have them | |
| ], | |
| inputs=[file_input] | |
| ) | |
| return demo | |
| # Create and launch the app | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch() | |