Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| from transformers import ( | |
| ViTFeatureExtractor, | |
| ViTForImageClassification, | |
| pipeline, | |
| AutoFeatureExtractor, | |
| AutoModelForObjectDetection, | |
| CLIPTokenizerFast, | |
| CLIPTextModel | |
| ) | |
| import torch | |
| from torchvision.transforms import functional as F | |
| import emoji | |
| # Load models | |
| def load_models(): | |
| age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier') | |
| age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier') | |
| gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2') | |
| gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2') | |
| emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection') | |
| emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection') | |
| object_detector = pipeline("object-detection", model="facebook/detr-resnet-50") | |
| action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') | |
| action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') | |
| prompt_generator = pipeline("text2text-generation", model="succinctly/text2image-prompt-generator") | |
| clip_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32") | |
| clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") | |
| return (age_model, age_transforms, gender_model, gender_transforms, | |
| emotion_model, emotion_transforms, object_detector, | |
| action_model, action_transforms, prompt_generator, | |
| clip_tokenizer, clip_model) | |
| models = load_models() | |
| (age_model, age_transforms, gender_model, gender_transforms, | |
| emotion_model, emotion_transforms, object_detector, | |
| action_model, action_transforms, prompt_generator, | |
| clip_tokenizer, clip_model) = models | |
| def predict(image, model, transforms): | |
| inputs = transforms(image, return_tensors='pt') | |
| output = model(**inputs) | |
| proba = output.logits.softmax(1) | |
| return proba.argmax(1).item() | |
| def detect_attributes(image): | |
| age = predict(image, age_model, age_transforms) | |
| gender = predict(image, gender_model, gender_transforms) | |
| emotion = predict(image, emotion_model, emotion_transforms) | |
| action = predict(image, action_model, action_transforms) | |
| objects = object_detector(image) | |
| return { | |
| 'age': age_model.config.id2label[age], | |
| 'gender': gender_model.config.id2label[gender], | |
| 'emotion': emotion_model.config.id2label[emotion], | |
| 'action': action_model.config.id2label[action], | |
| 'objects': [obj['label'] for obj in objects] | |
| } | |
| def generate_prompt(attributes): | |
| prompt = f"A {attributes['age']} {attributes['gender']} person feeling {attributes['emotion']} " | |
| prompt += f"while {attributes['action']}. " | |
| if attributes['objects']: | |
| prompt += f"Surrounded by {', '.join(attributes['objects'])}. " | |
| return prompt | |
| def generate_emoji(prompt): | |
| inputs = clip_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) | |
| outputs = clip_model(**inputs) | |
| embedding = outputs.last_hidden_state.mean(dim=1) | |
| # Simple emoji mapping based on embedding features | |
| if embedding[0, 0] > 0: | |
| return emoji.emojize(":grinning_face:") | |
| elif embedding[0, 1] > 0: | |
| return emoji.emojize(":smiling_face_with_heart-eyes:") | |
| elif embedding[0, 2] > 0: | |
| return emoji.emojize(":face_with_tears_of_joy:") | |
| elif embedding[0, 3] > 0: | |
| return emoji.emojize(":thinking_face:") | |
| else: | |
| return emoji.emojize(":neutral_face:") | |
| st.title("Image Attribute Detection and Emoji Generation") | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption='Uploaded Image', use_column_width=True) | |
| if st.button('Analyze and Generate Emoji'): | |
| with st.spinner('Detecting attributes...'): | |
| attributes = detect_attributes(image) | |
| st.write("Detected Attributes:") | |
| for key, value in attributes.items(): | |
| st.write(f"{key.capitalize()}: {value}") | |
| with st.spinner('Generating prompt...'): | |
| prompt = generate_prompt(attributes) | |
| st.write("Generated Prompt:") | |
| st.write(prompt) | |
| with st.spinner('Generating emoji...'): | |
| emoji_result = generate_emoji(prompt) | |
| st.write("Generated Emoji:") | |
| st.write(emoji_result) |