Spaces:
Runtime error
Runtime error
| import os | |
| from transformers import pipeline | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from langchain_core.prompts import PromptTemplate | |
| from PIL import Image | |
| class StoryGenerator: | |
| def __init__(self, image_model="Salesforce/blip-image-captioning-base"): | |
| self.image_model = image_model | |
| self.image_to_text = pipeline("image-to-text", model=self.image_model) | |
| self.text_models = { | |
| "Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2", | |
| "FLAN-T5": "google/flan-t5-large", | |
| "MPT-7B": "mosaicml/mpt-7b-instruct", | |
| "Falcon-7B": "tiiuae/falcon-7b-instruct" | |
| } | |
| self.prompt_template = PromptTemplate.from_template(""" | |
| You are a kids story writer. Provide a coherent story for kids | |
| using this simple instruction: {scenario}. The story should have a clear | |
| beginning, middle, and end. The story should be interesting and engaging for | |
| kids. The story should be maximum 200 words long. Do not include | |
| any adult or polemic content. | |
| Story: | |
| """) | |
| def get_llm(self, model_name): | |
| return HuggingFaceEndpoint( | |
| repo_id=self.text_models[model_name], | |
| temperature=0.5, | |
| streaming=True | |
| ) | |
| def img2txt(self, image_path): | |
| """Convert image to text using Hugging Face pipeline.""" | |
| text = self.image_to_text(image_path)[0]["generated_text"] | |
| print(f"Image caption: {text}") | |
| return text | |
| def generate_story(self, scenario, model_name): | |
| """Generate a story using image captioning and language model.""" | |
| llm = self.get_llm(model_name) | |
| story = self.prompt_template | llm | |
| generated_story = story.invoke( | |
| input={"scenario": scenario} | |
| ).strip().rstrip('</s>').strip() | |
| return generated_story | |
| def generate_story_from_image(self, image, model_name): | |
| """Generate a story from an image.""" | |
| print(f"Received image: {image}") | |
| print(f"Image type: {type(image)}") | |
| if isinstance(image, str): # If it's a file path | |
| temp_image_path = image | |
| else: # If it's a PIL Image object | |
| temp_image_path = "temp_image.jpg" | |
| image.save(temp_image_path) | |
| try: | |
| scenario = self.img2txt(temp_image_path) | |
| story = self.generate_story(scenario, model_name) | |
| finally: | |
| if temp_image_path != image and os.path.exists(temp_image_path): | |
| os.remove(temp_image_path) | |
| return story | |
| # Example usage | |
| if __name__ == "__main__": | |
| generator = StoryGenerator() | |
| example_image_path = os.path.join("assets", "image.jpg") | |
| if os.path.exists(example_image_path): | |
| story = generator.generate_story_from_image(example_image_path, "Mistral-7B") | |
| print(story) | |
| else: | |
| print(f"Example image not found at {example_image_path}") |