|
|
import os |
|
|
import subprocess |
|
|
import random |
|
|
import numpy as np |
|
|
import json |
|
|
from datetime import timedelta |
|
|
import tempfile |
|
|
import re |
|
|
import gradio as gr |
|
|
import groq |
|
|
from groq import Groq |
|
|
import io |
|
|
|
|
|
import soundfile as sf |
|
|
|
|
|
|
|
|
|
|
|
client = Groq(api_key=os.environ.get("Groq_Api_Key")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transcribe_audio(audio): |
|
|
if audio is None: |
|
|
return "" |
|
|
|
|
|
client = groq.Client(api_key=os.environ.get("Groq_Api_Key")) |
|
|
|
|
|
|
|
|
|
|
|
audio_data = audio[1] |
|
|
buffer = io.BytesIO() |
|
|
sf.write(buffer, audio_data, audio[0], format='wav') |
|
|
buffer.seek(0) |
|
|
|
|
|
bytes_audio = io.BytesIO() |
|
|
np.save(bytes_audio, audio_data) |
|
|
bytes_audio.seek(0) |
|
|
|
|
|
try: |
|
|
|
|
|
completion = client.audio.transcriptions.create( |
|
|
model="distil-whisper-large-v3-en", |
|
|
file=("audio.wav", buffer), |
|
|
response_format="text" |
|
|
) |
|
|
return completion |
|
|
except Exception as e: |
|
|
return f"Error in transcription: {str(e)}" |
|
|
|
|
|
def generate_response(transcription, api_key): |
|
|
if not transcription: |
|
|
return "No transcription available. Please try speaking again." |
|
|
|
|
|
client = groq.Client(api_key=api_key) |
|
|
|
|
|
try: |
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model="llama3-70b-8192", |
|
|
messages=[ |
|
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
|
{"role": "user", "content": transcription} |
|
|
], |
|
|
) |
|
|
return completion.choices[0].message.content |
|
|
except Exception as e: |
|
|
return f"Error in response generation: {str(e)}" |
|
|
|
|
|
def process_audio(audio, api_key): |
|
|
if not api_key: |
|
|
return "Please enter your Groq API key.", "API key is required." |
|
|
transcription = transcribe_audio(audio, api_key) |
|
|
response = generate_response(transcription, api_key) |
|
|
return transcription, response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_groq_error(e, model_name): |
|
|
error_data = e.args[0] |
|
|
|
|
|
if isinstance(error_data, str): |
|
|
|
|
|
json_match = re.search(r'(\{.*\})', error_data) |
|
|
if json_match: |
|
|
json_str = json_match.group(1) |
|
|
|
|
|
json_str = json_str.replace("'", '"') |
|
|
error_data = json.loads(json_str) |
|
|
|
|
|
if isinstance(e, groq.RateLimitError): |
|
|
if isinstance(error_data, dict) and 'error' in error_data and 'message' in error_data['error']: |
|
|
error_message = error_data['error']['message'] |
|
|
raise gr.Error(error_message) |
|
|
else: |
|
|
raise gr.Error(f"Error during Groq API call: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
def update_max_tokens(model): |
|
|
if model in ["llama3-70b-8192", "llama3-8b-8192", "gemma-7b-it", "gemma2-9b-it"]: |
|
|
return gr.update(maximum=8192) |
|
|
elif model == "mixtral-8x7b-32768": |
|
|
return gr.update(maximum=32768) |
|
|
|
|
|
def create_history_messages(history): |
|
|
history_messages = [{"role": "user", "content": m[0]} for m in history] |
|
|
history_messages.extend([{"role": "assistant", "content": m[1]} for m in history]) |
|
|
return history_messages |
|
|
|
|
|
def generate_response(prompt, history, model, temperature, max_tokens, top_p, seed): |
|
|
messages = create_history_messages(history) |
|
|
messages.append({"role": "user", "content": prompt}) |
|
|
print(messages) |
|
|
|
|
|
if seed == 0: |
|
|
seed = random.randint(1, MAX_SEED) |
|
|
|
|
|
try: |
|
|
stream = client.chat.completions.create( |
|
|
messages=messages, |
|
|
model=model, |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens, |
|
|
top_p=top_p, |
|
|
seed=seed, |
|
|
stop=None, |
|
|
stream=True, |
|
|
) |
|
|
|
|
|
response = "" |
|
|
for chunk in stream: |
|
|
delta_content = chunk.choices[0].delta.content |
|
|
if delta_content is not None: |
|
|
response += delta_content |
|
|
yield response |
|
|
|
|
|
return response |
|
|
except Groq.GroqApiException as e: |
|
|
handle_groq_error(e, model) |
|
|
|
|
|
|
|
|
|
|
|
ALLOWED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"] |
|
|
MAX_FILE_SIZE_MB = 25 |
|
|
CHUNK_SIZE_MB = 25 |
|
|
|
|
|
LANGUAGE_CODES = { |
|
|
"English": "en", |
|
|
"Chinese": "zh", |
|
|
"German": "de", |
|
|
"Spanish": "es", |
|
|
"Russian": "ru", |
|
|
"Korean": "ko", |
|
|
"French": "fr", |
|
|
"Japanese": "ja", |
|
|
"Portuguese": "pt", |
|
|
"Turkish": "tr", |
|
|
"Polish": "pl", |
|
|
"Catalan": "ca", |
|
|
"Dutch": "nl", |
|
|
"Arabic": "ar", |
|
|
"Swedish": "sv", |
|
|
"Italian": "it", |
|
|
"Indonesian": "id", |
|
|
"Hindi": "hi", |
|
|
"Finnish": "fi", |
|
|
"Vietnamese": "vi", |
|
|
"Hebrew": "he", |
|
|
"Ukrainian": "uk", |
|
|
"Greek": "el", |
|
|
"Malay": "ms", |
|
|
"Czech": "cs", |
|
|
"Romanian": "ro", |
|
|
"Danish": "da", |
|
|
"Hungarian": "hu", |
|
|
"Tamil": "ta", |
|
|
"Norwegian": "no", |
|
|
"Thai": "th", |
|
|
"Urdu": "ur", |
|
|
"Croatian": "hr", |
|
|
"Bulgarian": "bg", |
|
|
"Lithuanian": "lt", |
|
|
"Latin": "la", |
|
|
"Māori": "mi", |
|
|
"Malayalam": "ml", |
|
|
"Welsh": "cy", |
|
|
"Slovak": "sk", |
|
|
"Telugu": "te", |
|
|
"Persian": "fa", |
|
|
"Latvian": "lv", |
|
|
"Bengali": "bn", |
|
|
"Serbian": "sr", |
|
|
"Azerbaijani": "az", |
|
|
"Slovenian": "sl", |
|
|
"Kannada": "kn", |
|
|
"Estonian": "et", |
|
|
"Macedonian": "mk", |
|
|
"Breton": "br", |
|
|
"Basque": "eu", |
|
|
"Icelandic": "is", |
|
|
"Armenian": "hy", |
|
|
"Nepali": "ne", |
|
|
"Mongolian": "mn", |
|
|
"Bosnian": "bs", |
|
|
"Kazakh": "kk", |
|
|
"Albanian": "sq", |
|
|
"Swahili": "sw", |
|
|
"Galician": "gl", |
|
|
"Marathi": "mr", |
|
|
"Panjabi": "pa", |
|
|
"Sinhala": "si", |
|
|
"Khmer": "km", |
|
|
"Shona": "sn", |
|
|
"Yoruba": "yo", |
|
|
"Somali": "so", |
|
|
"Afrikaans": "af", |
|
|
"Occitan": "oc", |
|
|
"Georgian": "ka", |
|
|
"Belarusian": "be", |
|
|
"Tajik": "tg", |
|
|
"Sindhi": "sd", |
|
|
"Gujarati": "gu", |
|
|
"Amharic": "am", |
|
|
"Yiddish": "yi", |
|
|
"Lao": "lo", |
|
|
"Uzbek": "uz", |
|
|
"Faroese": "fo", |
|
|
"Haitian": "ht", |
|
|
"Pashto": "ps", |
|
|
"Turkmen": "tk", |
|
|
"Norwegian Nynorsk": "nn", |
|
|
"Maltese": "mt", |
|
|
"Sanskrit": "sa", |
|
|
"Luxembourgish": "lb", |
|
|
"Burmese": "my", |
|
|
"Tibetan": "bo", |
|
|
"Tagalog": "tl", |
|
|
"Malagasy": "mg", |
|
|
"Assamese": "as", |
|
|
"Tatar": "tt", |
|
|
"Hawaiian": "haw", |
|
|
"Lingala": "ln", |
|
|
"Hausa": "ha", |
|
|
"Bashkir": "ba", |
|
|
"jw": "jw", |
|
|
"Sundanese": "su", |
|
|
} |
|
|
|
|
|
|
|
|
def split_audio(audio_file_path, chunk_size_mb): |
|
|
chunk_size = chunk_size_mb * 1024 * 1024 |
|
|
file_number = 1 |
|
|
chunks = [] |
|
|
with open(audio_file_path, 'rb') as f: |
|
|
chunk = f.read(chunk_size) |
|
|
while chunk: |
|
|
chunk_name = f"{os.path.splitext(audio_file_path)[0]}_part{file_number:03}.mp3" |
|
|
with open(chunk_name, 'wb') as chunk_file: |
|
|
chunk_file.write(chunk) |
|
|
chunks.append(chunk_name) |
|
|
file_number += 1 |
|
|
chunk = f.read(chunk_size) |
|
|
return chunks |
|
|
|
|
|
def merge_audio(chunks, output_file_path): |
|
|
with open("temp_list.txt", "w") as f: |
|
|
for file in chunks: |
|
|
f.write(f"file '{file}'\n") |
|
|
try: |
|
|
subprocess.run( |
|
|
[ |
|
|
"ffmpeg", |
|
|
"-f", |
|
|
"concat", |
|
|
"-safe", "0", |
|
|
"-i", |
|
|
"temp_list.txt", |
|
|
"-c", |
|
|
"copy", |
|
|
"-y", |
|
|
output_file_path |
|
|
], |
|
|
check=True |
|
|
) |
|
|
os.remove("temp_list.txt") |
|
|
for chunk in chunks: |
|
|
os.remove(chunk) |
|
|
except subprocess.CalledProcessError as e: |
|
|
raise gr.Error(f"Error during audio merging: {e}") |
|
|
|
|
|
|
|
|
def check_file(audio_file_path): |
|
|
if not audio_file_path: |
|
|
raise gr.Error("Please upload an audio file.") |
|
|
|
|
|
file_size_mb = os.path.getsize(audio_file_path) / (1024 * 1024) |
|
|
file_extension = audio_file_path.split(".")[-1].lower() |
|
|
|
|
|
if file_extension not in ALLOWED_FILE_EXTENSIONS: |
|
|
raise gr.Error(f"Invalid file type (.{file_extension}). Allowed types: {', '.join(ALLOWED_FILE_EXTENSIONS)}") |
|
|
|
|
|
if file_size_mb > MAX_FILE_SIZE_MB: |
|
|
gr.Warning( |
|
|
f"File size too large ({file_size_mb:.2f} MB). Attempting to downsample to 16kHz MP3 128kbps. Maximum size allowed: {MAX_FILE_SIZE_MB} MB" |
|
|
) |
|
|
|
|
|
output_file_path = os.path.splitext(audio_file_path)[0] + "_downsampled.mp3" |
|
|
try: |
|
|
subprocess.run( |
|
|
[ |
|
|
"ffmpeg", |
|
|
"-i", |
|
|
audio_file_path, |
|
|
"-ar", |
|
|
"16000", |
|
|
"-ab", |
|
|
"128k", |
|
|
"-ac", |
|
|
"1", |
|
|
"-f", |
|
|
"mp3", |
|
|
"-y", |
|
|
output_file_path, |
|
|
], |
|
|
check=True |
|
|
) |
|
|
|
|
|
|
|
|
downsampled_size_mb = os.path.getsize(output_file_path) / (1024 * 1024) |
|
|
if downsampled_size_mb > MAX_FILE_SIZE_MB: |
|
|
gr.Warning(f"File still too large after downsampling ({downsampled_size_mb:.2f} MB). Splitting into {CHUNK_SIZE_MB} MB chunks.") |
|
|
return split_audio(output_file_path, CHUNK_SIZE_MB), "split" |
|
|
|
|
|
return output_file_path, None |
|
|
except subprocess.CalledProcessError as e: |
|
|
raise gr.Error(f"Error during downsampling: {e}") |
|
|
return audio_file_path, None |
|
|
|
|
|
|
|
|
def transcribe_audio(audio_file_path, model, prompt, language, auto_detect_language): |
|
|
processed_path, split_status = check_file(audio_file_path) |
|
|
full_transcription = "" |
|
|
|
|
|
if split_status == "split": |
|
|
processed_chunks = [] |
|
|
for i, chunk_path in enumerate(processed_path): |
|
|
try: |
|
|
with open(chunk_path, "rb") as file: |
|
|
transcription = client.audio.transcriptions.create( |
|
|
file=(os.path.basename(chunk_path), file.read()), |
|
|
model=model, |
|
|
prompt=prompt, |
|
|
response_format="text", |
|
|
language=None if auto_detect_language else language, |
|
|
temperature=0.0, |
|
|
) |
|
|
full_transcription += transcription |
|
|
processed_chunks.append(chunk_path) |
|
|
except groq.RateLimitError as e: |
|
|
handle_groq_error(e, model) |
|
|
gr.Warning(f"API limit reached during chunk {i+1}. Returning processed chunks only.") |
|
|
if processed_chunks: |
|
|
merge_audio(processed_chunks, 'merged_output.mp3') |
|
|
return full_transcription, 'merged_output.mp3' |
|
|
else: |
|
|
return "Transcription failed due to API limits.", None |
|
|
merge_audio(processed_path, 'merged_output.mp3') |
|
|
return full_transcription, 'merged_output.mp3' |
|
|
else: |
|
|
try: |
|
|
with open(processed_path, "rb") as file: |
|
|
transcription = client.audio.transcriptions.create( |
|
|
file=(os.path.basename(processed_path), file.read()), |
|
|
model=model, |
|
|
prompt=prompt, |
|
|
response_format="text", |
|
|
language=None if auto_detect_language else language, |
|
|
temperature=0.0, |
|
|
) |
|
|
return transcription, None |
|
|
except groq.RateLimitError as e: |
|
|
handle_groq_error(e, model) |
|
|
|
|
|
def translate_audio(audio_file_path, model, prompt): |
|
|
processed_path, split_status = check_file(audio_file_path) |
|
|
full_translation = "" |
|
|
|
|
|
if split_status == "split": |
|
|
for chunk_path in processed_path: |
|
|
try: |
|
|
with open(chunk_path, "rb") as file: |
|
|
translation = client.audio.translations.create( |
|
|
file=(os.path.basename(chunk_path), file.read()), |
|
|
model=model, |
|
|
prompt=prompt, |
|
|
response_format="text", |
|
|
temperature=0.0, |
|
|
) |
|
|
full_translation += translation |
|
|
except Groq.GroqApiException as e: |
|
|
handle_groq_error(e, model) |
|
|
return f"API limit reached. Partial translation: {full_translation}" |
|
|
return full_translation |
|
|
else: |
|
|
try: |
|
|
with open(processed_path, "rb") as file: |
|
|
translation = client.audio.translations.create( |
|
|
file=(os.path.basename(processed_path), file.read()), |
|
|
model=model, |
|
|
prompt=prompt, |
|
|
response_format="text", |
|
|
temperature=0.0, |
|
|
) |
|
|
return translation |
|
|
except Groq.GroqApiException as e: |
|
|
handle_groq_error(e, model) |
|
|
|
|
|
|
|
|
with gr.Blocks(theme="Hev832/niceandsimple") as interface: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Groq API UI |
|
|
Inference by Groq API |
|
|
If you are having API Rate Limit issues, you can retry later based on the [rate limits](https://console.groq.com/docs/rate-limits) or <a href="https://huggingface.co/spaces/Nick088/Fast-Subtitle-Maker?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank"> <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a> with <a href=https://console.groq.com/keys>your own API Key</a> </p> |
|
|
Hugging Face Space by [Nick088](https://linktr.ee/Nick088) |
|
|
<br> <a href="https://discord.gg/osai"> <img src="https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge" alt="Discord"> </a> |
|
|
""" |
|
|
) |
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.TabItem("Speech To Text"): |
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Transcription"): |
|
|
gr.Markdown("Transcript audio from files to text!") |
|
|
with gr.Row(): |
|
|
audio_input = gr.File( |
|
|
type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS] |
|
|
) |
|
|
model_choice_transcribe = gr.Dropdown( |
|
|
choices=["whisper-large-v3"], |
|
|
value="whisper-large-v3", |
|
|
label="Model", |
|
|
) |
|
|
with gr.Row(): |
|
|
transcribe_prompt = gr.Textbox( |
|
|
label="Prompt (Optional)", |
|
|
info="Specify any context or spelling corrections.", |
|
|
) |
|
|
with gr.Column(): |
|
|
language = gr.Dropdown( |
|
|
choices=[(lang, code) for lang, code in LANGUAGE_CODES.items()], |
|
|
value="en", |
|
|
label="Language", |
|
|
) |
|
|
auto_detect_language = gr.Checkbox(label="Auto Detect Language") |
|
|
transcribe_button = gr.Button("Transcribe") |
|
|
transcription_output = gr.Textbox(label="Transcription") |
|
|
merged_audio_output = gr.File(label="Merged Audio (if chunked)") |
|
|
transcribe_button.click( |
|
|
transcribe_audio, |
|
|
inputs=[audio_input, model_choice_transcribe, transcribe_prompt, language, auto_detect_language], |
|
|
outputs=[transcription_output, merged_audio_output], |
|
|
) |
|
|
with gr.TabItem("Translation"): |
|
|
gr.Markdown("Transcript audio from files and translate them to English text!") |
|
|
with gr.Row(): |
|
|
audio_input_translate = gr.File( |
|
|
type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS] |
|
|
) |
|
|
model_choice_translate = gr.Dropdown( |
|
|
choices=["whisper-large-v3"], |
|
|
value="whisper-large-v3", |
|
|
label="Audio Speech Recognition (ASR) Model", |
|
|
) |
|
|
with gr.Row(): |
|
|
translate_prompt = gr.Textbox( |
|
|
label="Prompt (Optional)", |
|
|
info="Specify any context or spelling corrections.", |
|
|
) |
|
|
translate_button = gr.Button("Translate") |
|
|
translation_output = gr.Textbox(label="Translation") |
|
|
translate_button.click( |
|
|
translate_audio, |
|
|
inputs=[audio_input_translate, model_choice_translate, translate_prompt], |
|
|
outputs=translation_output, |
|
|
) |
|
|
|
|
|
|
|
|
with gr.TabItem("LLMs"): |
|
|
with gr.Tab("Chat"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1, min_width=250): |
|
|
model = gr.Dropdown( |
|
|
choices=[ |
|
|
"llama3-70b-8192", |
|
|
"llama3-8b-8192", |
|
|
"mixtral-8x7b-32768", |
|
|
"gemma-7b-it", |
|
|
"gemma2-9b-it", |
|
|
], |
|
|
value="llama3-70b-8192", |
|
|
label="Model", |
|
|
) |
|
|
temperature = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
step=0.01, |
|
|
value=0.5, |
|
|
label="Temperature", |
|
|
info="Controls diversity of the generated text. Lower is more deterministic, higher is more creative.", |
|
|
) |
|
|
max_tokens = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=8192, |
|
|
step=1, |
|
|
value=4096, |
|
|
label="Max Tokens", |
|
|
info="The maximum number of tokens that the model can process in a single response.<br>Maximums: 8k for gemma 7b it, gemma2 9b it, llama 7b & 70b, 32k for mixtral 8x7b.", |
|
|
) |
|
|
top_p = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
step=0.01, |
|
|
value=0.5, |
|
|
label="Top P", |
|
|
info="A method of text generation where a model will only consider the most probable next tokens that make up the probability p.", |
|
|
) |
|
|
seed = gr.Number( |
|
|
precision=0, value=42, label="Seed", info="A starting point to initiate generation, use 0 for random" |
|
|
) |
|
|
model.change(update_max_tokens, inputs=[model], outputs=max_tokens) |
|
|
with gr.Column(scale=1, min_width=400): |
|
|
chatbot = gr.ChatInterface( |
|
|
fn=generate_response, |
|
|
chatbot=None, |
|
|
additional_inputs=[ |
|
|
model, |
|
|
temperature, |
|
|
max_tokens, |
|
|
top_p, |
|
|
seed, |
|
|
], |
|
|
) |
|
|
model.change( |
|
|
update_max_tokens, |
|
|
inputs=[ |
|
|
model, |
|
|
], |
|
|
outputs=max_tokens, |
|
|
) |
|
|
with gr.Tab("Voice-Powered AI Assistant"): |
|
|
with gr.Row(): |
|
|
audio_input = gr.Audio(label="Speak!", type="numpy") |
|
|
|
|
|
with gr.Row(): |
|
|
transcription_output = gr.Textbox(label="Transcription") |
|
|
response_output = gr.Textbox(label="AI Assistant Response") |
|
|
submit_button = gr.Button("Process", variant="primary") |
|
|
|
|
|
submit_button.click( |
|
|
process_audio, |
|
|
inputs=[audio_input, api_key_input], |
|
|
outputs=[transcription_output, response_output] |
|
|
) |
|
|
|
|
|
|
|
|
interface.launch(share=True) |