Spaces:
Running
Running
| import json | |
| from collections import defaultdict | |
| import safetensors | |
| import timm | |
| from transformers import AutoProcessor | |
| import gradio as gr | |
| import torch | |
| import time | |
| from florence2_implementation.modeling_florence2 import Florence2ForConditionalGeneration | |
| from torchvision.transforms import InterpolationMode | |
| from PIL import Image | |
| import torchvision.transforms.functional as TF | |
| from torchvision.transforms import transforms | |
| import random | |
| import csv | |
| import os | |
| torch.set_grad_enabled(False) | |
| # HF now (Feb 20, 2025) imposes a storage limit of 1GB. Will have to pull JTP from other places. | |
| os.system("wget -nv https://huggingface.co/RedRocket/JointTaggerProject/resolve/main/JTP_PILOT2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors") | |
| category_id_to_str = { | |
| "0": "general", | |
| # 3 copyright | |
| "4": "character", | |
| "5": "species", | |
| "7": "meta", | |
| "8": "lore", | |
| "1": "artist", | |
| } | |
| class Pruner: | |
| def __init__(self, path_to_tag_list_csv): | |
| species_tags = set() | |
| allowed_tags = set() | |
| with open(path_to_tag_list_csv, "r") as f: | |
| reader = csv.reader(f) | |
| header = next(reader) | |
| name_index = header.index("name") | |
| category_index = header.index("category") | |
| post_count_index = header.index("post_count") | |
| for row in reader: | |
| if int(row[post_count_index]) > 20: | |
| category = row[category_index] | |
| name = row[name_index] | |
| if category == "5": | |
| species_tags.add(name) | |
| allowed_tags.add(name) | |
| elif category == "0": | |
| allowed_tags.add(name) | |
| elif category == "7": | |
| allowed_tags.add(name) | |
| self.species_tags = species_tags | |
| self.allowed_tags = allowed_tags | |
| def _prune_not_allowed_tags(self, raw_tags): | |
| this_allowed_tags = set() | |
| for tag in raw_tags: | |
| if tag in self.allowed_tags: | |
| this_allowed_tags.add(tag) | |
| return this_allowed_tags | |
| def _find_and_format_species_tags(self, tag_set): | |
| this_specie_tags = [] | |
| for tag in tag_set: | |
| if tag in self.species_tags: | |
| this_specie_tags.append(tag) | |
| formatted_tags = f"species: {' '.join([t for t in this_specie_tags])}\n" | |
| return formatted_tags, this_specie_tags | |
| def prompt_construction_pipeline_florence2(self, tags, length): | |
| if type(tags) is str: | |
| tags = tags.split(" ") | |
| random.shuffle(tags) | |
| tags = self._prune_not_allowed_tags(tags, ) | |
| formatted_species_tags, this_specie_tags = self._find_and_format_species_tags(tags) | |
| non_species_tags = [t for t in tags if t not in this_specie_tags] | |
| prompt = f"{' '.join(non_species_tags)}\n{formatted_species_tags}\nlength: {length}\n\nSTYLE1 FURRY CAPTION:" | |
| return prompt | |
| class Fit(torch.nn.Module): | |
| def __init__( | |
| self, | |
| bounds: tuple[int, int] | int, | |
| interpolation=InterpolationMode.LANCZOS, | |
| grow: bool = True, | |
| pad: float | None = None | |
| ): | |
| super().__init__() | |
| self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds | |
| self.interpolation = interpolation | |
| self.grow = grow | |
| self.pad = pad | |
| def forward(self, img: Image) -> Image: | |
| wimg, himg = img.size | |
| hbound, wbound = self.bounds | |
| hscale = hbound / himg | |
| wscale = wbound / wimg | |
| if not self.grow: | |
| hscale = min(hscale, 1.0) | |
| wscale = min(wscale, 1.0) | |
| scale = min(hscale, wscale) | |
| if scale == 1.0: | |
| return img | |
| hnew = min(round(himg * scale), hbound) | |
| wnew = min(round(wimg * scale), wbound) | |
| img = TF.resize(img, (hnew, wnew), self.interpolation) | |
| if self.pad is None: | |
| return img | |
| hpad = hbound - hnew | |
| wpad = wbound - wnew | |
| tpad = hpad // 2 | |
| bpad = hpad - tpad | |
| lpad = wpad // 2 | |
| rpad = wpad - lpad | |
| return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad) | |
| def __repr__(self) -> str: | |
| return ( | |
| f"{self.__class__.__name__}(" + | |
| f"bounds={self.bounds}, " + | |
| f"interpolation={self.interpolation.value}, " + | |
| f"grow={self.grow}, " + | |
| f"pad={self.pad})" | |
| ) | |
| class CompositeAlpha(torch.nn.Module): | |
| def __init__( | |
| self, | |
| background: tuple[float, float, float] | float, | |
| ): | |
| super().__init__() | |
| self.background = (background, background, background) if isinstance(background, float) else background | |
| self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2) | |
| def forward(self, img: torch.Tensor) -> torch.Tensor: | |
| if img.shape[-3] == 3: | |
| return img | |
| alpha = img[..., 3, None, :, :] | |
| img[..., :3, :, :] *= alpha | |
| background = self.background.expand(-1, img.shape[-2], img.shape[-1]) | |
| if background.ndim == 1: | |
| background = background[:, None, None] | |
| elif background.ndim == 2: | |
| background = background[None, :, :] | |
| img[..., :3, :, :] += (1.0 - alpha) * background | |
| return img[..., :3, :, :] | |
| def __repr__(self) -> str: | |
| return ( | |
| f"{self.__class__.__name__}(" + | |
| f"background={self.background})" | |
| ) | |
| class GatedHead(torch.nn.Module): | |
| def __init__(self, | |
| num_features: int, | |
| num_classes: int | |
| ): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.linear = torch.nn.Linear(num_features, num_classes * 2) | |
| self.act = torch.nn.Sigmoid() | |
| self.gate = torch.nn.Sigmoid() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.linear(x) | |
| x = self.act(x[:, :self.num_classes]) * self.gate(x[:, self.num_classes:]) | |
| return x | |
| model_id = "lodestone-horizon/furrence2-large" | |
| model = Florence2ForConditionalGeneration.from_pretrained(model_id,).eval() | |
| processor = AutoProcessor.from_pretrained("./florence2_implementation/", trust_remote_code=True) | |
| tree = defaultdict(list) | |
| with open('tag_implications-2024-05-05.csv', 'rt') as csvfile: | |
| reader = csv.DictReader(csvfile) | |
| for row in reader: | |
| if row["status"] == "active": | |
| tree[row["consequent_name"]].append(row["antecedent_name"]) | |
| title = """<h1 align="center">Furrence2 Captioner Demo</h1>""" | |
| description=( | |
| """<br> The captioner is being prompted by JTP Pilot2 tagger. You may use hand-curated tags to get better results. </a> | |
| <br> This demo is running on CPU. For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.</a>""" | |
| ) | |
| tagger_transform = transforms.Compose([ | |
| Fit((384, 384)), | |
| transforms.ToTensor(), | |
| CompositeAlpha(0.5), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| transforms.CenterCrop((384, 384)), | |
| ]) | |
| THRESHOLD = 0.2 | |
| tagger_model = timm.create_model( | |
| "vit_so400m_patch14_siglip_384.webli", | |
| pretrained=False, | |
| num_classes=9083, | |
| ) # type: VisionTransformer | |
| tagger_model.head = GatedHead(min(tagger_model.head.weight.shape), 9083) | |
| safetensors.torch.load_model(tagger_model, "JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors") | |
| tagger_model.eval() | |
| with open("JTP_PILOT2_tags.json", "r") as file: | |
| tags = json.load(file) # type: dict | |
| allowed_tags = list(tags.keys()) | |
| for idx, tag in enumerate(allowed_tags): | |
| allowed_tags[idx] = tag | |
| pruner = Pruner("tags-2024-05-05.csv") | |
| def generate_prompt(image, expected_caption_length): | |
| global THRESHOLD, tree, tokenizer, model, tagger_model, tagger_transform | |
| tagger_input = tagger_transform(image.convert('RGBA')).unsqueeze(0) | |
| probabilities = tagger_model(tagger_input) | |
| for prob in probabilities: | |
| indices = torch.where(prob > THRESHOLD)[0] | |
| sorted_indices = torch.argsort(prob[indices], descending=True) | |
| final_tags = [] | |
| for i in sorted_indices: | |
| final_tags.append(allowed_tags[indices[i]]) | |
| final_tags = " ".join(final_tags) | |
| task_prompt = pruner.prompt_construction_pipeline_florence2(final_tags, expected_caption_length) | |
| return task_prompt | |
| def inference_caption(image, expected_caption_length, seq_len=512,): | |
| start_time = time.time() | |
| prompt_input = generate_prompt(image, expected_caption_length) | |
| end_time = time.time() | |
| execution_time = end_time - start_time | |
| print(f"Finished tagging in {execution_time:.3f} seconds") | |
| try: | |
| pixel_values = processor.image_processor(image, return_tensors="pt", )["pixel_values"] | |
| encoder_inputs = processor.tokenizer( | |
| text=prompt_input, | |
| return_tensors="pt", | |
| # padding = "max_length", | |
| # truncation = True, | |
| # max_length = 256, | |
| # don't add these; these will cause problems when doing inference | |
| ) | |
| start_time = time.time() | |
| generated_ids = model.generate( | |
| input_ids=encoder_inputs["input_ids"], | |
| attention_mask=encoder_inputs["attention_mask"], | |
| pixel_values=pixel_values, | |
| max_new_tokens=seq_len, | |
| early_stopping=False, | |
| do_sample=False, | |
| num_beams=3, | |
| ) | |
| end_time = time.time() | |
| execution_time = end_time - start_time | |
| print(f"Finished captioning in {execution_time:.3f} seconds") | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return generated_text | |
| except Exception as e: | |
| print("error message:", e) | |
| return "An error occurred." | |
| def main(): | |
| with gr.Blocks() as iface: | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil") | |
| seq_len = gr.Number( | |
| value=512, label="Output Cutoff Length", precision=0, | |
| interactive=True | |
| ) | |
| expected_length = gr.Number(minimum=50, maximum=200, | |
| value=100, label="Expected Caption Length", precision=0, | |
| interactive=True | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Column(): | |
| caption_button = gr.Button( | |
| value="Caption it!", interactive=True, variant="primary", | |
| ) | |
| caption_output = gr.Textbox(lines=1, label="Caption Output") | |
| caption_button.click( | |
| inference_caption, | |
| [ | |
| image_input, | |
| expected_length, | |
| seq_len, | |
| ], | |
| [caption_output,], | |
| ) | |
| iface.launch(share=False) | |
| if __name__ == "__main__": | |
| main() | |