Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import gradio as gr | |
| text_generator = None | |
| is_hugging_face = False | |
| def init(): | |
| global text_generator | |
| huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
| if not huggingface_token: | |
| pass | |
| print("no HUGGINGFACE_TOKEN if you need set secret ") | |
| #raise ValueError("HUGGINGFACE_TOKEN environment variable is not set") | |
| model_id = "Qwen/Qwen2.5-0.5B-Instruct" | |
| device = "auto" # torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| #device = "cuda" | |
| dtype = torch.bfloat16 | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token) | |
| print(model_id,device,dtype) | |
| histories = [] | |
| #model = None | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device | |
| ) | |
| text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device ) #pipeline has not to(device) | |
| if not is_hugging_face: | |
| if next(model.parameters()).is_cuda: | |
| print("The model is on a GPU") | |
| else: | |
| print("The model is on a CPU") | |
| #print(f"text_generator.device='{text_generator.device}") | |
| if str(text_generator.device).strip() == 'cuda': | |
| print("The pipeline is using a GPU") | |
| else: | |
| print("The pipeline is using a CPU") | |
| print("initialized") | |
| def generate_text(messages): | |
| global text_generator | |
| if is_hugging_face:#need everytime initialize for ZeroGPU | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device | |
| ) | |
| text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device ) #pipeline has not to(device) | |
| result = text_generator(messages, max_new_tokens=32, do_sample=True, temperature=0.7) | |
| generated_output = result[0]["generated_text"] | |
| if isinstance(generated_output, list): | |
| for message in reversed(generated_output): | |
| if message.get("role") == "assistant": | |
| content= message.get("content", "No content found.") | |
| return content | |
| return "No assistant response found." | |
| else: | |
| return "Unexpected output format." | |
| def call_generate_text(message, history): | |
| if len(message) == 0: | |
| message.append({"role": "system", "content": "you response around 10 words"}) | |
| # history.append({"role": "user", "content": message}) | |
| print(message) | |
| print(history) | |
| messages = history+[{"role":"user","content":message}] | |
| try: | |
| text = generate_text(messages) | |
| messages += [{"role":"assistant","content":text}] | |
| return "",messages | |
| except RuntimeError as e: | |
| print(f"An unexpected error occurred: {e}") | |
| return "",history | |
| head = ''' | |
| <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.webgpu.min.js" ></script> | |
| <script type="module"> | |
| import { matcha_tts,env } from "https://akjava.github.io/Matcha-TTS-Japanese/js-esm/v001-20240921/matcha_tts_onnx_en.js"; | |
| window.MatchaTTSEn = matcha_tts | |
| </script> | |
| ''' | |
| with gr.Blocks(title="LLM with TTS",head=head) as demo: | |
| gr.Markdown("## Please be patient, the first response may have a delay of up to 20 seconds while loading.") | |
| gr.Markdown("**Qwen2.5-0.5B-Instruct/LJSpeech-q8(extremly slow:but Gradio Space need absolute onnx-url and url will block soon.so use github slow version here)**.LLM and TTS models will change without notice.") | |
| js = """ | |
| function(chatbot){ | |
| text = (chatbot[chatbot.length -1])["content"] | |
| window.MatchaTTSEn(text) | |
| } | |
| """ | |
| chatbot = gr.Chatbot(type="messages") | |
| chatbot.change(None,[chatbot],[],js=js) | |
| msg = gr.Textbox() | |
| clear = gr.ClearButton([msg, chatbot]) | |
| gr.HTML(""" | |
| <br> | |
| <div id="footer"> | |
| <b>Spaces</b><br> | |
| <a href="https://huggingface.co/spaces/Akjava/matcha-tts_vctk-onnx" style="font-size: 9px" target="link">Match-TTS VCTK-ONNX</a> | | |
| <a href="https://huggingface.co/spaces/Akjava/matcha-tts-onnx-benchmarks" style="font-size: 9px" target="link">Match-TTS ONNX-Benchmark</a> | | |
| <br><br> | |
| <b>Credits</b><br> | |
| <a href="https://github.com/akjava/Matcha-TTS-Japanese" style="font-size: 9px" target="link">Matcha-TTS-Japanese</a> | | |
| <a href = "http://www.udialogue.org/download/cstr-vctk-corpus.html" style="font-size: 9px" target="link">CSTR VCTK Corpus</a> | | |
| <a href = "https://github.com/cmusphinx/cmudict" style="font-size: 9px" target="link">CMUDict</a> | | |
| <a href = "https://huggingface.co/docs/transformers.js/index" style="font-size: 9px" target="link">Transformer.js</a> | | |
| <a href = "https://huggingface.co/cisco-ai/mini-bart-g2p" style="font-size: 9px" target="link">mini-bart-g2p</a> | | |
| <a href = "https://onnxruntime.ai/docs/get-started/with-javascript/web.html" style="font-size: 9px" target="link">ONNXRuntime-Web</a> | | |
| <a href = "https://github.com/akjava/English-To-IPA-Collections" style="font-size: 9px" target="link">English-To-IPA-Collections</a> | | |
| <a href ="https://huggingface.co/papers/2309.03199" style="font-size: 9px" target="link">Matcha-TTS Paper</a> | |
| </div> | |
| """) | |
| msg.submit(call_generate_text, [msg, chatbot], [msg, chatbot]) | |
| if __name__ == "__main__": | |
| init() | |
| demo.launch(share=True) |