File size: 6,490 Bytes
010341e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch

from typing import Optional, List, Tuple, Dict, Any
from transformers.cache_utils import Cache
from contextlib import contextmanager

class BlockFlowMatchingCache(Cache):
    def __init__(
            self, 
            text_lengths: Optional[torch.Tensor] = None, 
            block_size: Optional[int] = None, 
            num_history_block: Optional[int] = None
        ) -> None:
        super().__init__()
        self._seen_tokens = 0 
        self.text_key_cache: List[torch.Tensor] = []
        self.text_value_cache: List[torch.Tensor] = []
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self.text_lengths = text_lengths
        self.block_size = block_size
        self.num_history_block = num_history_block
        self.is_cache_text = False
        self.is_storage_cache = False
        assert (
            (
                self.num_history_block is not None 
                and 
                self.block_size is not None
            ) or self.num_history_block is None
        ), "num_history_block and block_size must be set at the same time."

    @contextmanager
    def cache_text(self):
        self.is_cache_text = True
        try:
            yield self
        finally:
            self.is_cache_text = False

    @contextmanager
    def cache_context(self):
        self.is_storage_cache = True
        try:
            yield self
        finally:
            self.is_storage_cache = False

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """
        # cache text
        if self.is_cache_text:
            if self.text_lengths is None:
                self.text_lengths = torch.LongTensor([key_states.shape[-2]] * key_states.shape[0])
            self.text_key_cache.append(key_states)
            self.text_value_cache.append(value_states)
            return self.text_key_cache[layer_idx], self.text_value_cache[layer_idx]

        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        if key_states is not None:
            if len(self.key_cache) <= layer_idx:
                # There may be skipped layers, fill them with empty lists
                for _ in range(len(self.key_cache), layer_idx + 1):
                    self.key_cache.append([])
                    self.value_cache.append([])
            cached_key_state = self.key_cache[layer_idx]
            cached_value_state = self.value_cache[layer_idx]
            if len(cached_key_state) != 0:
                key_states = torch.cat([cached_key_state, key_states], dim=-2)
                value_states = torch.cat([cached_value_state, value_states], dim=-2)
            if self.num_history_block is not None:
                history_length = self.block_size * (self.num_history_block + 1)
                key_states = key_states[:, :, -history_length:, :]
                value_states = value_states[:, :, -history_length:, :]
            if self.is_storage_cache:
                self.key_cache[layer_idx] = key_states
                self.value_cache[layer_idx] = value_states
        
        k_s = []
        v_s = []

        text_key_cache = (
            self.text_key_cache[layer_idx] 
            if len(self.text_key_cache) > layer_idx 
            else torch.zeros(key_states.shape[0], key_states.shape[1], 0, key_states.shape[3], device=key_states.device, dtype=key_states.dtype)
        )
        text_value_cache = (
            self.text_value_cache[layer_idx] 
            if len(self.text_value_cache) > layer_idx 
            else torch.zeros(value_states.shape[0], value_states.shape[1], 0, value_states.shape[3], device=value_states.device, dtype=value_states.dtype)
        )
        for b in range(self.text_lengths.shape[0]):
            k_s.append(torch.cat([text_key_cache[b][:, :self.text_lengths[b], :], key_states[b]], dim=-2))
            v_s.append(torch.cat([text_value_cache[b][:, :self.text_lengths[b], :], value_states[b]], dim=-2))
        k_s = torch.nn.utils.rnn.pad_sequence(k_s, batch_first=True)
        v_s = torch.nn.utils.rnn.pad_sequence(v_s, batch_first=True)

        return k_s, v_s

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        # TODO: deprecate this function in favor of `cache_position`
        is_empty_layer = (
            len(self.key_cache) == 0  # no cache in any layer
            or len(self.key_cache) <= layer_idx  # skipped `layer_idx` and hasn't run a layer with cache after it
            or len(self.key_cache[layer_idx]) == 0  # the layer has no cache
        )
        layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
        return layer_seq_length

    def get_max_cache_shape(self) -> Optional[int]:
        """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length."""
        return None