Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import julius | |
| import pesq | |
| import torch | |
| import torchmetrics | |
| class PesqMetric(torchmetrics.Metric): | |
| """Metric for Perceptual Evaluation of Speech Quality. | |
| (https://doi.org/10.5281/zenodo.6549559) | |
| """ | |
| sum_pesq: torch.Tensor | |
| total: torch.Tensor | |
| def __init__(self, sample_rate: int): | |
| super().__init__() | |
| self.sr = sample_rate | |
| self.add_state("sum_pesq", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
| self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") | |
| def update(self, preds: torch.Tensor, targets: torch.Tensor): | |
| if self.sr != 16000: | |
| preds = julius.resample_frac(preds, self.sr, 16000) | |
| targets = julius.resample_frac(targets, self.sr, 16000) | |
| for ii in range(preds.size(0)): | |
| try: | |
| self.sum_pesq += pesq.pesq( | |
| 16000, targets[ii, 0].detach().cpu().numpy(), preds[ii, 0].detach().cpu().numpy() | |
| ) | |
| self.total += 1 | |
| except ( | |
| pesq.NoUtterancesError | |
| ): # this error can append when the sample don't contain speech | |
| pass | |
| def compute(self) -> torch.Tensor: | |
| return ( | |
| self.sum_pesq / self.total | |
| if (self.total != 0).item() | |
| else torch.tensor(0.0) | |
| ) | |
