Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Create app.py
Browse files
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,121 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import streamlit as st
         | 
| 2 | 
            +
            from PIL import Image
         | 
| 3 | 
            +
            import requests
         | 
| 4 | 
            +
            from io import BytesIO
         | 
| 5 | 
            +
            from transformers import (
         | 
| 6 | 
            +
                ViTFeatureExtractor, 
         | 
| 7 | 
            +
                ViTForImageClassification, 
         | 
| 8 | 
            +
                pipeline,
         | 
| 9 | 
            +
                AutoFeatureExtractor, 
         | 
| 10 | 
            +
                AutoModelForObjectDetection,
         | 
| 11 | 
            +
                CLIPTokenizerFast,
         | 
| 12 | 
            +
                CLIPTextModel
         | 
| 13 | 
            +
            )
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            from torchvision.transforms import functional as F
         | 
| 16 | 
            +
            import emoji
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # Load models
         | 
| 19 | 
            +
            @st.cache_resource
         | 
| 20 | 
            +
            def load_models():
         | 
| 21 | 
            +
                age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
         | 
| 22 | 
            +
                age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
         | 
| 23 | 
            +
                
         | 
| 24 | 
            +
                gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2')
         | 
| 25 | 
            +
                gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2')
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection')
         | 
| 28 | 
            +
                emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection')
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                object_detector = pipeline("object-detection", model="facebook/detr-resnet-50")
         | 
| 31 | 
            +
                
         | 
| 32 | 
            +
                action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
         | 
| 33 | 
            +
                action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
         | 
| 34 | 
            +
                
         | 
| 35 | 
            +
                prompt_generator = pipeline("text2text-generation", model="succinctly/text2image-prompt-generator")
         | 
| 36 | 
            +
                
         | 
| 37 | 
            +
                clip_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32")
         | 
| 38 | 
            +
                clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
         | 
| 39 | 
            +
                
         | 
| 40 | 
            +
                return (age_model, age_transforms, gender_model, gender_transforms, 
         | 
| 41 | 
            +
                        emotion_model, emotion_transforms, object_detector, 
         | 
| 42 | 
            +
                        action_model, action_transforms, prompt_generator, 
         | 
| 43 | 
            +
                        clip_tokenizer, clip_model)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            models = load_models()
         | 
| 46 | 
            +
            (age_model, age_transforms, gender_model, gender_transforms, 
         | 
| 47 | 
            +
             emotion_model, emotion_transforms, object_detector, 
         | 
| 48 | 
            +
             action_model, action_transforms, prompt_generator, 
         | 
| 49 | 
            +
             clip_tokenizer, clip_model) = models
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            def predict(image, model, transforms):
         | 
| 52 | 
            +
                inputs = transforms(image, return_tensors='pt')
         | 
| 53 | 
            +
                output = model(**inputs)
         | 
| 54 | 
            +
                proba = output.logits.softmax(1)
         | 
| 55 | 
            +
                return proba.argmax(1).item()
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            def detect_attributes(image):
         | 
| 58 | 
            +
                age = predict(image, age_model, age_transforms)
         | 
| 59 | 
            +
                gender = predict(image, gender_model, gender_transforms)
         | 
| 60 | 
            +
                emotion = predict(image, emotion_model, emotion_transforms)
         | 
| 61 | 
            +
                action = predict(image, action_model, action_transforms)
         | 
| 62 | 
            +
                
         | 
| 63 | 
            +
                objects = object_detector(image)
         | 
| 64 | 
            +
                
         | 
| 65 | 
            +
                return {
         | 
| 66 | 
            +
                    'age': age_model.config.id2label[age],
         | 
| 67 | 
            +
                    'gender': gender_model.config.id2label[gender],
         | 
| 68 | 
            +
                    'emotion': emotion_model.config.id2label[emotion],
         | 
| 69 | 
            +
                    'action': action_model.config.id2label[action],
         | 
| 70 | 
            +
                    'objects': [obj['label'] for obj in objects]
         | 
| 71 | 
            +
                }
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            def generate_prompt(attributes):
         | 
| 74 | 
            +
                prompt = f"A {attributes['age']} {attributes['gender']} person feeling {attributes['emotion']} "
         | 
| 75 | 
            +
                prompt += f"while {attributes['action']}. "
         | 
| 76 | 
            +
                if attributes['objects']:
         | 
| 77 | 
            +
                    prompt += f"Surrounded by {', '.join(attributes['objects'])}. "
         | 
| 78 | 
            +
                return prompt
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            def generate_emoji(prompt):
         | 
| 81 | 
            +
                inputs = clip_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
         | 
| 82 | 
            +
                outputs = clip_model(**inputs)
         | 
| 83 | 
            +
                embedding = outputs.last_hidden_state.mean(dim=1)
         | 
| 84 | 
            +
                
         | 
| 85 | 
            +
                # Simple emoji mapping based on embedding features
         | 
| 86 | 
            +
                if embedding[0, 0] > 0:
         | 
| 87 | 
            +
                    return emoji.emojize(":grinning_face:")
         | 
| 88 | 
            +
                elif embedding[0, 1] > 0:
         | 
| 89 | 
            +
                    return emoji.emojize(":smiling_face_with_heart-eyes:")
         | 
| 90 | 
            +
                elif embedding[0, 2] > 0:
         | 
| 91 | 
            +
                    return emoji.emojize(":face_with_tears_of_joy:")
         | 
| 92 | 
            +
                elif embedding[0, 3] > 0:
         | 
| 93 | 
            +
                    return emoji.emojize(":thinking_face:")
         | 
| 94 | 
            +
                else:
         | 
| 95 | 
            +
                    return emoji.emojize(":neutral_face:")
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            st.title("Image Attribute Detection and Emoji Generation")
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            if uploaded_file is not None:
         | 
| 102 | 
            +
                image = Image.open(uploaded_file)
         | 
| 103 | 
            +
                st.image(image, caption='Uploaded Image', use_column_width=True)
         | 
| 104 | 
            +
                
         | 
| 105 | 
            +
                if st.button('Analyze and Generate Emoji'):
         | 
| 106 | 
            +
                    with st.spinner('Detecting attributes...'):
         | 
| 107 | 
            +
                        attributes = detect_attributes(image)
         | 
| 108 | 
            +
                    
         | 
| 109 | 
            +
                    st.write("Detected Attributes:")
         | 
| 110 | 
            +
                    for key, value in attributes.items():
         | 
| 111 | 
            +
                        st.write(f"{key.capitalize()}: {value}")
         | 
| 112 | 
            +
                    
         | 
| 113 | 
            +
                    with st.spinner('Generating prompt...'):
         | 
| 114 | 
            +
                        prompt = generate_prompt(attributes)
         | 
| 115 | 
            +
                    st.write("Generated Prompt:")
         | 
| 116 | 
            +
                    st.write(prompt)
         | 
| 117 | 
            +
                    
         | 
| 118 | 
            +
                    with st.spinner('Generating emoji...'):
         | 
| 119 | 
            +
                        emoji_result = generate_emoji(prompt)
         | 
| 120 | 
            +
                    st.write("Generated Emoji:")
         | 
| 121 | 
            +
                    st.write(emoji_result)
         |