seanmanifest commited on
Commit
dd72573
·
verified ·
1 Parent(s): 4edf638

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
2
+ from transformers import GPT2Tokenizer, GPT2TokenizerFast
3
+ from .configuration_powercoder import PowerCoderConfig
4
+ from .modeling_powercoder import PowerCoderForCausalLM
5
+
6
+ # make HF aware of the new model
7
+ AutoConfig.register("powercoder", PowerCoderConfig)
8
+ AutoModelForCausalLM.register(PowerCoderConfig, PowerCoderForCausalLM)
9
+
10
+ AutoTokenizer.register(PowerCoderConfig, GPT2Tokenizer, GPT2TokenizerFast)
11
+
12
+ __all__ = ["PowerCoderConfig", "PowerCoderForCausalLM"]
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "PowerCoderForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 50256,
7
+ "dtype": "float32",
8
+ "embedding_dropout": 0.0,
9
+ "eos_token_id": 50256,
10
+ "hidden_act": "gelu_pytorch_tanh",
11
+ "hidden_size": 3072,
12
+ "hybrid_exp": "none",
13
+ "initializer_range": 0.018042,
14
+ "intermediate_size": 12288,
15
+ "max_position_embeddings": 4096,
16
+ "model_type": "powercoder",
17
+ "norm_epsilon": 1e-05,
18
+ "num_attention_heads": 24,
19
+ "num_hidden_layers": 30,
20
+ "num_key_value_heads": 2,
21
+ "residual_dropout": 0.0,
22
+ "rope_scaling": null,
23
+ "rope_theta": 10000.0,
24
+ "sliding_window": null,
25
+ "transformers_version": "4.57.0.dev0",
26
+ "use_bias": true,
27
+ "use_cache": true,
28
+ "vocab_size": 49152
29
+ }
configuration_powercoder.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PowerCoder model configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.modeling_rope_utils import rope_config_validation
5
+ from transformers.utils import logging
6
+
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+ class PowerCoderConfig(PretrainedConfig):
12
+ r"""
13
+ This is the configuration class to store the configuration of a [`PowerCoderModel`]. It is used to instantiate a
14
+ PowerCoder model according to the specified arguments, defining the model architecture. Instantiating a configuration
15
+ with the defaults will yield a similar configuration to that of the [bigcode/starcoder2-7b](https://huggingface.co/bigcode/starcoder2-7b) model.
16
+
17
+
18
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
19
+ documentation from [`PretrainedConfig`] for more information.
20
+
21
+
22
+ Args:
23
+ vocab_size (`int`, *optional*, defaults to 49152):
24
+ Vocabulary size of the PowerCoder model. Defines the number of different tokens that can be represented by the
25
+ `inputs_ids` passed when calling [`PowerCoderModel`]
26
+ hidden_size (`int`, *optional*, defaults to 3072):
27
+ Dimension of the hidden representations.
28
+ intermediate_size (`int`, *optional*, defaults to 12288):
29
+ Dimension of the MLP representations.
30
+ num_hidden_layers (`int`, *optional*, defaults to 30):
31
+ Number of hidden layers in the Transformer encoder.
32
+ num_attention_heads (`int`, *optional*, defaults to 24):
33
+ Number of attention heads for each attention layer in the Transformer encoder.
34
+ num_key_value_heads (`int`, *optional*, defaults to 2):
35
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
36
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
37
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
38
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
39
+ by meanpooling all the original heads within that group. For more details, check out [this
40
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
41
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
42
+ The non-linear activation function (function or string) in the decoder.
43
+ max_position_embeddings (`int`, *optional*, defaults to 16384):
44
+ The maximum sequence length that this model might ever be used with.
45
+ initializer_range (`float`, *optional*, defaults to 0.02):
46
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
47
+ norm_epsilon (`float`, *optional*, defaults to 1e-05):
48
+ Epsilon value for the layer norm
49
+ use_cache (`bool`, *optional*, defaults to `True`):
50
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
51
+ relevant if `config.is_decoder=True`.
52
+ bos_token_id (`int`, *optional*, defaults to 50256):
53
+ The id of the "beginning-of-sequence" token.
54
+ eos_token_id (`int`, *optional*, defaults to 50256):
55
+ The id of the "end-of-sequence" token.
56
+ rope_theta (`float`, *optional*, defaults to 10000.0):
57
+ The base period of the RoPE embeddings.
58
+ rope_scaling (`Dict`, *optional*):
59
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
60
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
61
+ accordingly.
62
+ Expected contents:
63
+ `rope_type` (`str`):
64
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
65
+ 'llama3'], with 'default' being the original RoPE implementation.
66
+ `factor` (`float`, *optional*):
67
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
68
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
69
+ original maximum pre-trained length.
70
+ `original_max_position_embeddings` (`int`, *optional*):
71
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
72
+ pretraining.
73
+ `attention_factor` (`float`, *optional*):
74
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
75
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
76
+ `factor` field to infer the suggested value.
77
+ `beta_fast` (`float`, *optional*):
78
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
79
+ ramp function. If unspecified, it defaults to 32.
80
+ `beta_slow` (`float`, *optional*):
81
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
82
+ ramp function. If unspecified, it defaults to 1.
83
+ `short_factor` (`list[float]`, *optional*):
84
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
85
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
86
+ size divided by the number of attention heads divided by 2
87
+ `long_factor` (`list[float]`, *optional*):
88
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
89
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
90
+ size divided by the number of attention heads divided by 2
91
+ `low_freq_factor` (`float`, *optional*):
92
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
93
+ `high_freq_factor` (`float`, *optional*):
94
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
95
+ sliding_window (`int`, *optional*):
96
+ Sliding window attention window size. If not specified, will default to `None` (no sliding window).
97
+ attention_dropout (`float`, *optional*, defaults to 0.0):
98
+ The dropout ratio for the attention probabilities.
99
+ residual_dropout (`float`, *optional*, defaults to 0.0):
100
+ Residual connection dropout value.
101
+ embedding_dropout (`float`, *optional*, defaults to 0.0):
102
+ Embedding dropout.
103
+ use_bias (`bool`, *optional*, defaults to `True`):
104
+ Whether to use bias term on linear layers of the model.
105
+
106
+
107
+ ```python
108
+ >>> from transformers import PowerCoderModel, PowerCoderConfig
109
+
110
+ >>> # Initializing a PowerCoder 7B style configuration
111
+ >>> configuration = PowerCoderConfig()
112
+
113
+ >>> # Initializing a model from the PowerCoder 7B style configuration
114
+ >>> model = PowerCoderModel(configuration)
115
+
116
+ >>> # Accessing the model configuration
117
+ >>> configuration = model.config
118
+ ```"""
119
+
120
+ model_type = "powercoder"
121
+ keys_to_ignore_at_inference = ["past_key_values"]
122
+ # Default tensor parallel plan for base model `PowerCoder`
123
+ base_model_tp_plan = {
124
+ "layers.*.self_attn.q_proj": "colwise",
125
+ "layers.*.self_attn.k_proj": "colwise",
126
+ "layers.*.self_attn.v_proj": "colwise",
127
+ "layers.*.self_attn.o_proj": "rowwise",
128
+ "layers.*.mlp.c_fc": "colwise",
129
+ "layers.*.mlp.c_proj": "colwise",
130
+ }
131
+ base_model_pp_plan = {
132
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
133
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
134
+ "norm": (["hidden_states"], ["hidden_states"]),
135
+ }
136
+
137
+ def __init__(
138
+ self,
139
+ vocab_size=49152,
140
+ hidden_size=3072,
141
+ intermediate_size=12288,
142
+ num_hidden_layers=30,
143
+ num_attention_heads=24,
144
+ num_key_value_heads=2,
145
+ hidden_act="gelu_pytorch_tanh",
146
+ max_position_embeddings=4096,
147
+ initializer_range=0.018042,
148
+ norm_epsilon=1e-5,
149
+ use_cache=True,
150
+ bos_token_id=50256,
151
+ eos_token_id=50256,
152
+ rope_theta=10000.0,
153
+ rope_scaling=None,
154
+ sliding_window=None,
155
+ attention_dropout=0.0,
156
+ residual_dropout=0.0,
157
+ embedding_dropout=0.0,
158
+ use_bias=True,
159
+ chunk_size=None,
160
+ switch_over_seq_len=None,
161
+ **kwargs,
162
+ ):
163
+ self.vocab_size = vocab_size
164
+ self.max_position_embeddings = max_position_embeddings
165
+ self.hidden_size = hidden_size
166
+ self.intermediate_size = intermediate_size
167
+ self.num_hidden_layers = num_hidden_layers
168
+ self.num_attention_heads = num_attention_heads
169
+ self.sliding_window = sliding_window
170
+ self.use_bias = use_bias
171
+ self.num_key_value_heads = num_key_value_heads
172
+ self.hidden_act = hidden_act
173
+ self.initializer_range = initializer_range
174
+ self.norm_epsilon = norm_epsilon
175
+ self.use_cache = use_cache
176
+ self.rope_theta = rope_theta
177
+ self.rope_scaling = rope_scaling
178
+ self.attention_dropout = attention_dropout
179
+ self.residual_dropout = residual_dropout
180
+ self.embedding_dropout = embedding_dropout
181
+ self.chunk_size = chunk_size
182
+ self.switch_over_seq_len = switch_over_seq_len
183
+ # Validate the correctness of rotary position embeddings parameters
184
+ # BC: if there is a 'type' field, move it to 'rope_type'.
185
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
186
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
187
+ rope_config_validation(self)
188
+
189
+ super().__init__(
190
+ bos_token_id=bos_token_id,
191
+ eos_token_id=eos_token_id,
192
+ **kwargs,
193
+ )
194
+
195
+
196
+ __all__ = ["PowerCoderConfig"]
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "transformers_version": "4.57.0.dev0"
6
+ }
kvg_dynamic_cache.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+ from collections.abc import Iterable
6
+ from dataclasses import dataclass
7
+ from typing import Any, Optional, Union
8
+
9
+ import torch
10
+
11
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6
12
+
13
+ from transformers.configuration_utils import PretrainedConfig
14
+ from transformers.utils import (
15
+ is_torch_greater_or_equal,
16
+ is_torchdynamo_compiling,
17
+ logging,
18
+ )
19
+
20
+ _is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class CacheLayerMixin(ABC):
27
+ """Base, abstract class for a single layer's cache."""
28
+
29
+ is_compileable = False
30
+
31
+ def __init__(self):
32
+ self.keys, self.values, self.gatings = None, None, None
33
+
34
+ @abstractmethod
35
+ def update(
36
+ self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None
37
+ ) -> tuple[torch.Tensor, torch.Tensor]: ...
38
+
39
+ @abstractmethod
40
+ def lazy_initialization(self, key_states: torch.Tensor): ...
41
+
42
+ @abstractmethod
43
+ def get_seq_length(self, cache_position=None) -> int: ...
44
+
45
+ @abstractmethod
46
+ def get_max_cache_shape(self) -> int: ...
47
+
48
+ @abstractmethod
49
+ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ...
50
+
51
+ def offload(self):
52
+ """Offload this layer's data to CPU device."""
53
+ if self.keys is not None:
54
+ self.keys = self.keys.to("cpu", non_blocking=True)
55
+ self.values = self.values.to("cpu", non_blocking=True)
56
+ self.gatings = self.gatings.to("cpu", non_blocking=True)
57
+
58
+ def prefetch(self):
59
+ """In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
60
+ if self.keys is not None and self.keys.device != self.device:
61
+ self.keys = self.keys.to(self.device, non_blocking=True)
62
+ self.values = self.values.to(self.device, non_blocking=True)
63
+ self.gatings = self.gatings.to(self.device, non_blocking=True)
64
+
65
+ def reset(self) -> None:
66
+ """Resets the cache values while preserving the objects"""
67
+ if self.keys is not None:
68
+ self.keys.zero_()
69
+ self.values.zero_()
70
+ self.gatings.zero_()
71
+
72
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]:
73
+ """Reorders this layer's cache for beam search."""
74
+ if self.keys.numel():
75
+ device = self.keys.device
76
+ self.keys = self.keys.index_select(0, beam_idx.to(device))
77
+ if self.values.numel():
78
+ device = self.values.device
79
+ self.values = self.values.index_select(0, beam_idx.to(device))
80
+ if self.gatings.numel():
81
+ device = self.gatings.device
82
+ self.gatings = self.gatings.index_select(0, beam_idx.to(device))
83
+
84
+
85
+ class DynamicLayer(CacheLayerMixin):
86
+ """
87
+ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
88
+ It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`.
89
+
90
+ See `CacheLayerMixin` for details on common methods that are implemented by all cache layers.
91
+ """
92
+
93
+ is_sliding = False
94
+
95
+ def lazy_initialization(self, key_states: torch.Tensor):
96
+ self.dtype, self.device = key_states.dtype, key_states.device
97
+ self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
98
+ self.values = torch.tensor([], dtype=self.dtype, device=self.device)
99
+ self.gatings = torch.tensor([], dtype=torch.float32, device=self.device)
100
+
101
+ def update(
102
+ self,
103
+ key_states: torch.Tensor,
104
+ value_states: torch.Tensor,
105
+ gate_states: torch.Tensor,
106
+ cache_kwargs: Optional[dict[str, Any]] = None,
107
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
108
+ """
109
+ Updates the cache with the new `key_states` and `value_states`.
110
+
111
+ Parameters:
112
+ key_states (`torch.Tensor`):
113
+ The new key states to cache.
114
+ value_states (`torch.Tensor`):
115
+ The new value states to cache.
116
+ cache_kwargs (`dict[str, Any]`, *optional*):
117
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`.
118
+
119
+ Return:
120
+ A tuple containing the updated key and value states.
121
+ """
122
+ # Lazy initialization
123
+ if self.keys is None:
124
+ self.lazy_initialization(key_states)
125
+
126
+ self.keys = torch.cat([self.keys, key_states], dim=-2)
127
+ self.values = torch.cat([self.values, value_states], dim=-2)
128
+ self.gatings = torch.cat([self.gatings, gate_states], dim=-1)
129
+ return self.keys, self.values, self.gatings
130
+
131
+ def get_seq_length(self, cache_position=None) -> int:
132
+ """Returns the sequence length of the cached states."""
133
+ if self.keys is None or self.keys.numel() == 0:
134
+ return 0
135
+ return self.keys.shape[-2]
136
+
137
+ def get_max_cache_shape(self) -> int:
138
+ """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
139
+ return -1
140
+
141
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
142
+ """Reorders the cache for beam search, given the selected beam indices."""
143
+ if self.keys is not None and self.keys.numel():
144
+ self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
145
+ self.values = self.values.index_select(0, beam_idx.to(self.values.device))
146
+ self.gatings = self.gatings.index_select(0, beam_idx.to(self.gatings.device))
147
+
148
+ def crop(self, max_length: int) -> None:
149
+ """
150
+ Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
151
+ negative to remove `max_length` tokens.
152
+ """
153
+ if max_length < 0:
154
+ max_length = self.get_seq_length() - abs(max_length)
155
+
156
+ if self.get_seq_length() <= max_length:
157
+ return
158
+
159
+ if self.keys is not None and self.keys.numel():
160
+ self.keys = self.keys[..., :max_length, :]
161
+ self.values = self.values[..., :max_length, :]
162
+ self.gatings = self.gatings[..., :max_length]
163
+
164
+ def batch_repeat_interleave(self, repeats: int) -> None:
165
+ """Repeat the cache `repeats` times in the batch dimension."""
166
+ if self.keys is not None and self.keys.numel():
167
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
168
+ self.values = self.values.repeat_interleave(repeats, dim=0)
169
+ self.gatings = self.gatings.repeat_interleave(repeats, dim=0)
170
+
171
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
172
+ """Only keep the `indices` in the batch dimension of the cache."""
173
+ if self.keys is not None and self.keys.numel():
174
+ self.keys = self.keys[indices, ...]
175
+ self.values = self.values[indices, ...]
176
+ self.gatings = self.gatings[indices, ...]
177
+
178
+ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
179
+ """Return the length and offset of the cache, used to generate the mask"""
180
+ kv_offset = 0
181
+ query_length = cache_position.shape[0]
182
+ past_seen_tokens = self.get_seq_length()
183
+ kv_length = query_length + past_seen_tokens
184
+ return kv_length, kv_offset
185
+
186
+ @classmethod
187
+ def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor, gatings: torch.Tensor) -> "DynamicLayer":
188
+ """
189
+ Build a `DynamicLayer` instance from pre-existing key/value tensors.
190
+
191
+ Args:
192
+ keys (`torch.Tensor`):
193
+ Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
194
+ values (`torch.Tensor`):
195
+ Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
196
+ gatings (`torch.Tensor`):
197
+ Gating cache tensor of shape ``[batch_size, num_heads, seq_len]``.
198
+
199
+ Returns:
200
+ `DynamicLayer`: The newly constructed layer whose internal cache directly references
201
+ the supplied tensors.
202
+ """
203
+ layer = cls()
204
+ layer.dtype, layer.device = keys.dtype, keys.device
205
+ layer.keys = keys
206
+ layer.values = values
207
+ layer.gatings = gatings
208
+ return layer
209
+
210
+
211
+ class StaticLayer(CacheLayerMixin):
212
+ """
213
+ A static cache layer that stores the Key and Value states as static tensors with shape `[batch_size, num_heads, seq_len, head_dim]`.
214
+ It allocates its full backing tensors up-front and mutates them in-place. Built for `torch.compile` support.
215
+
216
+ See `CacheLayerMixin` for details on common methods that are implemented by all cache layers.
217
+ """
218
+
219
+ is_compileable = True
220
+ is_sliding = False
221
+
222
+ def __init__(self, max_cache_len: int):
223
+ """
224
+ Args:
225
+ max_cache_len (`int`):
226
+ Maximum number of tokens that can be stored, used for tensor preallocation.
227
+ """
228
+ super().__init__()
229
+ self.max_cache_len = max_cache_len
230
+
231
+ def lazy_initialization(self, key_states: torch.Tensor):
232
+ """
233
+ Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
234
+ num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
235
+ devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
236
+
237
+ If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
238
+ function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
239
+ internally don't compile the prefill, this is guaranteed to have been called already when compiling.
240
+ If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
241
+ it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
242
+ i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
243
+ not be compiled anyway for performances!
244
+ """
245
+ self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
246
+ self.dtype, self.device = key_states.dtype, key_states.device
247
+
248
+ self.keys = torch.zeros(
249
+ (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
250
+ dtype=self.dtype,
251
+ device=self.device,
252
+ )
253
+ self.values = torch.zeros(
254
+ (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
255
+ dtype=self.dtype,
256
+ device=self.device,
257
+ )
258
+ self.gatings = torch.zeros(
259
+ (self.max_batch_size, self.num_heads, self.max_cache_len),
260
+ dtype=torch.float32,
261
+ device=self.device,
262
+ )
263
+ # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
264
+ # breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
265
+ # As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
266
+ # prefill explicitly, but this should be avoided!)
267
+ if not is_torchdynamo_compiling():
268
+ torch._dynamo.mark_static_address(self.keys)
269
+ torch._dynamo.mark_static_address(self.values)
270
+
271
+ def update(
272
+ self,
273
+ key_states: torch.Tensor,
274
+ value_states: torch.Tensor,
275
+ gate_states: torch.Tensor,
276
+ cache_kwargs: Optional[dict[str, Any]] = None,
277
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
278
+ """
279
+ Update the static cache tensors in place.
280
+
281
+ Args:
282
+ key_states (`torch.Tensor`): The new key states to cache.
283
+ value_states (`torch.Tensor`): The new value states to cache.
284
+ gate_states (`torch.Tensor`): The new gate states to cache.
285
+ cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
286
+
287
+ Returns:
288
+ tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]: The updated key, value, and gate states.
289
+ """
290
+ # Lazy initialization
291
+ if self.keys is None:
292
+ self.lazy_initialization(key_states)
293
+
294
+ # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
295
+ # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
296
+ cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
297
+ cache_position = (
298
+ cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
299
+ )
300
+
301
+ # Update the cache
302
+ try:
303
+ self.keys.index_copy_(2, cache_position, key_states)
304
+ self.values.index_copy_(2, cache_position, value_states)
305
+ self.gatings.index_copy_(2, cache_position, gate_states)
306
+ except NotImplementedError:
307
+ # Fallback for devices like MPS where index_copy_ might not be supported.
308
+ self.keys[:, :, cache_position] = key_states
309
+ self.values[:, :, cache_position] = value_states
310
+ self.gatings[:, :, cache_position] = gate_states
311
+ return self.keys, self.values, self.gatings
312
+
313
+ def get_max_cache_shape(self) -> int:
314
+ """Return the maximum cache shape of the cache"""
315
+ return self.max_cache_len
316
+
317
+ def get_seq_length(self, cache_position=None) -> int:
318
+ """Returns the sequence length of the cached states."""
319
+ if cache_position is not None:
320
+ return int(cache_position[-1] + 1)
321
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
322
+ # limit the check to the first batch member and head dimension.
323
+ seq_length = (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0
324
+ return seq_length
325
+
326
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
327
+ """Reorders the cache for beam search, given the selected beam indices."""
328
+ dev = self.keys.device
329
+ beam_idx_dev = beam_idx.to(dev)
330
+ self.keys = self.keys.index_select(0, beam_idx_dev)
331
+ self.values = self.values.index_select(0, beam_idx_dev)
332
+ self.gatings = self.gatings.index_select(0, beam_idx_dev)
333
+
334
+ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
335
+ """Return the length and offset of the cache, used to generate the attention mask"""
336
+ kv_offset = 0
337
+ kv_length = self.max_cache_len
338
+ return kv_length, kv_offset
339
+
340
+
341
+
342
+ class KeyValuesGatingWrapper:
343
+ """Helper class for Cache that simulates layer-indexed key/value lists from a layered cache.
344
+ This allows for BC access and writing, e.g., cache.key_cache[idx] = ...
345
+ Deprecated in favor of Cache.layers[idx].keys/values. TODO: remove in v4.56.0"""
346
+
347
+ def __init__(self, layers, cache_type="keys"):
348
+ self.layers = layers
349
+ self.cache_type = cache_type
350
+
351
+ def __getitem__(self, idx):
352
+ if isinstance(idx, slice):
353
+ return [getattr(layer, self.cache_type) for layer in self.layers[idx]]
354
+ return getattr(self.layers[idx], self.cache_type)
355
+
356
+ def __setitem__(self, idx, value):
357
+ if isinstance(idx, slice):
358
+ for layer, val in zip(self.layers[idx], value):
359
+ setattr(layer, self.cache_type, val)
360
+ else:
361
+ setattr(self.layers[idx], self.cache_type, value)
362
+
363
+ def __len__(self):
364
+ return len(self.layers)
365
+
366
+ def __iter__(self):
367
+ for layer in self.layers:
368
+ yield getattr(layer, self.cache_type)
369
+
370
+ def __bool__(self):
371
+ return bool(self.layers)
372
+
373
+
374
+ class Cache:
375
+ """
376
+ A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for
377
+ the Cache of each layer.
378
+
379
+ Parameters:
380
+ layers (`Optional`, *optional*):
381
+ A list of pre-created `CacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` will
382
+ be used.
383
+ layer_class_to_replicate (`type[CacheLayerMixin]`, *optional*):
384
+ Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer,
385
+ and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current
386
+ list of layers.
387
+ offloading (`bool`, *optional*, defaults to `False`):
388
+ Whether to perform offloading of the layers to `cpu`, to save GPU memory.
389
+ offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
390
+ If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
391
+ usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
392
+ """
393
+
394
+ def __init__(
395
+ self,
396
+ layers: Optional[list[CacheLayerMixin]] = None,
397
+ layer_class_to_replicate: Optional[type[CacheLayerMixin]] = None,
398
+ offloading: bool = False,
399
+ offload_only_non_sliding: bool = True,
400
+ ):
401
+ if layers is not None and layer_class_to_replicate is not None:
402
+ raise ValueError(
403
+ "You can construct a Cache either from a list `layers` of all the predefined `CacheLayer`, or from a "
404
+ "`layer_class_to_replicate`, in which case the Cache will append a new layer corresponding to "
405
+ "`layer_class_to_replicate` for each new call to `update` with an idx not already in the Cache."
406
+ )
407
+ if layers is None and layer_class_to_replicate is None:
408
+ raise ValueError(
409
+ "You should provide exactly one of `layers` or `layer_class_to_replicate` to initialize a Cache."
410
+ )
411
+ self.layers = layers if layers is not None else []
412
+ self.layer_class_to_replicate = layer_class_to_replicate
413
+ self.offloading = offloading
414
+ if self.offloading:
415
+ self.only_non_sliding = offload_only_non_sliding
416
+ self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 else torch.cuda.Stream()
417
+
418
+ def __repr__(self):
419
+ return f"{self.__class__.__name__}(layers={self.layers})"
420
+
421
+ def prefetch(self, layer_idx: int, only_non_sliding: bool = True):
422
+ """
423
+ Prefetch a given layer on its device. If `only_non_sliding` is True, it will try to prefetch only the layers
424
+ which are non-sliding. If the `layer_idx` is outside the range, this will circle back to the first layers.
425
+ Note that we use a non-default stream for this, to avoid blocking.
426
+ """
427
+ if only_non_sliding:
428
+ # Try to find next non-sliding, starting at `layer_idx`
429
+ try:
430
+ layer_idx = layer_idx + self.is_sliding[layer_idx:].index(False)
431
+ # In this case, we need to circle back to the begining
432
+ except ValueError:
433
+ layer_idx = self.is_sliding.index(False)
434
+ else:
435
+ layer_idx = layer_idx if layer_idx < len(self.layers) else 0
436
+
437
+ # Prefetch
438
+ with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream):
439
+ self.layers[layer_idx].prefetch()
440
+
441
+ def offload(self, layer_idx: int, only_non_sliding: bool = True):
442
+ """
443
+ Offload a given `layer_idx`. If `only_non_sliding` is True, it will offload `layer_idx` only if it is a
444
+ non-sliding layer. Note that we do it on the default stream, so that we ensure all earlier
445
+ computation in the layer's `update` methods are finished.
446
+ """
447
+ if not (only_non_sliding and self.is_sliding[layer_idx]):
448
+ self.layers[layer_idx].offload()
449
+
450
+ def update(
451
+ self,
452
+ key_states: torch.Tensor,
453
+ value_states: torch.Tensor,
454
+ gate_states: torch.Tensor,
455
+ layer_idx: int,
456
+ cache_kwargs: Optional[dict[str, Any]] = None,
457
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
458
+ """
459
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
460
+
461
+ Parameters:
462
+ key_states (`torch.Tensor`):
463
+ The new key states to cache.
464
+ value_states (`torch.Tensor`):
465
+ The new value states to cache.
466
+ gate_states (`torch.Tensor`):
467
+ The new gate states to cache.
468
+ layer_idx (`int`):
469
+ The index of the layer to cache the states for.
470
+ cache_kwargs (`dict[str, Any]`, *optional*):
471
+ Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
472
+ cache to be created.
473
+
474
+ Return:
475
+ A tuple containing the updated key, value, and gate states.
476
+ """
477
+ # In this case, the `layers` were not provided, and we must append as much as `layer_idx`
478
+ if self.layer_class_to_replicate is not None:
479
+ while len(self.layers) <= layer_idx:
480
+ self.layers.append(self.layer_class_to_replicate())
481
+
482
+ if self.offloading:
483
+ # Wait for the stream to finish if needed, and start prefetching the next layer
484
+ torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream)
485
+ self.prefetch(layer_idx + 1, self.only_non_sliding)
486
+
487
+ keys, values, gatings = self.layers[layer_idx].update(key_states, value_states, gate_states, cache_kwargs)
488
+
489
+ if self.offloading:
490
+ self.offload(layer_idx, self.only_non_sliding)
491
+
492
+ return keys, values, gatings
493
+
494
+ def early_initialization(
495
+ self, batch_size: int, num_heads: int, head_dim: int, dtype: torch.dtype, device: torch.device
496
+ ):
497
+ """
498
+ Initialize all the layers in advance (it's otherwise lazily initialized on the first `update` call).
499
+ This is useful for our `export` recipes, as `export` needs everything in advance.
500
+ """
501
+ # Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use
502
+ # this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only
503
+ # creates an empty tensor with correct shape, dtype and device), which is very efficient and practical
504
+ fake_keys_tensor = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=dtype, device=device)
505
+ # Init all layers
506
+ for layer in self.layers:
507
+ layer.lazy_initialization(fake_keys_tensor)
508
+
509
+ def get_seq_length(self, layer_idx: int = 0, cache_position=None) -> int:
510
+ """Returns the sequence length of the cache for the given layer."""
511
+ if layer_idx >= len(self.layers):
512
+ return 0
513
+ return self.layers[layer_idx].get_seq_length(cache_position)
514
+
515
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
516
+ """
517
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
518
+ the given layer at `layer_idx`.
519
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
520
+ """
521
+ # For DynamicCache, where the layers are created at runtime -> if it was not yet created, the size is
522
+ # simply the shape of `cache_position`
523
+ if layer_idx >= len(self.layers):
524
+ return cache_position.shape[0], 0
525
+ return self.layers[layer_idx].get_mask_sizes(cache_position)
526
+
527
+ def get_max_cache_shape(self, layer_idx: int = 0) -> int:
528
+ """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length."""
529
+ # For DynamicCache, where the layers are created at runtime -> if it was not yet created, return -1
530
+ # as DynamicLayer does
531
+ if layer_idx >= len(self.layers):
532
+ return -1
533
+ return self.layers[layer_idx].get_max_cache_shape()
534
+
535
+ def reset(self):
536
+ """Recursively reset all layers tensors"""
537
+ for layer_idx in range(len(self.layers)):
538
+ self.layers[layer_idx].reset()
539
+
540
+ def reorder_cache(self, beam_idx: torch.LongTensor):
541
+ """Reorder the cache for beam search"""
542
+ for layer_idx in range(len(self.layers)):
543
+ self.layers[layer_idx].reorder_cache(beam_idx)
544
+
545
+ def crop(self, max_length: int):
546
+ """Crop the cache to the given length"""
547
+ for layer_idx in range(len(self.layers)):
548
+ self.layers[layer_idx].crop(max_length)
549
+
550
+ def batch_repeat_interleave(self, repeats: int):
551
+ """Repeat and interleave the cache"""
552
+ for layer_idx in range(len(self.layers)):
553
+ self.layers[layer_idx].batch_repeat_interleave(repeats)
554
+
555
+ def batch_select_indices(self, indices: torch.Tensor):
556
+ """Select indices from the cache"""
557
+ for layer_idx in range(len(self.layers)):
558
+ self.layers[layer_idx].batch_select_indices(indices)
559
+
560
+ @property
561
+ def max_batch_size(self) -> int:
562
+ """Return the maximum batch size of the cache"""
563
+ values = [layer.max_batch_size for layer in self.layers]
564
+ if len(set(values)) > 1:
565
+ raise ValueError(f"Max batch size is not consistent across layers: {values}")
566
+ return values[0]
567
+
568
+ @property
569
+ def max_cache_len(self) -> int:
570
+ """Return the maximum cache length of the cache"""
571
+ values = [layer.max_cache_len for layer in self.layers]
572
+ return max(values)
573
+
574
+ @property
575
+ def is_compileable(self) -> bool:
576
+ """Return whether the cache is compileable"""
577
+ # For DynamicCache dispatching the layers lazily (otherwise, all([]) is True)
578
+ if len(self.layers) == 0:
579
+ return False
580
+ return all(layer.is_compileable for layer in self.layers)
581
+
582
+ @property
583
+ def is_sliding(self) -> list[bool]:
584
+ """Return whether the layers of the cache are sliding window"""
585
+ return [getattr(layer, "is_sliding", False) for layer in self.layers]
586
+
587
+ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
588
+ """
589
+ Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the
590
+ sequence length.
591
+ """
592
+ if layer_idx < len(self.layers):
593
+ return self.layers[layer_idx].keys, self.layers[layer_idx].values, self.layers[layer_idx].gatings
594
+ else:
595
+ raise KeyError(
596
+ f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}"
597
+ )
598
+
599
+ def __iter__(self):
600
+ """
601
+ Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over
602
+ keys and values
603
+ """
604
+ for layer_idx in range(len(self)):
605
+ yield (self.layers[layer_idx].keys, self.layers[layer_idx].values, self.layers[layer_idx].gatings)
606
+
607
+ def __len__(self):
608
+ """
609
+ This value corresponds to the number of layers in the model.
610
+ """
611
+ # Note: for DynamicCache, layers are initialized lazily, so this will not be accurate before the first
612
+ # forward through all the layers
613
+ return len(self.layers)
614
+
615
+ @property
616
+ def key_cache(self) -> KeyValuesGatingWrapper:
617
+ """List-like object of key cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].keys`"""
618
+ logger.warning_once(
619
+ "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead."
620
+ )
621
+ return KeyValuesGatingWrapper(self.layers, "keys")
622
+
623
+ @property
624
+ def value_cache(self) -> KeyValuesGatingWrapper:
625
+ """List-like object of value cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].values`"""
626
+ logger.warning_once(
627
+ "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead."
628
+ )
629
+ return KeyValuesGatingWrapper(self.layers, "values")
630
+
631
+ @property
632
+ def gating_cache(self) -> KeyValuesGatingWrapper:
633
+ """List-like object of gate cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].gatings`"""
634
+ logger.warning_once(
635
+ "`cache.gate_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].gatings` instead."
636
+ )
637
+ return KeyValuesGatingWrapper(self.layers, "gatings")
638
+
639
+ class DynamicCache(Cache):
640
+ """
641
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
642
+
643
+ It stores the Key, Value, and Gating states as a list of tensors, one for each layer. The expected shape for each tensor is
644
+ `[batch_size, num_heads, seq_len, head_dim]` for Key and Value, and `[batch_size, num_heads, seq_len]` for Gating.
645
+
646
+ See `Cache` for details on common methods that are implemented by all cache classes.
647
+
648
+ Example:
649
+
650
+ ```python
651
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
652
+
653
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
654
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
655
+
656
+ >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
657
+
658
+ >>> # Prepare a cache class and pass it to model's forward
659
+ >>> past_key_values = DynamicCache()
660
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
661
+ >>> outputs.past_key_values # access cache filled with key/values from generation
662
+ DynamicCache()
663
+ ```
664
+ """
665
+
666
+ # Specialized constructor for DDP cache data, needed for BC
667
+ def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]] = None):
668
+ # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212
669
+ # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the
670
+ # iterable contains the key and value states for a layer gathered across replicas by torch.distributed
671
+ # (shape=[global batch size, num_heads, seq_len, head_dim]).
672
+ if ddp_cache_data is not None:
673
+ layers = []
674
+ for key_states, value_states, gate_states in ddp_cache_data:
675
+ layers.append(DynamicLayer.from_tensors(key_states, value_states, gate_states))
676
+ super().__init__(layers=layers)
677
+ else:
678
+ super().__init__(layer_class_to_replicate=DynamicLayer)
679
+
680
+ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
681
+ """
682
+ Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
683
+ backward compatibility.
684
+ """
685
+ legacy_cache = ()
686
+ for layer in self.layers:
687
+ legacy_cache += ((layer.keys, layer.values, layer.gatings),)
688
+ return legacy_cache
689
+
690
+ @classmethod
691
+ def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...]) -> "Cache":
692
+ """
693
+ Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
694
+ backward compatibility.
695
+ """
696
+ cache = cls()
697
+ if past_key_values is not None:
698
+ for layer_idx in range(len(past_key_values)):
699
+ key_states, value_states, gate_states = past_key_values[layer_idx]
700
+ cache.update(key_states, value_states, gate_states, layer_idx)
701
+ return cache
702
+
703
+
704
+ # Utilities for `DynamicCache` <> torch.export support
705
+
706
+ if is_torch_greater_or_equal("2.3"):
707
+
708
+ def _get_cache_dict(cache: DynamicCache):
709
+ if any(not isinstance(layer, DynamicLayer) for layer in cache.layers):
710
+ raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
711
+
712
+ if not is_torch_greater_or_equal_than_2_6:
713
+ logger.warning_once(
714
+ "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions."
715
+ )
716
+
717
+ return {
718
+ "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None],
719
+ "value_cache": [layer.values for layer in cache.layers if layer.values is not None],
720
+ "gating_cache": [layer.gatings for layer in cache.layers if layer.gatings is not None],
721
+ }
722
+
723
+ def _unflatten_dynamic_cache(
724
+ values,
725
+ context: torch.utils._pytree.Context,
726
+ ):
727
+ dictionary = torch.utils._pytree._dict_unflatten(values, context)
728
+ cache = DynamicCache()
729
+ # Reconstruct layers from keys and values lists
730
+ key_list = dictionary.get("key_cache", [])
731
+ value_list = dictionary.get("value_cache", [])
732
+ gating_list = dictionary.get("gating_cache", [])
733
+ for idx in range(max(len(key_list), len(value_list), len(gating_list))):
734
+ key = key_list[idx] if idx < len(key_list) else None
735
+ value = value_list[idx] if idx < len(value_list) else None
736
+ gating = gating_list[idx] if idx < len(gating_list) else None
737
+ cache.update(key, value, gating, idx)
738
+ return cache
739
+
740
+ torch.utils._pytree.register_pytree_node(
741
+ DynamicCache,
742
+ lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)),
743
+ _unflatten_dynamic_cache,
744
+ serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
745
+ flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
746
+ _get_cache_dict(dynamic_cache)
747
+ ),
748
+ )
749
+ # TODO (tmanlaibaatar) This won't be needed in torch 2.7.
750
+ torch.fx._pytree.register_pytree_flatten_spec(
751
+ DynamicCache, lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec)
752
+ )
kvgs_dynamic_cache.py ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+ from collections.abc import Iterable
6
+ from dataclasses import dataclass
7
+ from typing import Any, Optional, Union
8
+
9
+ import torch
10
+
11
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6
12
+
13
+ from transformers.configuration_utils import PretrainedConfig
14
+ from transformers.utils import (
15
+ is_torch_greater_or_equal,
16
+ is_torchdynamo_compiling,
17
+ logging,
18
+ )
19
+
20
+ _is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class CacheLayerMixin(ABC):
27
+ """Base, abstract class for a single layer's cache."""
28
+
29
+ is_compileable = False
30
+
31
+ def __init__(self):
32
+ self.keys, self.values, self.gatings, self.state, self.sum_of_keys = None, None, None, None, None
33
+
34
+ @abstractmethod
35
+ def update_kv(
36
+ self, key_states: torch.Tensor, value_states: torch.Tensor, gate_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None
37
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ...
38
+
39
+ @abstractmethod
40
+ def update_state(
41
+ self, state: torch.Tensor, sum_of_keys: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None
42
+ ) -> tuple[torch.Tensor, torch.Tensor]: ...
43
+
44
+ @abstractmethod
45
+ def lazy_initialization(self, key_states: torch.Tensor): ...
46
+
47
+ @abstractmethod
48
+ def lazy_initialization_state(self, state: torch.Tensor): ...
49
+
50
+ @abstractmethod
51
+ def get_seq_length(self, cache_position=None) -> int: ...
52
+
53
+ @abstractmethod
54
+ def get_max_cache_shape(self) -> int: ...
55
+
56
+ @abstractmethod
57
+ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ...
58
+
59
+ def offload(self):
60
+ """Offload this layer's data to CPU device."""
61
+ if self.keys is not None:
62
+ self.keys = self.keys.to("cpu", non_blocking=True)
63
+ self.values = self.values.to("cpu", non_blocking=True)
64
+ self.gatings = self.gatings.to("cpu", non_blocking=True)
65
+ self.state = self.state.to("cpu", non_blocking=True)
66
+ self.sum_of_keys = self.sum_of_keys.to("cpu", non_blocking=True)
67
+
68
+ def prefetch(self):
69
+ """In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
70
+ if self.keys is not None and self.keys.device != self.device:
71
+ self.keys = self.keys.to(self.device, non_blocking=True)
72
+ self.values = self.values.to(self.device, non_blocking=True)
73
+ self.gatings = self.gatings.to(self.device, non_blocking=True)
74
+ self.state = self.state.to(self.device, non_blocking=True)
75
+ self.sum_of_keys = self.sum_of_keys.to(self.device, non_blocking=True)
76
+
77
+ def reset(self) -> None:
78
+ """Resets the cache values while preserving the objects"""
79
+ if self.keys is not None:
80
+ self.keys.zero_()
81
+ self.values.zero_()
82
+ self.gatings.zero_()
83
+ self.state.zero_()
84
+ self.sum_of_keys.zero_()
85
+
86
+ def clean_kv(self) -> None:
87
+ if self.keys is not None:
88
+ self.keys = None
89
+ self.values = None
90
+ self.gatings = None
91
+
92
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]:
93
+ """Reorders this layer's cache for beam search."""
94
+ if self.keys.numel():
95
+ device = self.keys.device
96
+ self.keys = self.keys.index_select(0, beam_idx.to(device))
97
+ if self.values.numel():
98
+ device = self.values.device
99
+ self.values = self.values.index_select(0, beam_idx.to(device))
100
+ if self.gatings.numel():
101
+ device = self.gatings.device
102
+ self.gatings = self.gatings.index_select(0, beam_idx.to(device))
103
+ if self.state.numel():
104
+ device = self.state.device
105
+ self.state = self.state.index_select(0, beam_idx.to(device))
106
+ if self.sum_of_keys.numel():
107
+ device = self.sum_of_keys.device
108
+ self.sum_of_keys = self.sum_of_keys.index_select(0, beam_idx.to(device))
109
+
110
+
111
+ class DynamicLayer(CacheLayerMixin):
112
+ """
113
+ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
114
+ It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`.
115
+
116
+ See `CacheLayerMixin` for details on common methods that are implemented by all cache layers.
117
+ """
118
+
119
+ is_sliding = False
120
+
121
+ def lazy_initialization(self, key_states: torch.Tensor):
122
+ self.dtype, self.device = key_states.dtype, key_states.device
123
+ self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
124
+ self.values = torch.tensor([], dtype=self.dtype, device=self.device)
125
+ self.gatings = torch.tensor([], dtype=torch.float32, device=self.device)
126
+
127
+ def lazy_initialization_state(self, state: torch.Tensor):
128
+ self.state = torch.tensor([], dtype=torch.float32, device=self.device)
129
+ self.sum_of_keys = torch.tensor([], dtype=torch.float32, device=self.device)
130
+
131
+ def update_kv(
132
+ self,
133
+ key_states: torch.Tensor,
134
+ value_states: torch.Tensor,
135
+ gate_states: torch.Tensor,
136
+ cache_kwargs: Optional[dict[str, Any]] = None,
137
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
138
+ """
139
+ Updates the cache with the new `key_states` and `value_states`.
140
+
141
+ Parameters:
142
+ key_states (`torch.Tensor`):
143
+ The new key states to cache.
144
+ value_states (`torch.Tensor`):
145
+ The new value states to cache.
146
+ gate_states (`torch.Tensor`):
147
+ The new gate states to cache.
148
+ cache_kwargs (`dict[str, Any]`, *optional*):
149
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`.
150
+
151
+ Return:
152
+ A tuple containing the updated key and value states, and current state and sum of keys.
153
+ """
154
+ # Lazy initialization
155
+ if self.keys is None:
156
+ self.lazy_initialization(key_states)
157
+
158
+ self.keys = torch.cat([self.keys, key_states], dim=-2)
159
+ self.values = torch.cat([self.values, value_states], dim=-2)
160
+ self.gatings = torch.cat([self.gatings, gate_states], dim=-1)
161
+ return self.keys, self.values, self.gatings, self.state, self.sum_of_keys
162
+
163
+ def update_state(
164
+ self, state: torch.Tensor, sum_of_keys: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None
165
+ ) -> tuple[torch.Tensor, torch.Tensor]:
166
+ # Lazy initialization
167
+ if self.state is None:
168
+ self.lazy_initialization_state(state)
169
+
170
+ self.state = state
171
+ self.sum_of_keys = sum_of_keys
172
+ return self.state, self.sum_of_keys
173
+
174
+ def get_seq_length(self, cache_position=None) -> int:
175
+ """Returns the sequence length of the cached states."""
176
+ if self.keys is None or self.keys.numel() == 0:
177
+ return 0
178
+ return self.keys.shape[-2]
179
+
180
+ def get_max_cache_shape(self) -> int:
181
+ """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
182
+ return -1
183
+
184
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
185
+ """Reorders the cache for beam search, given the selected beam indices."""
186
+ if self.keys is not None and self.keys.numel():
187
+ self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
188
+ self.values = self.values.index_select(0, beam_idx.to(self.values.device))
189
+ self.gatings = self.gatings.index_select(0, beam_idx.to(self.gatings.device))
190
+ self.state = self.state.index_select(0, beam_idx.to(self.state.device))
191
+ self.sum_of_keys = self.sum_of_keys.index_select(0, beam_idx.to(self.sum_of_keys.device))
192
+
193
+ def crop(self, max_length: int) -> None:
194
+ """
195
+ Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
196
+ negative to remove `max_length` tokens.
197
+ """
198
+ if max_length < 0:
199
+ max_length = self.get_seq_length() - abs(max_length)
200
+
201
+ if self.get_seq_length() <= max_length:
202
+ return
203
+
204
+ if self.keys is not None and self.keys.numel():
205
+ self.keys = self.keys[..., :max_length, :]
206
+ self.values = self.values[..., :max_length, :]
207
+ self.gatings = self.gatings[..., :max_length]
208
+
209
+ def batch_repeat_interleave(self, repeats: int) -> None:
210
+ """Repeat the cache `repeats` times in the batch dimension."""
211
+ if self.keys is not None and self.keys.numel():
212
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
213
+ self.values = self.values.repeat_interleave(repeats, dim=0)
214
+ self.gatings = self.gatings.repeat_interleave(repeats, dim=0)
215
+ self.state = self.state.repeat_interleave(repeats, dim=0)
216
+ self.sum_of_keys = self.sum_of_keys.repeat_interleave(repeats, dim=0)
217
+
218
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
219
+ """Only keep the `indices` in the batch dimension of the cache."""
220
+ if self.keys is not None and self.keys.numel():
221
+ self.keys = self.keys[indices, ...]
222
+ self.values = self.values[indices, ...]
223
+ self.gatings = self.gatings[indices, ...]
224
+ self.state = self.state[indices, ...]
225
+ self.sum_of_keys = self.sum_of_keys[indices, ...]
226
+
227
+ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
228
+ """Return the length and offset of the cache, used to generate the mask"""
229
+ kv_offset = 0
230
+ query_length = cache_position.shape[0]
231
+ past_seen_tokens = self.get_seq_length()
232
+ kv_length = query_length + past_seen_tokens
233
+ return kv_length, kv_offset
234
+
235
+ @classmethod
236
+ def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor, gatings: torch.Tensor, state: torch.Tensor, sum_of_keys: torch.Tensor) -> "DynamicLayer":
237
+ """
238
+ Build a `DynamicLayer` instance from pre-existing key/value tensors.
239
+
240
+ Args:
241
+ keys (`torch.Tensor`):
242
+ Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
243
+ values (`torch.Tensor`):
244
+ Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
245
+ gatings (`torch.Tensor`):
246
+ Gating cache tensor of shape ``[batch_size, num_heads, seq_len]``.
247
+
248
+ Returns:
249
+ `DynamicLayer`: The newly constructed layer whose internal cache directly references
250
+ the supplied tensors.
251
+ """
252
+ layer = cls()
253
+ layer.dtype, layer.device = keys.dtype, keys.device
254
+ layer.keys = keys
255
+ layer.values = values
256
+ layer.gatings = gatings
257
+ layer.state = state
258
+ layer.sum_of_keys = sum_of_keys
259
+ return layer
260
+
261
+
262
+ class StaticLayer(CacheLayerMixin):
263
+ """
264
+ A static cache layer that stores the Key and Value states as static tensors with shape `[batch_size, num_heads, seq_len, head_dim]`.
265
+ It allocates its full backing tensors up-front and mutates them in-place. Built for `torch.compile` support.
266
+
267
+ See `CacheLayerMixin` for details on common methods that are implemented by all cache layers.
268
+ """
269
+
270
+ is_compileable = True
271
+ is_sliding = False
272
+
273
+ def __init__(self, max_cache_len: int):
274
+ """
275
+ Args:
276
+ max_cache_len (`int`):
277
+ Maximum number of tokens that can be stored, used for tensor preallocation.
278
+ """
279
+ super().__init__()
280
+ self.max_cache_len = max_cache_len
281
+
282
+ def lazy_initialization(self, key_states: torch.Tensor):
283
+ """
284
+ Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
285
+ num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
286
+ devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
287
+
288
+ If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
289
+ function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
290
+ internally don't compile the prefill, this is guaranteed to have been called already when compiling.
291
+ If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
292
+ it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
293
+ i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
294
+ not be compiled anyway for performances!
295
+ """
296
+ self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
297
+ self.dtype, self.device = key_states.dtype, key_states.device
298
+
299
+ self.keys = torch.zeros(
300
+ (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
301
+ dtype=self.dtype,
302
+ device=self.device,
303
+ )
304
+ self.values = torch.zeros(
305
+ (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
306
+ dtype=self.dtype,
307
+ device=self.device,
308
+ )
309
+ self.gatings = torch.zeros(
310
+ (self.max_batch_size, self.num_heads, self.max_cache_len),
311
+ dtype=torch.float32,
312
+ device=self.device,
313
+ )
314
+ # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
315
+ # breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
316
+ # As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
317
+ # prefill explicitly, but this should be avoided!)
318
+ if not is_torchdynamo_compiling():
319
+ torch._dynamo.mark_static_address(self.keys)
320
+ torch._dynamo.mark_static_address(self.values)
321
+ torch._dynamo.mark_static_address(self.gatings)
322
+
323
+
324
+ def lazy_initialization_state(self, state: torch.Tensor):
325
+ self.state = torch.zeros(
326
+ (self.max_batch_size, self.num_heads, self.D, self.head_dim),
327
+ dtype=self.dtype,
328
+ device=self.device,
329
+ )
330
+ self.sum_of_keys = torch.zeros(
331
+ (self.max_batch_size, self.num_heads, self.max_cache_len),
332
+ dtype=torch.float32,
333
+ device=self.device,
334
+ )
335
+ if not is_torchdynamo_compiling():
336
+ torch._dynamo.mark_static_address(self.state)
337
+ torch._dynamo.mark_static_address(self.sum_of_keys)
338
+
339
+ def update(
340
+ self,
341
+ key_states: torch.Tensor,
342
+ value_states: torch.Tensor,
343
+ gate_states: torch.Tensor,
344
+ cache_kwargs: Optional[dict[str, Any]] = None,
345
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
346
+ """
347
+ Update the static cache tensors in place.
348
+
349
+ Args:
350
+ key_states (`torch.Tensor`): The new key states to cache.
351
+ value_states (`torch.Tensor`): The new value states to cache.
352
+ gate_states (`torch.Tensor`): The new gate states to cache.
353
+ cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
354
+
355
+ Returns:
356
+ tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`, `torch.Tensor`, `torch.Tensor`]: The updated key, value, and gate states, and current state and sum of keys.
357
+ """
358
+ # Lazy initialization
359
+ if self.keys is None:
360
+ self.lazy_initialization(key_states)
361
+
362
+ # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
363
+ # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
364
+ cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
365
+ cache_position = (
366
+ cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
367
+ )
368
+
369
+ # Update the cache
370
+ try:
371
+ self.keys.index_copy_(2, cache_position, key_states)
372
+ self.values.index_copy_(2, cache_position, value_states)
373
+ self.gatings.index_copy_(2, cache_position, gate_states)
374
+ except NotImplementedError:
375
+ # Fallback for devices like MPS where index_copy_ might not be supported.
376
+ self.keys[:, :, cache_position] = key_states
377
+ self.values[:, :, cache_position] = value_states
378
+ self.gatings[:, :, cache_position] = gate_states
379
+ return self.keys, self.values, self.gatings, self.state, self.sum_of_keys
380
+
381
+ def update_state(
382
+ self, state: torch.Tensor, sum_of_keys: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None
383
+ ) -> tuple[torch.Tensor, torch.Tensor]:
384
+ # Lazy initialization
385
+ if self.state is None:
386
+ self.lazy_initialization_state(state)
387
+
388
+ self.state = state
389
+ self.sum_of_keys = sum_of_keys
390
+ return self.state, self.sum_of_keys
391
+
392
+ def get_max_cache_shape(self) -> int:
393
+ """Return the maximum cache shape of the cache"""
394
+ return self.max_cache_len
395
+
396
+ def get_seq_length(self, cache_position=None) -> int:
397
+ """Returns the sequence length of the cached states."""
398
+ if cache_position is not None:
399
+ return int(cache_position[-1] + 1)
400
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
401
+ # limit the check to the first batch member and head dimension.
402
+ seq_length = (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0
403
+ return seq_length
404
+
405
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
406
+ """Reorders the cache for beam search, given the selected beam indices."""
407
+ dev = self.keys.device
408
+ beam_idx_dev = beam_idx.to(dev)
409
+ self.keys = self.keys.index_select(0, beam_idx_dev)
410
+ self.values = self.values.index_select(0, beam_idx_dev)
411
+ self.gatings = self.gatings.index_select(0, beam_idx_dev)
412
+ self.state = self.state.index_select(0, beam_idx_dev)
413
+ self.sum_of_keys = self.sum_of_keys.index_select(0, beam_idx_dev)
414
+
415
+ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
416
+ """Return the length and offset of the cache, used to generate the attention mask"""
417
+ kv_offset = 0
418
+ kv_length = self.max_cache_len
419
+ return kv_length, kv_offset
420
+
421
+
422
+
423
+ class KeyValuesGatingStateWrapper:
424
+ """Helper class for Cache that simulates layer-indexed key/value lists from a layered cache.
425
+ This allows for BC access and writing, e.g., cache.key_cache[idx] = ...
426
+ Deprecated in favor of Cache.layers[idx].keys/values. TODO: remove in v4.56.0"""
427
+
428
+ def __init__(self, layers, cache_type="keys"):
429
+ self.layers = layers
430
+ self.cache_type = cache_type
431
+
432
+ def __getitem__(self, idx):
433
+ if isinstance(idx, slice):
434
+ return [getattr(layer, self.cache_type) for layer in self.layers[idx]]
435
+ return getattr(self.layers[idx], self.cache_type)
436
+
437
+ def __setitem__(self, idx, value):
438
+ if isinstance(idx, slice):
439
+ for layer, val in zip(self.layers[idx], value):
440
+ setattr(layer, self.cache_type, val)
441
+ else:
442
+ setattr(self.layers[idx], self.cache_type, value)
443
+
444
+ def __len__(self):
445
+ return len(self.layers)
446
+
447
+ def __iter__(self):
448
+ for layer in self.layers:
449
+ yield getattr(layer, self.cache_type)
450
+
451
+ def __bool__(self):
452
+ return bool(self.layers)
453
+
454
+
455
+ class Cache:
456
+ """
457
+ A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for
458
+ the Cache of each layer.
459
+
460
+ Parameters:
461
+ layers (`Optional`, *optional*):
462
+ A list of pre-created `CacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` will
463
+ be used.
464
+ layer_class_to_replicate (`type[CacheLayerMixin]`, *optional*):
465
+ Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer,
466
+ and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current
467
+ list of layers.
468
+ offloading (`bool`, *optional*, defaults to `False`):
469
+ Whether to perform offloading of the layers to `cpu`, to save GPU memory.
470
+ offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
471
+ If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
472
+ usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
473
+ """
474
+
475
+ def __init__(
476
+ self,
477
+ layers: Optional[list[CacheLayerMixin]] = None,
478
+ layer_class_to_replicate: Optional[type[CacheLayerMixin]] = None,
479
+ offloading: bool = False,
480
+ offload_only_non_sliding: bool = True,
481
+ ):
482
+ if layers is not None and layer_class_to_replicate is not None:
483
+ raise ValueError(
484
+ "You can construct a Cache either from a list `layers` of all the predefined `CacheLayer`, or from a "
485
+ "`layer_class_to_replicate`, in which case the Cache will append a new layer corresponding to "
486
+ "`layer_class_to_replicate` for each new call to `update` with an idx not already in the Cache."
487
+ )
488
+ if layers is None and layer_class_to_replicate is None:
489
+ raise ValueError(
490
+ "You should provide exactly one of `layers` or `layer_class_to_replicate` to initialize a Cache."
491
+ )
492
+ self.layers = layers if layers is not None else []
493
+ self.layer_class_to_replicate = layer_class_to_replicate
494
+ self.offloading = offloading
495
+ if self.offloading:
496
+ self.only_non_sliding = offload_only_non_sliding
497
+ self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 else torch.cuda.Stream()
498
+
499
+ def __repr__(self):
500
+ return f"{self.__class__.__name__}(layers={self.layers})"
501
+
502
+ def prefetch(self, layer_idx: int, only_non_sliding: bool = True):
503
+ """
504
+ Prefetch a given layer on its device. If `only_non_sliding` is True, it will try to prefetch only the layers
505
+ which are non-sliding. If the `layer_idx` is outside the range, this will circle back to the first layers.
506
+ Note that we use a non-default stream for this, to avoid blocking.
507
+ """
508
+ if only_non_sliding:
509
+ # Try to find next non-sliding, starting at `layer_idx`
510
+ try:
511
+ layer_idx = layer_idx + self.is_sliding[layer_idx:].index(False)
512
+ # In this case, we need to circle back to the begining
513
+ except ValueError:
514
+ layer_idx = self.is_sliding.index(False)
515
+ else:
516
+ layer_idx = layer_idx if layer_idx < len(self.layers) else 0
517
+
518
+ # Prefetch
519
+ with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream):
520
+ self.layers[layer_idx].prefetch()
521
+
522
+ def offload(self, layer_idx: int, only_non_sliding: bool = True):
523
+ """
524
+ Offload a given `layer_idx`. If `only_non_sliding` is True, it will offload `layer_idx` only if it is a
525
+ non-sliding layer. Note that we do it on the default stream, so that we ensure all earlier
526
+ computation in the layer's `update` methods are finished.
527
+ """
528
+ if not (only_non_sliding and self.is_sliding[layer_idx]):
529
+ self.layers[layer_idx].offload()
530
+
531
+ def update_kv(
532
+ self,
533
+ key_states: torch.Tensor,
534
+ value_states: torch.Tensor,
535
+ gate_states: torch.Tensor,
536
+ layer_idx: int,
537
+ cache_kwargs: Optional[dict[str, Any]] = None,
538
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
539
+ """
540
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
541
+
542
+ Parameters:
543
+ key_states (`torch.Tensor`):
544
+ The new key states to cache.
545
+ value_states (`torch.Tensor`):
546
+ The new value states to cache.
547
+ gate_states (`torch.Tensor`):
548
+ The new gate states to cache.
549
+ layer_idx (`int`):
550
+ The index of the layer to cache the states for.
551
+ cache_kwargs (`dict[str, Any]`, *optional*):
552
+ Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
553
+ cache to be created.
554
+
555
+ Return:
556
+ A tuple containing the updated key, value, and gate states, and current state and sum of keys.
557
+ """
558
+ # In this case, the `layers` were not provided, and we must append as much as `layer_idx`
559
+ if self.layer_class_to_replicate is not None:
560
+ while len(self.layers) <= layer_idx:
561
+ self.layers.append(self.layer_class_to_replicate())
562
+
563
+ if self.offloading:
564
+ # Wait for the stream to finish if needed, and start prefetching the next layer
565
+ torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream)
566
+ self.prefetch(layer_idx + 1, self.only_non_sliding)
567
+
568
+ keys, values, gatings, state, sum_of_keys = self.layers[layer_idx].update_kv(key_states, value_states, gate_states, cache_kwargs)
569
+
570
+ if self.offloading:
571
+ self.offload(layer_idx, self.only_non_sliding)
572
+
573
+ return keys, values, gatings, state, sum_of_keys
574
+
575
+ def clean_kv(self, layer_idx: int) -> None:
576
+ self.layers[layer_idx].clean_kv()
577
+
578
+ def update_state(
579
+ self, state: torch.Tensor, sum_of_keys: torch.Tensor, layer_idx: int, cache_kwargs: Optional[dict[str, Any]] = None
580
+ ) -> tuple[torch.Tensor, torch.Tensor]:
581
+ # In this case, the `layers` were not provided, and we must append as much as `layer_idx`
582
+
583
+ state, sum_of_keys = self.layers[layer_idx].update_state(state, sum_of_keys, cache_kwargs)
584
+
585
+ return state, sum_of_keys
586
+
587
+ def early_initialization(
588
+ self, batch_size: int, num_heads: int, head_dim: int, D: int, dtype: torch.dtype, device: torch.device
589
+ ):
590
+ """
591
+ Initialize all the layers in advance (it's otherwise lazily initialized on the first `update` call).
592
+ This is useful for our `export` recipes, as `export` needs everything in advance.
593
+ """
594
+ # Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use
595
+ # this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only
596
+ # creates an empty tensor with correct shape, dtype and device), which is very efficient and practical
597
+ fake_keys_tensor = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=dtype, device=device)
598
+ fake_state_tensor = torch.zeros((batch_size, num_heads, D, head_dim), dtype=dtype, device=device)
599
+ # Init all layers
600
+ for layer in self.layers:
601
+ layer.lazy_initialization(fake_keys_tensor)
602
+ layer.lazy_initialization_state(fake_state_tensor)
603
+
604
+ def get_seq_length(self, layer_idx: int = 0, cache_position=None) -> int:
605
+ """Returns the sequence length of the cache for the given layer."""
606
+ if layer_idx >= len(self.layers):
607
+ return 0
608
+ return self.layers[layer_idx].get_seq_length(cache_position)
609
+
610
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
611
+ """
612
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
613
+ the given layer at `layer_idx`.
614
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
615
+ """
616
+ # For DynamicCache, where the layers are created at runtime -> if it was not yet created, the size is
617
+ # simply the shape of `cache_position`
618
+ if layer_idx >= len(self.layers):
619
+ return cache_position.shape[0], 0
620
+ return self.layers[layer_idx].get_mask_sizes(cache_position)
621
+
622
+ def get_max_cache_shape(self, layer_idx: int = 0) -> int:
623
+ """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length."""
624
+ # For DynamicCache, where the layers are created at runtime -> if it was not yet created, return -1
625
+ # as DynamicLayer does
626
+ if layer_idx >= len(self.layers):
627
+ return -1
628
+ return self.layers[layer_idx].get_max_cache_shape()
629
+
630
+ def reset(self):
631
+ """Recursively reset all layers tensors"""
632
+ for layer_idx in range(len(self.layers)):
633
+ self.layers[layer_idx].reset()
634
+
635
+ def reorder_cache(self, beam_idx: torch.LongTensor):
636
+ """Reorder the cache for beam search"""
637
+ for layer_idx in range(len(self.layers)):
638
+ self.layers[layer_idx].reorder_cache(beam_idx)
639
+
640
+ def crop(self, max_length: int):
641
+ """Crop the cache to the given length"""
642
+ for layer_idx in range(len(self.layers)):
643
+ self.layers[layer_idx].crop(max_length)
644
+
645
+ def batch_repeat_interleave(self, repeats: int):
646
+ """Repeat and interleave the cache"""
647
+ for layer_idx in range(len(self.layers)):
648
+ self.layers[layer_idx].batch_repeat_interleave(repeats)
649
+
650
+ def batch_select_indices(self, indices: torch.Tensor):
651
+ """Select indices from the cache"""
652
+ for layer_idx in range(len(self.layers)):
653
+ self.layers[layer_idx].batch_select_indices(indices)
654
+
655
+ @property
656
+ def max_batch_size(self) -> int:
657
+ """Return the maximum batch size of the cache"""
658
+ values = [layer.max_batch_size for layer in self.layers]
659
+ if len(set(values)) > 1:
660
+ raise ValueError(f"Max batch size is not consistent across layers: {values}")
661
+ return values[0]
662
+
663
+ @property
664
+ def max_cache_len(self) -> int:
665
+ """Return the maximum cache length of the cache"""
666
+ values = [layer.max_cache_len for layer in self.layers]
667
+ return max(values)
668
+
669
+ @property
670
+ def is_compileable(self) -> bool:
671
+ """Return whether the cache is compileable"""
672
+ # For DynamicCache dispatching the layers lazily (otherwise, all([]) is True)
673
+ if len(self.layers) == 0:
674
+ return False
675
+ return all(layer.is_compileable for layer in self.layers)
676
+
677
+ @property
678
+ def is_sliding(self) -> list[bool]:
679
+ """Return whether the layers of the cache are sliding window"""
680
+ return [getattr(layer, "is_sliding", False) for layer in self.layers]
681
+
682
+ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
683
+ """
684
+ Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the
685
+ sequence length.
686
+ """
687
+ if layer_idx < len(self.layers):
688
+ return self.layers[layer_idx].keys, self.layers[layer_idx].values, self.layers[layer_idx].gatings
689
+ else:
690
+ raise KeyError(
691
+ f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}"
692
+ )
693
+
694
+ def __iter__(self):
695
+ """
696
+ Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over
697
+ keys and values
698
+ """
699
+ for layer_idx in range(len(self)):
700
+ yield (self.layers[layer_idx].keys, self.layers[layer_idx].values, self.layers[layer_idx].gatings)
701
+
702
+ def __len__(self):
703
+ """
704
+ This value corresponds to the number of layers in the model.
705
+ """
706
+ # Note: for DynamicCache, layers are initialized lazily, so this will not be accurate before the first
707
+ # forward through all the layers
708
+ return len(self.layers)
709
+
710
+ @property
711
+ def key_cache(self) -> KeyValuesGatingStateWrapper:
712
+ """List-like object of key cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].keys`"""
713
+ logger.warning_once(
714
+ "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead."
715
+ )
716
+ return KeyValuesGatingStateWrapper(self.layers, "keys")
717
+
718
+ @property
719
+ def value_cache(self) -> KeyValuesGatingStateWrapper:
720
+ """List-like object of value cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].values`"""
721
+ logger.warning_once(
722
+ "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead."
723
+ )
724
+ return KeyValuesGatingStateWrapper(self.layers, "values")
725
+
726
+ @property
727
+ def gating_cache(self) -> KeyValuesGatingStateWrapper:
728
+ """List-like object of gate cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].gatings`"""
729
+ logger.warning_once(
730
+ "`cache.gate_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].gatings` instead."
731
+ )
732
+ return KeyValuesGatingStateWrapper(self.layers, "gatings")
733
+
734
+ @property
735
+ def state_cache(self) -> KeyValuesGatingStateWrapper:
736
+ """List-like object of state cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].state`"""
737
+ logger.warning_once(
738
+ "`cache.state_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].state` instead."
739
+ )
740
+ return KeyValuesGatingStateWrapper(self.layers, "state")
741
+
742
+ @property
743
+ def sum_of_keys_cache(self) -> KeyValuesGatingStateWrapper:
744
+ """List-like object of sum of keys cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].sum_of_keys`"""
745
+ logger.warning_once(
746
+ "`cache.sum_of_keys_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].sum_of_keys` instead."
747
+ )
748
+ return KeyValuesGatingStateWrapper(self.layers, "sum_of_keys")
749
+
750
+ class DynamicCache(Cache):
751
+ """
752
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
753
+
754
+ It stores the Key, Value, and Gating states as a list of tensors, one for each layer. The expected shape for each tensor is
755
+ `[batch_size, num_heads, seq_len, head_dim]` for Key and Value, and `[batch_size, num_heads, seq_len]` for Gating.
756
+
757
+ See `Cache` for details on common methods that are implemented by all cache classes.
758
+
759
+ Example:
760
+
761
+ ```python
762
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
763
+
764
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
765
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
766
+
767
+ >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
768
+
769
+ >>> # Prepare a cache class and pass it to model's forward
770
+ >>> past_key_values = DynamicCache()
771
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
772
+ >>> outputs.past_key_values # access cache filled with key/values from generation
773
+ DynamicCache()
774
+ ```
775
+ """
776
+
777
+ # Specialized constructor for DDP cache data, needed for BC
778
+ def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]] = None):
779
+ # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212
780
+ # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the
781
+ # iterable contains the key and value states for a layer gathered across replicas by torch.distributed
782
+ # (shape=[global batch size, num_heads, seq_len, head_dim]).
783
+ if ddp_cache_data is not None:
784
+ layers = []
785
+ for key_states, value_states, gate_states, state, sum_of_keys in ddp_cache_data:
786
+ layers.append(DynamicLayer.from_tensors(key_states, value_states, gate_states, state, sum_of_keys))
787
+ super().__init__(layers=layers)
788
+ else:
789
+ super().__init__(layer_class_to_replicate=DynamicLayer)
790
+
791
+ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
792
+ """
793
+ Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
794
+ backward compatibility.
795
+ """
796
+ legacy_cache = ()
797
+ for layer in self.layers:
798
+ legacy_cache += ((layer.keys, layer.values, layer.gatings, layer.state, layer.sum_of_keys),)
799
+ return legacy_cache
800
+
801
+ @classmethod
802
+ def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...]) -> "Cache":
803
+ """
804
+ Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
805
+ backward compatibility.
806
+ """
807
+ cache = cls()
808
+ if past_key_values is not None:
809
+ for layer_idx in range(len(past_key_values)):
810
+ key_states, value_states, gate_states, state, sum_of_keys = past_key_values[layer_idx]
811
+ cache.update(key_states, value_states, gate_states, state, sum_of_keys, layer_idx)
812
+ return cache
813
+
814
+
815
+ # Utilities for `DynamicCache` <> torch.export support
816
+
817
+ if is_torch_greater_or_equal("2.3"):
818
+
819
+ def _get_cache_dict(cache: DynamicCache):
820
+ if any(not isinstance(layer, DynamicLayer) for layer in cache.layers):
821
+ raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
822
+
823
+ if not is_torch_greater_or_equal_than_2_6:
824
+ logger.warning_once(
825
+ "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions."
826
+ )
827
+
828
+ return {
829
+ "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None],
830
+ "value_cache": [layer.values for layer in cache.layers if layer.values is not None],
831
+ "gating_cache": [layer.gatings for layer in cache.layers if layer.gatings is not None],
832
+ "state_cache": [layer.state for layer in cache.layers if layer.state is not None],
833
+ "sum_of_keys_cache": [layer.sum_of_keys for layer in cache.layers if layer.sum_of_keys is not None],
834
+ }
835
+
836
+ def _unflatten_dynamic_cache(
837
+ values,
838
+ context: torch.utils._pytree.Context,
839
+ ):
840
+ dictionary = torch.utils._pytree._dict_unflatten(values, context)
841
+ cache = DynamicCache()
842
+ # Reconstruct layers from keys and values lists
843
+ key_list = dictionary.get("key_cache", [])
844
+ value_list = dictionary.get("value_cache", [])
845
+ gating_list = dictionary.get("gating_cache", [])
846
+ state_list = dictionary.get("state_cache", [])
847
+ sum_of_keys_list = dictionary.get("sum_of_keys_cache", [])
848
+ for idx in range(max(len(key_list), len(value_list), len(gating_list), len(state_list), len(sum_of_keys_list))):
849
+ key = key_list[idx] if idx < len(key_list) else None
850
+ value = value_list[idx] if idx < len(value_list) else None
851
+ gating = gating_list[idx] if idx < len(gating_list) else None
852
+ state = state_list[idx] if idx < len(state_list) else None
853
+ sum_of_keys = sum_of_keys_list[idx] if idx < len(sum_of_keys_list) else None
854
+ cache.update(key, value, gating, state, sum_of_keys, idx)
855
+ return cache
856
+
857
+ torch.utils._pytree.register_pytree_node(
858
+ DynamicCache,
859
+ lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)),
860
+ _unflatten_dynamic_cache,
861
+ serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
862
+ flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
863
+ _get_cache_dict(dynamic_cache)
864
+ ),
865
+ )
866
+ # TODO (tmanlaibaatar) This won't be needed in torch 2.7.
867
+ torch.fx._pytree.register_pytree_flatten_spec(
868
+ DynamicCache, lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec)
869
+ )
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors.index.json ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 378819426,
4
+ "total_size": 12726202608
5
+ },
6
+ "weight_map": {
7
+ "lm_head.weight": "model-00003-of-00003.safetensors",
8
+ "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
9
+ "model.layers.0.input_layernorm.bias": "model-00001-of-00003.safetensors",
10
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
11
+ "model.layers.0.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
12
+ "model.layers.0.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
13
+ "model.layers.0.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
14
+ "model.layers.0.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
15
+ "model.layers.0.post_attention_layernorm.bias": "model-00001-of-00003.safetensors",
16
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
17
+ "model.layers.0.self_attn.g_proj.bias": "model-00001-of-00003.safetensors",
18
+ "model.layers.0.self_attn.g_proj.weight": "model-00001-of-00003.safetensors",
19
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00003.safetensors",
20
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
21
+ "model.layers.0.self_attn.o_proj.bias": "model-00001-of-00003.safetensors",
22
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
23
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00003.safetensors",
24
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
25
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00003.safetensors",
26
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
27
+ "model.layers.1.input_layernorm.bias": "model-00001-of-00003.safetensors",
28
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
29
+ "model.layers.1.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
30
+ "model.layers.1.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
31
+ "model.layers.1.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
32
+ "model.layers.1.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
33
+ "model.layers.1.post_attention_layernorm.bias": "model-00001-of-00003.safetensors",
34
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
35
+ "model.layers.1.self_attn.g_proj.bias": "model-00001-of-00003.safetensors",
36
+ "model.layers.1.self_attn.g_proj.weight": "model-00001-of-00003.safetensors",
37
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00003.safetensors",
38
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
39
+ "model.layers.1.self_attn.o_proj.bias": "model-00001-of-00003.safetensors",
40
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
41
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00003.safetensors",
42
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
43
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00003.safetensors",
44
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
45
+ "model.layers.10.input_layernorm.bias": "model-00001-of-00003.safetensors",
46
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
47
+ "model.layers.10.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
48
+ "model.layers.10.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
49
+ "model.layers.10.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
50
+ "model.layers.10.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
51
+ "model.layers.10.post_attention_layernorm.bias": "model-00001-of-00003.safetensors",
52
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
53
+ "model.layers.10.self_attn.g_proj.bias": "model-00001-of-00003.safetensors",
54
+ "model.layers.10.self_attn.g_proj.weight": "model-00001-of-00003.safetensors",
55
+ "model.layers.10.self_attn.k_proj.bias": "model-00001-of-00003.safetensors",
56
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
57
+ "model.layers.10.self_attn.o_proj.bias": "model-00001-of-00003.safetensors",
58
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
59
+ "model.layers.10.self_attn.q_proj.bias": "model-00001-of-00003.safetensors",
60
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
61
+ "model.layers.10.self_attn.v_proj.bias": "model-00001-of-00003.safetensors",
62
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
63
+ "model.layers.11.input_layernorm.bias": "model-00002-of-00003.safetensors",
64
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00003.safetensors",
65
+ "model.layers.11.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
66
+ "model.layers.11.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
67
+ "model.layers.11.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
68
+ "model.layers.11.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
69
+ "model.layers.11.post_attention_layernorm.bias": "model-00002-of-00003.safetensors",
70
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
71
+ "model.layers.11.self_attn.g_proj.bias": "model-00001-of-00003.safetensors",
72
+ "model.layers.11.self_attn.g_proj.weight": "model-00001-of-00003.safetensors",
73
+ "model.layers.11.self_attn.k_proj.bias": "model-00001-of-00003.safetensors",
74
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
75
+ "model.layers.11.self_attn.o_proj.bias": "model-00001-of-00003.safetensors",
76
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
77
+ "model.layers.11.self_attn.q_proj.bias": "model-00001-of-00003.safetensors",
78
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
79
+ "model.layers.11.self_attn.v_proj.bias": "model-00001-of-00003.safetensors",
80
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
81
+ "model.layers.12.input_layernorm.bias": "model-00002-of-00003.safetensors",
82
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
83
+ "model.layers.12.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
84
+ "model.layers.12.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
85
+ "model.layers.12.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
86
+ "model.layers.12.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
87
+ "model.layers.12.post_attention_layernorm.bias": "model-00002-of-00003.safetensors",
88
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
89
+ "model.layers.12.self_attn.g_proj.bias": "model-00002-of-00003.safetensors",
90
+ "model.layers.12.self_attn.g_proj.weight": "model-00002-of-00003.safetensors",
91
+ "model.layers.12.self_attn.k_proj.bias": "model-00002-of-00003.safetensors",
92
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
93
+ "model.layers.12.self_attn.o_proj.bias": "model-00002-of-00003.safetensors",
94
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
95
+ "model.layers.12.self_attn.q_proj.bias": "model-00002-of-00003.safetensors",
96
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
97
+ "model.layers.12.self_attn.v_proj.bias": "model-00002-of-00003.safetensors",
98
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
99
+ "model.layers.13.input_layernorm.bias": "model-00002-of-00003.safetensors",
100
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
101
+ "model.layers.13.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
102
+ "model.layers.13.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
103
+ "model.layers.13.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
104
+ "model.layers.13.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
105
+ "model.layers.13.post_attention_layernorm.bias": "model-00002-of-00003.safetensors",
106
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
107
+ "model.layers.13.self_attn.g_proj.bias": "model-00002-of-00003.safetensors",
108
+ "model.layers.13.self_attn.g_proj.weight": "model-00002-of-00003.safetensors",
109
+ "model.layers.13.self_attn.k_proj.bias": "model-00002-of-00003.safetensors",
110
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
111
+ "model.layers.13.self_attn.o_proj.bias": "model-00002-of-00003.safetensors",
112
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
113
+ "model.layers.13.self_attn.q_proj.bias": "model-00002-of-00003.safetensors",
114
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
115
+ "model.layers.13.self_attn.v_proj.bias": "model-00002-of-00003.safetensors",
116
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
117
+ "model.layers.14.input_layernorm.bias": "model-00002-of-00003.safetensors",
118
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
119
+ "model.layers.14.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
120
+ "model.layers.14.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
121
+ "model.layers.14.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
122
+ "model.layers.14.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
123
+ "model.layers.14.post_attention_layernorm.bias": "model-00002-of-00003.safetensors",
124
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
125
+ "model.layers.14.self_attn.g_proj.bias": "model-00002-of-00003.safetensors",
126
+ "model.layers.14.self_attn.g_proj.weight": "model-00002-of-00003.safetensors",
127
+ "model.layers.14.self_attn.k_proj.bias": "model-00002-of-00003.safetensors",
128
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
129
+ "model.layers.14.self_attn.o_proj.bias": "model-00002-of-00003.safetensors",
130
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
131
+ "model.layers.14.self_attn.q_proj.bias": "model-00002-of-00003.safetensors",
132
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
133
+ "model.layers.14.self_attn.v_proj.bias": "model-00002-of-00003.safetensors",
134
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
135
+ "model.layers.15.input_layernorm.bias": "model-00002-of-00003.safetensors",
136
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
137
+ "model.layers.15.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
138
+ "model.layers.15.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
139
+ "model.layers.15.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
140
+ "model.layers.15.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
141
+ "model.layers.15.post_attention_layernorm.bias": "model-00002-of-00003.safetensors",
142
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
143
+ "model.layers.15.self_attn.g_proj.bias": "model-00002-of-00003.safetensors",
144
+ "model.layers.15.self_attn.g_proj.weight": "model-00002-of-00003.safetensors",
145
+ "model.layers.15.self_attn.k_proj.bias": "model-00002-of-00003.safetensors",
146
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
147
+ "model.layers.15.self_attn.o_proj.bias": "model-00002-of-00003.safetensors",
148
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
149
+ "model.layers.15.self_attn.q_proj.bias": "model-00002-of-00003.safetensors",
150
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
151
+ "model.layers.15.self_attn.v_proj.bias": "model-00002-of-00003.safetensors",
152
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
153
+ "model.layers.16.input_layernorm.bias": "model-00002-of-00003.safetensors",
154
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
155
+ "model.layers.16.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
156
+ "model.layers.16.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
157
+ "model.layers.16.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
158
+ "model.layers.16.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
159
+ "model.layers.16.post_attention_layernorm.bias": "model-00002-of-00003.safetensors",
160
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
161
+ "model.layers.16.self_attn.g_proj.bias": "model-00002-of-00003.safetensors",
162
+ "model.layers.16.self_attn.g_proj.weight": "model-00002-of-00003.safetensors",
163
+ "model.layers.16.self_attn.k_proj.bias": "model-00002-of-00003.safetensors",
164
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
165
+ "model.layers.16.self_attn.o_proj.bias": "model-00002-of-00003.safetensors",
166
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
167
+ "model.layers.16.self_attn.q_proj.bias": "model-00002-of-00003.safetensors",
168
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
169
+ "model.layers.16.self_attn.v_proj.bias": "model-00002-of-00003.safetensors",
170
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
171
+ "model.layers.17.input_layernorm.bias": "model-00002-of-00003.safetensors",
172
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
173
+ "model.layers.17.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
174
+ "model.layers.17.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
175
+ "model.layers.17.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
176
+ "model.layers.17.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
177
+ "model.layers.17.post_attention_layernorm.bias": "model-00002-of-00003.safetensors",
178
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
179
+ "model.layers.17.self_attn.g_proj.bias": "model-00002-of-00003.safetensors",
180
+ "model.layers.17.self_attn.g_proj.weight": "model-00002-of-00003.safetensors",
181
+ "model.layers.17.self_attn.k_proj.bias": "model-00002-of-00003.safetensors",
182
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
183
+ "model.layers.17.self_attn.o_proj.bias": "model-00002-of-00003.safetensors",
184
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
185
+ "model.layers.17.self_attn.q_proj.bias": "model-00002-of-00003.safetensors",
186
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
187
+ "model.layers.17.self_attn.v_proj.bias": "model-00002-of-00003.safetensors",
188
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
189
+ "model.layers.18.input_layernorm.bias": "model-00002-of-00003.safetensors",
190
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
191
+ "model.layers.18.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
192
+ "model.layers.18.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
193
+ "model.layers.18.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
194
+ "model.layers.18.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
195
+ "model.layers.18.post_attention_layernorm.bias": "model-00002-of-00003.safetensors",
196
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
197
+ "model.layers.18.self_attn.g_proj.bias": "model-00002-of-00003.safetensors",
198
+ "model.layers.18.self_attn.g_proj.weight": "model-00002-of-00003.safetensors",
199
+ "model.layers.18.self_attn.k_proj.bias": "model-00002-of-00003.safetensors",
200
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
201
+ "model.layers.18.self_attn.o_proj.bias": "model-00002-of-00003.safetensors",
202
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
203
+ "model.layers.18.self_attn.q_proj.bias": "model-00002-of-00003.safetensors",
204
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
205
+ "model.layers.18.self_attn.v_proj.bias": "model-00002-of-00003.safetensors",
206
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
207
+ "model.layers.19.input_layernorm.bias": "model-00002-of-00003.safetensors",
208
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
209
+ "model.layers.19.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
210
+ "model.layers.19.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
211
+ "model.layers.19.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
212
+ "model.layers.19.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
213
+ "model.layers.19.post_attention_layernorm.bias": "model-00002-of-00003.safetensors",
214
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
215
+ "model.layers.19.self_attn.g_proj.bias": "model-00002-of-00003.safetensors",
216
+ "model.layers.19.self_attn.g_proj.weight": "model-00002-of-00003.safetensors",
217
+ "model.layers.19.self_attn.k_proj.bias": "model-00002-of-00003.safetensors",
218
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
219
+ "model.layers.19.self_attn.o_proj.bias": "model-00002-of-00003.safetensors",
220
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
221
+ "model.layers.19.self_attn.q_proj.bias": "model-00002-of-00003.safetensors",
222
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
223
+ "model.layers.19.self_attn.v_proj.bias": "model-00002-of-00003.safetensors",
224
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
225
+ "model.layers.2.input_layernorm.bias": "model-00001-of-00003.safetensors",
226
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
227
+ "model.layers.2.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
228
+ "model.layers.2.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
229
+ "model.layers.2.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
230
+ "model.layers.2.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
231
+ "model.layers.2.post_attention_layernorm.bias": "model-00001-of-00003.safetensors",
232
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
233
+ "model.layers.2.self_attn.g_proj.bias": "model-00001-of-00003.safetensors",
234
+ "model.layers.2.self_attn.g_proj.weight": "model-00001-of-00003.safetensors",
235
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00003.safetensors",
236
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
237
+ "model.layers.2.self_attn.o_proj.bias": "model-00001-of-00003.safetensors",
238
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
239
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00003.safetensors",
240
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
241
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00003.safetensors",
242
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
243
+ "model.layers.20.input_layernorm.bias": "model-00002-of-00003.safetensors",
244
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
245
+ "model.layers.20.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
246
+ "model.layers.20.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
247
+ "model.layers.20.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
248
+ "model.layers.20.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
249
+ "model.layers.20.post_attention_layernorm.bias": "model-00002-of-00003.safetensors",
250
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
251
+ "model.layers.20.self_attn.g_proj.bias": "model-00002-of-00003.safetensors",
252
+ "model.layers.20.self_attn.g_proj.weight": "model-00002-of-00003.safetensors",
253
+ "model.layers.20.self_attn.k_proj.bias": "model-00002-of-00003.safetensors",
254
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
255
+ "model.layers.20.self_attn.o_proj.bias": "model-00002-of-00003.safetensors",
256
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
257
+ "model.layers.20.self_attn.q_proj.bias": "model-00002-of-00003.safetensors",
258
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
259
+ "model.layers.20.self_attn.v_proj.bias": "model-00002-of-00003.safetensors",
260
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
261
+ "model.layers.21.input_layernorm.bias": "model-00002-of-00003.safetensors",
262
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
263
+ "model.layers.21.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
264
+ "model.layers.21.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
265
+ "model.layers.21.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
266
+ "model.layers.21.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
267
+ "model.layers.21.post_attention_layernorm.bias": "model-00002-of-00003.safetensors",
268
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
269
+ "model.layers.21.self_attn.g_proj.bias": "model-00002-of-00003.safetensors",
270
+ "model.layers.21.self_attn.g_proj.weight": "model-00002-of-00003.safetensors",
271
+ "model.layers.21.self_attn.k_proj.bias": "model-00002-of-00003.safetensors",
272
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
273
+ "model.layers.21.self_attn.o_proj.bias": "model-00002-of-00003.safetensors",
274
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
275
+ "model.layers.21.self_attn.q_proj.bias": "model-00002-of-00003.safetensors",
276
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
277
+ "model.layers.21.self_attn.v_proj.bias": "model-00002-of-00003.safetensors",
278
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
279
+ "model.layers.22.input_layernorm.bias": "model-00002-of-00003.safetensors",
280
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
281
+ "model.layers.22.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
282
+ "model.layers.22.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
283
+ "model.layers.22.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
284
+ "model.layers.22.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
285
+ "model.layers.22.post_attention_layernorm.bias": "model-00002-of-00003.safetensors",
286
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
287
+ "model.layers.22.self_attn.g_proj.bias": "model-00002-of-00003.safetensors",
288
+ "model.layers.22.self_attn.g_proj.weight": "model-00002-of-00003.safetensors",
289
+ "model.layers.22.self_attn.k_proj.bias": "model-00002-of-00003.safetensors",
290
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
291
+ "model.layers.22.self_attn.o_proj.bias": "model-00002-of-00003.safetensors",
292
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
293
+ "model.layers.22.self_attn.q_proj.bias": "model-00002-of-00003.safetensors",
294
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
295
+ "model.layers.22.self_attn.v_proj.bias": "model-00002-of-00003.safetensors",
296
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
297
+ "model.layers.23.input_layernorm.bias": "model-00002-of-00003.safetensors",
298
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
299
+ "model.layers.23.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
300
+ "model.layers.23.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
301
+ "model.layers.23.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
302
+ "model.layers.23.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
303
+ "model.layers.23.post_attention_layernorm.bias": "model-00002-of-00003.safetensors",
304
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
305
+ "model.layers.23.self_attn.g_proj.bias": "model-00002-of-00003.safetensors",
306
+ "model.layers.23.self_attn.g_proj.weight": "model-00002-of-00003.safetensors",
307
+ "model.layers.23.self_attn.k_proj.bias": "model-00002-of-00003.safetensors",
308
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
309
+ "model.layers.23.self_attn.o_proj.bias": "model-00002-of-00003.safetensors",
310
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
311
+ "model.layers.23.self_attn.q_proj.bias": "model-00002-of-00003.safetensors",
312
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
313
+ "model.layers.23.self_attn.v_proj.bias": "model-00002-of-00003.safetensors",
314
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
315
+ "model.layers.24.input_layernorm.bias": "model-00003-of-00003.safetensors",
316
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
317
+ "model.layers.24.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
318
+ "model.layers.24.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
319
+ "model.layers.24.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
320
+ "model.layers.24.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
321
+ "model.layers.24.post_attention_layernorm.bias": "model-00003-of-00003.safetensors",
322
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
323
+ "model.layers.24.self_attn.g_proj.bias": "model-00002-of-00003.safetensors",
324
+ "model.layers.24.self_attn.g_proj.weight": "model-00002-of-00003.safetensors",
325
+ "model.layers.24.self_attn.k_proj.bias": "model-00002-of-00003.safetensors",
326
+ "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
327
+ "model.layers.24.self_attn.o_proj.bias": "model-00002-of-00003.safetensors",
328
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
329
+ "model.layers.24.self_attn.q_proj.bias": "model-00002-of-00003.safetensors",
330
+ "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
331
+ "model.layers.24.self_attn.v_proj.bias": "model-00002-of-00003.safetensors",
332
+ "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
333
+ "model.layers.25.input_layernorm.bias": "model-00003-of-00003.safetensors",
334
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
335
+ "model.layers.25.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
336
+ "model.layers.25.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
337
+ "model.layers.25.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
338
+ "model.layers.25.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
339
+ "model.layers.25.post_attention_layernorm.bias": "model-00003-of-00003.safetensors",
340
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
341
+ "model.layers.25.self_attn.g_proj.bias": "model-00003-of-00003.safetensors",
342
+ "model.layers.25.self_attn.g_proj.weight": "model-00003-of-00003.safetensors",
343
+ "model.layers.25.self_attn.k_proj.bias": "model-00003-of-00003.safetensors",
344
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
345
+ "model.layers.25.self_attn.o_proj.bias": "model-00003-of-00003.safetensors",
346
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
347
+ "model.layers.25.self_attn.q_proj.bias": "model-00003-of-00003.safetensors",
348
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
349
+ "model.layers.25.self_attn.v_proj.bias": "model-00003-of-00003.safetensors",
350
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
351
+ "model.layers.26.input_layernorm.bias": "model-00003-of-00003.safetensors",
352
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
353
+ "model.layers.26.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
354
+ "model.layers.26.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
355
+ "model.layers.26.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
356
+ "model.layers.26.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
357
+ "model.layers.26.post_attention_layernorm.bias": "model-00003-of-00003.safetensors",
358
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
359
+ "model.layers.26.self_attn.g_proj.bias": "model-00003-of-00003.safetensors",
360
+ "model.layers.26.self_attn.g_proj.weight": "model-00003-of-00003.safetensors",
361
+ "model.layers.26.self_attn.k_proj.bias": "model-00003-of-00003.safetensors",
362
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
363
+ "model.layers.26.self_attn.o_proj.bias": "model-00003-of-00003.safetensors",
364
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
365
+ "model.layers.26.self_attn.q_proj.bias": "model-00003-of-00003.safetensors",
366
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
367
+ "model.layers.26.self_attn.v_proj.bias": "model-00003-of-00003.safetensors",
368
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
369
+ "model.layers.27.input_layernorm.bias": "model-00003-of-00003.safetensors",
370
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
371
+ "model.layers.27.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
372
+ "model.layers.27.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
373
+ "model.layers.27.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
374
+ "model.layers.27.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
375
+ "model.layers.27.post_attention_layernorm.bias": "model-00003-of-00003.safetensors",
376
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
377
+ "model.layers.27.self_attn.g_proj.bias": "model-00003-of-00003.safetensors",
378
+ "model.layers.27.self_attn.g_proj.weight": "model-00003-of-00003.safetensors",
379
+ "model.layers.27.self_attn.k_proj.bias": "model-00003-of-00003.safetensors",
380
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
381
+ "model.layers.27.self_attn.o_proj.bias": "model-00003-of-00003.safetensors",
382
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
383
+ "model.layers.27.self_attn.q_proj.bias": "model-00003-of-00003.safetensors",
384
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
385
+ "model.layers.27.self_attn.v_proj.bias": "model-00003-of-00003.safetensors",
386
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
387
+ "model.layers.28.input_layernorm.bias": "model-00003-of-00003.safetensors",
388
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
389
+ "model.layers.28.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
390
+ "model.layers.28.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
391
+ "model.layers.28.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
392
+ "model.layers.28.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
393
+ "model.layers.28.post_attention_layernorm.bias": "model-00003-of-00003.safetensors",
394
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
395
+ "model.layers.28.self_attn.g_proj.bias": "model-00003-of-00003.safetensors",
396
+ "model.layers.28.self_attn.g_proj.weight": "model-00003-of-00003.safetensors",
397
+ "model.layers.28.self_attn.k_proj.bias": "model-00003-of-00003.safetensors",
398
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
399
+ "model.layers.28.self_attn.o_proj.bias": "model-00003-of-00003.safetensors",
400
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
401
+ "model.layers.28.self_attn.q_proj.bias": "model-00003-of-00003.safetensors",
402
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
403
+ "model.layers.28.self_attn.v_proj.bias": "model-00003-of-00003.safetensors",
404
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
405
+ "model.layers.29.input_layernorm.bias": "model-00003-of-00003.safetensors",
406
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
407
+ "model.layers.29.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
408
+ "model.layers.29.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
409
+ "model.layers.29.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
410
+ "model.layers.29.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
411
+ "model.layers.29.post_attention_layernorm.bias": "model-00003-of-00003.safetensors",
412
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
413
+ "model.layers.29.self_attn.g_proj.bias": "model-00003-of-00003.safetensors",
414
+ "model.layers.29.self_attn.g_proj.weight": "model-00003-of-00003.safetensors",
415
+ "model.layers.29.self_attn.k_proj.bias": "model-00003-of-00003.safetensors",
416
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
417
+ "model.layers.29.self_attn.o_proj.bias": "model-00003-of-00003.safetensors",
418
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
419
+ "model.layers.29.self_attn.q_proj.bias": "model-00003-of-00003.safetensors",
420
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
421
+ "model.layers.29.self_attn.v_proj.bias": "model-00003-of-00003.safetensors",
422
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
423
+ "model.layers.3.input_layernorm.bias": "model-00001-of-00003.safetensors",
424
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
425
+ "model.layers.3.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
426
+ "model.layers.3.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
427
+ "model.layers.3.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
428
+ "model.layers.3.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
429
+ "model.layers.3.post_attention_layernorm.bias": "model-00001-of-00003.safetensors",
430
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
431
+ "model.layers.3.self_attn.g_proj.bias": "model-00001-of-00003.safetensors",
432
+ "model.layers.3.self_attn.g_proj.weight": "model-00001-of-00003.safetensors",
433
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00003.safetensors",
434
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
435
+ "model.layers.3.self_attn.o_proj.bias": "model-00001-of-00003.safetensors",
436
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
437
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00003.safetensors",
438
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
439
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00003.safetensors",
440
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
441
+ "model.layers.4.input_layernorm.bias": "model-00001-of-00003.safetensors",
442
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
443
+ "model.layers.4.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
444
+ "model.layers.4.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
445
+ "model.layers.4.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
446
+ "model.layers.4.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
447
+ "model.layers.4.post_attention_layernorm.bias": "model-00001-of-00003.safetensors",
448
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
449
+ "model.layers.4.self_attn.g_proj.bias": "model-00001-of-00003.safetensors",
450
+ "model.layers.4.self_attn.g_proj.weight": "model-00001-of-00003.safetensors",
451
+ "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00003.safetensors",
452
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
453
+ "model.layers.4.self_attn.o_proj.bias": "model-00001-of-00003.safetensors",
454
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
455
+ "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00003.safetensors",
456
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
457
+ "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00003.safetensors",
458
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
459
+ "model.layers.5.input_layernorm.bias": "model-00001-of-00003.safetensors",
460
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
461
+ "model.layers.5.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
462
+ "model.layers.5.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
463
+ "model.layers.5.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
464
+ "model.layers.5.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
465
+ "model.layers.5.post_attention_layernorm.bias": "model-00001-of-00003.safetensors",
466
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
467
+ "model.layers.5.self_attn.g_proj.bias": "model-00001-of-00003.safetensors",
468
+ "model.layers.5.self_attn.g_proj.weight": "model-00001-of-00003.safetensors",
469
+ "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00003.safetensors",
470
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
471
+ "model.layers.5.self_attn.o_proj.bias": "model-00001-of-00003.safetensors",
472
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
473
+ "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00003.safetensors",
474
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
475
+ "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00003.safetensors",
476
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
477
+ "model.layers.6.input_layernorm.bias": "model-00001-of-00003.safetensors",
478
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
479
+ "model.layers.6.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
480
+ "model.layers.6.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
481
+ "model.layers.6.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
482
+ "model.layers.6.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
483
+ "model.layers.6.post_attention_layernorm.bias": "model-00001-of-00003.safetensors",
484
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
485
+ "model.layers.6.self_attn.g_proj.bias": "model-00001-of-00003.safetensors",
486
+ "model.layers.6.self_attn.g_proj.weight": "model-00001-of-00003.safetensors",
487
+ "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00003.safetensors",
488
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
489
+ "model.layers.6.self_attn.o_proj.bias": "model-00001-of-00003.safetensors",
490
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
491
+ "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00003.safetensors",
492
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
493
+ "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00003.safetensors",
494
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
495
+ "model.layers.7.input_layernorm.bias": "model-00001-of-00003.safetensors",
496
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
497
+ "model.layers.7.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
498
+ "model.layers.7.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
499
+ "model.layers.7.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
500
+ "model.layers.7.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
501
+ "model.layers.7.post_attention_layernorm.bias": "model-00001-of-00003.safetensors",
502
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
503
+ "model.layers.7.self_attn.g_proj.bias": "model-00001-of-00003.safetensors",
504
+ "model.layers.7.self_attn.g_proj.weight": "model-00001-of-00003.safetensors",
505
+ "model.layers.7.self_attn.k_proj.bias": "model-00001-of-00003.safetensors",
506
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
507
+ "model.layers.7.self_attn.o_proj.bias": "model-00001-of-00003.safetensors",
508
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
509
+ "model.layers.7.self_attn.q_proj.bias": "model-00001-of-00003.safetensors",
510
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
511
+ "model.layers.7.self_attn.v_proj.bias": "model-00001-of-00003.safetensors",
512
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
513
+ "model.layers.8.input_layernorm.bias": "model-00001-of-00003.safetensors",
514
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
515
+ "model.layers.8.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
516
+ "model.layers.8.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
517
+ "model.layers.8.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
518
+ "model.layers.8.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
519
+ "model.layers.8.post_attention_layernorm.bias": "model-00001-of-00003.safetensors",
520
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
521
+ "model.layers.8.self_attn.g_proj.bias": "model-00001-of-00003.safetensors",
522
+ "model.layers.8.self_attn.g_proj.weight": "model-00001-of-00003.safetensors",
523
+ "model.layers.8.self_attn.k_proj.bias": "model-00001-of-00003.safetensors",
524
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
525
+ "model.layers.8.self_attn.o_proj.bias": "model-00001-of-00003.safetensors",
526
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
527
+ "model.layers.8.self_attn.q_proj.bias": "model-00001-of-00003.safetensors",
528
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
529
+ "model.layers.8.self_attn.v_proj.bias": "model-00001-of-00003.safetensors",
530
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
531
+ "model.layers.9.input_layernorm.bias": "model-00001-of-00003.safetensors",
532
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
533
+ "model.layers.9.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
534
+ "model.layers.9.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
535
+ "model.layers.9.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
536
+ "model.layers.9.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
537
+ "model.layers.9.post_attention_layernorm.bias": "model-00001-of-00003.safetensors",
538
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
539
+ "model.layers.9.self_attn.g_proj.bias": "model-00001-of-00003.safetensors",
540
+ "model.layers.9.self_attn.g_proj.weight": "model-00001-of-00003.safetensors",
541
+ "model.layers.9.self_attn.k_proj.bias": "model-00001-of-00003.safetensors",
542
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
543
+ "model.layers.9.self_attn.o_proj.bias": "model-00001-of-00003.safetensors",
544
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
545
+ "model.layers.9.self_attn.q_proj.bias": "model-00001-of-00003.safetensors",
546
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
547
+ "model.layers.9.self_attn.v_proj.bias": "model-00001-of-00003.safetensors",
548
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
549
+ "model.norm.bias": "model-00003-of-00003.safetensors",
550
+ "model.norm.weight": "model-00003-of-00003.safetensors"
551
+ }
552
+ }
modeling_powercoder.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Callable, Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from retention.triton import power_retention, power_retention_inference
8
+
9
+ from transformers.activations import ACT2FN
10
+ from transformers.generation import GenerationMixin
11
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
12
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
13
+ from transformers.modeling_layers import (
14
+ GenericForSequenceClassification,
15
+ GenericForTokenClassification,
16
+ GradientCheckpointingLayer,
17
+ )
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
20
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
+ from transformers.processing_utils import Unpack
22
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
23
+ from .configuration_powercoder import PowerCoderConfig
24
+ from .kvgs_dynamic_cache import Cache, DynamicCache
25
+
26
+ class PowerCoderMLP(nn.Module):
27
+ def __init__(self, config: PowerCoderConfig):
28
+ super().__init__()
29
+ embed_dim = config.hidden_size
30
+ self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias)
31
+ self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias)
32
+ self.act = ACT2FN[config.hidden_act]
33
+ self.residual_dropout = config.residual_dropout
34
+
35
+ def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.FloatTensor:
36
+ hidden_states = self.c_fc(hidden_states)
37
+ hidden_states = self.act(hidden_states)
38
+ hidden_states = self.c_proj(hidden_states)
39
+ hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training)
40
+ return hidden_states
41
+
42
+
43
+ def rotate_half(x):
44
+ """Rotates half the hidden dims of the input."""
45
+ x1 = x[..., : x.shape[-1] // 2]
46
+ x2 = x[..., x.shape[-1] // 2 :]
47
+ return torch.cat((-x2, x1), dim=-1)
48
+
49
+
50
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
51
+ """Applies Rotary Position Embedding to the query and key tensors.
52
+
53
+ Args:
54
+ q (`torch.Tensor`): The query tensor.
55
+ k (`torch.Tensor`): The key tensor.
56
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
57
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
58
+ position_ids (`torch.Tensor`, *optional*):
59
+ Deprecated and unused.
60
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
61
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
62
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
63
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
64
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
65
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
66
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
67
+ Returns:
68
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
69
+ """
70
+ cos = cos.unsqueeze(unsqueeze_dim)
71
+ sin = sin.unsqueeze(unsqueeze_dim)
72
+ q_embed = (q * cos) + (rotate_half(q) * sin)
73
+ k_embed = (k * cos) + (rotate_half(k) * sin)
74
+ return q_embed.to(q.dtype), k_embed.to(k.dtype)
75
+
76
+
77
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
78
+ """
79
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
80
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
81
+ """
82
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
83
+ if n_rep == 1:
84
+ return hidden_states
85
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
86
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
87
+
88
+
89
+ def eager_power_attention_forward(
90
+ module: nn.Module,
91
+ query: torch.Tensor,
92
+ key: torch.Tensor,
93
+ value: torch.Tensor,
94
+ attention_mask: Optional[torch.Tensor],
95
+ scaling: float,
96
+ dropout: float = 0.0,
97
+ **kwargs: Unpack[TransformersKwargs],
98
+ ):
99
+ key_states = repeat_kv(key, module.num_key_value_groups)
100
+ value_states = repeat_kv(value, module.num_key_value_groups)
101
+
102
+ attn_weights = 2*torch.log(torch.abs( torch.matmul(query, key_states.transpose(2, 3)) * scaling + 1e-5))
103
+ if attention_mask is not None:
104
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
105
+ attn_weights = attn_weights + causal_mask
106
+
107
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
108
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
109
+ attn_output = torch.matmul(attn_weights, value_states)
110
+ attn_output = attn_output.transpose(1, 2).contiguous()
111
+
112
+ return attn_output, attn_weights
113
+
114
+
115
+ class PowerCoderAttention(nn.Module):
116
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
117
+
118
+ def __init__(self, config: PowerCoderConfig, layer_idx: Optional[int] = None):
119
+ super().__init__()
120
+ self.config = config
121
+ self.layer_idx = layer_idx
122
+ self.chunk_size = config.chunk_size
123
+ self.switch_over_seq_len = config.switch_over_seq_len
124
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
125
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
126
+ self.scaling = self.head_dim**-0.5
127
+ self.attention_dropout = config.attention_dropout
128
+ self.is_causal = True
129
+
130
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias)
131
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
132
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
133
+ self.g_proj = nn.Linear(config.hidden_size, config.num_key_value_heads, bias=config.use_bias)
134
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
135
+ self.residual_dropout = config.residual_dropout
136
+
137
+ def forward(
138
+ self,
139
+ hidden_states: torch.Tensor,
140
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
141
+ padding_starts: Optional[torch.Tensor],
142
+ past_key_value: Optional[Cache] = None,
143
+ cache_position: Optional[torch.LongTensor] = None,
144
+ **kwargs: Unpack[FlashAttentionKwargs],
145
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
146
+ input_shape = hidden_states.shape[:-1]
147
+ hidden_shape = (*input_shape, -1, self.head_dim)
148
+ interpolate_exp_amount = kwargs.get('interpolate_exp', 0)
149
+ assert 0 <= interpolate_exp_amount <= 1, f'{interpolate_exp_amount=}'
150
+ run_exp = interpolate_exp_amount > 0
151
+
152
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
153
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
154
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
155
+ gate_states = self.g_proj(hidden_states).view(hidden_shape[:-1]).transpose(1, 2)
156
+ gate_states = nn.functional.logsigmoid(gate_states.to(torch.float32))
157
+
158
+ cos, sin = position_embeddings
159
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
160
+
161
+ if past_key_value is not None:
162
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
163
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
164
+ key_states, value_states, gate_states, state, sum_of_keys = past_key_value.update_kv(key_states, value_states, gate_states, self.layer_idx, cache_kwargs)
165
+
166
+ if run_exp:
167
+ attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
168
+
169
+ exp_attn_output, exp_attn_weights = attention_interface(
170
+ self,
171
+ query_states,
172
+ key_states,
173
+ value_states,
174
+ is_causal=True,
175
+ attention_mask=None,
176
+ dropout=0.0 if not self.training else self.attention_dropout,
177
+ scaling=self.scaling,
178
+ **kwargs,
179
+ )
180
+
181
+ if query_states.shape[2] == 1:
182
+ key_len = key_states.shape[2]
183
+ power_attn_output, state, sum_of_keys = power_retention_inference(
184
+ query_states.transpose(1, 2),
185
+ key_states.transpose(1, 2),
186
+ value_states.transpose(1, 2),
187
+ gate_states.transpose(1, 2),
188
+ initial_state=state,
189
+ sum_of_keys=sum_of_keys,
190
+ deg=2,
191
+ scale=self.scaling,
192
+ switch_over_seq_len=self.switch_over_seq_len,
193
+ )
194
+ if self.switch_over_seq_len is not None and key_len >= self.switch_over_seq_len:
195
+ past_key_value.clean_kv(self.layer_idx)
196
+ past_key_value.update_state(state, sum_of_keys, self.layer_idx, cache_kwargs)
197
+
198
+ else:
199
+ key_len = key_states.shape[2]
200
+ power_attn_output = power_retention(
201
+ query_states.transpose(1, 2),
202
+ key_states.transpose(1, 2),
203
+ value_states.transpose(1, 2),
204
+ gate_states.transpose(1, 2),
205
+ deg=2,
206
+ scale=self.scaling,
207
+ chunk_size=self.chunk_size, # enable chunked prefilling by default
208
+ )
209
+
210
+ if interpolate_exp_amount == 1:
211
+ attn_output = exp_attn_output
212
+ elif interpolate_exp_amount == 0:
213
+ attn_output = power_attn_output
214
+ else:
215
+ attn_output = interpolate_exp_amount * exp_attn_output + (1 - interpolate_exp_amount) * power_attn_output
216
+
217
+ assert attn_output.shape == (input_shape[0], query_states.shape[2], self.config.num_attention_heads, self.head_dim),\
218
+ f'{attn_output.shape=} {(input_shape[0], query_states.shape[2], self.config.num_attention_heads, self.head_dim)=}'
219
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
220
+ attn_output = self.o_proj(attn_output)
221
+ attn_output = nn.functional.dropout(
222
+ attn_output, p=self.residual_dropout, training=self.training
223
+ ) # diff with Llama
224
+
225
+ return attn_output
226
+
227
+
228
+ class PowerCoderDecoderLayer(GradientCheckpointingLayer):
229
+ def __init__(self, config: PowerCoderConfig, layer_idx: int):
230
+ super().__init__()
231
+ self.hidden_size = config.hidden_size
232
+ self.self_attn = PowerCoderAttention(config=config, layer_idx=layer_idx)
233
+ self.mlp = PowerCoderMLP(config)
234
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
235
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
236
+
237
+ def forward(
238
+ self,
239
+ hidden_states: torch.Tensor,
240
+ padding_starts: Optional[torch.Tensor] = None,
241
+ position_ids: Optional[torch.LongTensor] = None,
242
+ past_key_value: Optional[Cache] = None,
243
+ use_cache: Optional[bool] = False,
244
+ cache_position: Optional[torch.LongTensor] = None,
245
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
246
+ **kwargs: Unpack[TransformersKwargs],
247
+ ) -> tuple[torch.Tensor]:
248
+ residual = hidden_states
249
+ hidden_states = self.input_layernorm(hidden_states)
250
+ # Self Attention
251
+ hidden_states = self.self_attn(
252
+ hidden_states=hidden_states,
253
+ padding_starts=padding_starts,
254
+ position_ids=position_ids,
255
+ past_key_value=past_key_value,
256
+ use_cache=use_cache,
257
+ cache_position=cache_position,
258
+ position_embeddings=position_embeddings,
259
+ **kwargs,
260
+ )
261
+ hidden_states = residual + hidden_states
262
+ # Fully Connected
263
+ residual = hidden_states
264
+ hidden_states = self.post_attention_layernorm(hidden_states)
265
+ hidden_states = self.mlp(hidden_states)
266
+ hidden_states = residual + hidden_states
267
+ return hidden_states
268
+
269
+
270
+ class PowerCoderRotaryEmbedding(nn.Module):
271
+ def __init__(self, config: PowerCoderConfig, device=None):
272
+ super().__init__()
273
+ # BC: "rope_type" was originally "type"
274
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
275
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
276
+ else:
277
+ self.rope_type = "default"
278
+ self.max_seq_len_cached = config.max_position_embeddings
279
+ self.original_max_seq_len = config.max_position_embeddings
280
+
281
+ self.config = config
282
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
283
+
284
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
285
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
286
+ self.original_inv_freq = self.inv_freq
287
+
288
+ @torch.no_grad()
289
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
290
+ def forward(self, x, position_ids):
291
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
292
+ position_ids_expanded = position_ids[:, None, :].float()
293
+
294
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
295
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
296
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
297
+ emb = torch.cat((freqs, freqs), dim=-1)
298
+ cos = emb.cos() * self.attention_scaling
299
+ sin = emb.sin() * self.attention_scaling
300
+
301
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
302
+
303
+
304
+ @auto_docstring
305
+ class PowerCoderPreTrainedModel(PreTrainedModel):
306
+ config: PowerCoderConfig
307
+ base_model_prefix = "model"
308
+ supports_gradient_checkpointing = True
309
+ _no_split_modules = ["PowerCoderDecoderLayer"]
310
+ _skip_keys_device_placement = ["past_key_values"]
311
+ _supports_flash_attn = True
312
+ _supports_sdpa = True
313
+ _supports_flex_attn = True
314
+
315
+ _can_compile_fullgraph = True
316
+ _supports_attention_backend = True
317
+ _can_record_outputs = {
318
+ "hidden_states": PowerCoderDecoderLayer,
319
+ "attentions": PowerCoderAttention,
320
+ }
321
+
322
+
323
+ @auto_docstring
324
+ class PowerCoderModel(PowerCoderPreTrainedModel):
325
+ def __init__(self, config: PowerCoderConfig):
326
+ super().__init__(config)
327
+ self.padding_idx = config.pad_token_id
328
+ self.vocab_size = config.vocab_size
329
+
330
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
331
+ self.layers = nn.ModuleList(
332
+ [PowerCoderDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
333
+ )
334
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
335
+ self.rotary_emb = PowerCoderRotaryEmbedding(config=config)
336
+ self.gradient_checkpointing = False
337
+ self.embedding_dropout = config.embedding_dropout
338
+
339
+ # Initialize weights and apply final processing
340
+ self.post_init()
341
+
342
+ def forward(
343
+ self,
344
+ input_ids: Optional[torch.LongTensor] = None,
345
+ attention_mask: Optional[torch.Tensor] = None,
346
+ position_ids: Optional[torch.LongTensor] = None,
347
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
348
+ inputs_embeds: Optional[torch.FloatTensor] = None,
349
+ use_cache: Optional[bool] = None,
350
+ cache_position: Optional[torch.LongTensor] = None,
351
+ **kwargs: Unpack[TransformersKwargs],
352
+ ) -> BaseModelOutputWithPast:
353
+ if (input_ids is None) ^ (inputs_embeds is not None):
354
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
355
+
356
+ if inputs_embeds is None:
357
+ inputs_embeds = self.embed_tokens(input_ids)
358
+
359
+ # Always use our local DynamicCache implementation for compatibility with gating
360
+ if use_cache:
361
+ if past_key_values is None or not isinstance(past_key_values, Cache):
362
+ past_key_values = DynamicCache()
363
+
364
+ if cache_position is None:
365
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
366
+ cache_position = torch.arange(
367
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
368
+ )
369
+
370
+ if position_ids is None:
371
+ position_ids = cache_position.unsqueeze(0)
372
+
373
+ # mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
374
+ # causal_mask = mask_function(
375
+ # config=self.config,
376
+ # input_embeds=inputs_embeds,
377
+ # attention_mask=attention_mask,
378
+ # cache_position=cache_position,
379
+ # past_key_values=past_key_values,
380
+ # position_ids=position_ids,
381
+ # )
382
+ padding_starts = attention_mask.argmin(-1) if attention_mask is not None else None
383
+
384
+ hidden_states = inputs_embeds
385
+ hidden_states = nn.functional.dropout(
386
+ hidden_states, p=self.embedding_dropout, training=self.training
387
+ ) # main diff with Llama
388
+
389
+ # create position embeddings to be shared across the decoder layers
390
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
391
+ for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
392
+ hidden_states = decoder_layer(
393
+ hidden_states,
394
+ padding_starts=padding_starts,
395
+ position_ids=position_ids,
396
+ past_key_value=past_key_values,
397
+ use_cache=use_cache,
398
+ cache_position=cache_position,
399
+ position_embeddings=position_embeddings,
400
+ **kwargs,
401
+ )
402
+
403
+ hidden_states = self.norm(hidden_states)
404
+
405
+ return BaseModelOutputWithPast(
406
+ last_hidden_state=hidden_states,
407
+ past_key_values=past_key_values if use_cache else None,
408
+ )
409
+
410
+
411
+ @auto_docstring
412
+ class PowerCoderForCausalLM(PowerCoderPreTrainedModel, GenerationMixin):
413
+ _tied_weights_keys = ["lm_head.weight"]
414
+ _tp_plan = {"lm_head": "colwise_rep"}
415
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
416
+
417
+ def __init__(self, config, chunk_size=None, switch_over_seq_len=None):
418
+ if chunk_size is not None:
419
+ config.chunk_size = chunk_size
420
+ if switch_over_seq_len is not None:
421
+ config.switch_over_seq_len = switch_over_seq_len
422
+ super().__init__(config)
423
+ self.model = PowerCoderModel(config)
424
+ self.vocab_size = config.vocab_size
425
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
426
+
427
+ # Initialize weights and apply final processing
428
+ self.post_init()
429
+
430
+ def set_decoder(self, decoder):
431
+ self.model = decoder
432
+
433
+ def get_decoder(self):
434
+ return self.model
435
+
436
+ @can_return_tuple
437
+ @auto_docstring
438
+ def forward(
439
+ self,
440
+ input_ids: Optional[torch.LongTensor] = None,
441
+ attention_mask: Optional[torch.Tensor] = None,
442
+ position_ids: Optional[torch.LongTensor] = None,
443
+ past_key_values: Optional[Cache] = None,
444
+ inputs_embeds: Optional[torch.FloatTensor] = None,
445
+ labels: Optional[torch.LongTensor] = None,
446
+ use_cache: Optional[bool] = None,
447
+ cache_position: Optional[torch.LongTensor] = None,
448
+ logits_to_keep: Union[int, torch.Tensor] = 0,
449
+ **kwargs: Unpack[TransformersKwargs],
450
+ ) -> CausalLMOutputWithPast:
451
+ r"""
452
+ Example:
453
+
454
+ ```python
455
+ >>> from transformers import AutoTokenizer, PowerCoderForCausalLM
456
+
457
+ >>> model = PowerCoderForCausalLM.from_pretrained("meta-PowerCoder/PowerCoder-2-7b-hf")
458
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-PowerCoder/PowerCoder-2-7b-hf")
459
+
460
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
461
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
462
+
463
+ >>> # Generate
464
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
465
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
466
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
467
+ ```
468
+
469
+ Args:
470
+ input_ids (`Optional[torch.LongTensor]`, *optional*):
471
+ Indices of input sequence tokens in the vocabulary.
472
+ attention_mask (`Optional[torch.Tensor]`, *optional*):
473
+ Mask to avoid performing attention on padding token indices.
474
+ position_ids (`Optional[torch.LongTensor]`, *optional*):
475
+ Indices of positions of each input sequence tokens.
476
+ past_key_values (`Optional[Cache]`, *optional*):
477
+ Cache containing pre-computed key and value states for attention layers, used for faster inference.
478
+ If `use_cache` is True, the cache will be used and updated with new key/value states.
479
+ inputs_embeds (`Optional[torch.FloatTensor]`, *optional*):
480
+ Pre-computed input embeddings. Useful for scenarios where you want to compute embeddings separately.
481
+ labels (`Optional[torch.LongTensor]`, *optional*):
482
+ Labels for computing language modeling loss.
483
+ use_cache (`Optional[bool]`, *optional*):
484
+ If True, past key/value states are returned and can be used for future predictions.
485
+ cache_position (`Optional[torch.LongTensor]`, *optional*):
486
+ Position indices for cached key/value states when using incremental decoding.
487
+ logits_to_keep (`Union[int, torch.Tensor]`, *optional*, defaults to 0):
488
+ Number of logits to compute from the end of the sequence, or specific indices to compute.
489
+ **kwargs:
490
+ Additional arguments passed to the underlying model's forward method.
491
+
492
+ Returns:
493
+ `CausalLMOutputWithPast`: A dataclass containing:
494
+ - loss (`Optional[torch.FloatTensor]`): Language modeling loss if labels were provided.
495
+ - logits (`torch.FloatTensor`): Prediction scores for the vocabulary.
496
+ - past_key_values (`Optional[Cache]`): Updated key/value states for attention layers if use_cache=True.
497
+ - hidden_states (`Optional[Tuple[torch.FloatTensor]]`): Model's hidden states.
498
+ - attentions (`Optional[Tuple[torch.FloatTensor]]`): Attention weights if output_attentions=True.
499
+ """
500
+ outputs: BaseModelOutputWithPast = self.model(
501
+ input_ids=input_ids,
502
+ attention_mask=attention_mask,
503
+ position_ids=position_ids,
504
+ past_key_values=past_key_values,
505
+ inputs_embeds=inputs_embeds,
506
+ use_cache=use_cache,
507
+ cache_position=cache_position,
508
+ **kwargs,
509
+ )
510
+
511
+ hidden_states = outputs.last_hidden_state
512
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
513
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
514
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
515
+
516
+ loss = None
517
+ if labels is not None:
518
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
519
+
520
+ return CausalLMOutputWithPast(
521
+ loss=loss,
522
+ logits=logits,
523
+ past_key_values=outputs.past_key_values,
524
+ hidden_states=outputs.hidden_states,
525
+ attentions=outputs.attentions,
526
+ )
527
+
528
+
529
+ class PowerCoderForSequenceClassification(GenericForSequenceClassification, PowerCoderPreTrainedModel):
530
+ pass
531
+
532
+
533
+ class PowerCoderForTokenClassification(GenericForTokenClassification, PowerCoderPreTrainedModel):
534
+ pass
535
+
536
+
537
+ __all__ = [
538
+ "PowerCoderForCausalLM",
539
+ "PowerCoderModel",
540
+ "PowerCoderPreTrainedModel",
541
+ "PowerCoderForSequenceClassification",
542
+ "PowerCoderForTokenClassification",
543
+ ]
special_tokens_map.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|endoftext|>",
4
+ "<fim_prefix>",
5
+ "<fim_middle>",
6
+ "<fim_suffix>",
7
+ "<fim_pad>",
8
+ "<repo_name>",
9
+ "<file_sep>",
10
+ "<issue_start>",
11
+ "<issue_comment>",
12
+ "<issue_closed>",
13
+ "<jupyter_start>",
14
+ "<jupyter_text>",
15
+ "<jupyter_code>",
16
+ "<jupyter_output>",
17
+ "<jupyter_script>",
18
+ "<empty_output>",
19
+ "<code_to_intermediate>",
20
+ "<intermediate_to_code>",
21
+ "<pr>",
22
+ "<pr_status>",
23
+ "<pr_is_merged>",
24
+ "<pr_base>",
25
+ "<pr_file>",
26
+ "<pr_base_code>",
27
+ "<pr_diff>",
28
+ "<pr_diff_hunk>",
29
+ "<pr_comment>",
30
+ "<pr_event_id>",
31
+ "<pr_review>",
32
+ "<pr_review_state>",
33
+ "<pr_review_comment>",
34
+ "<pr_in_reply_to_review_id>",
35
+ "<pr_in_reply_to_comment_id>",
36
+ "<pr_diff_hunk_comment_line>",
37
+ "<NAME>",
38
+ "<EMAIL>",
39
+ "<KEY>",
40
+ "<PASSWORD>"
41
+ ],
42
+ "bos_token": {
43
+ "content": "<|endoftext|>",
44
+ "lstrip": false,
45
+ "normalized": false,
46
+ "rstrip": false,
47
+ "single_word": false
48
+ },
49
+ "eos_token": {
50
+ "content": "<|endoftext|>",
51
+ "lstrip": false,
52
+ "normalized": false,
53
+ "rstrip": false,
54
+ "single_word": false
55
+ },
56
+ "pad_token": "<|endoftext|>",
57
+ "unk_token": {
58
+ "content": "<|endoftext|>",
59
+ "lstrip": false,
60
+ "normalized": false,
61
+ "rstrip": false,
62
+ "single_word": false
63
+ }
64
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<fim_prefix>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<fim_middle>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<fim_suffix>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "4": {
37
+ "content": "<fim_pad>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "5": {
45
+ "content": "<repo_name>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "6": {
53
+ "content": "<file_sep>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "7": {
61
+ "content": "<issue_start>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "8": {
69
+ "content": "<issue_comment>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "9": {
77
+ "content": "<issue_closed>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "10": {
85
+ "content": "<jupyter_start>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "11": {
93
+ "content": "<jupyter_text>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "12": {
101
+ "content": "<jupyter_code>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "13": {
109
+ "content": "<jupyter_output>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "14": {
117
+ "content": "<jupyter_script>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "15": {
125
+ "content": "<empty_output>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "16": {
133
+ "content": "<code_to_intermediate>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "17": {
141
+ "content": "<intermediate_to_code>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "18": {
149
+ "content": "<pr>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "19": {
157
+ "content": "<pr_status>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "20": {
165
+ "content": "<pr_is_merged>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "21": {
173
+ "content": "<pr_base>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "22": {
181
+ "content": "<pr_file>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "23": {
189
+ "content": "<pr_base_code>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "24": {
197
+ "content": "<pr_diff>",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "25": {
205
+ "content": "<pr_diff_hunk>",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "26": {
213
+ "content": "<pr_comment>",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "27": {
221
+ "content": "<pr_event_id>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "28": {
229
+ "content": "<pr_review>",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "29": {
237
+ "content": "<pr_review_state>",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "30": {
245
+ "content": "<pr_review_comment>",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "31": {
253
+ "content": "<pr_in_reply_to_review_id>",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "32": {
261
+ "content": "<pr_in_reply_to_comment_id>",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "33": {
269
+ "content": "<pr_diff_hunk_comment_line>",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "34": {
277
+ "content": "<NAME>",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "35": {
285
+ "content": "<EMAIL>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "36": {
293
+ "content": "<KEY>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "37": {
301
+ "content": "<PASSWORD>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ }
308
+ },
309
+ "additional_special_tokens": [
310
+ "<|endoftext|>",
311
+ "<fim_prefix>",
312
+ "<fim_middle>",
313
+ "<fim_suffix>",
314
+ "<fim_pad>",
315
+ "<repo_name>",
316
+ "<file_sep>",
317
+ "<issue_start>",
318
+ "<issue_comment>",
319
+ "<issue_closed>",
320
+ "<jupyter_start>",
321
+ "<jupyter_text>",
322
+ "<jupyter_code>",
323
+ "<jupyter_output>",
324
+ "<jupyter_script>",
325
+ "<empty_output>",
326
+ "<code_to_intermediate>",
327
+ "<intermediate_to_code>",
328
+ "<pr>",
329
+ "<pr_status>",
330
+ "<pr_is_merged>",
331
+ "<pr_base>",
332
+ "<pr_file>",
333
+ "<pr_base_code>",
334
+ "<pr_diff>",
335
+ "<pr_diff_hunk>",
336
+ "<pr_comment>",
337
+ "<pr_event_id>",
338
+ "<pr_review>",
339
+ "<pr_review_state>",
340
+ "<pr_review_comment>",
341
+ "<pr_in_reply_to_review_id>",
342
+ "<pr_in_reply_to_comment_id>",
343
+ "<pr_diff_hunk_comment_line>",
344
+ "<NAME>",
345
+ "<EMAIL>",
346
+ "<KEY>",
347
+ "<PASSWORD>"
348
+ ],
349
+ "bos_token": "<|endoftext|>",
350
+ "clean_up_tokenization_spaces": true,
351
+ "eos_token": "<|endoftext|>",
352
+ "extra_special_tokens": {},
353
+ "model_max_length": 1000000000000000019884624838656,
354
+ "pad_token": "<|endoftext|>",
355
+ "tokenizer_class": "GPT2Tokenizer",
356
+ "unk_token": "<|endoftext|>",
357
+ "vocab_size": 49152
358
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff