Spaces:
Running
Running
| import torch | |
| import torchaudio | |
| import librosa | |
| import numpy as np | |
| from transformers import pipeline | |
| from typing import Union, Tuple, List | |
| class MusicGenreClassifier: | |
| def __init__(self): | |
| # Initialize both audio and text classification pipelines | |
| self.text_classifier = pipeline( | |
| "zero-shot-classification", | |
| model="facebook/bart-large-mnli" | |
| ) | |
| # For audio classification, we'll use MIT's music classification model | |
| self.audio_classifier = pipeline( | |
| "audio-classification", | |
| model="mit/ast-finetuned-audioset-10-10-0.4593" | |
| ) | |
| # Define standard genres for classification | |
| self.genres = [ | |
| "rock", "pop", "hip hop", "country", "jazz", | |
| "classical", "electronic", "blues", "reggae", "metal" | |
| ] | |
| # Mapping from model output labels to our standard genres | |
| self.label_mapping = { | |
| "Music": "pop", # Default mapping | |
| "Rock music": "rock", | |
| "Pop music": "pop", | |
| "Hip hop music": "hip hop", | |
| "Country": "country", | |
| "Jazz": "jazz", | |
| "Classical music": "classical", | |
| "Electronic music": "electronic", | |
| "Blues": "blues", | |
| "Reggae": "reggae", | |
| "Heavy metal": "metal" | |
| } | |
| def process_audio(self, audio_path: str) -> torch.Tensor: | |
| """Process audio file to match model requirements.""" | |
| try: | |
| # Load audio using librosa (handles more formats) | |
| waveform, sample_rate = librosa.load(audio_path, sr=16000) | |
| # Convert to torch tensor and ensure proper shape | |
| waveform = torch.from_numpy(waveform).float() | |
| if len(waveform.shape) == 1: | |
| waveform = waveform.unsqueeze(0) | |
| return waveform | |
| except Exception as e: | |
| raise ValueError(f"Error processing audio file: {str(e)}") | |
| def map_label_to_genre(self, label: str) -> str: | |
| """Map model output label to standard genre.""" | |
| return self.label_mapping.get(label, "pop") # Default to pop if unknown | |
| def classify_audio(self, audio_path: str) -> Tuple[str, float]: | |
| """Classify genre from audio file.""" | |
| try: | |
| waveform = self.process_audio(audio_path) | |
| predictions = self.audio_classifier(waveform, top_k=3) | |
| # Process predictions | |
| if isinstance(predictions, list): | |
| predictions = predictions[0] | |
| # Find the highest scoring music-related prediction | |
| music_preds = [ | |
| (self.map_label_to_genre(p['label']), p['score']) | |
| for p in predictions | |
| if p['label'] in self.label_mapping | |
| ] | |
| if not music_preds: | |
| # If no music genres found, return default | |
| return "pop", 0.5 | |
| # Get the highest scoring genre | |
| genre, score = max(music_preds, key=lambda x: x[1]) | |
| return genre, score | |
| except Exception as e: | |
| raise ValueError(f"Audio classification failed: {str(e)}") | |
| def classify_text(self, lyrics: str) -> Tuple[str, float]: | |
| """Classify genre from lyrics text.""" | |
| try: | |
| # Prepare the hypothesis template for zero-shot classification | |
| hypothesis_template = "This text contains {} music lyrics." | |
| result = self.text_classifier( | |
| lyrics, | |
| candidate_labels=self.genres, | |
| hypothesis_template=hypothesis_template | |
| ) | |
| return result['labels'][0], result['scores'][0] | |
| except Exception as e: | |
| raise ValueError(f"Text classification failed: {str(e)}") | |
| def predict(self, input_data: str, input_type: str = None) -> dict: | |
| """ | |
| Main prediction method that handles both audio and text inputs. | |
| Args: | |
| input_data: Path to audio file or lyrics text | |
| input_type: Optional, 'audio' or 'text'. If None, will try to auto-detect | |
| Returns: | |
| dict containing predicted genre and confidence score | |
| """ | |
| # Try to auto-detect input type if not specified | |
| if input_type is None: | |
| input_type = 'audio' if input_data.lower().endswith(('.mp3', '.wav', '.ogg', '.flac')) else 'text' | |
| try: | |
| if input_type == 'audio': | |
| genre, confidence = self.classify_audio(input_data) | |
| else: | |
| genre, confidence = self.classify_text(input_data) | |
| return { | |
| 'genre': genre, | |
| 'confidence': float(confidence), | |
| 'input_type': input_type | |
| } | |
| except Exception as e: | |
| raise ValueError(f"Prediction failed: {str(e)}") |