Ganbatte commited on
Commit
c401595
·
verified ·
1 Parent(s): 00024bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -1,40 +1,51 @@
1
  import torch
2
  import torchaudio
 
 
3
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
4
  import gradio as gr
5
 
6
  model_name = "ibm-granite/granite-speech-3.3-8b"
7
-
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
  processor = AutoProcessor.from_pretrained(model_name)
11
  model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(device)
12
 
13
- def transcribe(audio, translate_to=None):
14
- # audio: (sampling rate, numpy array) from Gradio
15
- sr, audio_data = audio
16
- waveform = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, seq)
17
- # Resample if not 16kHz
 
 
 
 
 
 
 
 
18
  if sr != 16000:
19
  waveform = torchaudio.functional.resample(waveform, sr, 16000)
20
-
21
  inputs = processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
22
- # Set beam size >1 แนะนำ beam=5
23
  outputs = model.generate(**inputs, num_beams=5, max_new_tokens=512)
24
  text = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
25
 
26
- # ถ้ามี translate_to, เติม prompt: e.g. "<|translate_to=es|>"
27
  if translate_to:
28
  text = f"<|translate_to={translate_to}|> " + text
29
  inputs2 = processor(text, return_tensors="pt").to(device)
30
  outputs2 = model.generate(**inputs2, num_beams=5)
31
  text = processor.tokenizer.batch_decode(outputs2, skip_special_tokens=True)[0]
32
-
33
- return text
34
 
35
- translator_options = [None, "fr", "es", "it", "de", "pt", "ja", "zh"]
36
 
37
  gr.Interface(
38
- fn=transcribe,
39
  inputs=[
40
- gr.Audio(source="microphone", type="numpy", label="Upload
 
 
 
 
 
 
 
1
  import torch
2
  import torchaudio
3
+ import tempfile
4
+ import requests
5
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
6
  import gradio as gr
7
 
8
  model_name = "ibm-granite/granite-speech-3.3-8b"
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  processor = AutoProcessor.from_pretrained(model_name)
12
  model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(device)
13
 
14
+ def download_audio_from_url(url):
15
+ response = requests.get(url)
16
+ if response.status_code != 200:
17
+ raise Exception("Failed to download file from URL.")
18
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
19
+ tmp.write(response.content)
20
+ tmp.close()
21
+ waveform, sr = torchaudio.load(tmp.name)
22
+ return waveform, sr
23
+
24
+ def transcribe_from_url(audio_url, translate_to=None):
25
+ waveform, sr = download_audio_from_url(audio_url)
26
+ # Resample if needed
27
  if sr != 16000:
28
  waveform = torchaudio.functional.resample(waveform, sr, 16000)
29
+
30
  inputs = processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
 
31
  outputs = model.generate(**inputs, num_beams=5, max_new_tokens=512)
32
  text = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
33
 
 
34
  if translate_to:
35
  text = f"<|translate_to={translate_to}|> " + text
36
  inputs2 = processor(text, return_tensors="pt").to(device)
37
  outputs2 = model.generate(**inputs2, num_beams=5)
38
  text = processor.tokenizer.batch_decode(outputs2, skip_special_tokens=True)[0]
 
 
39
 
40
+ return text
41
 
42
  gr.Interface(
43
+ fn=transcribe_from_url,
44
  inputs=[
45
+ gr.Textbox(label="🎧 Audio File URL (.mp3, .wav)", placeholder="Paste Google Drive direct link or other audio URL"),
46
+ gr.Dropdown(choices=[None, "fr", "es", "it", "de", "pt", "ja", "zh"], label="Translate to (optional)")
47
+ ],
48
+ outputs=gr.Textbox(label="📝 Transcription / Translation"),
49
+ title="Granite Speech 3.3-8B - Audio from URL",
50
+ description="Paste a direct URL to an audio file (Google Drive with 'uc?export=download' format or any MP3/WAV link)"
51
+ ).launch()