Spaces:
Runtime error
Runtime error
| from huggingface_hub import cached_download, hf_hub_url | |
| from PIL import Image | |
| import os | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from torch import nn | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPModel | |
| def train_image_generation_model(image_folder, text_folder, model_name="image_generation_model"): | |
| """Trains an image generation model on the provided dataset. | |
| Args: | |
| image_folder (str): Path to the folder containing training images. | |
| text_folder (str): Path to the folder containing text prompts for each image. | |
| model_name (str, optional): Name for the saved model file. Defaults to "image_generation_model". | |
| Returns: | |
| str: Path to the saved model file. | |
| """ | |
| class ImageTextDataset(Dataset): | |
| def __init__(self, image_folder, text_folder, transform=None): | |
| self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
| self.text_paths = [os.path.join(text_folder, f) for f in os.listdir(text_folder) if f.lower().endswith('.txt')] | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| image = Image.open(self.image_paths[idx]).convert("RGB") | |
| if self.transform: | |
| image = self.transform(image) | |
| with open(self.text_paths[idx], 'r') as f: | |
| text = f.read().strip() | |
| return image, text | |
| # Load CLIP model | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
| # Define image and text transformations | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) | |
| ]) | |
| # Create dataset and dataloader | |
| dataset = ImageTextDataset(image_folder, text_folder, transform=transform) | |
| dataloader = DataLoader(dataset, batch_size=8, shuffle=True) | |
| # Define optimizer and loss function | |
| optimizer = torch.optim.Adam(clip_model.parameters(), lr=1e-5) | |
| loss_fn = nn.CrossEntropyLoss() | |
| # Train the model | |
| for epoch in range(10): | |
| for i, (images, texts) in enumerate(dataloader): | |
| optimizer.zero_grad() | |
| image_features = clip_model.get_image_features(images) | |
| text_features = clip_model.get_text_features(tokenizer(texts, return_tensors="pt")["input_ids"]) | |
| similarity = image_features @ text_features.T | |
| loss = loss_fn(similarity, torch.arange(images.size(0), device=images.device)) | |
| loss.backward() | |
| optimizer.step() | |
| print(f"Epoch: {epoch} | Iteration: {i} | Loss: {loss.item()}") | |
| # Save the trained model | |
| model_path = os.path.join(os.getcwd(), model_name + ".pt") | |
| torch.save(clip_model.state_dict(), model_path) | |
| return model_path | |
| # Define Gradio interface | |
| iface = gr.Interface( | |
| fn=train_image_generation_model, | |
| inputs=[ | |
| gr.File(label="Image Folder", file_count="directory"), | |
| gr.File(label="Text Prompts Folder", file_count="directory"), | |
| gr.Textbox(label="Model Name"), | |
| ], | |
| outputs=gr.File(label="Model File"), | |
| title="Image Generation Model Trainer", | |
| description="Upload a folder of images and their corresponding text prompts to train a model.\n Images foler should contain image files. Prompts folder should contain .txt files. Each text file is prompt for each image in images folder.", | |
| ) | |
| iface.launch(share=True) |