Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import random | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from dia.model import Dia | |
| def set_seed(seed: int): | |
| """Sets the random seed for reproducibility.""" | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| # Ensure deterministic behavior for cuDNN (if used) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Generate audio using the Dia model.") | |
| parser.add_argument("text", type=str, help="Input text for speech generation.") | |
| parser.add_argument( | |
| "--output", type=str, required=True, help="Path to save the generated audio file (e.g., output.wav)." | |
| ) | |
| parser.add_argument( | |
| "--repo-id", | |
| type=str, | |
| default="nari-labs/Dia-1.6B", | |
| help="Hugging Face repository ID (e.g., nari-labs/Dia-1.6B).", | |
| ) | |
| parser.add_argument( | |
| "--local-paths", action="store_true", help="Load model from local config and checkpoint files." | |
| ) | |
| parser.add_argument( | |
| "--config", type=str, help="Path to local config.json file (required if --local-paths is set)." | |
| ) | |
| parser.add_argument( | |
| "--checkpoint", type=str, help="Path to local model checkpoint .pth file (required if --local-paths is set)." | |
| ) | |
| parser.add_argument( | |
| "--audio-prompt", type=str, default=None, help="Path to an optional audio prompt WAV file for voice cloning." | |
| ) | |
| gen_group = parser.add_argument_group("Generation Parameters") | |
| gen_group.add_argument( | |
| "--max-tokens", | |
| type=int, | |
| default=None, | |
| help="Maximum number of audio tokens to generate (defaults to config value).", | |
| ) | |
| gen_group.add_argument( | |
| "--cfg-scale", type=float, default=3.0, help="Classifier-Free Guidance scale (default: 3.0)." | |
| ) | |
| gen_group.add_argument( | |
| "--temperature", type=float, default=1.3, help="Sampling temperature (higher is more random, default: 0.7)." | |
| ) | |
| gen_group.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling probability (default: 0.95).") | |
| infra_group = parser.add_argument_group("Infrastructure") | |
| infra_group.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.") | |
| infra_group.add_argument( | |
| "--device", | |
| type=str, | |
| default="cuda" if torch.cuda.is_available() else "cpu", | |
| help="Device to run inference on (e.g., 'cuda', 'cpu', default: auto).", | |
| ) | |
| args = parser.parse_args() | |
| # Validation for local paths | |
| if args.local_paths: | |
| if not args.config: | |
| parser.error("--config is required when --local-paths is set.") | |
| if not args.checkpoint: | |
| parser.error("--checkpoint is required when --local-paths is set.") | |
| if not os.path.exists(args.config): | |
| parser.error(f"Config file not found: {args.config}") | |
| if not os.path.exists(args.checkpoint): | |
| parser.error(f"Checkpoint file not found: {args.checkpoint}") | |
| # Set seed if provided | |
| if args.seed is not None: | |
| set_seed(args.seed) | |
| print(f"Using random seed: {args.seed}") | |
| # Determine device | |
| device = torch.device(args.device) | |
| print(f"Using device: {device}") | |
| # Load model | |
| print("Loading model...") | |
| if args.local_paths: | |
| print(f"Loading from local paths: config='{args.config}', checkpoint='{args.checkpoint}'") | |
| try: | |
| model = Dia.from_local(args.config, args.checkpoint, device=device) | |
| except Exception as e: | |
| print(f"Error loading local model: {e}") | |
| exit(1) | |
| else: | |
| print(f"Loading from Hugging Face Hub: repo_id='{args.repo_id}'") | |
| try: | |
| model = Dia.from_pretrained(args.repo_id, device=device) | |
| except Exception as e: | |
| print(f"Error loading model from Hub: {e}") | |
| exit(1) | |
| print("Model loaded.") | |
| # Generate audio | |
| print("Generating audio...") | |
| try: | |
| sample_rate = 44100 # Default assumption | |
| output_audio = model.generate( | |
| text=args.text, | |
| audio_prompt=args.audio_prompt, | |
| max_tokens=args.max_tokens, | |
| cfg_scale=args.cfg_scale, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| ) | |
| print("Audio generation complete.") | |
| print(f"Saving audio to {args.output}...") | |
| os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) | |
| sf.write(args.output, output_audio, sample_rate) | |
| print(f"Audio successfully saved to {args.output}") | |
| except Exception as e: | |
| print(f"Error during audio generation or saving: {e}") | |
| exit(1) | |
| if __name__ == "__main__": | |
| main() | |