Ganbatte commited on
Commit
0e9ef5e
·
verified ·
1 Parent(s): e422552

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -22
app.py CHANGED
@@ -1,30 +1,40 @@
1
  import torch
2
- from transformers import AutoProcessor, AutoModelForTextToWaveform
 
3
  import gradio as gr
4
- import scipy.io.wavfile
5
 
6
- # โหลด processor และโมเดล
7
- processor = AutoProcessor.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
8
- model = AutoModelForTextToWaveform.from_pretrained("nvidia/parakeet-tdt-0.6b-v2").to("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
- # รายชื่อ speaker ที่รองรับ (mock list - ต้องดูจาก actual model config ด้วย)
11
- speakers = ["emma", "ryan", "brian", "karen", "amy", "john"]
12
 
13
- def synthesize(text, speaker):
14
- inputs = processor(text, speaker=speaker, return_tensors="pt").to(model.device)
15
- with torch.no_grad():
16
- waveform = model(**inputs).waveform
17
- waveform = waveform.squeeze().cpu().numpy()
18
- return (24000, waveform) # sample rate 24kHz
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # สร้าง Gradio Interface
21
  gr.Interface(
22
- fn=synthesize,
23
  inputs=[
24
- gr.Textbox(label="Enter text to synthesize"),
25
- gr.Dropdown(choices=speakers, label="Select speaker")
26
- ],
27
- outputs=gr.Audio(label="Generated Speech"),
28
- title="🗣️ NVIDIA Parakeet TTS Demo",
29
- description="Text-to-Speech using NVIDIA Parakeet-TDT-0.6B-v2 model"
30
- ).launch()
 
1
  import torch
2
+ import torchaudio
3
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
4
  import gradio as gr
 
5
 
6
+ model_name = "ibm-granite/granite-speech-3.3-8b"
 
 
7
 
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
9
 
10
+ processor = AutoProcessor.from_pretrained(model_name)
11
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(device)
12
+
13
+ def transcribe(audio, translate_to=None):
14
+ # audio: (sampling rate, numpy array) from Gradio
15
+ sr, audio_data = audio
16
+ waveform = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, seq)
17
+ # Resample if not 16kHz
18
+ if sr != 16000:
19
+ waveform = torchaudio.functional.resample(waveform, sr, 16000)
20
+
21
+ inputs = processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
22
+ # Set beam size >1 แนะนำ beam=5
23
+ outputs = model.generate(**inputs, num_beams=5, max_new_tokens=512)
24
+ text = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
25
+
26
+ # ถ้ามี translate_to, เติม prompt: e.g. "<|translate_to=es|>"
27
+ if translate_to:
28
+ text = f"<|translate_to={translate_to}|> " + text
29
+ inputs2 = processor(text, return_tensors="pt").to(device)
30
+ outputs2 = model.generate(**inputs2, num_beams=5)
31
+ text = processor.tokenizer.batch_decode(outputs2, skip_special_tokens=True)[0]
32
+
33
+ return text
34
+
35
+ translator_options = [None, "fr", "es", "it", "de", "pt", "ja", "zh"]
36
 
 
37
  gr.Interface(
38
+ fn=transcribe,
39
  inputs=[
40
+ gr.Audio(source="microphone", type="numpy", label="Upload