Spaces:
Runtime error
Runtime error
use mg-llava instead llava in AutoConfig.register
Browse files- ml_mgie/mgie_llava.py +91 -47
ml_mgie/mgie_llava.py
CHANGED
|
@@ -12,12 +12,12 @@ import torch.nn.functional as F
|
|
| 12 |
from torch.nn import CrossEntropyLoss
|
| 13 |
|
| 14 |
from transformers import AutoConfig, AutoModelForCausalLM, \
|
| 15 |
-
|
| 16 |
-
|
| 17 |
|
| 18 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 19 |
|
| 20 |
-
import os
|
| 21 |
|
| 22 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 23 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
@@ -26,7 +26,7 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
|
|
| 26 |
|
| 27 |
|
| 28 |
class LlavaConfig(LlamaConfig):
|
| 29 |
-
model_type = "llava"
|
| 30 |
|
| 31 |
|
| 32 |
class LlavaLlamaModel(LlamaModel):
|
|
@@ -37,11 +37,13 @@ class LlavaLlamaModel(LlamaModel):
|
|
| 37 |
|
| 38 |
if hasattr(config, "mm_vision_tower"):
|
| 39 |
# HACK: for FSDP
|
| 40 |
-
self.vision_tower = [
|
|
|
|
| 41 |
# self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
|
| 42 |
|
| 43 |
if hasattr(config, "use_mm_proj"):
|
| 44 |
-
self.mm_projector = nn.Linear(
|
|
|
|
| 45 |
|
| 46 |
def get_vision_tower(self):
|
| 47 |
vision_tower = getattr(self, 'vision_tower', None)
|
|
@@ -67,18 +69,22 @@ class LlavaLlamaModel(LlamaModel):
|
|
| 67 |
self.vision_tower = vision_tower
|
| 68 |
|
| 69 |
vision_config = vision_tower.config
|
| 70 |
-
num_patches = (vision_config.image_size //
|
|
|
|
| 71 |
|
| 72 |
self.config.use_mm_proj = True
|
| 73 |
self.config.mm_hidden_size = vision_config.hidden_size
|
| 74 |
self.config.mm_vision_select_layer = mm_vision_select_layer
|
| 75 |
|
| 76 |
if not hasattr(self, 'mm_projector'):
|
| 77 |
-
self.mm_projector = nn.Linear(
|
|
|
|
| 78 |
|
| 79 |
if pretrain_mm_mlp_adapter is not None:
|
| 80 |
-
mm_projector_weights = torch.load(
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
|
| 83 |
return dict(
|
| 84 |
image_processor=image_processor,
|
|
@@ -117,21 +123,28 @@ class LlavaLlamaModel(LlamaModel):
|
|
| 117 |
# variable length images
|
| 118 |
image_features = []
|
| 119 |
for image in images:
|
| 120 |
-
image_forward_out = vision_tower(
|
| 121 |
-
|
|
|
|
|
|
|
| 122 |
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
|
| 123 |
image_feature = select_hidden_state[:, 1:]
|
| 124 |
image_features.append(image_feature)
|
| 125 |
else:
|
| 126 |
-
image_forward_outs = vision_tower(
|
| 127 |
-
|
|
|
|
|
|
|
| 128 |
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
|
| 129 |
-
image_features = select_hidden_state[:, 1:].to(
|
|
|
|
| 130 |
if type(images) is list:
|
| 131 |
-
image_features = [self.mm_projector(
|
|
|
|
| 132 |
else:
|
| 133 |
image_features = self.mm_projector(image_features)
|
| 134 |
-
dummy_image_features = torch.zeros(
|
|
|
|
| 135 |
dummy_image_features = self.mm_projector(dummy_image_features)
|
| 136 |
|
| 137 |
new_input_embeds = []
|
|
@@ -139,7 +152,8 @@ class LlavaLlamaModel(LlamaModel):
|
|
| 139 |
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
| 140 |
if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
|
| 141 |
# multimodal LLM, but the current sample is not multimodal
|
| 142 |
-
cur_input_embeds = cur_input_embeds +
|
|
|
|
| 143 |
new_input_embeds.append(cur_input_embeds)
|
| 144 |
cur_image_idx += 1
|
| 145 |
continue
|
|
@@ -147,32 +161,43 @@ class LlavaLlamaModel(LlamaModel):
|
|
| 147 |
cur_image_features = image_features[cur_image_idx]
|
| 148 |
num_patches = cur_image_features.shape[0]
|
| 149 |
if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
|
| 150 |
-
raise ValueError(
|
| 151 |
-
|
|
|
|
|
|
|
| 152 |
for image_start_token_pos in image_start_tokens:
|
| 153 |
-
cur_image_features = image_features[cur_image_idx].to(
|
|
|
|
| 154 |
num_patches = cur_image_features.shape[0]
|
| 155 |
if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
|
| 156 |
-
raise ValueError(
|
|
|
|
| 157 |
if orig_embeds_params is not None:
|
| 158 |
-
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
|
|
|
|
| 159 |
else:
|
| 160 |
-
cur_new_input_embeds = torch.cat(
|
|
|
|
| 161 |
cur_image_idx += 1
|
| 162 |
new_input_embeds.append(cur_new_input_embeds)
|
| 163 |
else:
|
| 164 |
cur_image_features = image_features[cur_image_idx]
|
| 165 |
num_patches = cur_image_features.shape[0]
|
| 166 |
if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
|
| 167 |
-
raise ValueError(
|
| 168 |
-
|
|
|
|
|
|
|
| 169 |
mask_index_start = masked_indices[0]
|
| 170 |
if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
|
| 171 |
-
raise ValueError(
|
|
|
|
| 172 |
if orig_embeds_params is not None:
|
| 173 |
-
cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(
|
|
|
|
| 174 |
else:
|
| 175 |
-
cur_new_input_embeds = torch.cat(
|
|
|
|
| 176 |
new_input_embeds.append(cur_new_input_embeds)
|
| 177 |
cur_image_idx += 1
|
| 178 |
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
|
@@ -184,6 +209,7 @@ class LlavaLlamaModel(LlamaModel):
|
|
| 184 |
return_dict=return_dict
|
| 185 |
)
|
| 186 |
|
|
|
|
| 187 |
class EditMapper(nn.Module):
|
| 188 |
def __init__(self):
|
| 189 |
super().__init__()
|
|
@@ -202,6 +228,7 @@ class EditMapper(nn.Module):
|
|
| 202 |
|
| 203 |
return feat
|
| 204 |
|
|
|
|
| 205 |
class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
| 206 |
config_class = LlavaConfig
|
| 207 |
|
|
@@ -209,7 +236,8 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
| 209 |
super(LlamaForCausalLM, self).__init__(config)
|
| 210 |
self.model = LlavaLlamaModel(config)
|
| 211 |
|
| 212 |
-
self.lm_head = nn.Linear(
|
|
|
|
| 213 |
|
| 214 |
self.edit_head = EditMapper()
|
| 215 |
|
|
@@ -292,12 +320,15 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
| 292 |
if labels is not None:
|
| 293 |
llm = []
|
| 294 |
for i in range(labels.shape[0]):
|
| 295 |
-
try:
|
| 296 |
-
|
|
|
|
|
|
|
| 297 |
p = min(len(hidden_states[i])-9, p)
|
| 298 |
llm.append(hidden_states[i][p:p+8].unsqueeze(0))
|
| 299 |
llm = torch.cat(llm, dim=0)
|
| 300 |
-
hid_edit = self.edit_head(
|
|
|
|
| 301 |
|
| 302 |
B, DROP = labels.shape[0], 0.05
|
| 303 |
|
|
@@ -305,24 +336,30 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
| 305 |
self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
|
| 306 |
|
| 307 |
with torch.no_grad():
|
| 308 |
-
lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample(
|
|
|
|
| 309 |
lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
|
| 310 |
torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
|
| 311 |
|
| 312 |
noise = torch.randn_like(lat_ans)
|
| 313 |
-
ts = torch.randint(
|
|
|
|
| 314 |
lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
|
| 315 |
|
| 316 |
prob = torch.rand(B, device=lat_ans.device)
|
| 317 |
-
mask = (prob<(DROP*2)).reshape(B, 1, 1)
|
| 318 |
hid_edit = torch.where(mask, hid_null, hid_edit)
|
| 319 |
-
mask = (1.0-((prob>=DROP).to(lat_inp.dtype)*
|
|
|
|
| 320 |
lat_inp *= mask
|
| 321 |
|
| 322 |
-
out = self.unet(
|
|
|
|
| 323 |
|
| 324 |
-
loss_ce, loss_edit = loss, nn.functional.mse_loss(
|
| 325 |
-
|
|
|
|
|
|
|
| 326 |
loss = loss_ce+loss_edit*0.5
|
| 327 |
|
| 328 |
if not return_dict:
|
|
@@ -367,9 +404,11 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
| 367 |
self.resize_token_embeddings(len(tokenizer))
|
| 368 |
|
| 369 |
if mm_use_im_start_end:
|
| 370 |
-
num_new_tokens = tokenizer.add_tokens(
|
|
|
|
| 371 |
self.resize_token_embeddings(len(tokenizer))
|
| 372 |
-
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
|
|
|
|
| 373 |
|
| 374 |
if num_new_tokens > 0:
|
| 375 |
input_embeddings = self.get_input_embeddings().weight.data
|
|
@@ -384,14 +423,16 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
| 384 |
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 385 |
|
| 386 |
if tune_mm_mlp_adapter:
|
| 387 |
-
self.get_model().orig_embeds_params = [
|
|
|
|
| 388 |
for p in self.get_input_embeddings().parameters():
|
| 389 |
p.requires_grad = True
|
| 390 |
for p in self.get_output_embeddings().parameters():
|
| 391 |
p.requires_grad = False
|
| 392 |
|
| 393 |
if pretrain_mm_mlp_adapter:
|
| 394 |
-
mm_projector_weights = torch.load(
|
|
|
|
| 395 |
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
| 396 |
assert num_new_tokens == 2
|
| 397 |
if input_embeddings.shape == embed_tokens_weight.shape:
|
|
@@ -399,9 +440,12 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
| 399 |
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
| 400 |
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
| 401 |
else:
|
| 402 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
-
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
| 405 |
|
| 406 |
-
AutoConfig.register("llava", LlavaConfig)
|
| 407 |
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
|
|
|
|
| 12 |
from torch.nn import CrossEntropyLoss
|
| 13 |
|
| 14 |
from transformers import AutoConfig, AutoModelForCausalLM, \
|
| 15 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM, \
|
| 16 |
+
CLIPVisionModel, CLIPImageProcessor
|
| 17 |
|
| 18 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 19 |
|
| 20 |
+
import os
|
| 21 |
|
| 22 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 23 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
class LlavaConfig(LlamaConfig):
|
| 29 |
+
model_type = "mg-llava"
|
| 30 |
|
| 31 |
|
| 32 |
class LlavaLlamaModel(LlamaModel):
|
|
|
|
| 37 |
|
| 38 |
if hasattr(config, "mm_vision_tower"):
|
| 39 |
# HACK: for FSDP
|
| 40 |
+
self.vision_tower = [
|
| 41 |
+
CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
|
| 42 |
# self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
|
| 43 |
|
| 44 |
if hasattr(config, "use_mm_proj"):
|
| 45 |
+
self.mm_projector = nn.Linear(
|
| 46 |
+
config.mm_hidden_size, config.hidden_size)
|
| 47 |
|
| 48 |
def get_vision_tower(self):
|
| 49 |
vision_tower = getattr(self, 'vision_tower', None)
|
|
|
|
| 69 |
self.vision_tower = vision_tower
|
| 70 |
|
| 71 |
vision_config = vision_tower.config
|
| 72 |
+
num_patches = (vision_config.image_size //
|
| 73 |
+
vision_config.patch_size) ** 2
|
| 74 |
|
| 75 |
self.config.use_mm_proj = True
|
| 76 |
self.config.mm_hidden_size = vision_config.hidden_size
|
| 77 |
self.config.mm_vision_select_layer = mm_vision_select_layer
|
| 78 |
|
| 79 |
if not hasattr(self, 'mm_projector'):
|
| 80 |
+
self.mm_projector = nn.Linear(
|
| 81 |
+
vision_config.hidden_size, self.config.hidden_size)
|
| 82 |
|
| 83 |
if pretrain_mm_mlp_adapter is not None:
|
| 84 |
+
mm_projector_weights = torch.load(
|
| 85 |
+
pretrain_mm_mlp_adapter, map_location='cpu')
|
| 86 |
+
self.mm_projector.load_state_dict(
|
| 87 |
+
{k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
|
| 88 |
|
| 89 |
return dict(
|
| 90 |
image_processor=image_processor,
|
|
|
|
| 123 |
# variable length images
|
| 124 |
image_features = []
|
| 125 |
for image in images:
|
| 126 |
+
image_forward_out = vision_tower(
|
| 127 |
+
image.unsqueeze(0), output_hidden_states=True)
|
| 128 |
+
select_hidden_state_layer = getattr(
|
| 129 |
+
self.config, "mm_vision_select_layer", -1)
|
| 130 |
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
|
| 131 |
image_feature = select_hidden_state[:, 1:]
|
| 132 |
image_features.append(image_feature)
|
| 133 |
else:
|
| 134 |
+
image_forward_outs = vision_tower(
|
| 135 |
+
images.to(vision_tower.dtype), output_hidden_states=True)
|
| 136 |
+
select_hidden_state_layer = getattr(
|
| 137 |
+
self.config, "mm_vision_select_layer", -1)
|
| 138 |
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
|
| 139 |
+
image_features = select_hidden_state[:, 1:].to(
|
| 140 |
+
images.dtype)
|
| 141 |
if type(images) is list:
|
| 142 |
+
image_features = [self.mm_projector(
|
| 143 |
+
image_feature)[0] for image_feature in image_features]
|
| 144 |
else:
|
| 145 |
image_features = self.mm_projector(image_features)
|
| 146 |
+
dummy_image_features = torch.zeros(
|
| 147 |
+
256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
| 148 |
dummy_image_features = self.mm_projector(dummy_image_features)
|
| 149 |
|
| 150 |
new_input_embeds = []
|
|
|
|
| 152 |
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
| 153 |
if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
|
| 154 |
# multimodal LLM, but the current sample is not multimodal
|
| 155 |
+
cur_input_embeds = cur_input_embeds + \
|
| 156 |
+
(0. * dummy_image_features).sum()
|
| 157 |
new_input_embeds.append(cur_input_embeds)
|
| 158 |
cur_image_idx += 1
|
| 159 |
continue
|
|
|
|
| 161 |
cur_image_features = image_features[cur_image_idx]
|
| 162 |
num_patches = cur_image_features.shape[0]
|
| 163 |
if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
|
| 164 |
+
raise ValueError(
|
| 165 |
+
"The number of image start tokens and image end tokens should be the same.")
|
| 166 |
+
image_start_tokens = torch.where(
|
| 167 |
+
cur_input_ids == vision_tower.config.im_start_token)[0]
|
| 168 |
for image_start_token_pos in image_start_tokens:
|
| 169 |
+
cur_image_features = image_features[cur_image_idx].to(
|
| 170 |
+
device=cur_input_embeds.device)
|
| 171 |
num_patches = cur_image_features.shape[0]
|
| 172 |
if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
"The image end token should follow the image start token.")
|
| 175 |
if orig_embeds_params is not None:
|
| 176 |
+
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
|
| 177 |
+
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
|
| 178 |
else:
|
| 179 |
+
cur_new_input_embeds = torch.cat(
|
| 180 |
+
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
| 181 |
cur_image_idx += 1
|
| 182 |
new_input_embeds.append(cur_new_input_embeds)
|
| 183 |
else:
|
| 184 |
cur_image_features = image_features[cur_image_idx]
|
| 185 |
num_patches = cur_image_features.shape[0]
|
| 186 |
if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
|
| 187 |
+
raise ValueError(
|
| 188 |
+
"The number of image patch tokens should be the same as the number of image patches.")
|
| 189 |
+
masked_indices = torch.where(
|
| 190 |
+
cur_input_ids == vision_tower.config.im_patch_token)[0]
|
| 191 |
mask_index_start = masked_indices[0]
|
| 192 |
if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
|
| 193 |
+
raise ValueError(
|
| 194 |
+
"The image patch tokens should be consecutive.")
|
| 195 |
if orig_embeds_params is not None:
|
| 196 |
+
cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(
|
| 197 |
+
), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
|
| 198 |
else:
|
| 199 |
+
cur_new_input_embeds = torch.cat(
|
| 200 |
+
(cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
|
| 201 |
new_input_embeds.append(cur_new_input_embeds)
|
| 202 |
cur_image_idx += 1
|
| 203 |
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
|
|
|
| 209 |
return_dict=return_dict
|
| 210 |
)
|
| 211 |
|
| 212 |
+
|
| 213 |
class EditMapper(nn.Module):
|
| 214 |
def __init__(self):
|
| 215 |
super().__init__()
|
|
|
|
| 228 |
|
| 229 |
return feat
|
| 230 |
|
| 231 |
+
|
| 232 |
class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
| 233 |
config_class = LlavaConfig
|
| 234 |
|
|
|
|
| 236 |
super(LlamaForCausalLM, self).__init__(config)
|
| 237 |
self.model = LlavaLlamaModel(config)
|
| 238 |
|
| 239 |
+
self.lm_head = nn.Linear(
|
| 240 |
+
config.hidden_size, config.vocab_size, bias=False)
|
| 241 |
|
| 242 |
self.edit_head = EditMapper()
|
| 243 |
|
|
|
|
| 320 |
if labels is not None:
|
| 321 |
llm = []
|
| 322 |
for i in range(labels.shape[0]):
|
| 323 |
+
try:
|
| 324 |
+
p = labels[i].data.cpu().tolist().index(32003)-1
|
| 325 |
+
except:
|
| 326 |
+
p = len(labels[i])-9
|
| 327 |
p = min(len(hidden_states[i])-9, p)
|
| 328 |
llm.append(hidden_states[i][p:p+8].unsqueeze(0))
|
| 329 |
llm = torch.cat(llm, dim=0)
|
| 330 |
+
hid_edit = self.edit_head(
|
| 331 |
+
llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
|
| 332 |
|
| 333 |
B, DROP = labels.shape[0], 0.05
|
| 334 |
|
|
|
|
| 336 |
self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
|
| 337 |
|
| 338 |
with torch.no_grad():
|
| 339 |
+
lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample(
|
| 340 |
+
)*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
|
| 341 |
lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
|
| 342 |
torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
|
| 343 |
|
| 344 |
noise = torch.randn_like(lat_ans)
|
| 345 |
+
ts = torch.randint(
|
| 346 |
+
0, self.scheduler.config.num_train_timesteps, (B, ), device=noise.device).long()
|
| 347 |
lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
|
| 348 |
|
| 349 |
prob = torch.rand(B, device=lat_ans.device)
|
| 350 |
+
mask = (prob < (DROP*2)).reshape(B, 1, 1)
|
| 351 |
hid_edit = torch.where(mask, hid_null, hid_edit)
|
| 352 |
+
mask = (1.0-((prob >= DROP).to(lat_inp.dtype) *
|
| 353 |
+
(prob < (DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
|
| 354 |
lat_inp *= mask
|
| 355 |
|
| 356 |
+
out = self.unet(
|
| 357 |
+
torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
|
| 358 |
|
| 359 |
+
loss_ce, loss_edit = loss, nn.functional.mse_loss(
|
| 360 |
+
out, noise, reduction='mean')
|
| 361 |
+
if int(os.environ['LOCAL_RANK']) == 0:
|
| 362 |
+
print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
|
| 363 |
loss = loss_ce+loss_edit*0.5
|
| 364 |
|
| 365 |
if not return_dict:
|
|
|
|
| 404 |
self.resize_token_embeddings(len(tokenizer))
|
| 405 |
|
| 406 |
if mm_use_im_start_end:
|
| 407 |
+
num_new_tokens = tokenizer.add_tokens(
|
| 408 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
| 409 |
self.resize_token_embeddings(len(tokenizer))
|
| 410 |
+
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
|
| 411 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
| 412 |
|
| 413 |
if num_new_tokens > 0:
|
| 414 |
input_embeddings = self.get_input_embeddings().weight.data
|
|
|
|
| 423 |
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 424 |
|
| 425 |
if tune_mm_mlp_adapter:
|
| 426 |
+
self.get_model().orig_embeds_params = [
|
| 427 |
+
self.get_input_embeddings().weight.data.clone().to(device=device)]
|
| 428 |
for p in self.get_input_embeddings().parameters():
|
| 429 |
p.requires_grad = True
|
| 430 |
for p in self.get_output_embeddings().parameters():
|
| 431 |
p.requires_grad = False
|
| 432 |
|
| 433 |
if pretrain_mm_mlp_adapter:
|
| 434 |
+
mm_projector_weights = torch.load(
|
| 435 |
+
pretrain_mm_mlp_adapter, map_location='cpu')
|
| 436 |
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
| 437 |
assert num_new_tokens == 2
|
| 438 |
if input_embeddings.shape == embed_tokens_weight.shape:
|
|
|
|
| 440 |
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
| 441 |
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
| 442 |
else:
|
| 443 |
+
raise ValueError(
|
| 444 |
+
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
| 445 |
+
|
| 446 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
| 447 |
+
[DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
| 448 |
|
|
|
|
| 449 |
|
| 450 |
+
AutoConfig.register("mg-llava", LlavaConfig)
|
| 451 |
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
|