Spaces:
Runtime error
Runtime error
Commit
Β·
bf71575
1
Parent(s):
de9a113
misc
Browse files- app.py +3 -2
- config.py +1 -1
- last_epoch_ckpt/diffusion_pytorch_model.safetensors +1 -1
- lightning_app_deprecated.py +460 -0
app.py
CHANGED
|
@@ -106,14 +106,15 @@ def get_user_emb(embs, ys):
|
|
| 106 |
if len(positives) == 0:
|
| 107 |
positives = torch.zeros_like(im_emb)[None]
|
| 108 |
else:
|
| 109 |
-
|
|
|
|
| 110 |
positives = torch.stack(embs, 1)
|
| 111 |
|
| 112 |
negs = [e for e, ys in zip(embs, ys) if ys == 0]
|
| 113 |
if len(negs) == 0:
|
| 114 |
negatives = torch.zeros_like(im_emb)[None]
|
| 115 |
else:
|
| 116 |
-
negative_embs = random.sample(negs, min(4, len(negs))) + negs[-4:]
|
| 117 |
negatives = torch.stack(negative_embs, 1)
|
| 118 |
# if random.random() < .5:
|
| 119 |
# negatives = torch.zeros_like(negatives)
|
|
|
|
| 106 |
if len(positives) == 0:
|
| 107 |
positives = torch.zeros_like(im_emb)[None]
|
| 108 |
else:
|
| 109 |
+
# take last 8 TODO verify this is chronolgical; should be and also k-4 random ones.
|
| 110 |
+
embs = random.sample(positives, k=min(k-8, len(positives))) + positives[-8:]
|
| 111 |
positives = torch.stack(embs, 1)
|
| 112 |
|
| 113 |
negs = [e for e, ys in zip(embs, ys) if ys == 0]
|
| 114 |
if len(negs) == 0:
|
| 115 |
negatives = torch.zeros_like(im_emb)[None]
|
| 116 |
else:
|
| 117 |
+
negative_embs = random.sample(negs, min(k-4, len(negs))) + negs[-4:]
|
| 118 |
negatives = torch.stack(negative_embs, 1)
|
| 119 |
# if random.random() < .5:
|
| 120 |
# negatives = torch.zeros_like(negatives)
|
config.py
CHANGED
|
@@ -12,5 +12,5 @@ batch_size = 16
|
|
| 12 |
number_k_clip_embed = 16 # divide by this to determine bundling together of sequences -> CLIP
|
| 13 |
num_workers = 32
|
| 14 |
seed = 107
|
| 15 |
-
k =
|
| 16 |
# TODO config option to swap to diffusion?
|
|
|
|
| 12 |
number_k_clip_embed = 16 # divide by this to determine bundling together of sequences -> CLIP
|
| 13 |
num_workers = 32
|
| 14 |
seed = 107
|
| 15 |
+
k = 16
|
| 16 |
# TODO config option to swap to diffusion?
|
last_epoch_ckpt/diffusion_pytorch_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 136790920
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:33d7ca8a1d0f179ade0aa00cf9d622b0ac60ea2b58c79933a9212c54b5d6f719
|
| 3 |
size 136790920
|
lightning_app_deprecated.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
import glob
|
| 7 |
+
|
| 8 |
+
import config
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
from diffusers import EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL, AutoPipelineForText2Image
|
| 11 |
+
from transformers import CLIPVisionModelWithProjection
|
| 12 |
+
from safetensors.torch import load_file
|
| 13 |
+
|
| 14 |
+
from model import get_model_and_tokenizer
|
| 15 |
+
|
| 16 |
+
model, tokenizer = get_model_and_tokenizer(config.model_path, 'cuda', torch.bfloat16)
|
| 17 |
+
|
| 18 |
+
del model.kandinsky_pipe
|
| 19 |
+
del tokenizer
|
| 20 |
+
|
| 21 |
+
torch.set_float32_matmul_precision('high')
|
| 22 |
+
|
| 23 |
+
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 24 |
+
sdxl_lightening = "ByteDance/SDXL-Lightning"
|
| 25 |
+
ckpt = "sdxl_lightning_8step_unet.safetensors"
|
| 26 |
+
unet = UNet2DConditionModel.from_config(model_id, subfolder="unet", low_cpu_mem_usage=True, device_map='cuda').to(torch.float16)
|
| 27 |
+
unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt)))
|
| 28 |
+
|
| 29 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map='cuda')
|
| 30 |
+
pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder, low_cpu_mem_usage=True)
|
| 31 |
+
pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl.bin')))
|
| 32 |
+
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
| 33 |
+
pipe.register_modules(image_encoder = image_encoder)
|
| 34 |
+
pipe.set_ip_adapter_scale(0.8)
|
| 35 |
+
|
| 36 |
+
#pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
| 37 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
| 38 |
+
|
| 39 |
+
pipe.to(device='cuda').to(dtype=config.dtype)
|
| 40 |
+
output_hidden_state = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# TODO unify/merge origin and this
|
| 44 |
+
# TODO save & restart from (if it exists) dataframe parquet
|
| 45 |
+
|
| 46 |
+
device = "cuda"
|
| 47 |
+
|
| 48 |
+
k = config.k
|
| 49 |
+
|
| 50 |
+
import spaces
|
| 51 |
+
import matplotlib.pyplot as plt
|
| 52 |
+
|
| 53 |
+
import os
|
| 54 |
+
import gradio as gr
|
| 55 |
+
import pandas as pd
|
| 56 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
| 57 |
+
|
| 58 |
+
import random
|
| 59 |
+
import time
|
| 60 |
+
from PIL import Image
|
| 61 |
+
# from safety_checker_improved import maybe_nsfw
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
torch.set_grad_enabled(False)
|
| 65 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 66 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 67 |
+
|
| 68 |
+
prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id', 'text', 'gemb'])
|
| 69 |
+
|
| 70 |
+
import spaces
|
| 71 |
+
start_time = time.time()
|
| 72 |
+
|
| 73 |
+
####################### Setup Model
|
| 74 |
+
from diffusers import EulerDiscreteScheduler
|
| 75 |
+
from PIL import Image
|
| 76 |
+
import uuid
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@spaces.GPU()
|
| 80 |
+
def generate_gpu(in_im_embs, prompt='the scene'):
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
in_im_embs = in_im_embs.to('cuda')
|
| 83 |
+
|
| 84 |
+
negative_image_embeds = in_im_embs[0]# if random.random() < .3 else model.prior_pipe.get_zero_embed()
|
| 85 |
+
positive_image_embeds = in_im_embs[1]
|
| 86 |
+
|
| 87 |
+
in_im_embs = in_im_embs.to('cuda').view(2, 1, -1)
|
| 88 |
+
images = pipe(prompt=prompt, guidance_scale=4, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=8).images[0]
|
| 89 |
+
im_emb, _ = pipe.encode_image(
|
| 90 |
+
images, 'cuda', 1, output_hidden_state
|
| 91 |
+
)
|
| 92 |
+
im_emb = im_emb.detach().to('cpu').to(torch.float32)
|
| 93 |
+
return images, im_emb
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def generate(in_im_embs, ):
|
| 97 |
+
output, im_emb = generate_gpu(in_im_embs)
|
| 98 |
+
nsfw = False#maybe_nsfw(output.images[0])
|
| 99 |
+
|
| 100 |
+
name = str(uuid.uuid4()).replace("-", "")
|
| 101 |
+
path = f"/tmp/{name}.png"
|
| 102 |
+
|
| 103 |
+
if nsfw:
|
| 104 |
+
gr.Warning("NSFW content detected.")
|
| 105 |
+
# TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
|
| 106 |
+
return None, im_emb
|
| 107 |
+
|
| 108 |
+
output.save(path)
|
| 109 |
+
return path, im_emb
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
#######################
|
| 113 |
+
|
| 114 |
+
@spaces.GPU()
|
| 115 |
+
def sample_embs(prompt_embeds):
|
| 116 |
+
latent = torch.randn(prompt_embeds.shape[0], 1, prompt_embeds.shape[-1])
|
| 117 |
+
if prompt_embeds.shape[1] < k:
|
| 118 |
+
prompt_embeds = torch.nn.functional.pad(prompt_embeds, [0, 0, 0, k-prompt_embeds.shape[1]])
|
| 119 |
+
assert prompt_embeds.shape[1] == k, f"The model is set to take `k`` cond image embeds but is shape {prompt_embeds.shape}"
|
| 120 |
+
image_embeds = model(latent.to('cuda'), prompt_embeds.to('cuda')).predicted_image_embedding
|
| 121 |
+
return image_embeds
|
| 122 |
+
|
| 123 |
+
@spaces.GPU()
|
| 124 |
+
def get_user_emb(embs, ys):
|
| 125 |
+
positives = [e for e, ys in zip(embs, ys) if ys == 1]
|
| 126 |
+
if len(positives) == 0:
|
| 127 |
+
positives = torch.zeros_like(im_emb)[None]
|
| 128 |
+
else:
|
| 129 |
+
embs = random.sample(positives, min(k-4, len(positives))) + positives[-4:]
|
| 130 |
+
positives = torch.stack(embs, 1)
|
| 131 |
+
|
| 132 |
+
negs = [e for e, ys in zip(embs, ys) if ys == 0]
|
| 133 |
+
if len(negs) == 0:
|
| 134 |
+
negatives = torch.zeros_like(im_emb)[None]
|
| 135 |
+
else:
|
| 136 |
+
negative_embs = random.sample(negs, min(k-4, len(negs))) + negs[-4:]
|
| 137 |
+
negatives = torch.stack(negative_embs, 1)
|
| 138 |
+
# if random.random() < .5:
|
| 139 |
+
# negatives = torch.zeros_like(negatives)
|
| 140 |
+
|
| 141 |
+
image_embeds = torch.stack([sample_embs(negatives), sample_embs(positives)])
|
| 142 |
+
|
| 143 |
+
return image_embeds
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def background_next_image():
|
| 147 |
+
global prevs_df
|
| 148 |
+
# only let it get N (maybe 3) ahead of the user
|
| 149 |
+
#not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
| 150 |
+
rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
|
| 151 |
+
if len(rated_rows) < 4:
|
| 152 |
+
time.sleep(.1)
|
| 153 |
+
# not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
| 154 |
+
return
|
| 155 |
+
|
| 156 |
+
user_id_list = set(rated_rows['latest_user_to_rate'].to_list())
|
| 157 |
+
for uid in user_id_list:
|
| 158 |
+
rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is not None for i in prevs_df.iterrows()]]
|
| 159 |
+
not_rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is None for i in prevs_df.iterrows()]]
|
| 160 |
+
|
| 161 |
+
# we need to intersect not_rated_rows from this user's embed > 7. Just add a new column on which user_id spawned the
|
| 162 |
+
# media.
|
| 163 |
+
|
| 164 |
+
unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == uid for i in not_rated_rows.iterrows()]]
|
| 165 |
+
|
| 166 |
+
# we don't compute more after n are in the queue for them
|
| 167 |
+
if len(unrated_from_user) >= 10:
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
if len(rated_rows) < 4:
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
global glob_idx
|
| 174 |
+
glob_idx += 1
|
| 175 |
+
|
| 176 |
+
ems = rated_rows['embeddings'].to_list()
|
| 177 |
+
ys = [i[uid][0] for i in rated_rows['user:rating'].to_list()]
|
| 178 |
+
|
| 179 |
+
emz = get_user_emb(ems, ys)
|
| 180 |
+
img, embs = generate(emz)
|
| 181 |
+
|
| 182 |
+
if img:
|
| 183 |
+
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb'])
|
| 184 |
+
tmp_df['paths'] = [img]
|
| 185 |
+
tmp_df['embeddings'] = [embs.to(torch.float32).to('cpu')]
|
| 186 |
+
tmp_df['user:rating'] = [{' ': ' '}]
|
| 187 |
+
tmp_df['from_user_id'] = [uid]
|
| 188 |
+
tmp_df['text'] = ['']
|
| 189 |
+
prevs_df = pd.concat((prevs_df, tmp_df))
|
| 190 |
+
# we can free up storage by deleting the image
|
| 191 |
+
if len(prevs_df) > 500:
|
| 192 |
+
oldest_path = prevs_df.iloc[6]['paths']
|
| 193 |
+
if os.path.isfile(oldest_path):
|
| 194 |
+
os.remove(oldest_path)
|
| 195 |
+
else:
|
| 196 |
+
# If it fails, inform the user.
|
| 197 |
+
print("Error: %s file not found" % oldest_path)
|
| 198 |
+
# only keep 50 images & embeddings & ips, then remove oldest besides calibrating
|
| 199 |
+
prevs_df = pd.concat((prevs_df.iloc[:6], prevs_df.iloc[7:]))
|
| 200 |
+
|
| 201 |
+
def pluck_img(user_id):
|
| 202 |
+
# TODO pluck images based on similarity but also based on diversity by cluster every few times.
|
| 203 |
+
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) is not None for i in prevs_df.iterrows()]]
|
| 204 |
+
ems = rated_rows['embeddings'].to_list()
|
| 205 |
+
ys = [i[user_id][0] for i in rated_rows['user:rating'].to_list()]
|
| 206 |
+
user_emb = get_user_emb(ems, ys)
|
| 207 |
+
|
| 208 |
+
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
| 209 |
+
while len(not_rated_rows) == 0:
|
| 210 |
+
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
| 211 |
+
time.sleep(.1)
|
| 212 |
+
# TODO optimize this lol
|
| 213 |
+
|
| 214 |
+
# NOTE could opt for only showing their own or prioritizing their own media.
|
| 215 |
+
unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == user_id for i in not_rated_rows.iterrows()]]
|
| 216 |
+
|
| 217 |
+
best_sim = -10000000
|
| 218 |
+
for i in not_rated_rows.iterrows():
|
| 219 |
+
# TODO sloppy .to but it is 3am.
|
| 220 |
+
sim = torch.cosine_similarity(i[1]['embeddings'].detach().to('cpu'), user_emb.detach().to('cpu'), -1)
|
| 221 |
+
if len(sim) > 1: sim = sim[1]
|
| 222 |
+
if sim.squeeze() > best_sim:
|
| 223 |
+
best_sim = sim
|
| 224 |
+
best_row = i[1]
|
| 225 |
+
img = best_row['paths']
|
| 226 |
+
return img
|
| 227 |
+
|
| 228 |
+
def next_image(calibrate_prompts, user_id):
|
| 229 |
+
with torch.no_grad():
|
| 230 |
+
# once we've done so many random calibration prompts out of the full media
|
| 231 |
+
if len(m_calibrate) - len(calibrate_prompts) < 5:
|
| 232 |
+
cal_video = calibrate_prompts.pop(random.randint(0, len(calibrate_prompts)-1))
|
| 233 |
+
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
| 234 |
+
# we switch to just getting media by similarity.
|
| 235 |
+
else:
|
| 236 |
+
image = pluck_img(user_id)
|
| 237 |
+
return image, calibrate_prompts
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def start(_, calibrate_prompts, user_id, request: gr.Request):
|
| 245 |
+
user_id = int(str(time.time())[-7:].replace('.', ''))
|
| 246 |
+
image, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
| 247 |
+
return [
|
| 248 |
+
gr.Button(value='π', interactive=True),
|
| 249 |
+
gr.Button(value='Neither (Space)', interactive=True, visible=False),
|
| 250 |
+
gr.Button(value='π', interactive=True),
|
| 251 |
+
gr.Button(value='Start', interactive=False),
|
| 252 |
+
gr.Button(value='π Content', interactive=True, visible=False),
|
| 253 |
+
gr.Button(value='π Style', interactive=True, visible=False),
|
| 254 |
+
image,
|
| 255 |
+
calibrate_prompts,
|
| 256 |
+
user_id,
|
| 257 |
+
]
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
|
| 261 |
+
global prevs_df
|
| 262 |
+
|
| 263 |
+
if choice == 'π':
|
| 264 |
+
choice = [1, 1]
|
| 265 |
+
elif choice == 'Neither (Space)':
|
| 266 |
+
img, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
| 267 |
+
return img, calibrate_prompts
|
| 268 |
+
elif choice == 'π':
|
| 269 |
+
choice = [0, 0]
|
| 270 |
+
elif choice == 'π Style':
|
| 271 |
+
choice = [0, 1]
|
| 272 |
+
elif choice == 'π Content':
|
| 273 |
+
choice = [1, 0]
|
| 274 |
+
else:
|
| 275 |
+
assert False, f'choice is {choice}'
|
| 276 |
+
|
| 277 |
+
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
| 278 |
+
# TODO skip allowing rating & just continue
|
| 279 |
+
if img is None:
|
| 280 |
+
print('NSFW -- choice is disliked')
|
| 281 |
+
choice = [0, 0]
|
| 282 |
+
|
| 283 |
+
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
| 284 |
+
# if it's still in the dataframe, add the choice
|
| 285 |
+
if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
|
| 286 |
+
prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
|
| 287 |
+
prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
|
| 288 |
+
else:
|
| 289 |
+
print('Image apparently removed', img)
|
| 290 |
+
img, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
| 291 |
+
return img, calibrate_prompts
|
| 292 |
+
|
| 293 |
+
css = '''.gradio-container{max-width: 700px !important}
|
| 294 |
+
#description{text-align: center}
|
| 295 |
+
#description h1, #description h3{display: block}
|
| 296 |
+
#description p{margin-top: 0}
|
| 297 |
+
.fade-in-out {animation: fadeInOut 3s forwards}
|
| 298 |
+
@keyframes fadeInOut {
|
| 299 |
+
0% {
|
| 300 |
+
background: var(--bg-color);
|
| 301 |
+
}
|
| 302 |
+
100% {
|
| 303 |
+
background: var(--button-secondary-background-fill);
|
| 304 |
+
}
|
| 305 |
+
}
|
| 306 |
+
'''
|
| 307 |
+
js_head = '''
|
| 308 |
+
<script>
|
| 309 |
+
document.addEventListener('keydown', function(event) {
|
| 310 |
+
if (event.key === 'a' || event.key === 'A') {
|
| 311 |
+
// Trigger click on 'dislike' if 'A' is pressed
|
| 312 |
+
document.getElementById('dislike').click();
|
| 313 |
+
} else if (event.key === ' ' || event.keyCode === 32) {
|
| 314 |
+
// Trigger click on 'neither' if Spacebar is pressed
|
| 315 |
+
document.getElementById('neither').click();
|
| 316 |
+
} else if (event.key === 'l' || event.key === 'L') {
|
| 317 |
+
// Trigger click on 'like' if 'L' is pressed
|
| 318 |
+
document.getElementById('like').click();
|
| 319 |
+
}
|
| 320 |
+
});
|
| 321 |
+
function fadeInOut(button, color) {
|
| 322 |
+
button.style.setProperty('--bg-color', color);
|
| 323 |
+
button.classList.remove('fade-in-out');
|
| 324 |
+
void button.offsetWidth; // This line forces a repaint by accessing a DOM property
|
| 325 |
+
|
| 326 |
+
button.classList.add('fade-in-out');
|
| 327 |
+
button.addEventListener('animationend', () => {
|
| 328 |
+
button.classList.remove('fade-in-out'); // Reset the animation state
|
| 329 |
+
}, {once: true});
|
| 330 |
+
}
|
| 331 |
+
document.body.addEventListener('click', function(event) {
|
| 332 |
+
const target = event.target;
|
| 333 |
+
if (target.id === 'dislike') {
|
| 334 |
+
fadeInOut(target, '#ff1717');
|
| 335 |
+
} else if (target.id === 'like') {
|
| 336 |
+
fadeInOut(target, '#006500');
|
| 337 |
+
} else if (target.id === 'neither') {
|
| 338 |
+
fadeInOut(target, '#cccccc');
|
| 339 |
+
}
|
| 340 |
+
});
|
| 341 |
+
|
| 342 |
+
</script>
|
| 343 |
+
'''
|
| 344 |
+
|
| 345 |
+
with gr.Blocks(head=js_head, css=css) as demo:
|
| 346 |
+
gr.Markdown('''# The Other Tiger
|
| 347 |
+
### Generative Recommenders for Exporation of Possible Images
|
| 348 |
+
|
| 349 |
+
Explore the latent space using binary feedback.
|
| 350 |
+
|
| 351 |
+
[rynmurdock.github.io](https://rynmurdock.github.io/)
|
| 352 |
+
''', elem_id="description")
|
| 353 |
+
user_id = gr.State()
|
| 354 |
+
# calibration videos -- this is a misnomer now :D
|
| 355 |
+
calibrate_prompts = gr.State( glob.glob('image_init/*') )
|
| 356 |
+
def l():
|
| 357 |
+
return None
|
| 358 |
+
|
| 359 |
+
with gr.Row(elem_id='output-image'):
|
| 360 |
+
img = gr.Image(
|
| 361 |
+
label='Lightning',
|
| 362 |
+
interactive=False,
|
| 363 |
+
elem_id="output_im",
|
| 364 |
+
type='filepath',
|
| 365 |
+
height=700,
|
| 366 |
+
width=700,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
with gr.Row(equal_height=True):
|
| 372 |
+
b3 = gr.Button(value='π', interactive=False, elem_id="dislike")
|
| 373 |
+
|
| 374 |
+
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
|
| 375 |
+
|
| 376 |
+
b1 = gr.Button(value='π', interactive=False, elem_id="like")
|
| 377 |
+
with gr.Row(equal_height=True):
|
| 378 |
+
b6 = gr.Button(value='π Style', interactive=False, elem_id="dislike like", visible=False)
|
| 379 |
+
|
| 380 |
+
b5 = gr.Button(value='π Content', interactive=False, elem_id="like dislike", visible=False)
|
| 381 |
+
|
| 382 |
+
b1.click(
|
| 383 |
+
choose,
|
| 384 |
+
[img, b1, calibrate_prompts, user_id],
|
| 385 |
+
[img, calibrate_prompts, ],
|
| 386 |
+
)
|
| 387 |
+
b2.click(
|
| 388 |
+
choose,
|
| 389 |
+
[img, b2, calibrate_prompts, user_id],
|
| 390 |
+
[img, calibrate_prompts, ],
|
| 391 |
+
)
|
| 392 |
+
b3.click(
|
| 393 |
+
choose,
|
| 394 |
+
[img, b3, calibrate_prompts, user_id],
|
| 395 |
+
[img, calibrate_prompts, ],
|
| 396 |
+
)
|
| 397 |
+
b5.click(
|
| 398 |
+
choose,
|
| 399 |
+
[img, b5, calibrate_prompts, user_id],
|
| 400 |
+
[img, calibrate_prompts, ],
|
| 401 |
+
)
|
| 402 |
+
b6.click(
|
| 403 |
+
choose,
|
| 404 |
+
[img, b6, calibrate_prompts, user_id],
|
| 405 |
+
[img, calibrate_prompts, ],
|
| 406 |
+
)
|
| 407 |
+
with gr.Row():
|
| 408 |
+
b4 = gr.Button(value='Start')
|
| 409 |
+
b4.click(start,
|
| 410 |
+
[b4, calibrate_prompts, user_id],
|
| 411 |
+
[b1, b2, b3, b4, b5, b6, img, calibrate_prompts, user_id, ]
|
| 412 |
+
)
|
| 413 |
+
with gr.Row():
|
| 414 |
+
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several images and then roam. When your media is generating, you may encounter others'.</ div><br><br><br>
|
| 415 |
+
|
| 416 |
+
<br><br>
|
| 417 |
+
<div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback.
|
| 418 |
+
</ div>''')
|
| 419 |
+
|
| 420 |
+
# TODO quiet logging
|
| 421 |
+
scheduler = BackgroundScheduler()
|
| 422 |
+
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.2)
|
| 423 |
+
scheduler.start()
|
| 424 |
+
|
| 425 |
+
# TODO shouldn't call this before gradio launch, yeah?
|
| 426 |
+
@spaces.GPU()
|
| 427 |
+
def encode_space(x):
|
| 428 |
+
im = (
|
| 429 |
+
model.prior_pipe.image_processor(x, return_tensors="pt")
|
| 430 |
+
.pixel_values[0]
|
| 431 |
+
.unsqueeze(0)
|
| 432 |
+
.to(dtype=model.prior_pipe.image_encoder.dtype, device=device)
|
| 433 |
+
)
|
| 434 |
+
im_emb = model.prior_pipe.image_encoder(im)["image_embeds"]
|
| 435 |
+
return im_emb.detach().to('cpu').to(torch.float32)
|
| 436 |
+
|
| 437 |
+
# NOTE:
|
| 438 |
+
# media is moved into a random tmp folder so we need to parse filenames carefully.
|
| 439 |
+
# do not have any cases where a file name is the same or could be `in` another filename
|
| 440 |
+
# you also maybe can't use jpegs lmao
|
| 441 |
+
|
| 442 |
+
# prep our calibration videos
|
| 443 |
+
m_calibrate = glob.glob('image_init/*')
|
| 444 |
+
for im in m_calibrate:
|
| 445 |
+
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb', 'from_user_id'])
|
| 446 |
+
tmp_df['paths'] = [im]
|
| 447 |
+
image = Image.open(im).convert('RGB')
|
| 448 |
+
im_emb = encode_space(image)
|
| 449 |
+
|
| 450 |
+
tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
|
| 451 |
+
tmp_df['user:rating'] = [{' ': ' '}]
|
| 452 |
+
tmp_df['text'] = ['']
|
| 453 |
+
|
| 454 |
+
# seems to break things...
|
| 455 |
+
tmp_df['from_user_id'] = [0]
|
| 456 |
+
tmp_df['latest_user_to_rate'] = [0]
|
| 457 |
+
prevs_df = pd.concat((prevs_df, tmp_df))
|
| 458 |
+
|
| 459 |
+
glob_idx = 0
|
| 460 |
+
demo.launch(share=True,)
|