Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import typing as tp | |
| from abc import ABC, abstractmethod | |
| import torch | |
| import torch.nn as nn | |
| from audiocraft.models.loaders import load_audioseal_models | |
| class WMModel(ABC, nn.Module): | |
| """ | |
| A wrapper interface to different watermarking models for | |
| training or evaluation purporses | |
| """ | |
| def get_watermark( | |
| self, | |
| x: torch.Tensor, | |
| message: tp.Optional[torch.Tensor] = None, | |
| sample_rate: int = 16_000, | |
| ) -> torch.Tensor: | |
| """Get the watermark from an audio tensor and a message. | |
| If the input message is None, a random message of | |
| n bits {0,1} will be generated | |
| """ | |
| def detect_watermark(self, x: torch.Tensor) -> torch.Tensor: | |
| """Detect the watermarks from the audio signal | |
| Args: | |
| x: Audio signal, size batch x frames | |
| Returns: | |
| tensor of size (B, 2+n, frames) where: | |
| Detection results of shape (B, 2, frames) | |
| Message decoding results of shape (B, n, frames) | |
| """ | |
| class AudioSeal(WMModel): | |
| """Wrap Audioseal (https://github.com/facebookresearch/audioseal) for the | |
| training and evaluation. The generator and detector are jointly trained | |
| """ | |
| def __init__( | |
| self, | |
| generator: nn.Module, | |
| detector: nn.Module, | |
| nbits: int = 0, | |
| ): | |
| super().__init__() | |
| self.generator = generator # type: ignore | |
| self.detector = detector # type: ignore | |
| # Allow to re-train an n-bit model with new 0-bit message | |
| self.nbits = nbits if nbits else self.generator.msg_processor.nbits | |
| def get_watermark( | |
| self, | |
| x: torch.Tensor, | |
| message: tp.Optional[torch.Tensor] = None, | |
| sample_rate: int = 16_000, | |
| ) -> torch.Tensor: | |
| return self.generator.get_watermark(x, message=message, sample_rate=sample_rate) | |
| def detect_watermark(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Detect the watermarks from the audio signal. The first two units of the output | |
| are used for detection, the rest is used to decode the message. If the audio is | |
| not watermarked, the message will be random. | |
| Args: | |
| x: Audio signal, size batch x frames | |
| Returns | |
| torch.Tensor: Detection + decoding results of shape (B, 2+nbits, T). | |
| """ | |
| # Getting the direct decoded message from the detector | |
| result = self.detector.detector(x) # b x 2+nbits | |
| # hardcode softmax on 2 first units used for detection | |
| result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1) | |
| return result | |
| def forward( # generator | |
| self, | |
| x: torch.Tensor, | |
| message: tp.Optional[torch.Tensor] = None, | |
| sample_rate: int = 16_000, | |
| alpha: float = 1.0, | |
| ) -> torch.Tensor: | |
| """Apply the watermarking to the audio signal x with a tune-down ratio (default 1.0)""" | |
| wm = self.get_watermark(x, message) | |
| return x + alpha * wm | |
| def get_pretrained(name="base", device=None) -> WMModel: | |
| if device is None: | |
| if torch.cuda.device_count(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| return load_audioseal_models("facebook/audioseal", filename=name, device=device) | |