Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import soundfile as sf | |
| import gradio as gr | |
| import jax | |
| import numpy as np | |
| from PIL import Image | |
| import random | |
| import sox | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForCTC | |
| from transformers import pipeline | |
| from dalle_mini import DalleBart, DalleBartProcessor | |
| from vqgan_jax.modeling_flax_vqgan import VQModel | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| #asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew") | |
| #asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew") | |
| asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-1b-hebrew") | |
| asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-1b-hebrew") | |
| he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en") | |
| # Model references | |
| # dalle-mini, mega too large | |
| # DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or π€ Hub or local folder or google bucket | |
| DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0" | |
| DALLE_COMMIT_ID = None | |
| # VQGAN model | |
| VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384" | |
| VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9" | |
| model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID) | |
| vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID) | |
| processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID) | |
| def generate_image(text): | |
| tokenized_prompt = processor([text]) | |
| gen_top_k = None | |
| gen_top_p = None | |
| temperature = 0.85 | |
| cond_scale = 3.0 | |
| encoded_images = model.generate( | |
| **tokenized_prompt, | |
| prng_key=jax.random.PRNGKey(random.randint(0, 1e7)), | |
| params=model.params, | |
| top_k=gen_top_k, | |
| top_p=gen_top_p, | |
| temperature=temperature, | |
| condition_scale=cond_scale, | |
| ) | |
| encoded_images = encoded_images.sequences[..., 1:] | |
| decoded_images = vqgan.decode_code(encoded_images, vqgan.params) | |
| decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3)) | |
| img = decoded_images[0] | |
| return Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) | |
| def convert(inputfile, outfile): | |
| sox_tfm = sox.Transformer() | |
| sox_tfm.set_output_format( | |
| file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16 | |
| ) | |
| sox_tfm.build(inputfile, outfile) | |
| def parse_transcription(wav_file): | |
| # Get the wav file from the microphone | |
| filename = wav_file.name.split('.')[0] | |
| convert(wav_file.name, filename + "16k.wav") | |
| speech, _ = sf.read(filename + "16k.wav") | |
| # transcribe to hebrew | |
| input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values | |
| logits = asr_model(input_values).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True) | |
| print(transcription) | |
| # translate to english | |
| translated = he_en_translator(transcription)[0]['translation_text'] | |
| print(translated) | |
| # generate image | |
| image = generate_image(translated) | |
| return transcription, translated, image | |
| outputs = [gr.outputs.Textbox(label="transcript"), gr.outputs.Textbox(label="translated prompet"), gr.outputs.Image(label='')] | |
| input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True) | |
| gr.Interface(parse_transcription, inputs=[input_mic], outputs=outputs, | |
| analytics_enabled=False, | |
| show_tips=False, | |
| theme='huggingface', | |
| layout='horizontal', | |
| title="Draw Me A Sheep in Hebrew", | |
| enable_queue=True).launch(inline=False) |