Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	File size: 3,560 Bytes
			
			| 9d0d223 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from audiocraft.models.loaders import load_audioseal_models
class WMModel(ABC, nn.Module):
    """
    A wrapper interface to different watermarking models for
    training or evaluation purporses
    """
    @abstractmethod
    def get_watermark(
        self,
        x: torch.Tensor,
        message: tp.Optional[torch.Tensor] = None,
        sample_rate: int = 16_000,
    ) -> torch.Tensor:
        """Get the watermark from an audio tensor and a message.
        If the input message is None, a random message of
        n bits {0,1} will be generated
        """
    @abstractmethod
    def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
        """Detect the watermarks from the audio signal
        Args:
            x: Audio signal, size batch x frames
        Returns:
            tensor of size (B, 2+n, frames) where:
            Detection results of shape (B, 2, frames)
            Message decoding results of shape (B, n, frames)
        """
class AudioSeal(WMModel):
    """Wrap Audioseal (https://github.com/facebookresearch/audioseal) for the
    training and evaluation. The generator and detector are jointly trained
    """
    def __init__(
        self,
        generator: nn.Module,
        detector: nn.Module,
        nbits: int = 0,
    ):
        super().__init__()
        self.generator = generator  # type: ignore
        self.detector = detector  # type: ignore
        # Allow to re-train an n-bit model with new 0-bit message
        self.nbits = nbits if nbits else self.generator.msg_processor.nbits
    def get_watermark(
        self,
        x: torch.Tensor,
        message: tp.Optional[torch.Tensor] = None,
        sample_rate: int = 16_000,
    ) -> torch.Tensor:
        return self.generator.get_watermark(x, message=message, sample_rate=sample_rate)
    def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
        """
        Detect the watermarks from the audio signal.  The first two units of the output
        are used for detection, the rest is used to decode the message. If the audio is
        not watermarked, the message will be random.
        Args:
            x: Audio signal, size batch x frames
        Returns
            torch.Tensor: Detection + decoding results of shape (B, 2+nbits, T).
        """
        # Getting the direct decoded message from the detector
        result = self.detector.detector(x)  # b x 2+nbits
        # hardcode softmax on 2 first units used for detection
        result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1)
        return result
    def forward(  # generator
        self,
        x: torch.Tensor,
        message: tp.Optional[torch.Tensor] = None,
        sample_rate: int = 16_000,
        alpha: float = 1.0,
    ) -> torch.Tensor:
        """Apply the watermarking to the audio signal x with a tune-down ratio (default 1.0)"""
        wm = self.get_watermark(x, message)
        return x + alpha * wm
    @staticmethod
    def get_pretrained(name="base", device=None) -> WMModel:
        if device is None:
            if torch.cuda.device_count():
                device = "cuda"
            else:
                device = "cpu"
        return load_audioseal_models("facebook/audioseal", filename=name, device=device)
 | 
