|
|
from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM |
|
|
from .configuration_deepseek_v2 import DeepseekV2Config |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
from typing import List, Optional, Tuple, Union |
|
|
from transformers.cache_utils import Cache |
|
|
import requests |
|
|
from PIL import Image, ImageOps, ImageDraw, ImageFont |
|
|
from io import BytesIO |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import CrossEntropyLoss |
|
|
from torchvision import transforms |
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
import os |
|
|
from .deepencoder import build_sam_vit_b, build_clip_l, MlpProjector |
|
|
from addict import Dict |
|
|
from transformers import TextStreamer |
|
|
from .conversation import get_conv_template |
|
|
from abc import ABC |
|
|
import math |
|
|
import re |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
import time |
|
|
|
|
|
|
|
|
def load_image(image_path): |
|
|
|
|
|
try: |
|
|
image = Image.open(image_path) |
|
|
|
|
|
corrected_image = ImageOps.exif_transpose(image) |
|
|
|
|
|
return corrected_image |
|
|
|
|
|
except Exception as e: |
|
|
print(f"error: {e}") |
|
|
try: |
|
|
return Image.open(image_path) |
|
|
except: |
|
|
return None |
|
|
|
|
|
|
|
|
def re_match(text): |
|
|
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' |
|
|
matches = re.findall(pattern, text, re.DOTALL) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mathes_image = [] |
|
|
mathes_other = [] |
|
|
for a_match in matches: |
|
|
if '<|ref|>image<|/ref|>' in a_match[0]: |
|
|
mathes_image.append(a_match[0]) |
|
|
else: |
|
|
mathes_other.append(a_match[0]) |
|
|
return matches, mathes_image, mathes_other |
|
|
|
|
|
|
|
|
def extract_coordinates_and_label(ref_text, image_width, image_height): |
|
|
|
|
|
try: |
|
|
label_type = ref_text[1] |
|
|
cor_list = eval(ref_text[2]) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
return None |
|
|
|
|
|
return (label_type, cor_list) |
|
|
|
|
|
|
|
|
def draw_bounding_boxes(image, refs, ouput_path): |
|
|
|
|
|
image_width, image_height = image.size |
|
|
|
|
|
img_draw = image.copy() |
|
|
draw = ImageDraw.Draw(img_draw) |
|
|
|
|
|
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) |
|
|
draw2 = ImageDraw.Draw(overlay) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
font = ImageFont.load_default() |
|
|
|
|
|
img_idx = 0 |
|
|
|
|
|
for i, ref in enumerate(refs): |
|
|
try: |
|
|
result = extract_coordinates_and_label(ref, image_width, image_height) |
|
|
if result: |
|
|
label_type, points_list = result |
|
|
|
|
|
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) |
|
|
|
|
|
color_a = color + (20, ) |
|
|
for points in points_list: |
|
|
x1, y1, x2, y2 = points |
|
|
|
|
|
x1 = int(x1 / 999 * image_width) |
|
|
y1 = int(y1 / 999 * image_height) |
|
|
|
|
|
x2 = int(x2 / 999 * image_width) |
|
|
y2 = int(y2 / 999 * image_height) |
|
|
|
|
|
if label_type == 'image': |
|
|
try: |
|
|
cropped = image.crop((x1, y1, x2, y2)) |
|
|
cropped.save(f"{ouput_path}/images/{img_idx}.jpg") |
|
|
except Exception as e: |
|
|
print(e) |
|
|
pass |
|
|
img_idx += 1 |
|
|
|
|
|
try: |
|
|
if label_type == 'title': |
|
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=4) |
|
|
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) |
|
|
else: |
|
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=2) |
|
|
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) |
|
|
text_x = x1 |
|
|
text_y = max(0, y1 - 15) |
|
|
|
|
|
|
|
|
text_bbox = draw.textbbox((0, 0), label_type, font=font) |
|
|
text_width = text_bbox[2] - text_bbox[0] |
|
|
text_height = text_bbox[3] - text_bbox[1] |
|
|
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], |
|
|
fill=(255, 255, 255, 30)) |
|
|
|
|
|
draw.text((text_x, text_y), label_type, font=font, fill=color) |
|
|
except: |
|
|
pass |
|
|
except: |
|
|
continue |
|
|
img_draw.paste(overlay, (0, 0), overlay) |
|
|
return img_draw |
|
|
|
|
|
|
|
|
def process_image_with_refs(image, ref_texts, output_path): |
|
|
|
|
|
result_image = draw_bounding_boxes(image, ref_texts, output_path) |
|
|
|
|
|
return result_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
|
|
best_ratio_diff = float('inf') |
|
|
best_ratio = (1, 1) |
|
|
area = width * height |
|
|
for ratio in target_ratios: |
|
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
|
if ratio_diff < best_ratio_diff: |
|
|
best_ratio_diff = ratio_diff |
|
|
best_ratio = ratio |
|
|
elif ratio_diff == best_ratio_diff: |
|
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
|
|
best_ratio = ratio |
|
|
|
|
|
return best_ratio |
|
|
|
|
|
|
|
|
def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False): |
|
|
orig_width, orig_height = image.size |
|
|
aspect_ratio = orig_width / orig_height |
|
|
|
|
|
|
|
|
target_ratios = set( |
|
|
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
|
|
i * j <= max_num and i * j >= min_num) |
|
|
|
|
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio( |
|
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size) |
|
|
|
|
|
|
|
|
|
|
|
target_width = image_size * target_aspect_ratio[0] |
|
|
target_height = image_size * target_aspect_ratio[1] |
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
|
|
|
resized_img = image.resize((target_width, target_height)) |
|
|
processed_images = [] |
|
|
for i in range(blocks): |
|
|
box = ( |
|
|
(i % (target_width // image_size)) * image_size, |
|
|
(i // (target_width // image_size)) * image_size, |
|
|
((i % (target_width // image_size)) + 1) * image_size, |
|
|
((i // (target_width // image_size)) + 1) * image_size |
|
|
) |
|
|
|
|
|
split_img = resized_img.crop(box) |
|
|
processed_images.append(split_img) |
|
|
assert len(processed_images) == blocks |
|
|
if use_thumbnail and len(processed_images) != 1: |
|
|
thumbnail_img = image.resize((image_size, image_size)) |
|
|
processed_images.append(thumbnail_img) |
|
|
return processed_images, target_aspect_ratio |
|
|
|
|
|
|
|
|
|
|
|
def normalize_transform(mean, std): |
|
|
if mean is None and std is None: |
|
|
transform = None |
|
|
elif mean is None and std is not None: |
|
|
mean = [0.] * len(std) |
|
|
transform = transforms.Normalize(mean=mean, std=std) |
|
|
elif mean is not None and std is None: |
|
|
std = [1.] * len(mean) |
|
|
transform = transforms.Normalize(mean=mean, std=std) |
|
|
else: |
|
|
transform = transforms.Normalize(mean=mean, std=std) |
|
|
|
|
|
return transform |
|
|
|
|
|
|
|
|
|
|
|
def format_messages( |
|
|
conversations: List[Dict[str, str]], |
|
|
sft_format: str = "deepseek", |
|
|
system_prompt: str = "", |
|
|
): |
|
|
""" |
|
|
Applies the SFT template to conversation. |
|
|
|
|
|
Args: |
|
|
conversations (List[Dict]): A List of messages. |
|
|
sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". |
|
|
system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". |
|
|
|
|
|
Returns: |
|
|
sft_prompt (str): The formatted text. |
|
|
""" |
|
|
|
|
|
conv = get_conv_template(sft_format) |
|
|
conv.set_system_message(system_prompt) |
|
|
for message in conversations: |
|
|
conv.append_message(message["role"], message["content"].strip()) |
|
|
sft_prompt = conv.get_prompt().strip() |
|
|
|
|
|
return sft_prompt |
|
|
|
|
|
|
|
|
def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False): |
|
|
t = tokenizer.encode(text, add_special_tokens=False) |
|
|
bos_id = 0 |
|
|
eos_id = 1 |
|
|
if bos: |
|
|
t = [bos_id] + t |
|
|
if eos: |
|
|
t = t + [eos_id] |
|
|
|
|
|
return t |
|
|
|
|
|
def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]: |
|
|
""" |
|
|
|
|
|
Args: |
|
|
conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : |
|
|
[ |
|
|
{ |
|
|
"role": "User", |
|
|
"content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.", |
|
|
"images": ["./examples/table_datasets.png"] |
|
|
}, |
|
|
{"role": "Assistant", "content": ""}, |
|
|
] |
|
|
|
|
|
Returns: |
|
|
pil_images (List[PIL.Image.Image]): the list of PIL images. |
|
|
|
|
|
""" |
|
|
|
|
|
pil_images = [] |
|
|
|
|
|
for message in conversations: |
|
|
if "images" not in message: |
|
|
continue |
|
|
|
|
|
for image_path in message["images"]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pil_img = load_image(image_path) |
|
|
pil_img = pil_img.convert("RGB") |
|
|
pil_images.append(pil_img) |
|
|
|
|
|
return pil_images |
|
|
|
|
|
|
|
|
class BaseTransform(ABC): |
|
|
|
|
|
def set_rng(self, *args, **kwargs): |
|
|
pass |
|
|
|
|
|
def __call__(self, *args, **kwargs) -> torch.Tensor: |
|
|
pass |
|
|
|
|
|
@property |
|
|
def default_shape(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class BasicImageTransform(BaseTransform): |
|
|
def __init__( |
|
|
self, |
|
|
mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), |
|
|
std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), |
|
|
normalize: bool = True |
|
|
): |
|
|
self.mean = mean |
|
|
self.std = std |
|
|
|
|
|
transform_pipelines = [ |
|
|
transforms.ToTensor() |
|
|
] |
|
|
|
|
|
normalize = normalize_transform(mean, std) if normalize else nn.Identity() |
|
|
if normalize is not None: |
|
|
transform_pipelines.append(normalize) |
|
|
|
|
|
self.transform = transforms.Compose(transform_pipelines) |
|
|
|
|
|
def __call__(self, x): |
|
|
x = self.transform(x) |
|
|
return x |
|
|
|
|
|
class NoEOSTextStreamer(TextStreamer): |
|
|
def on_finalized_text(self, text: str, stream_end: bool = False): |
|
|
|
|
|
eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False) |
|
|
text = text.replace(eos_text, "\n") |
|
|
print(text, flush=True, end="") |
|
|
|
|
|
|
|
|
class DeepseekOCRConfig(DeepseekV2Config): |
|
|
model_type = "DeepseekOCR" |
|
|
|
|
|
class DeepseekOCRModel(DeepseekV2Model): |
|
|
config_class = DeepseekOCRConfig |
|
|
|
|
|
def __init__(self, config: DeepseekV2Config): |
|
|
super(DeepseekOCRModel, self).__init__(config) |
|
|
|
|
|
self.sam_model = build_sam_vit_b() |
|
|
self.vision_model = build_clip_l() |
|
|
|
|
|
n_embed = 1280 |
|
|
self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed)) |
|
|
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) |
|
|
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) |
|
|
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
images: Optional[torch.FloatTensor] = None, |
|
|
images_seq_mask: Optional[torch.FloatTensor] = None, |
|
|
images_spatial_crop: Optional[torch.FloatTensor] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if inputs_embeds is None: |
|
|
|
|
|
inputs_embeds = self.get_input_embeddings()(input_ids) |
|
|
|
|
|
|
|
|
|
|
|
sam_model = getattr(self, 'sam_model', None) |
|
|
|
|
|
vision_model = getattr(self, 'vision_model', None) |
|
|
|
|
|
|
|
|
|
|
|
if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0: |
|
|
|
|
|
idx = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for image, crop_shape in zip(images, images_spatial_crop): |
|
|
images_in_this_batch = [] |
|
|
|
|
|
patches = image[0] |
|
|
image_ori = image[1] |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
if torch.sum(patches).item() != 0: |
|
|
|
|
|
crop_flag = 1 |
|
|
local_features_1 = sam_model(patches) |
|
|
|
|
|
local_features_2 = vision_model(patches, local_features_1) |
|
|
|
|
|
local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1) |
|
|
local_features = self.projector(local_features) |
|
|
|
|
|
|
|
|
global_features_1 = sam_model(image_ori) |
|
|
global_features_2 = vision_model(image_ori, global_features_1) |
|
|
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) |
|
|
global_features = self.projector(global_features) |
|
|
|
|
|
print('=====================') |
|
|
print('BASE: ', global_features.shape) |
|
|
print('PATCHES: ', local_features.shape) |
|
|
print('=====================') |
|
|
|
|
|
_, hw, n_dim = global_features.shape |
|
|
h = w = int(hw ** 0.5) |
|
|
|
|
|
_2, hw2, n_dim2 = local_features.shape |
|
|
h2 = w2 = int(hw2 ** 0.5) |
|
|
|
|
|
width_crop_num, height_crop_num = crop_shape[0], crop_shape[1] |
|
|
|
|
|
global_features = global_features.view(h, w, n_dim) |
|
|
|
|
|
global_features = torch.cat( |
|
|
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 |
|
|
) |
|
|
|
|
|
global_features = global_features.view(-1, n_dim) |
|
|
|
|
|
|
|
|
local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2) |
|
|
local_features = torch.cat( |
|
|
[local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1 |
|
|
) |
|
|
local_features = local_features.view(-1, n_dim2) |
|
|
|
|
|
global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
global_features_1 = sam_model(image_ori) |
|
|
global_features_2 = vision_model(image_ori, global_features_1) |
|
|
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) |
|
|
global_features = self.projector(global_features) |
|
|
print('=====================') |
|
|
print('BASE: ', global_features.shape) |
|
|
print('NO PATCHES') |
|
|
print('=====================') |
|
|
_, hw, n_dim = global_features.shape |
|
|
h = w = int(hw ** 0.5) |
|
|
|
|
|
|
|
|
global_features = global_features.view(h, w, n_dim) |
|
|
|
|
|
global_features = torch.cat( |
|
|
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 |
|
|
) |
|
|
|
|
|
global_features = global_features.view(-1, n_dim) |
|
|
|
|
|
global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0) |
|
|
|
|
|
images_in_this_batch.append(global_local_features) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if images_in_this_batch: |
|
|
images_in_this_batch = torch.cat(images_in_this_batch, dim=0) |
|
|
|
|
|
|
|
|
inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch) |
|
|
|
|
|
idx += 1 |
|
|
|
|
|
|
|
|
return super(DeepseekOCRModel, self).forward( |
|
|
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, |
|
|
output_attentions=output_attentions, output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict |
|
|
) |
|
|
|
|
|
|
|
|
class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): |
|
|
|
|
|
config_class = DeepseekOCRConfig |
|
|
|
|
|
|
|
|
def __init__(self, config): |
|
|
super(DeepseekV2ForCausalLM, self).__init__(config) |
|
|
self.model = DeepseekOCRModel(config) |
|
|
|
|
|
self.vocab_size = config.vocab_size |
|
|
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_model(self): |
|
|
return self.model |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
images: Optional[torch.FloatTensor] = None, |
|
|
images_seq_mask: Optional[torch.FloatTensor] = None, |
|
|
images_spatial_crop: Optional[torch.FloatTensor] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
|
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
past_key_values=past_key_values, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
images=images, |
|
|
images_seq_mask = images_seq_mask, |
|
|
images_spatial_crop = images_spatial_crop, |
|
|
return_dict=return_dict |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = outputs[0] |
|
|
logits = self.lm_head(hidden_states) |
|
|
logits = logits.float() |
|
|
|
|
|
|
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
|
|
): |
|
|
|
|
|
past_length = 0 |
|
|
if past_key_values is not None: |
|
|
if isinstance(past_key_values, Cache): |
|
|
cache_length = past_key_values.get_seq_length() |
|
|
past_length = past_key_values.seen_tokens |
|
|
max_cache_length = past_key_values.get_max_length() |
|
|
else: |
|
|
cache_length = past_length = past_key_values[0][0].shape[2] |
|
|
max_cache_length = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
|
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] |
|
|
|
|
|
|
|
|
elif past_length < input_ids.shape[1]: |
|
|
input_ids = input_ids[:, past_length:] |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
max_cache_length is not None |
|
|
and attention_mask is not None |
|
|
and cache_length + input_ids.shape[1] > max_cache_length |
|
|
): |
|
|
attention_mask = attention_mask[:, -max_cache_length:] |
|
|
|
|
|
position_ids = kwargs.get("position_ids", None) |
|
|
if attention_mask is not None and position_ids is None: |
|
|
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
if past_key_values: |
|
|
position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) |
|
|
|
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
|
else: |
|
|
model_inputs = {"input_ids": input_ids} |
|
|
|
|
|
model_inputs.update( |
|
|
{ |
|
|
"position_ids": position_ids, |
|
|
"past_key_values": past_key_values, |
|
|
"use_cache": kwargs.get("use_cache"), |
|
|
"attention_mask": attention_mask, |
|
|
"images": kwargs.get("images", None), |
|
|
"images_seq_mask": kwargs.get("images_seq_mask", None), |
|
|
"images_spatial_crop": kwargs.get("images_spatial_crop", None), |
|
|
} |
|
|
) |
|
|
return model_inputs |
|
|
|
|
|
|
|
|
def disable_torch_init(self): |
|
|
""" |
|
|
Disable the redundant torch default initialization to accelerate model creation. |
|
|
""" |
|
|
import torch |
|
|
setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
|
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
|
|
|
|
|
|
|
|
|
|
|
def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False): |
|
|
self.disable_torch_init() |
|
|
|
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
os.makedirs(f'{output_path}/images', exist_ok=True) |
|
|
|
|
|
if prompt and image_file: |
|
|
conversation = [ |
|
|
{ |
|
|
"role": "<|User|>", |
|
|
|
|
|
"content": f'{prompt}', |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"images": [f'{image_file}'], |
|
|
}, |
|
|
{"role": "<|Assistant|>", "content": ""}, |
|
|
] |
|
|
|
|
|
elif prompt: |
|
|
conversation = [ |
|
|
{ |
|
|
"role": "<|User|>", |
|
|
|
|
|
"content": f'{prompt}', |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}, |
|
|
{"role": "<|Assistant|>", "content": ""}, |
|
|
] |
|
|
else: |
|
|
assert False, f'prompt is none!' |
|
|
|
|
|
prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='') |
|
|
|
|
|
patch_size = 16 |
|
|
downsample_ratio = 4 |
|
|
images = load_pil_images(conversation) |
|
|
|
|
|
valid_img_tokens = 0 |
|
|
ratio = 1 |
|
|
|
|
|
image_draw = images[0].copy() |
|
|
|
|
|
w,h = image_draw.size |
|
|
|
|
|
ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) |
|
|
|
|
|
|
|
|
image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) |
|
|
images_seq_mask = [] |
|
|
|
|
|
image_token = '<image>' |
|
|
image_token_id = 128815 |
|
|
text_splits = prompt.split(image_token) |
|
|
|
|
|
images_list, images_crop_list, images_seq_mask = [], [], [] |
|
|
tokenized_str = [] |
|
|
images_spatial_crop = [] |
|
|
for text_sep, image in zip(text_splits, images): |
|
|
|
|
|
tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) |
|
|
tokenized_str += tokenized_sep |
|
|
images_seq_mask += [False] * len(tokenized_sep) |
|
|
|
|
|
if crop_mode: |
|
|
|
|
|
if image.size[0] <= 640 and image.size[1] <= 640: |
|
|
crop_ratio = [1, 1] |
|
|
|
|
|
else: |
|
|
if crop_mode: |
|
|
|
|
|
images_crop_raw, crop_ratio = dynamic_preprocess(image) |
|
|
else: |
|
|
|
|
|
crop_ratio = [1, 1] |
|
|
|
|
|
"""process the global view""" |
|
|
|
|
|
global_view = ImageOps.pad(image, (base_size, base_size), |
|
|
color=tuple(int(x * 255) for x in image_transform.mean)) |
|
|
|
|
|
if base_size == 1024: |
|
|
valid_img_tokens += int(256 * ratio) |
|
|
elif base_size == 1280: |
|
|
valid_img_tokens += int(400 * ratio) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
images_list.append(image_transform(global_view).to(torch.bfloat16)) |
|
|
|
|
|
|
|
|
|
|
|
width_crop_num, height_crop_num = crop_ratio |
|
|
|
|
|
images_spatial_crop.append([width_crop_num, height_crop_num]) |
|
|
|
|
|
|
|
|
if width_crop_num > 1 or height_crop_num > 1: |
|
|
"""process the local views""" |
|
|
|
|
|
for i in range(len(images_crop_raw)): |
|
|
images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16)) |
|
|
|
|
|
if image_size == 640: |
|
|
valid_img_tokens += len(images_crop_list) * 100 |
|
|
|
|
|
num_queries = math.ceil((image_size // patch_size) / downsample_ratio) |
|
|
num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) |
|
|
|
|
|
|
|
|
|
|
|
"""add image tokens""" |
|
|
|
|
|
|
|
|
|
|
|
tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base |
|
|
tokenized_image += [image_token_id] |
|
|
if width_crop_num > 1 or height_crop_num > 1: |
|
|
tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * ( |
|
|
num_queries * height_crop_num) |
|
|
tokenized_str += tokenized_image |
|
|
images_seq_mask += [True] * len(tokenized_image) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
"""process the global view""" |
|
|
if image_size <= 640: |
|
|
print('directly resize') |
|
|
image = image.resize((image_size, image_size)) |
|
|
|
|
|
global_view = ImageOps.pad(image, (image_size, image_size), |
|
|
color=tuple(int(x * 255) for x in image_transform.mean)) |
|
|
images_list.append(image_transform(global_view).to(torch.bfloat16)) |
|
|
|
|
|
if base_size == 1024: |
|
|
valid_img_tokens += int(256 * ratio) |
|
|
elif base_size == 1280: |
|
|
valid_img_tokens += int(400 * ratio) |
|
|
elif base_size == 640: |
|
|
valid_img_tokens += int(100 * 1) |
|
|
elif base_size == 512: |
|
|
valid_img_tokens += int(64 * 1) |
|
|
|
|
|
width_crop_num, height_crop_num = 1, 1 |
|
|
|
|
|
images_spatial_crop.append([width_crop_num, height_crop_num]) |
|
|
|
|
|
|
|
|
"""add image tokens""" |
|
|
num_queries = math.ceil((image_size // patch_size) / downsample_ratio) |
|
|
|
|
|
tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries |
|
|
tokenized_image += [image_token_id] |
|
|
|
|
|
|
|
|
tokenized_str += tokenized_image |
|
|
images_seq_mask += [True] * len(tokenized_image) |
|
|
|
|
|
|
|
|
|
|
|
"""process the last text split""" |
|
|
tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) |
|
|
tokenized_str += tokenized_sep |
|
|
images_seq_mask += [False] * len(tokenized_sep) |
|
|
|
|
|
"""add the bos tokens""" |
|
|
bos_id = 0 |
|
|
tokenized_str = [bos_id] + tokenized_str |
|
|
images_seq_mask = [False] + images_seq_mask |
|
|
|
|
|
|
|
|
|
|
|
input_ids = torch.LongTensor(tokenized_str) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) |
|
|
|
|
|
|
|
|
if len(images_list) == 0: |
|
|
images_ori = torch.zeros((1, 3, image_size, image_size)) |
|
|
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) |
|
|
images_crop = torch.zeros((1, 3, base_size, base_size)) |
|
|
|
|
|
else: |
|
|
images_ori = torch.stack(images_list, dim=0) |
|
|
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) |
|
|
if images_crop_list: |
|
|
images_crop = torch.stack(images_crop_list, dim=0) |
|
|
else: |
|
|
images_crop = torch.zeros((1, 3, base_size, base_size)) |
|
|
|
|
|
|
|
|
|
|
|
if not eval_mode: |
|
|
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) |
|
|
with torch.autocast("cuda", dtype=torch.bfloat16): |
|
|
with torch.no_grad(): |
|
|
output_ids = self.generate( |
|
|
input_ids.unsqueeze(0).cuda(), |
|
|
images=[(images_crop.cuda(), images_ori.cuda())], |
|
|
images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), |
|
|
images_spatial_crop = images_spatial_crop, |
|
|
|
|
|
|
|
|
temperature=0.0, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
streamer=streamer, |
|
|
max_new_tokens=8192, |
|
|
no_repeat_ngram_size = 20, |
|
|
use_cache = True |
|
|
) |
|
|
|
|
|
else: |
|
|
with torch.autocast("cuda", dtype=torch.bfloat16): |
|
|
with torch.no_grad(): |
|
|
output_ids = self.generate( |
|
|
input_ids.unsqueeze(0).cuda(), |
|
|
images=[(images_crop.cuda(), images_ori.cuda())], |
|
|
images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), |
|
|
images_spatial_crop = images_spatial_crop, |
|
|
|
|
|
|
|
|
temperature=0.0, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
max_new_tokens=8192, |
|
|
no_repeat_ngram_size = 35, |
|
|
use_cache = True |
|
|
) |
|
|
|
|
|
|
|
|
if '<image>' in conversation[0]['content'] and eval_mode: |
|
|
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) |
|
|
stop_str = '<|end▁of▁sentence|>' |
|
|
if outputs.endswith(stop_str): |
|
|
outputs = outputs[:-len(stop_str)] |
|
|
|
|
|
outputs = outputs.strip() |
|
|
|
|
|
return outputs |
|
|
|
|
|
if '<image>' in conversation[0]['content'] and test_compress: |
|
|
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) |
|
|
pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) |
|
|
print('='*50) |
|
|
print('image size: ', (w, h)) |
|
|
print('valid image tokens: ', int(valid_img_tokens)) |
|
|
print('output texts tokens (valid): ', pure_texts_outputs_token_length) |
|
|
print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) |
|
|
print('='*50) |
|
|
|
|
|
|
|
|
if '<image>' in conversation[0]['content'] and save_results: |
|
|
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) |
|
|
stop_str = '<|end▁of▁sentence|>' |
|
|
|
|
|
print('='*15 + 'save results:' + '='*15) |
|
|
|
|
|
|
|
|
if outputs.endswith(stop_str): |
|
|
outputs = outputs[:-len(stop_str)] |
|
|
outputs = outputs.strip() |
|
|
|
|
|
matches_ref, matches_images, mathes_other = re_match(outputs) |
|
|
|
|
|
result = process_image_with_refs(image_draw, matches_ref, output_path) |
|
|
|
|
|
|
|
|
for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): |
|
|
outputs = outputs.replace(a_match_image, ' + '.jpg)\n') |
|
|
|
|
|
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): |
|
|
outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: |
|
|
afile.write(outputs) |
|
|
|
|
|
if 'line_type' in outputs: |
|
|
import matplotlib.pyplot as plt |
|
|
lines = eval(outputs)['Line']['line'] |
|
|
|
|
|
line_type = eval(outputs)['Line']['line_type'] |
|
|
|
|
|
|
|
|
endpoints = eval(outputs)['Line']['line_endpoint'] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(3,3), dpi=200) |
|
|
ax.set_xlim(-15, 15) |
|
|
ax.set_ylim(-15, 15) |
|
|
|
|
|
for idx, line in enumerate(lines): |
|
|
try: |
|
|
p0 = eval(line.split(' -- ')[0]) |
|
|
p1 = eval(line.split(' -- ')[-1]) |
|
|
|
|
|
if line_type[idx] == '--': |
|
|
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') |
|
|
else: |
|
|
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') |
|
|
|
|
|
ax.scatter(p0[0], p0[1], s=5, color = 'k') |
|
|
ax.scatter(p1[0], p1[1], s=5, color = 'k') |
|
|
except: |
|
|
pass |
|
|
|
|
|
for endpoint in endpoints: |
|
|
|
|
|
label = endpoint.split(': ')[0] |
|
|
(x, y) = eval(endpoint.split(': ')[1]) |
|
|
ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', |
|
|
fontsize=5, fontweight='light') |
|
|
|
|
|
|
|
|
plt.savefig(f'{output_path}/geo.jpg') |
|
|
plt.close() |
|
|
|
|
|
result.save(f"{output_path}/result_with_boxes.jpg") |
|
|
|