Spaces:
Running
Running
| import torch | |
| import torchaudio | |
| import librosa | |
| import numpy as np | |
| from transformers import pipeline | |
| from typing import Union, Tuple, List | |
| class MusicGenreClassifier: | |
| def __init__(self): | |
| # Initialize both audio and text classification pipelines | |
| self.text_classifier = pipeline( | |
| "zero-shot-classification", | |
| model="facebook/bart-large-mnli" | |
| ) | |
| # For audio classification, we'll use a different pre-trained model | |
| self.audio_classifier = pipeline( | |
| "audio-classification", | |
| model="superb/wav2vec2-base-superb-gc" | |
| ) | |
| self.genres = [ | |
| "rock", "pop", "hip hop", "country", "jazz", | |
| "classical", "electronic", "blues", "reggae", "metal" | |
| ] | |
| def process_audio(self, audio_path: str) -> torch.Tensor: | |
| """Process audio file to match model requirements.""" | |
| try: | |
| # Load audio using librosa (handles more formats) | |
| waveform, sample_rate = librosa.load(audio_path, sr=16000) | |
| # Convert to torch tensor and ensure proper shape | |
| waveform = torch.from_numpy(waveform).float() | |
| if len(waveform.shape) == 1: | |
| waveform = waveform.unsqueeze(0) | |
| return waveform | |
| except Exception as e: | |
| raise ValueError(f"Error processing audio file: {str(e)}") | |
| def classify_audio(self, audio_path: str) -> Tuple[str, float]: | |
| """Classify genre from audio file.""" | |
| try: | |
| waveform = self.process_audio(audio_path) | |
| predictions = self.audio_classifier(waveform, top_k=1) | |
| # Get the top prediction | |
| if isinstance(predictions, list): | |
| predictions = predictions[0] | |
| top_pred = max(predictions, key=lambda x: x['score']) | |
| return top_pred['label'], top_pred['score'] | |
| except Exception as e: | |
| raise ValueError(f"Audio classification failed: {str(e)}") | |
| def classify_text(self, lyrics: str) -> Tuple[str, float]: | |
| """Classify genre from lyrics text.""" | |
| try: | |
| # Prepare the hypothesis template for zero-shot classification | |
| hypothesis_template = "This text contains {} music lyrics." | |
| result = self.text_classifier( | |
| lyrics, | |
| candidate_labels=self.genres, | |
| hypothesis_template=hypothesis_template | |
| ) | |
| return result['labels'][0], result['scores'][0] | |
| except Exception as e: | |
| raise ValueError(f"Text classification failed: {str(e)}") | |
| def predict(self, input_data: str, input_type: str = None) -> dict: | |
| """ | |
| Main prediction method that handles both audio and text inputs. | |
| Args: | |
| input_data: Path to audio file or lyrics text | |
| input_type: Optional, 'audio' or 'text'. If None, will try to auto-detect | |
| Returns: | |
| dict containing predicted genre and confidence score | |
| """ | |
| # Try to auto-detect input type if not specified | |
| if input_type is None: | |
| input_type = 'audio' if input_data.lower().endswith(('.mp3', '.wav', '.ogg', '.flac')) else 'text' | |
| try: | |
| if input_type == 'audio': | |
| genre, confidence = self.classify_audio(input_data) | |
| else: | |
| genre, confidence = self.classify_text(input_data) | |
| return { | |
| 'genre': genre, | |
| 'confidence': float(confidence), | |
| 'input_type': input_type | |
| } | |
| except Exception as e: | |
| raise ValueError(f"Prediction failed: {str(e)}") |