Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Commit 
							
							·
						
						b52bed7
	
1
								Parent(s):
							
							2d191f6
								
Added inference
Browse files- app.py +54 -1
- requirements.txt +2 -0
    	
        app.py
    CHANGED
    
    | @@ -7,6 +7,7 @@ import torch.nn as nn | |
| 7 | 
             
            from model import Projections
         | 
| 8 | 
             
            from transformers import WhisperProcessor, WhisperForConditionalGeneration
         | 
| 9 | 
             
            import gradio as gr
         | 
|  | |
| 10 |  | 
| 11 | 
             
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 12 | 
             
            projections = Projections(512, 3072)
         | 
| @@ -47,7 +48,59 @@ whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name) | |
| 47 | 
             
            whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
         | 
| 48 |  | 
| 49 | 
             
            def infer(message, history):
         | 
| 50 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 51 |  | 
| 52 | 
             
            examples=[{'text':"I am planning to buy a dog and a cat. Suggest some breeds that get along with each other"},
         | 
| 53 | 
             
                      {'text':"Explain biased coin flip"},
         | 
|  | |
| 7 | 
             
            from model import Projections
         | 
| 8 | 
             
            from transformers import WhisperProcessor, WhisperForConditionalGeneration
         | 
| 9 | 
             
            import gradio as gr
         | 
| 10 | 
            +
            import librosa
         | 
| 11 |  | 
| 12 | 
             
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 13 | 
             
            projections = Projections(512, 3072)
         | 
|  | |
| 48 | 
             
            whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
         | 
| 49 |  | 
| 50 | 
             
            def infer(message, history):
         | 
| 51 | 
            +
                max_generate_length = 100
         | 
| 52 | 
            +
                combined_embeds = []
         | 
| 53 | 
            +
                
         | 
| 54 | 
            +
                with torch.no_grad():
         | 
| 55 | 
            +
                    if message['file']:
         | 
| 56 | 
            +
                        projected_image_embeds = None
         | 
| 57 | 
            +
                        audio_text_embeds = None
         | 
| 58 | 
            +
                        for path in message['file']:
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                            if path.endswith(('.jpg', '.png', '.jpeg')):
         | 
| 61 | 
            +
                                image = clip_preprocess(Image.open(path)).unsqueeze(0).to(device)
         | 
| 62 | 
            +
                                image_features = clip_model.encode_image(image)
         | 
| 63 | 
            +
                                projected_image_embeds = projections(image_features.to(torch.bfloat16)).unsqueeze(0)
         | 
| 64 | 
            +
                        
         | 
| 65 | 
            +
                            elif path.endswith(('.mp3', '.wav')):
         | 
| 66 | 
            +
                                # Load and preprocess the audio
         | 
| 67 | 
            +
                                speech, rate = librosa.load(path, sr=16000)
         | 
| 68 | 
            +
                                input_features = whisper_processor(speech, return_tensors="pt", sampling_rate=16000).input_features 
         | 
| 69 | 
            +
                                predicted_ids = whisper_model.generate(input_features)
         | 
| 70 | 
            +
                                transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
         | 
| 71 | 
            +
                                prompt = tokenizer.apply_chat_template([{"from": "human", "value": transcription}], tokenize=False, add_generation_prompt=True)
         | 
| 72 | 
            +
                                prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids']
         | 
| 73 | 
            +
                                audio_text_embeds = model.get_input_embeddings()(prompt_tokens)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                        if projected_image_embeds:
         | 
| 76 | 
            +
                            combined_embeds.append(projected_image_embeds)
         | 
| 77 | 
            +
                        
         | 
| 78 | 
            +
                        if audio_text_embeds:
         | 
| 79 | 
            +
                            combined_embeds.append(audio_text_embeds)
         | 
| 80 | 
            +
                    
         | 
| 81 | 
            +
                    if  message['text']:
         | 
| 82 | 
            +
                        prompt = tokenizer.apply_chat_template([{"from": "human", "value": transcription}], tokenize=False, add_generation_prompt=True)
         | 
| 83 | 
            +
                        prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids']
         | 
| 84 | 
            +
                        text_embeds = model.get_input_embeddings()(prompt_tokens)
         | 
| 85 | 
            +
                        combined_embeds.append(text_embeds)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    combined_embeds = torch.cat(combined_embeds,dim=1)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
         | 
| 90 | 
            +
                    predicted_caption = torch.full((1,max_generate_length),50256).to(device)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    for g in range(max_generate_length):
         | 
| 93 | 
            +
                        phi_output_logits = model(inputs_embeds=combined_embeds)['logits'] # 4, 69, 51200
         | 
| 94 | 
            +
                        predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
         | 
| 95 | 
            +
                        predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
         | 
| 96 | 
            +
                        predicted_caption[:,g] = predicted_word_token.view(1,-1)
         | 
| 97 | 
            +
                        next_token_embeds = model.get_input_embeddings()(prompt_tokens) # 4,1,2560
         | 
| 98 | 
            +
                        combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
         | 
| 99 | 
            +
                        
         | 
| 100 | 
            +
                    predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                return predicted_captions_decoded
         | 
| 103 | 
            +
             | 
| 104 |  | 
| 105 | 
             
            examples=[{'text':"I am planning to buy a dog and a cat. Suggest some breeds that get along with each other"},
         | 
| 106 | 
             
                      {'text':"Explain biased coin flip"},
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -3,6 +3,8 @@ clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8 | |
| 3 | 
             
            colorama==0.4.6
         | 
| 4 | 
             
            datasets==3.0.0
         | 
| 5 | 
             
            dill==0.3.8
         | 
|  | |
|  | |
| 6 | 
             
            multiprocess==0.70.16
         | 
| 7 | 
             
            numpy==1.26.4
         | 
| 8 | 
             
            pandas==2.2.2
         | 
|  | |
| 3 | 
             
            colorama==0.4.6
         | 
| 4 | 
             
            datasets==3.0.0
         | 
| 5 | 
             
            dill==0.3.8
         | 
| 6 | 
            +
            gradio==5.0.2
         | 
| 7 | 
            +
            librosa==0.10.2
         | 
| 8 | 
             
            multiprocess==0.70.16
         | 
| 9 | 
             
            numpy==1.26.4
         | 
| 10 | 
             
            pandas==2.2.2
         |