Spaces:
Running
Running
| import torch | |
| import gradio as gr | |
| import soundfile as sf | |
| import numpy as np | |
| import random, os | |
| import spaces | |
| from consistencytta import ConsistencyTTA | |
| def seed_all(seed): | |
| """ Seed all random number generators. """ | |
| seed = int(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.cuda.random.manual_seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True | |
| device = torch.device( | |
| "cuda:0" if torch.cuda.is_available() else | |
| "mps" if torch.backends.mps.is_available() else "cpu" | |
| ) | |
| sr = 16000 | |
| # Build ConsistencyTTA model | |
| consistencytta = ConsistencyTTA().to(device) | |
| consistencytta.eval() | |
| consistencytta.requires_grad_(False) | |
| def generate(prompt: str, seed: str = '', cfg_weight: float = 4.): | |
| """ Generate audio from a given prompt. | |
| Args: | |
| prompt (str): Text prompt to generate audio from. | |
| seed (str, optional): Random seed. Defaults to '', which means no seed. | |
| """ | |
| if seed != '': | |
| try: | |
| seed_all(int(seed)) | |
| except: | |
| pass | |
| with torch.no_grad(): | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| wav = consistencytta( | |
| [prompt], num_steps=1, cfg_scale_input=cfg_weight, cfg_scale_post=1., sr=sr | |
| ) | |
| sf.write("output.wav", wav.T, samplerate=sr, subtype='PCM_16') | |
| return "output.wav" | |
| # Generate test audio | |
| print("Generating test audio...") | |
| generate("A dog barks as a train passes by.", seed=1) | |
| print("Test audio generated successfully! Starting Gradio interface...") | |
| # Launch Gradio interface | |
| iface = gr.Interface( | |
| fn=generate, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Text", value="Several people cheer and scream and speak as water flows hard." | |
| ), | |
| gr.Textbox(label="Random Seed (Optional)", value=''), | |
| gr.Slider( | |
| minimum=0., maximum=8., value=3.5, label="Classifier-Free Guidance Strength" | |
| )], | |
| outputs="audio", | |
| title="ConsistencyTTA: Accelerating Diffusion-Based Text-to-Audio " \ | |
| "Generation with Consistency Distillation", | |
| description="This is the official demo page for <a href='https://consistency-tta.github." \ | |
| "io' target=“blank”>ConsistencyTTA</a>, a model that accelerates " \ | |
| "diffusion-based text-to-audio generation hundreds of times with consistency " \ | |
| "models. <br> Here, the audio is generated within a single non-autoregressive " \ | |
| "forward pass from the CLAP-finetuned ConsistencyTTA checkpoint. <br> Since " \ | |
| "the training dataset does not include speech, the model is not expected to " \ | |
| "generate coherent speech. <br> Have fun!" | |
| ) | |
| iface.launch() | |