Add files using upload-large-folder tool
Browse files- config.json +0 -8
- config_molmoe.py +48 -291
- example.py +55 -0
- modeling_molmoe.py +14 -94
config.json
CHANGED
|
@@ -95,14 +95,6 @@
|
|
| 95 |
"rope_theta": 10000.0,
|
| 96 |
"scale_logits": false,
|
| 97 |
"system_prompt_kind": "demo_or_style",
|
| 98 |
-
"tokenizer": {
|
| 99 |
-
"identifier": "allenai/gpt-neox-olmo-dolma-v1_5",
|
| 100 |
-
"olmo_bos_token_id": null,
|
| 101 |
-
"olmo_eos_token_id": null,
|
| 102 |
-
"tokenizer_adds_space": false,
|
| 103 |
-
"tokenizer_dir": null,
|
| 104 |
-
"truncate_direction": "right"
|
| 105 |
-
},
|
| 106 |
"transformers_version": "4.45.0.dev0",
|
| 107 |
"unconditioned": false,
|
| 108 |
"use_cache": true,
|
|
|
|
| 95 |
"rope_theta": 10000.0,
|
| 96 |
"scale_logits": false,
|
| 97 |
"system_prompt_kind": "demo_or_style",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
"transformers_version": "4.45.0.dev0",
|
| 99 |
"unconditioned": false,
|
| 100 |
"use_cache": true,
|
config_molmoe.py
CHANGED
|
@@ -2,7 +2,9 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
from dataclasses import asdict, dataclass, field
|
|
|
|
| 5 |
from glob import glob
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
from typing import (
|
| 8 |
Any,
|
|
@@ -17,168 +19,36 @@ from typing import (
|
|
| 17 |
cast,
|
| 18 |
)
|
| 19 |
|
| 20 |
-
import torch
|
| 21 |
from transformers import PretrainedConfig
|
| 22 |
-
|
| 23 |
-
from omegaconf import OmegaConf as om
|
| 24 |
-
from omegaconf.errors import OmegaConfBaseException
|
| 25 |
-
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
| 26 |
-
import gin
|
| 27 |
-
|
| 28 |
-
#from olmo.aliases import PathOrStr
|
| 29 |
-
from .aliases import PathOrStr
|
| 30 |
-
#from olmo.exceptions import OLMoConfigurationError
|
| 31 |
-
from .exceptions import OLMoConfigurationError
|
| 32 |
-
#from olmo.util import StrEnum, resource_path
|
| 33 |
-
from .util import StrEnum, resource_path
|
| 34 |
-
|
| 35 |
-
#from olmo.mm_data.data_utils import build_tokenizer
|
| 36 |
-
from .data_utils import build_tokenizer
|
| 37 |
-
#from olmo.multimodal_preprocessor import MultiModalPreprocessor
|
| 38 |
-
from .multimodal_preprocessor import MultiModalPreprocessor
|
| 39 |
-
|
| 40 |
-
__all__ = [
|
| 41 |
-
"ActivationType",
|
| 42 |
-
"ActivationCheckpointingStrategy",
|
| 43 |
-
"BlockType",
|
| 44 |
-
"LayerNormType",
|
| 45 |
-
"VisionBackboneType",
|
| 46 |
-
"VisionBackboneConfig",
|
| 47 |
-
"InitFnType",
|
| 48 |
-
"ModelConfig",
|
| 49 |
-
"OptimizerType",
|
| 50 |
-
"OptimizerConfig",
|
| 51 |
-
"SchedulerType",
|
| 52 |
-
"SchedulerConfig",
|
| 53 |
-
"DataConfig",
|
| 54 |
-
"InstanceFilterConfig",
|
| 55 |
-
"EvaluatorConfig",
|
| 56 |
-
"TokenizerConfig",
|
| 57 |
-
"TrainConfig",
|
| 58 |
-
"PaddingDirection",
|
| 59 |
-
"TruncationDirection",
|
| 60 |
-
"SpeedMonitorConfig",
|
| 61 |
-
"WandbConfig",
|
| 62 |
-
"CompilerConfig",
|
| 63 |
-
"WandbConfig",
|
| 64 |
-
"FSDPPrecision",
|
| 65 |
-
"FSDPWrapStrategy",
|
| 66 |
-
"FSDPConfig",
|
| 67 |
-
"CheckpointType",
|
| 68 |
-
]
|
| 69 |
|
| 70 |
C = TypeVar("C", bound="BaseConfig")
|
| 71 |
D = TypeVar("D", bound="DictConfig|ListConfig")
|
| 72 |
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
class AttentionType(StrEnum):
|
| 75 |
sdpa = "sdpa"
|
| 76 |
direct = "direct"
|
| 77 |
flash = "flash"
|
| 78 |
|
| 79 |
|
| 80 |
-
class BaseConfig:
|
| 81 |
-
@classmethod
|
| 82 |
-
def _register_resolvers(cls, validate_paths: bool = True):
|
| 83 |
-
# Expands path globs into a list.
|
| 84 |
-
def path_glob(*paths) -> List[str]:
|
| 85 |
-
out = []
|
| 86 |
-
for path in paths:
|
| 87 |
-
matches = sorted(glob(path))
|
| 88 |
-
if not matches and validate_paths:
|
| 89 |
-
raise FileNotFoundError(f"{path} does not match any files or dirs")
|
| 90 |
-
out.extend(matches)
|
| 91 |
-
return out
|
| 92 |
-
|
| 93 |
-
# Chooses the first path in the arguments that exists.
|
| 94 |
-
def path_choose(*paths) -> str:
|
| 95 |
-
from .util import is_url
|
| 96 |
-
|
| 97 |
-
for path in paths:
|
| 98 |
-
if is_url(path) or Path(path).exists():
|
| 99 |
-
return path
|
| 100 |
-
if validate_paths:
|
| 101 |
-
raise FileNotFoundError(", ".join(paths))
|
| 102 |
-
else:
|
| 103 |
-
return ""
|
| 104 |
-
|
| 105 |
-
# Finds the latest checkpoint in a folder.
|
| 106 |
-
def path_last_checkpoint(path) -> str:
|
| 107 |
-
from .util import find_latest_checkpoint
|
| 108 |
-
|
| 109 |
-
latest_checkpoint = find_latest_checkpoint(path)
|
| 110 |
-
if latest_checkpoint is None:
|
| 111 |
-
if validate_paths:
|
| 112 |
-
raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
|
| 113 |
-
else:
|
| 114 |
-
return ""
|
| 115 |
-
else:
|
| 116 |
-
return str(latest_checkpoint)
|
| 117 |
-
|
| 118 |
-
om.register_new_resolver("path.glob", path_glob, replace=True)
|
| 119 |
-
om.register_new_resolver("path.choose", path_choose, replace=True)
|
| 120 |
-
om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)
|
| 121 |
-
|
| 122 |
-
@classmethod
|
| 123 |
-
def update_legacy_settings(cls, config: D) -> D:
|
| 124 |
-
"""
|
| 125 |
-
Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
|
| 126 |
-
"""
|
| 127 |
-
return config
|
| 128 |
-
|
| 129 |
-
@classmethod
|
| 130 |
-
def new(cls: Type[C], **kwargs) -> C:
|
| 131 |
-
cls._register_resolvers()
|
| 132 |
-
conf = om.structured(cls)
|
| 133 |
-
try:
|
| 134 |
-
if kwargs:
|
| 135 |
-
conf = om.merge(conf, kwargs)
|
| 136 |
-
return cast(C, om.to_object(conf))
|
| 137 |
-
except OmegaConfBaseException as e:
|
| 138 |
-
raise OLMoConfigurationError(str(e))
|
| 139 |
-
|
| 140 |
-
@classmethod
|
| 141 |
-
def load(
|
| 142 |
-
cls: Type[C],
|
| 143 |
-
path: PathOrStr,
|
| 144 |
-
overrides: Optional[List[str]] = None,
|
| 145 |
-
key: Optional[str] = None,
|
| 146 |
-
validate_paths: bool = True,
|
| 147 |
-
) -> C:
|
| 148 |
-
"""Load from a YAML file."""
|
| 149 |
-
cls._register_resolvers(validate_paths=validate_paths)
|
| 150 |
-
schema = om.structured(cls)
|
| 151 |
-
try:
|
| 152 |
-
raw = om.load(str(path))
|
| 153 |
-
|
| 154 |
-
# Backwards compatibility hack, we need this here not in `update_legacy_settings`
|
| 155 |
-
# since it has to be applied before selecting with `key`
|
| 156 |
-
if "tokenizer" in raw and "model" in raw:
|
| 157 |
-
raw["model"]["tokenizer"] = raw.pop("tokenizer")
|
| 158 |
-
|
| 159 |
-
if key is not None:
|
| 160 |
-
raw = raw[key] # type: ignore
|
| 161 |
-
raw = cls.update_legacy_settings(raw)
|
| 162 |
-
conf = om.merge(schema, raw)
|
| 163 |
-
if overrides:
|
| 164 |
-
conf = om.merge(conf, om.from_dotlist(overrides))
|
| 165 |
-
return cast(C, om.to_object(conf))
|
| 166 |
-
except OmegaConfBaseException as e:
|
| 167 |
-
raise OLMoConfigurationError(str(e))
|
| 168 |
-
|
| 169 |
-
def save(self, path: PathOrStr) -> None:
|
| 170 |
-
"""Save to a YAML file."""
|
| 171 |
-
om.save(config=self, f=str(path))
|
| 172 |
-
|
| 173 |
-
def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
|
| 174 |
-
out = asdict(self) # type: ignore
|
| 175 |
-
if exclude is not None:
|
| 176 |
-
for name in exclude:
|
| 177 |
-
if name in out:
|
| 178 |
-
del out[name]
|
| 179 |
-
return out
|
| 180 |
-
|
| 181 |
-
|
| 182 |
class LayerNormType(StrEnum):
|
| 183 |
default = "default"
|
| 184 |
"""
|
|
@@ -290,7 +160,7 @@ class ImageProjectType(StrEnum):
|
|
| 290 |
|
| 291 |
|
| 292 |
@dataclass
|
| 293 |
-
class VisionBackboneConfig
|
| 294 |
image_model_type: VisionBackboneType = VisionBackboneType.openai
|
| 295 |
image_default_input_size: Tuple[int, int] = (336, 336)
|
| 296 |
image_patch_size: int = 14
|
|
@@ -328,18 +198,7 @@ class TruncationDirection(StrEnum):
|
|
| 328 |
|
| 329 |
|
| 330 |
@dataclass
|
| 331 |
-
class
|
| 332 |
-
identifier: str = "gpt2"
|
| 333 |
-
truncate_direction: TruncationDirection = TruncationDirection.right
|
| 334 |
-
# Does the tokenizer automatically start input text with a space
|
| 335 |
-
tokenizer_adds_space: Optional[bool] = False
|
| 336 |
-
tokenizer_dir: Optional[str] = None # tokenizer directory if using a seqio tokenizer
|
| 337 |
-
olmo_bos_token_id: Optional[int] = None
|
| 338 |
-
olmo_eos_token_id: Optional[int] = None
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
@dataclass
|
| 342 |
-
class ModelConfig(BaseConfig):
|
| 343 |
"""
|
| 344 |
OLMo (model) configuration.
|
| 345 |
"""
|
|
@@ -429,11 +288,6 @@ class ModelConfig(BaseConfig):
|
|
| 429 |
|
| 430 |
rope_impl: str = "cockatoo"
|
| 431 |
|
| 432 |
-
vision_backbone: Optional[VisionBackboneConfig] = None
|
| 433 |
-
"""
|
| 434 |
-
Vision backbone settings for multi-modal models.
|
| 435 |
-
"""
|
| 436 |
-
|
| 437 |
vit_load_path: Optional[str] = None
|
| 438 |
"""
|
| 439 |
Use this to load the vit model.
|
|
@@ -749,129 +603,10 @@ class ModelConfig(BaseConfig):
|
|
| 749 |
Used for Gemma-2.
|
| 750 |
"""
|
| 751 |
|
| 752 |
-
tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
|
| 753 |
-
"""
|
| 754 |
-
Tokenizer configuration.
|
| 755 |
-
"""
|
| 756 |
-
|
| 757 |
loss_token_weighting: Optional[str] = None
|
| 758 |
|
| 759 |
gin_bindings: Optional[str] = None
|
| 760 |
|
| 761 |
-
def get_tokenizer(self):
|
| 762 |
-
tokenizer_cfg = self.tokenizer
|
| 763 |
-
assert tokenizer_cfg.identifier.startswith("mm:")
|
| 764 |
-
kargs = {}
|
| 765 |
-
if tokenizer_cfg.identifier[3:].startswith("olmo-"):
|
| 766 |
-
kargs["olmo_bos_token_id"] = tokenizer_cfg.olmo_bos_token_id
|
| 767 |
-
kargs["olmo_eos_token_id"] = tokenizer_cfg.olmo_eos_token_id
|
| 768 |
-
return build_tokenizer(
|
| 769 |
-
tokenizer_cfg.identifier[3:],
|
| 770 |
-
adds_space=tokenizer_cfg.tokenizer_adds_space,
|
| 771 |
-
tokenizer_dir=tokenizer_cfg.tokenizer_dir,
|
| 772 |
-
pad_tokenizer_to=self.vocab_size if self.pad_tokenizer else None,
|
| 773 |
-
**kargs
|
| 774 |
-
)
|
| 775 |
-
|
| 776 |
-
def get_preprocessor(self):
|
| 777 |
-
vision_cfg = self.vision_backbone
|
| 778 |
-
h, w = self.llm_patches_per_crop()
|
| 779 |
-
|
| 780 |
-
return MultiModalPreprocessor(
|
| 781 |
-
loss_token_weighting=self.loss_token_weighting,
|
| 782 |
-
always_start_with_space=self.always_start_with_space,
|
| 783 |
-
tokenizer=self.get_tokenizer(),
|
| 784 |
-
prompt_override=self.prompt_override,
|
| 785 |
-
fix_image_input_idx=self.fix_image_input_idx,
|
| 786 |
-
prompt_templates=self.prompt_type,
|
| 787 |
-
system_prompt=self.system_prompt_kind,
|
| 788 |
-
default_inference_len=self.default_inference_len,
|
| 789 |
-
message_format=self.message_formatting,
|
| 790 |
-
unconditioned=self.unconditioned,
|
| 791 |
-
crop_mode=self.crop_mode,
|
| 792 |
-
max_crops=self.max_crops,
|
| 793 |
-
do_random_scale=self.do_random_scale,
|
| 794 |
-
base_image_input_size=vision_cfg.image_default_input_size,
|
| 795 |
-
image_patch_size=vision_cfg.image_patch_size,
|
| 796 |
-
image_token_length_h=h,
|
| 797 |
-
image_token_length_w=w,
|
| 798 |
-
use_col_tokens=self.use_col_tokens,
|
| 799 |
-
overlap_margins=self.overlap_margins,
|
| 800 |
-
image_padding_mask=self.image_padding_embed is not None
|
| 801 |
-
)
|
| 802 |
-
|
| 803 |
-
def __post_init__(self):
|
| 804 |
-
self.vit_layers = tuple(self.vit_layers) # type: ignore[assignment]
|
| 805 |
-
|
| 806 |
-
@classmethod
|
| 807 |
-
def update_legacy_settings(cls, config: D) -> D:
|
| 808 |
-
"""
|
| 809 |
-
Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
|
| 810 |
-
"""
|
| 811 |
-
if "flash_attention" in config:
|
| 812 |
-
is_flash = config.flash_attention
|
| 813 |
-
del config.flash_attention
|
| 814 |
-
config.attention_type = AttentionType.flash if is_flash else AttentionType.sdpa
|
| 815 |
-
|
| 816 |
-
if "bos_token_id" in config:
|
| 817 |
-
config.tokenizer.olmo_bos_token_id = config.pop("bos_token_id")
|
| 818 |
-
config.tokenizer.olmo_eos_token_id = config.pop("eos_token_id")
|
| 819 |
-
|
| 820 |
-
if "image_padding_mask" in config:
|
| 821 |
-
assert not config["image_padding_mask"]
|
| 822 |
-
del config["image_padding_mask"]
|
| 823 |
-
config["image_padding_embed"] = None
|
| 824 |
-
elif "image_padding_embed" not in config:
|
| 825 |
-
config["image_padding_embed"] = None
|
| 826 |
-
return config
|
| 827 |
-
|
| 828 |
-
@property
|
| 829 |
-
def effective_n_kv_heads(self) -> int:
|
| 830 |
-
if self.n_kv_heads is None:
|
| 831 |
-
if self.multi_query_attention is True:
|
| 832 |
-
return 1
|
| 833 |
-
else:
|
| 834 |
-
return self.n_heads
|
| 835 |
-
else:
|
| 836 |
-
if self.multi_query_attention is None:
|
| 837 |
-
return self.n_kv_heads
|
| 838 |
-
if self.multi_query_attention:
|
| 839 |
-
n_kv_heads_should_be = 1
|
| 840 |
-
else:
|
| 841 |
-
n_kv_heads_should_be = self.n_heads
|
| 842 |
-
if self.n_kv_heads == n_kv_heads_should_be:
|
| 843 |
-
return n_kv_heads_should_be
|
| 844 |
-
else:
|
| 845 |
-
raise OLMoConfigurationError(
|
| 846 |
-
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
| 847 |
-
)
|
| 848 |
-
|
| 849 |
-
@property
|
| 850 |
-
def image_num_patch(self):
|
| 851 |
-
assert self.vision_backbone is not None
|
| 852 |
-
return self.vision_backbone.image_num_patch
|
| 853 |
-
|
| 854 |
-
@property
|
| 855 |
-
def image_patch_size(self):
|
| 856 |
-
assert self.vision_backbone is not None
|
| 857 |
-
return self.visoin_backbone.image_patch_size
|
| 858 |
-
|
| 859 |
-
def llm_patches_per_crop(self):
|
| 860 |
-
h, w = self.image_num_patch
|
| 861 |
-
# Round up in case we need to pad the image features for pooling
|
| 862 |
-
h = (h + self.image_pooling_h - 1) // self.image_pooling_h
|
| 863 |
-
w = (w + self.image_pooling_w - 1) // self.image_pooling_w
|
| 864 |
-
return h, w
|
| 865 |
-
|
| 866 |
-
def get_max_crops(self) -> int:
|
| 867 |
-
"""Max numbers of that can be built for one image"""
|
| 868 |
-
if self.crop_mode == "resize":
|
| 869 |
-
return 1
|
| 870 |
-
elif "resize" in self.crop_mode:
|
| 871 |
-
return 1 + self.max_crops
|
| 872 |
-
else:
|
| 873 |
-
return self.max_crops
|
| 874 |
-
|
| 875 |
|
| 876 |
class MolmoConfig(PretrainedConfig):
|
| 877 |
model_type = "molmo"
|
|
@@ -879,7 +614,7 @@ class MolmoConfig(PretrainedConfig):
|
|
| 879 |
|
| 880 |
def __init__(self, use_cache: bool = False, **kwargs):
|
| 881 |
model_config = ModelConfig()
|
| 882 |
-
all_kwargs =
|
| 883 |
all_kwargs.update(kwargs)
|
| 884 |
all_kwargs.update({"use_cache": use_cache})
|
| 885 |
all_kwargs.update(
|
|
@@ -901,8 +636,8 @@ class MolmoConfig(PretrainedConfig):
|
|
| 901 |
|
| 902 |
@property
|
| 903 |
def image_num_patch(self):
|
| 904 |
-
|
| 905 |
-
return
|
| 906 |
|
| 907 |
@property
|
| 908 |
def llm_patches_per_crop(self):
|
|
@@ -910,4 +645,26 @@ class MolmoConfig(PretrainedConfig):
|
|
| 910 |
# Round up in case we need to pad the image features for pooling
|
| 911 |
h = (h + self.image_pooling_h - 1) // self.image_pooling_h
|
| 912 |
w = (w + self.image_pooling_w - 1) // self.image_pooling_w
|
| 913 |
-
return h, w
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
from dataclasses import asdict, dataclass, field
|
| 5 |
+
from enum import Enum
|
| 6 |
from glob import glob
|
| 7 |
+
from os import PathLike
|
| 8 |
from pathlib import Path
|
| 9 |
from typing import (
|
| 10 |
Any,
|
|
|
|
| 19 |
cast,
|
| 20 |
)
|
| 21 |
|
|
|
|
| 22 |
from transformers import PretrainedConfig
|
| 23 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
C = TypeVar("C", bound="BaseConfig")
|
| 26 |
D = TypeVar("D", bound="DictConfig|ListConfig")
|
| 27 |
|
| 28 |
|
| 29 |
+
PathOrStr = Union[str, PathLike]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class StrEnum(str, Enum):
|
| 33 |
+
"""
|
| 34 |
+
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
|
| 35 |
+
We include this here for compatibility with older version of Python.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __str__(self) -> str:
|
| 39 |
+
return self.value
|
| 40 |
+
|
| 41 |
+
def __repr__(self) -> str:
|
| 42 |
+
return f"'{str(self)}'"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
class AttentionType(StrEnum):
|
| 47 |
sdpa = "sdpa"
|
| 48 |
direct = "direct"
|
| 49 |
flash = "flash"
|
| 50 |
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
class LayerNormType(StrEnum):
|
| 53 |
default = "default"
|
| 54 |
"""
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
@dataclass
|
| 163 |
+
class VisionBackboneConfig:
|
| 164 |
image_model_type: VisionBackboneType = VisionBackboneType.openai
|
| 165 |
image_default_input_size: Tuple[int, int] = (336, 336)
|
| 166 |
image_patch_size: int = 14
|
|
|
|
| 198 |
|
| 199 |
|
| 200 |
@dataclass
|
| 201 |
+
class ModelConfig:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
"""
|
| 203 |
OLMo (model) configuration.
|
| 204 |
"""
|
|
|
|
| 288 |
|
| 289 |
rope_impl: str = "cockatoo"
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
vit_load_path: Optional[str] = None
|
| 292 |
"""
|
| 293 |
Use this to load the vit model.
|
|
|
|
| 603 |
Used for Gemma-2.
|
| 604 |
"""
|
| 605 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
loss_token_weighting: Optional[str] = None
|
| 607 |
|
| 608 |
gin_bindings: Optional[str] = None
|
| 609 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
|
| 611 |
class MolmoConfig(PretrainedConfig):
|
| 612 |
model_type = "molmo"
|
|
|
|
| 614 |
|
| 615 |
def __init__(self, use_cache: bool = False, **kwargs):
|
| 616 |
model_config = ModelConfig()
|
| 617 |
+
all_kwargs = asdict(model_config)
|
| 618 |
all_kwargs.update(kwargs)
|
| 619 |
all_kwargs.update({"use_cache": use_cache})
|
| 620 |
all_kwargs.update(
|
|
|
|
| 636 |
|
| 637 |
@property
|
| 638 |
def image_num_patch(self):
|
| 639 |
+
h, w = (336, 336)
|
| 640 |
+
return h // 14, w // 14
|
| 641 |
|
| 642 |
@property
|
| 643 |
def llm_patches_per_crop(self):
|
|
|
|
| 645 |
# Round up in case we need to pad the image features for pooling
|
| 646 |
h = (h + self.image_pooling_h - 1) // self.image_pooling_h
|
| 647 |
w = (w + self.image_pooling_w - 1) // self.image_pooling_w
|
| 648 |
+
return h, w
|
| 649 |
+
|
| 650 |
+
@property
|
| 651 |
+
def effective_n_kv_heads(self) -> int:
|
| 652 |
+
if self.n_kv_heads is None:
|
| 653 |
+
if self.multi_query_attention is True:
|
| 654 |
+
return 1
|
| 655 |
+
else:
|
| 656 |
+
return self.n_heads
|
| 657 |
+
else:
|
| 658 |
+
if self.multi_query_attention is None:
|
| 659 |
+
return self.n_kv_heads
|
| 660 |
+
if self.multi_query_attention:
|
| 661 |
+
n_kv_heads_should_be = 1
|
| 662 |
+
else:
|
| 663 |
+
n_kv_heads_should_be = self.n_heads
|
| 664 |
+
if self.n_kv_heads == n_kv_heads_should_be:
|
| 665 |
+
return n_kv_heads_should_be
|
| 666 |
+
else:
|
| 667 |
+
raise ValueError(
|
| 668 |
+
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
| 669 |
+
)
|
| 670 |
+
|
example.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import requests
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
load_path = "."
|
| 8 |
+
|
| 9 |
+
# load the processor
|
| 10 |
+
print("Loading processor")
|
| 11 |
+
processor = AutoProcessor.from_pretrained(
|
| 12 |
+
load_path,
|
| 13 |
+
trust_remote_code=True,
|
| 14 |
+
torch_dtype='auto',
|
| 15 |
+
device_map='auto'
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# load the model
|
| 19 |
+
print("Loading model")
|
| 20 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 21 |
+
load_path,
|
| 22 |
+
trust_remote_code=True,
|
| 23 |
+
torch_dtype='auto',
|
| 24 |
+
device_map='auto'
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# process the image and text
|
| 28 |
+
print("Processing...")
|
| 29 |
+
inputs = processor.process(
|
| 30 |
+
images=[Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)],
|
| 31 |
+
text="Describe this image."
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# move inputs to the correct device and make a batch of size 1
|
| 35 |
+
inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
|
| 36 |
+
|
| 37 |
+
# generate output; maximum 200 new tokens; stop generation when <|endoftext|> is generated
|
| 38 |
+
print("Generating....")
|
| 39 |
+
output = model.generate_from_batch(
|
| 40 |
+
inputs,
|
| 41 |
+
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
|
| 42 |
+
tokenizer=processor.tokenizer
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# only get generated tokens; decode them to text
|
| 46 |
+
generated_tokens = output[0,inputs['input_ids'].size(1):]
|
| 47 |
+
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 48 |
+
|
| 49 |
+
# print the generated text
|
| 50 |
+
print(generated_text)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == '__main__':
|
| 55 |
+
main()
|
modeling_molmoe.py
CHANGED
|
@@ -27,7 +27,7 @@ from typing import (
|
|
| 27 |
Set,
|
| 28 |
Tuple,
|
| 29 |
cast,
|
| 30 |
-
Union,
|
| 31 |
)
|
| 32 |
from copy import deepcopy
|
| 33 |
import torch
|
|
@@ -36,17 +36,10 @@ import torch.nn as nn
|
|
| 36 |
import torch.nn.functional as F
|
| 37 |
from torch import einsum
|
| 38 |
import einops
|
| 39 |
-
from transformers import PreTrainedModel
|
| 40 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 41 |
|
| 42 |
-
from .
|
| 43 |
-
from .beam_search import (
|
| 44 |
-
BeamSearch,
|
| 45 |
-
Constraint,
|
| 46 |
-
FinalSequenceScorer,
|
| 47 |
-
Sampler
|
| 48 |
-
)
|
| 49 |
-
from .config import (
|
| 50 |
ActivationType,
|
| 51 |
BlockType,
|
| 52 |
LayerNormType,
|
|
@@ -56,10 +49,10 @@ from .config import (
|
|
| 56 |
AttentionType,
|
| 57 |
)
|
| 58 |
|
| 59 |
-
|
| 60 |
from .config_molmoe import (
|
| 61 |
MolmoConfig,
|
| 62 |
-
VisionBackboneConfig
|
| 63 |
)
|
| 64 |
|
| 65 |
if sys.version_info.minor > 8:
|
|
@@ -69,26 +62,14 @@ elif sys.version_info.minor == 8:
|
|
| 69 |
else:
|
| 70 |
raise SystemExit("This script supports Python 3.8 or higher")
|
| 71 |
|
| 72 |
-
__all__ = [
|
| 73 |
-
"LayerNormBase",
|
| 74 |
-
"LayerNorm",
|
| 75 |
-
"RMSLayerNorm",
|
| 76 |
-
"RotaryEmbedding",
|
| 77 |
-
"Activation",
|
| 78 |
-
"GELU",
|
| 79 |
-
"ReLU",
|
| 80 |
-
"SwiGLU",
|
| 81 |
-
"OLMoBlock",
|
| 82 |
-
"OLMoSequentialBlock",
|
| 83 |
-
"OLMo",
|
| 84 |
-
"OLMoOutput",
|
| 85 |
-
"OLMoGenerateOutput",
|
| 86 |
-
]
|
| 87 |
-
|
| 88 |
|
| 89 |
log = logging.getLogger(__name__)
|
| 90 |
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def activation_checkpoint_function(cfg: ModelConfig):
|
| 93 |
preserve_rng_state = not (
|
| 94 |
(cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and
|
|
@@ -114,20 +95,6 @@ def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: b
|
|
| 114 |
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
|
| 115 |
|
| 116 |
|
| 117 |
-
def activation_checkpoint_function(cfg: MolmoConfig):
|
| 118 |
-
preserve_rng_state = not (
|
| 119 |
-
(cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and
|
| 120 |
-
(cfg.residual_dropout == 0.0) and (cfg.response_residual_dropout == 0.0)
|
| 121 |
-
)
|
| 122 |
-
from torch.utils.checkpoint import checkpoint
|
| 123 |
-
|
| 124 |
-
return partial(
|
| 125 |
-
checkpoint,
|
| 126 |
-
preserve_rng_state=True,
|
| 127 |
-
use_reentrant=False,
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
def vit_activation_checkpoint_function(cfg: MolmoConfig):
|
| 132 |
v_cfg = cfg.vision_backbone
|
| 133 |
preserve_rng_state = (
|
|
@@ -142,22 +109,6 @@ def vit_activation_checkpoint_function(cfg: MolmoConfig):
|
|
| 142 |
)
|
| 143 |
|
| 144 |
|
| 145 |
-
def should_checkpoint_block(strategy: Optional[ActivationCheckpointingStrategy], block_idx: int) -> bool:
|
| 146 |
-
if strategy is None:
|
| 147 |
-
return False
|
| 148 |
-
elif (
|
| 149 |
-
(strategy == ActivationCheckpointingStrategy.whole_layer)
|
| 150 |
-
or (strategy == ActivationCheckpointingStrategy.one_in_two and block_idx % 2 == 0)
|
| 151 |
-
or (strategy == ActivationCheckpointingStrategy.one_in_three and block_idx % 3 == 0)
|
| 152 |
-
or (strategy == ActivationCheckpointingStrategy.one_in_four and block_idx % 4 == 0)
|
| 153 |
-
or (strategy == ActivationCheckpointingStrategy.two_in_three and block_idx % 3 != 0)
|
| 154 |
-
or (strategy == ActivationCheckpointingStrategy.three_in_four and block_idx % 4 != 0)
|
| 155 |
-
):
|
| 156 |
-
return True
|
| 157 |
-
else:
|
| 158 |
-
return False
|
| 159 |
-
|
| 160 |
-
|
| 161 |
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
|
| 162 |
"""
|
| 163 |
Cache for attention biases and other things that would normally be stored as buffers.
|
|
@@ -1557,15 +1508,11 @@ class MolmoVisionBackbone(nn.Module):
|
|
| 1557 |
self.image_feature_dropout = Dropout(config.image_feature_dropout)
|
| 1558 |
|
| 1559 |
@classmethod
|
| 1560 |
-
def build(cls, config: MolmoConfig)
|
| 1561 |
v_cfg = config.vision_backbone
|
| 1562 |
assert v_cfg is not None
|
| 1563 |
return MolmoPretrainedVisionBackbone(config)
|
| 1564 |
|
| 1565 |
-
@abstractmethod
|
| 1566 |
-
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 1567 |
-
raise NotImplementedError()
|
| 1568 |
-
|
| 1569 |
def reset_parameters(self):
|
| 1570 |
if self.image_pooling_2d is not None:
|
| 1571 |
self.image_pooling_2d.reset_parameters()
|
|
@@ -1583,9 +1530,9 @@ class MolmoVisionBackbone(nn.Module):
|
|
| 1583 |
|
| 1584 |
|
| 1585 |
class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
| 1586 |
-
def __init__(self, config:
|
| 1587 |
super().__init__(config)
|
| 1588 |
-
v_cfg =
|
| 1589 |
|
| 1590 |
if v_cfg.image_model_type == VisionBackboneType.openai:
|
| 1591 |
self.image_vit = VisionTransformer(config)
|
|
@@ -1640,11 +1587,6 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
|
| 1640 |
if self.config.use_cls_feature:
|
| 1641 |
nn.init.xavier_uniform_(self.cls_projector.weight)
|
| 1642 |
|
| 1643 |
-
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 1644 |
-
self.grad_checkpointing = True
|
| 1645 |
-
if strategy in (ActivationCheckpointingStrategy.whole_layer, ActivationCheckpointingStrategy.vit_only):
|
| 1646 |
-
self.image_vit.set_grad_checkpointing()
|
| 1647 |
-
|
| 1648 |
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
| 1649 |
"""
|
| 1650 |
: param images: (batch_size, num_crops, num_patch, n_pixels)
|
|
@@ -1802,9 +1744,6 @@ class MolmoModel(MolmoPretrainedModel):
|
|
| 1802 |
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
|
| 1803 |
)
|
| 1804 |
|
| 1805 |
-
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
| 1806 |
-
self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
|
| 1807 |
-
|
| 1808 |
if not (
|
| 1809 |
0 < self.config.block_group_size <= self.config.n_layers
|
| 1810 |
and self.config.n_layers % self.config.block_group_size == 0
|
|
@@ -1846,25 +1785,14 @@ class MolmoModel(MolmoPretrainedModel):
|
|
| 1846 |
]
|
| 1847 |
self.transformer.update({"blocks": nn.ModuleList(layers)})
|
| 1848 |
|
| 1849 |
-
self.vision_backbone: Optional[
|
| 1850 |
if config.vision_backbone is not None:
|
| 1851 |
self.vision_backbone = MolmoVisionBackbone.build(config)
|
| 1852 |
|
| 1853 |
if self.vision_backbone is not None:
|
| 1854 |
self.vision_backbone.reset_with_pretrained_weights()
|
| 1855 |
|
| 1856 |
-
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
| 1857 |
-
self.activation_checkpointing_strategy = strategy
|
| 1858 |
-
if self.config.block_group_size != 1:
|
| 1859 |
-
for block_group in self.transformer.block_groups:
|
| 1860 |
-
block_group.set_activation_checkpointing(strategy)
|
| 1861 |
-
else:
|
| 1862 |
-
for block in self.transformer.blocks:
|
| 1863 |
-
block.set_activation_checkpointing(strategy)
|
| 1864 |
|
| 1865 |
-
if self.vision_backbone is not None:
|
| 1866 |
-
self.vision_backbone.set_activation_checkpointing(strategy)
|
| 1867 |
-
|
| 1868 |
@property
|
| 1869 |
def device(self) -> torch.device:
|
| 1870 |
device: torch.device = self.transformer.wte.weight.device # type: ignore
|
|
@@ -1873,7 +1801,6 @@ class MolmoModel(MolmoPretrainedModel):
|
|
| 1873 |
else:
|
| 1874 |
return device
|
| 1875 |
|
| 1876 |
-
|
| 1877 |
def forward(
|
| 1878 |
self,
|
| 1879 |
input_ids: torch.LongTensor,
|
|
@@ -2069,14 +1996,7 @@ class MolmoModel(MolmoPretrainedModel):
|
|
| 2069 |
all_hidden_states.append(x)
|
| 2070 |
|
| 2071 |
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
| 2072 |
-
|
| 2073 |
-
# shape: (batch_size, seq_len, d_model)
|
| 2074 |
-
x, cache = self._activation_checkpoint_fn(
|
| 2075 |
-
layer, x, attention_bias=attention_bias, position_ids=position_ids, drop_mask=response_mask, layer_past=layer_past, use_cache=use_cache
|
| 2076 |
-
)
|
| 2077 |
-
else:
|
| 2078 |
-
# shape: (batch_size, seq_len, d_model)
|
| 2079 |
-
x, cache = layer(x, attention_bias=attention_bias, position_ids=position_ids, drop_mask=response_mask, layer_past=layer_past, use_cache=use_cache)
|
| 2080 |
|
| 2081 |
if attn_key_values is not None:
|
| 2082 |
assert cache is not None
|
|
|
|
| 27 |
Set,
|
| 28 |
Tuple,
|
| 29 |
cast,
|
| 30 |
+
Union, Any,
|
| 31 |
)
|
| 32 |
from copy import deepcopy
|
| 33 |
import torch
|
|
|
|
| 36 |
import torch.nn.functional as F
|
| 37 |
from torch import einsum
|
| 38 |
import einops
|
| 39 |
+
from transformers import PreTrainedModel, GenerationConfig, Cache
|
| 40 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 41 |
|
| 42 |
+
from .config_molmoe import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
ActivationType,
|
| 44 |
BlockType,
|
| 45 |
LayerNormType,
|
|
|
|
| 49 |
AttentionType,
|
| 50 |
)
|
| 51 |
|
| 52 |
+
|
| 53 |
from .config_molmoe import (
|
| 54 |
MolmoConfig,
|
| 55 |
+
VisionBackboneConfig, ModelConfig
|
| 56 |
)
|
| 57 |
|
| 58 |
if sys.version_info.minor > 8:
|
|
|
|
| 62 |
else:
|
| 63 |
raise SystemExit("This script supports Python 3.8 or higher")
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
log = logging.getLogger(__name__)
|
| 67 |
|
| 68 |
|
| 69 |
+
class OLMoConfigurationError(Exception):
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
|
| 73 |
def activation_checkpoint_function(cfg: ModelConfig):
|
| 74 |
preserve_rng_state = not (
|
| 75 |
(cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and
|
|
|
|
| 95 |
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
|
| 96 |
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
def vit_activation_checkpoint_function(cfg: MolmoConfig):
|
| 99 |
v_cfg = cfg.vision_backbone
|
| 100 |
preserve_rng_state = (
|
|
|
|
| 109 |
)
|
| 110 |
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
|
| 113 |
"""
|
| 114 |
Cache for attention biases and other things that would normally be stored as buffers.
|
|
|
|
| 1508 |
self.image_feature_dropout = Dropout(config.image_feature_dropout)
|
| 1509 |
|
| 1510 |
@classmethod
|
| 1511 |
+
def build(cls, config: MolmoConfig):
|
| 1512 |
v_cfg = config.vision_backbone
|
| 1513 |
assert v_cfg is not None
|
| 1514 |
return MolmoPretrainedVisionBackbone(config)
|
| 1515 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1516 |
def reset_parameters(self):
|
| 1517 |
if self.image_pooling_2d is not None:
|
| 1518 |
self.image_pooling_2d.reset_parameters()
|
|
|
|
| 1530 |
|
| 1531 |
|
| 1532 |
class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
| 1533 |
+
def __init__(self, config: MolmoConfig):
|
| 1534 |
super().__init__(config)
|
| 1535 |
+
v_cfg = VisionBackboneConfig()
|
| 1536 |
|
| 1537 |
if v_cfg.image_model_type == VisionBackboneType.openai:
|
| 1538 |
self.image_vit = VisionTransformer(config)
|
|
|
|
| 1587 |
if self.config.use_cls_feature:
|
| 1588 |
nn.init.xavier_uniform_(self.cls_projector.weight)
|
| 1589 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1590 |
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
| 1591 |
"""
|
| 1592 |
: param images: (batch_size, num_crops, num_patch, n_pixels)
|
|
|
|
| 1744 |
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
|
| 1745 |
)
|
| 1746 |
|
|
|
|
|
|
|
|
|
|
| 1747 |
if not (
|
| 1748 |
0 < self.config.block_group_size <= self.config.n_layers
|
| 1749 |
and self.config.n_layers % self.config.block_group_size == 0
|
|
|
|
| 1785 |
]
|
| 1786 |
self.transformer.update({"blocks": nn.ModuleList(layers)})
|
| 1787 |
|
| 1788 |
+
self.vision_backbone: Optional[MolmoVisionBackbone] = None
|
| 1789 |
if config.vision_backbone is not None:
|
| 1790 |
self.vision_backbone = MolmoVisionBackbone.build(config)
|
| 1791 |
|
| 1792 |
if self.vision_backbone is not None:
|
| 1793 |
self.vision_backbone.reset_with_pretrained_weights()
|
| 1794 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1795 |
|
|
|
|
|
|
|
|
|
|
| 1796 |
@property
|
| 1797 |
def device(self) -> torch.device:
|
| 1798 |
device: torch.device = self.transformer.wte.weight.device # type: ignore
|
|
|
|
| 1801 |
else:
|
| 1802 |
return device
|
| 1803 |
|
|
|
|
| 1804 |
def forward(
|
| 1805 |
self,
|
| 1806 |
input_ids: torch.LongTensor,
|
|
|
|
| 1996 |
all_hidden_states.append(x)
|
| 1997 |
|
| 1998 |
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
| 1999 |
+
x, cache = layer(x, attention_bias=attention_bias, position_ids=position_ids, drop_mask=response_mask, layer_past=layer_past, use_cache=use_cache)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2000 |
|
| 2001 |
if attn_key_values is not None:
|
| 2002 |
assert cache is not None
|