Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torchvision | |
| import clip | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import gradio as gr | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model_name = 'ViT-B/16' #@param ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16'] | |
| model, preprocess = clip.load(model_name) | |
| model.to(DEVICE).eval() | |
| resolution = model.visual.input_resolution | |
| resizer = torchvision.transforms.Resize(size=(resolution, resolution)) | |
| def create_rgb_tensor(color): | |
| """color is e.g. [1,0,0]""" | |
| return torch.tensor(color, device=DEVICE).reshape((1, 3, 1, 1)) | |
| def encode_color(color): | |
| """color is e.g. [1,0,0]""" | |
| rgb = create_rgb_tensor(color) | |
| return model.encode_image( resizer(rgb) ) | |
| def encode_text(text): | |
| tokenized_text = clip.tokenize(text).to(DEVICE) | |
| return model.encode_text(tokenized_text) | |
| class RGBModel(torch.nn.Module): | |
| def __init__(self, device): | |
| # Call nn.Module.__init__() to instantiate typical torch.nn.Module stuff | |
| super(RGBModel, self).__init__() | |
| self.color = torch.nn.Parameter(torch.ones((1, 3, 1, 1), device=device) / 2) | |
| def forward(self): | |
| # Clamp numbers to the closed interval [0,1] | |
| self.color.data = self.color.data.clamp(0,1) | |
| return self.color | |
| text_input = gr.inputs.Textbox(lines=1, label="Text Prompt", default='A solid red square') | |
| steps_input = gr.inputs.Slider(minimum=1, maximum=100, step=1, default=11, label="Training Steps") | |
| lr_input = gr.inputs.Number(default=0.06, label="Adam Optimizer Learning Rate") | |
| decay_input = gr.inputs.Number(default=0.01, label="Adam Optimizer Weight Decay") | |
| def gradio_fn(text_prompt, adam_learning_rate, adam_weight_decay, n_iterations=50): | |
| rgb_model = RGBModel(device=DEVICE) | |
| opt = torch.optim.AdamW([rgb_model()], lr=adam_learning_rate, weight_decay=adam_weight_decay) | |
| with torch.no_grad(): | |
| tokenized_text = clip.tokenize(text_prompt).to(DEVICE) | |
| target_embedding = model.encode_text(tokenized_text).detach().clone() | |
| def training_step(): | |
| opt.zero_grad() | |
| color = rgb_model() | |
| color_img = resizer(color) | |
| image_embedding = model.encode_image(color_img) | |
| loss = -1 * torch.cosine_similarity(target_embedding, image_embedding, dim=-1) | |
| loss.backward() | |
| opt.step() | |
| steps = [] | |
| steps.append(rgb_model().cpu().detach().numpy()) | |
| for iteration in range(n_iterations): | |
| training_step() | |
| steps.append(rgb_model().cpu().detach().numpy()) | |
| steps = np.stack([steps]) | |
| img_train = Image.fromarray((steps[:,:,0,:,0,0] * 255).astype(np.uint8)).resize((400, 100), 0) | |
| return img_train | |
| iface = gr.Interface( fn=gradio_fn, inputs=[text_input, lr_input, decay_input, steps_input], outputs="image") | |
| iface.launch() |