Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,490 Bytes
010341e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Optional, List, Tuple, Dict, Any
from transformers.cache_utils import Cache
from contextlib import contextmanager
class BlockFlowMatchingCache(Cache):
def __init__(
self,
text_lengths: Optional[torch.Tensor] = None,
block_size: Optional[int] = None,
num_history_block: Optional[int] = None
) -> None:
super().__init__()
self._seen_tokens = 0
self.text_key_cache: List[torch.Tensor] = []
self.text_value_cache: List[torch.Tensor] = []
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.text_lengths = text_lengths
self.block_size = block_size
self.num_history_block = num_history_block
self.is_cache_text = False
self.is_storage_cache = False
assert (
(
self.num_history_block is not None
and
self.block_size is not None
) or self.num_history_block is None
), "num_history_block and block_size must be set at the same time."
@contextmanager
def cache_text(self):
self.is_cache_text = True
try:
yield self
finally:
self.is_cache_text = False
@contextmanager
def cache_context(self):
self.is_storage_cache = True
try:
yield self
finally:
self.is_storage_cache = False
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
# cache text
if self.is_cache_text:
if self.text_lengths is None:
self.text_lengths = torch.LongTensor([key_states.shape[-2]] * key_states.shape[0])
self.text_key_cache.append(key_states)
self.text_value_cache.append(value_states)
return self.text_key_cache[layer_idx], self.text_value_cache[layer_idx]
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if key_states is not None:
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx + 1):
self.key_cache.append([])
self.value_cache.append([])
cached_key_state = self.key_cache[layer_idx]
cached_value_state = self.value_cache[layer_idx]
if len(cached_key_state) != 0:
key_states = torch.cat([cached_key_state, key_states], dim=-2)
value_states = torch.cat([cached_value_state, value_states], dim=-2)
if self.num_history_block is not None:
history_length = self.block_size * (self.num_history_block + 1)
key_states = key_states[:, :, -history_length:, :]
value_states = value_states[:, :, -history_length:, :]
if self.is_storage_cache:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
k_s = []
v_s = []
text_key_cache = (
self.text_key_cache[layer_idx]
if len(self.text_key_cache) > layer_idx
else torch.zeros(key_states.shape[0], key_states.shape[1], 0, key_states.shape[3], device=key_states.device, dtype=key_states.dtype)
)
text_value_cache = (
self.text_value_cache[layer_idx]
if len(self.text_value_cache) > layer_idx
else torch.zeros(value_states.shape[0], value_states.shape[1], 0, value_states.shape[3], device=value_states.device, dtype=value_states.dtype)
)
for b in range(self.text_lengths.shape[0]):
k_s.append(torch.cat([text_key_cache[b][:, :self.text_lengths[b], :], key_states[b]], dim=-2))
v_s.append(torch.cat([text_value_cache[b][:, :self.text_lengths[b], :], value_states[b]], dim=-2))
k_s = torch.nn.utils.rnn.pad_sequence(k_s, batch_first=True)
v_s = torch.nn.utils.rnn.pad_sequence(v_s, batch_first=True)
return k_s, v_s
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
len(self.key_cache) == 0 # no cache in any layer
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
return layer_seq_length
def get_max_cache_shape(self) -> Optional[int]:
"""Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length."""
return None
|