diff --git a/README.md b/README.md index 6f0187820075350082295838df5a819e821edc60..952064c5417b15f1606e48edabcb9ecc194c34bf 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ --- -title: MiniCPM4.1 8B Demo -emoji: 🦀 -colorFrom: pink -colorTo: red +title: MiniCPM4.1 8B Eagle3 Straming +emoji: 🚀 +colorFrom: yellow +colorTo: blue sdk: gradio -sdk_version: 5.46.0 +sdk_version: 5.44.1 app_file: app.py pinned: false -license: apache-2.0 -short_description: chat with MiniCPM4.1-8B with speculative decoding +tags: +- anycoder --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..81de8d5534becba52fcb43acbdca54c59002d8e4 --- /dev/null +++ b/app.py @@ -0,0 +1,319 @@ +# MiniCPM-4.1-8B-Eagle3 + +from pathlib import Path +import time +import logging +import gradio as gr +import torch +import spaces +import threading +from transformers import AutoTokenizer, TextIteratorStreamer +# 导入模型相关模块 +from eagle.model.ea_model import EaModel +from utils_chatbot import organize_messages, stream2display_text, mtp_new_tokens + +# 日志配置 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# 全局模型实例 +global_model = None +# 全局模型缓存(在GPU进程中) +_gpu_model_cache = None +# 全局模型配置 +model_config = dict( + base_model_path = "openbmb/MiniCPM4.1-8B", + ea_model_path = "openbmb/MiniCPM4.1-8B-Eagle3/MiniCPM4_1-8B-Eagle3-bf16", + total_token=40, + depth=3, + top_k=10, + threshold=1.0, + use_eagle3=True, + device_map = "cpu", + trust_remote_code=True, +) + +# 提前加载 tokenizer +tokenizer = AutoTokenizer.from_pretrained( + "openbmb/MiniCPM4.1-8B", + use_fast=False, + device_map="cpu", +) + +def _initialize_gpu_model(): + """在GPU进程中获取模型并移到GPU""" + global _gpu_model_cache + if _gpu_model_cache is None: + logger.info(f"在GPU进程中初始化模型") + _gpu_model_cache = EaModel.from_pretrained(**model_config) + logger.info(f"模型在CPU上初始化完成") + return _gpu_model_cache + +@spaces.GPU(duration=42) # default is 60 +def gpu_handler(inputs): + prompt_text = tokenizer.apply_chat_template( + inputs, + tokenize=False, + add_generation_prompt=True, + ) + model_inputs = tokenizer([prompt_text], return_tensors="pt") + inputs = { + "model_inputs": model_inputs, + "max_new_tokens": 65536, + "temperature": 0.6, + "top_p": 0.95, + "top_k": 50, + "max_length": 65536, + } + + logger.info(f"向 GPU 搬运 global_model") + + """GPU推理处理器""" + model = _initialize_gpu_model() + + cuda_inputs = dict( + input_ids=inputs["model_inputs"].input_ids.to("cuda"), + # attention_mask=inputs["model_inputs"].attention_mask.to("cuda"), + max_new_tokens=inputs["max_new_tokens"], + temperature=inputs["temperature"], + top_p=inputs["top_p"], + top_k=inputs["top_k"], + max_length=inputs["max_length"], + ) + + model.base_model.to("cuda") + model.ea_layer.to("cuda") + model.ea_layer.tree_mask_init.to("cuda") + + logger.info(f"pass inputs to global_model") + + output_ids = model.eagenerate(**cuda_inputs) + + logger.info(f"got outputs from global_model.eagenerate") + new_text = tokenizer.decode( + output_ids[0][model_inputs.input_ids.shape[1]:], + skip_special_tokens=True, + ) + + return new_text + +@spaces.GPU(duration=60) # default is 60 +def gpu_handler_s( + inputs, + history, + temperature, + top_p, + use_eagle, +): + prompt_text = tokenizer.apply_chat_template( + inputs, + tokenize=False, + add_generation_prompt=True, + ) + model_inputs = tokenizer([prompt_text], return_tensors="pt") + inputs = { + "model_inputs": model_inputs, + "max_new_tokens": 4096, + "temperature": temperature, + "top_p": top_p, + "top_k": 50, + "max_length": 65536, + } + + logger.info(f"向 GPU 搬运 global_model") + + """GPU推理处理器""" + model = _initialize_gpu_model() + + cuda_inputs = dict( + input_ids=inputs["model_inputs"].input_ids.to("cuda"), + # attention_mask=inputs["model_inputs"].attention_mask.to("cuda"), + max_new_tokens=inputs["max_new_tokens"], + temperature=inputs["temperature"], + top_p=inputs["top_p"], + top_k=inputs["top_k"], + max_length=inputs["max_length"], + ) + + model.base_model.to("cuda") + model.ea_layer.to("cuda") + model.ea_layer.tree_mask_init.to("cuda") + + logger.info(f"pass inputs to global_model") + + yield "", history + + stop_token_ids = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|eot_id|>") + ] + gen_tk_count, existing_tk_count = 0, len(inputs["model_inputs"].input_ids[0]) + + stream_text, start_time = "", time.time() + + generate_func = model.ea_generate if use_eagle else model.naive_generate + + for output_ids in generate_func(**cuda_inputs): + # for output_ids in model.ea_generate(**cuda_inputs): + new_tokens, gen_tk_count = mtp_new_tokens(output_ids, gen_tk_count, existing_tk_count, stop_token_ids) + new_token_text = tokenizer.decode(new_tokens, skip_special_tokens=False) + logger.info(f"[TOKEN]'''{new_token_text}'''") + stream_text += new_token_text + token_per_sec = gen_tk_count / (time.time() - start_time) + display_text = stream2display_text(stream_text, token_per_sec) + history[-1] = (history[-1][0], display_text) + yield "", history + + # logger.info(f"all gen text: \n{stream_text}") + history[-1] = (history[-1][0], stream_text.replace("<|im_end|>", "")) + # 替换 history 为非 display 形态的 text + + +class Model: + """模型封装类,不持有实际模型对象""" + + def __init__(self): + logger.info(f"创建封装类") + + def handler(self, inputs): + """非流式推理处理器""" + return gpu_handler(inputs) + + def stream_handler(self, inputs, history, **kwargs): + """流式推理处理器""" + yield from gpu_handler_s(inputs, history, **kwargs) + + +def initialize_model(): + """初始化全局模型""" + global global_model, _gpu_model_cache + + # 默认配置 + logger.info(f"="*50) + logger.info(f"启动 MiniCPM-4.1-8B-Eagle3 Chatbot 服务") + logger.info(f"="*50) + + # 创建模型封装类 + global_model = Model() + + # 在主进程中预加载模型到CPU(For faster 首次推理) + try: + logger.info("在主进程中预加载模型到 CPU...") + _gpu_model_cache = EaModel.from_pretrained(**model_config) + logger.info("模型在主进程CPU上预加载完成") + except Exception as e: + logger.warning(f"主进程预加载模型失败, 将在GPU进程中加载: {e}") + _gpu_model_cache = None + + return global_model + + +def gen_response(message, history, temperature, top_p): + chat_msg_ls = organize_messages(message, history) + + new_text = global_model.handler(chat_msg_ls) + + history.append((message, new_text)) + return "", history + +def gen_response_stream( + message, + history, + temperature, + top_p, + use_eagle, +): + chat_msg_ls = organize_messages(message, history) + + history.append((message, "")) + + sampling_kwargs = dict( + temperature = temperature, + top_p = top_p, + use_eagle = use_eagle, + ) + + yield from global_model.stream_handler(chat_msg_ls, history, **sampling_kwargs) + +def create_app(): + assets_path = Path.cwd().absolute()/"assets" + logger.info(f"设置静态资源路径: {assets_path}") + gr.set_static_paths(paths=[assets_path]) + logger.info("静态资源路径设置完成") + + theme = gr.themes.Soft( + primary_hue="blue", + secondary_hue="gray", + neutral_hue="slate", + font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"], + ) + + # # Add border styling to components + # theme = theme.set( + # primary_border_size='1px', # 组件外框 + # primary_border_color='*neutral_400', # 用主题里的 slate-400 灰色 + # ) + + with gr.Blocks( + theme=theme, + css=""" + .logo-container { + text-align: center; + margin: 0.5rem 0 1rem 0; + } + .logo-container img { + height: 96px; + width: auto; + max-width: 200px; + display: inline-block; + } + .input-box { + border: 1px solid #2f63b8; + border-radius: 8px; + } + """, + ) as demo: + with gr.Row(): + with gr.Column(scale=1): + gr.HTML('
MiniCPM Logo
') + + blank_1 = gr.HTML("
") + + temperature = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.05, label="Temperature", scale=1) + top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.01, label="Top-p", scale=1) + use_eagle = gr.Checkbox(label="Speculative Decoding", value=True) + + blank_2 = gr.HTML("
") + + clear = gr.Button("Clear History") + + gr.Markdown( + """ + Built with anycoder + """ + ) + with gr.Column(scale=4): + chatbot = gr.Chatbot(label="Chat History", placeholder="Input to start a new chat", height=500) + prompt = gr.Textbox( + label="Input Text", + placeholder="Type your message here...", + lines=1, + # submit_btn=True, + elem_classes=["input-box"], # 自定义 class 供 css 使用 + ) + + prompt.submit(gen_response_stream, inputs=[prompt, chatbot, temperature, top_p, use_eagle], outputs=[prompt, chatbot]) + clear.click(lambda: None, None, chatbot, queue=False) + + return demo + + +if __name__ == "__main__": + # 初始化模型 + initialize_model() + + # 创建并启动应用 + demo = create_app() + demo.launch() + diff --git a/eagle/model/__init__.py b/eagle/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/eagle/model/choices.py b/eagle/model/choices.py new file mode 100644 index 0000000000000000000000000000000000000000..3db0a23eb8808c59a09c5b464587c003777aa885 --- /dev/null +++ b/eagle/model/choices.py @@ -0,0 +1,3 @@ +mc_sim_7b_63 = [[0],[1],[2],[3],[0,0],[0,1],[0,2],[1,0],[1,1],[2,0],[2,1],[3,0] + ,[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,2,0],[0,2,1],[1,0,0], + [0,0,0,0],[0,0,0,1],[0,0,0,2],[0,0,0,0,0],[0,0,0,0,1]] diff --git a/eagle/model/cnets.py b/eagle/model/cnets.py new file mode 100644 index 0000000000000000000000000000000000000000..b8fd50070ed97e6f246371fe219ccae63fbb1c96 --- /dev/null +++ b/eagle/model/cnets.py @@ -0,0 +1,887 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" +import copy +import os +# os.environ["CUDA_VISIBLE_DEVICES"] = "5" +import math +from typing import List, Optional, Tuple, Union +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from huggingface_hub import hf_hub_download + + +try: + from .configs import EConfig + from .utils_c import * + from .choices import * +except: + from configs import EConfig + from utils_c import * + from choices import * + from utils import prepare_logits_processor + + + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +class MiniCPMLongRoPE(LlamaRotaryEmbedding): + """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, short_factor=None, long_factor=None, original_max_position_embeddings=None): + self.short_factor = short_factor + self.long_factor = long_factor + self.original_max_position_embeddings = original_max_position_embeddings + scale = (max_position_embeddings / self.original_max_position_embeddings) + self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device) + + freqs = torch.mul( + torch.outer(t, 1.0 / ext_factors).to(device=device), + self.inv_freq.to(device=device).to(dtype) + ) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', emb.cos()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False) + self.register_buffer('sin_cached', emb.sin()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size * 2, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + if hasattr(self.config, "rope_theta"): + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta) + else: + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, + max_position_embeddings=self.max_position_embeddings) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + elif scaling_type == "longrope": + self.rotary_emb = MiniCPMLongRoPE( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + short_factor=self.config.rope_scaling['short_factor'], + long_factor=self.config.rope_scaling['long_factor'], + base=self.rope_theta, + original_max_position_embeddings=self.config.rope_scaling['original_max_position_embeddings'] + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class LlamaDecoderLayeremb(nn.Module): + def __init__(self, config, last=True): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP(config) + self.last = last + # self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # if self.index!=0: + + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_emb: torch.Tensor, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.hidden_norm(hidden_states) + input_emb = self.input_layernorm(input_emb) + + hidden_states = torch.cat((input_emb, hidden_states), dim=-1) + + + # cache_hidden.append(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@torch.no_grad() +def padding(tensor, left=True): + zeropadding = torch.zeros_like(tensor[:, -1:]) + if left: + tensor = torch.cat((zeropadding, tensor[:, :-1]), dim=1) + else: + tensor = torch.cat((tensor[:, 1:], zeropadding), dim=1) + return tensor + + + +def len_list(x, n): + return [i for i in x if len(i) <= n] + + +class Model(nn.Module): + def __init__(self, config, load_emb=False, path=None, bias=True, total_tokens=63, depth=5, top_k=8, threshold=1.0): + super().__init__() + self.config=config + self.gradient_checkpointing = True + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.lm_head=nn.Linear(config.hidden_size,config.draft_vocab_size,bias=False) + if load_emb and not hasattr(config, "target_hidden_size"): + from safetensors import safe_open + import json + try: + index_json_path = os.path.join(path, "model.safetensors.index.json") + if not os.path.exists(index_json_path): + index_json_path = hf_hub_download(path, "model.safetensors.index.json") + with open(index_json_path, "r") as f: + index_json = json.loads(f.read()) + emb_path = index_json["weight_map"]["model.embed_tokens.weight"] + local_emb_path = os.path.join(path, emb_path) + if not os.path.exists(local_emb_path): + local_emb_path = hf_hub_download(path, emb_path) + with safe_open(local_emb_path, + framework="pt", + device="cpu") as f: + tensor_slice = f.get_slice("model.embed_tokens.weight") + vocab_size, hidden_dim = tensor_slice.get_shape() + tensor = tensor_slice[:, :hidden_dim].float() + except: + index_json_path = os.path.join(path, "pytorch_model.bin.index.json") + if not os.path.exists(index_json_path): + index_json_path = hf_hub_download(path, "pytorch_model.bin.index.json") + with open(index_json_path, "r") as f: + index_json = json.loads(f.read()) + emb_path = index_json["weight_map"]["model.embed_tokens.weight"] + local_emb_path = os.path.join(path, emb_path) + if not os.path.exists(local_emb_path): + local_emb_path = hf_hub_download(path, emb_path) + weights = torch.load(local_emb_path) + tensor = weights["model.embed_tokens.weight"].float() + self.embed_tokens.weight.data = tensor + + self.top_k = top_k + self.total_tokens = total_tokens - 1 + self.depth = depth + self.threshold = math.log(threshold) + # print("total_tokens",total_tokens) + # print("depth",depth) + # print("top_k",top_k) + # print("threshold",threshold) + self.hidden_size = config.hidden_size + self.midlayer = LlamaDecoderLayeremb(config) + if hasattr(config, "target_hidden_size"): + self.fc = nn.Linear(config.target_hidden_size * 3, self.hidden_size, bias=False) + else: + self.fc = nn.Linear(config.hidden_size * 3, self.hidden_size, bias=False) + self.norm=LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.logsoftmax = nn.LogSoftmax(dim=-1) + + d2t=torch.zeros((config.draft_vocab_size),dtype=torch.long) + t2d=torch.zeros((config.vocab_size),dtype=torch.bool) + self.register_buffer("d2t", d2t) + self.register_buffer("t2d", t2d) + + for param in self.embed_tokens.parameters(): + param.requires_grad = False + + def init_tree(self): + self.tree_mask_init = torch.eye(self.top_k, device=self.embed_tokens.weight.device)[None, None] + self.position_ids = torch.zeros(self.top_k, device=self.embed_tokens.weight.device, dtype=torch.long) + self.tree_mask_init = self.tree_mask_init.to(self.embed_tokens.weight.device) + + def reset(self): + self.tree_mask = None + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + # inputs_embeds.dtype, + torch.float32, # [MODIFIED] force to cast to float32 + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, torch.float32, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + # [MODIFIED] add tree mask + if hasattr(self, "tree_mask") and self.tree_mask is not None: + tree_mask = self.tree_mask + _, _, tree_shape0, tree_shape1 = tree_mask.shape + combined_attention_mask[:, :, -tree_shape0:, -tree_shape1:][ + tree_mask == 0 + ] = torch.finfo(torch.float32).min + + return combined_attention_mask + + def forward( + self, + hidden_states, + input_ids, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + std=None + ): + batch_size, seq_length, _ = hidden_states.shape + seq_length_with_past = seq_length + past_key_values_length = 0 + + with torch.no_grad(): + inputs_embeds = self.embed_tokens(input_ids) + # inputs_embeds = inputs_embeds.detach() + + # if std is not None: + # noise = torch.randn(inputs_embeds.size(),device=inputs_embeds.device) * std + # inputs_embeds=inputs_embeds+noise + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if position_ids is None: + device = hidden_states.device if hidden_states is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + #position_ids=position_ids//4 + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + # if self.gradient_checkpointing and self.training: + # if use_cache: + # use_cache = False + + # hidden_states=self.act(self.fc(torch.cat((inputs_embeds,hidden_states),dim=-1))) + inputs_embeds = inputs_embeds.to(hidden_states.dtype) + if hidden_states.shape[-1]!=inputs_embeds.shape[-1]: + hidden_states = self.fc(hidden_states) + # hidden_states = self.fc(hidden_states) + + all_hidden_states = () if output_hidden_states else None + next_decoder_cache = () if use_cache else None + + past_key_value = past_key_values[0] if past_key_values is not None else None + layer_outputs = self.midlayer( + input_emb=inputs_embeds, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=True, + ) + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + hidden_states = layer_outputs[0] + + + if use_cache: + return hidden_states, next_decoder_cache + + return hidden_states + + def reset_kv(self): + self.stable_kv = None + + @torch.no_grad() + def topK_genrate(self, hidden_states, input_ids, head, logits_processor): + + input_ids = input_ids.to(hidden_states.device) + total_tokens = self.total_tokens + depth = self.depth + top_k = self.top_k + + sample_token = input_ids[:, -1] + + scores_list = [] + parents_list = [] + ss_token = [] + + input_ids = input_ids[:, 1:] + input_ids = input_ids.to(hidden_states.device) + + len_posi = input_ids.shape[1] + self.reset() + + # with Timer("draft many"): + if hasattr(self, "stable_kv") and self.stable_kv is not None: + kv_len = self.stable_kv[0][0].shape[2] + out_hidden, past_key_values = self(hidden_states, input_ids=input_ids[:, kv_len:], + past_key_values=self.stable_kv, use_cache=True) + else: + out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, use_cache=True) + self.stable_kv = past_key_values + last_hidden = out_hidden[:, -1] + + # last_headout = head(last_hidden) + last_headout = self.lm_head(self.norm(last_hidden)) + + last_p = self.logsoftmax(last_headout) + top = torch.topk(last_p, top_k, dim=-1) + topk_index, topk_p = top.indices, top.values + scores = topk_p[0] + scores_list.append(scores[None]) + parents_list.append(torch.zeros(1, dtype=torch.long, device=scores.device)) + if self.config.vocab_size==self.config.draft_vocab_size: + ss_token.append(topk_index) + input_ids = topk_index + else: + ss_token.append(topk_index+self.d2t[topk_index]) + input_ids = topk_index+self.d2t[topk_index] + input_hidden = last_hidden[None].repeat(1, top_k, 1) + tree_mask = self.tree_mask_init + topk_cs_index = torch.arange(top_k, device=self.embed_tokens.weight.device) + + # 4 + for i in range(depth): + self.tree_mask = tree_mask + position_ids = len_posi + self.position_ids + # with Timer("draft one"): + out_hidden, past_key_values = self(input_hidden, input_ids=input_ids, past_key_values=past_key_values, + position_ids=position_ids, use_cache=True) + len_posi += 1 + + # with Timer("sort1"): + bias1 = top_k if i > 0 else 0 + bias2 = max(0, i - 1) + bias = 1 + top_k ** 2 * bias2 + bias1 + parents = (topk_cs_index + bias) + parents_list.append(parents) + + last_headout = self.lm_head(self.norm(out_hidden[0])) + last_p = self.logsoftmax(last_headout) + + top = torch.topk(last_p, top_k, dim=-1) + topk_index, topk_p = top.indices, top.values + + cu_scores = topk_p + scores[:, None] + + topk_cs = torch.topk(cu_scores.view(-1), top_k, dim=-1) + topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values + scores = topk_cs_p + + out_ids = topk_cs_index // top_k + input_hidden = out_hidden[:, out_ids] + + input_ids = topk_index.view(-1)[topk_cs_index][None] + + if self.config.vocab_size == self.config.draft_vocab_size: + ss_token.append(topk_index) + else: + input_ids = input_ids + self.d2t[input_ids] + ss_token.append(topk_index+self.d2t[topk_index]) + scores_list.append(cu_scores) + + # JQZ 250912 + # tree_mask = torch.cat((tree_mask[:, :, out_ids], self.tree_mask_init), dim=3) + # for dynamic moving between cpu and gpu + out_ids_for_mask = out_ids.to(tree_mask.device) + tree_mask = torch.cat((tree_mask[:, :, out_ids_for_mask], self.tree_mask_init), dim=3) + # + + + scores_list = torch.cat(scores_list, dim=0).view(-1) + ss_token_list = torch.cat(ss_token, dim=0).view(-1) + top_scores = torch.topk(scores_list, total_tokens, dim=-1) + top_scores_index = top_scores.indices + top_scores_index = torch.sort(top_scores_index).values + + draft_tokens = ss_token_list[top_scores_index] + draft_tokens = torch.cat((sample_token, draft_tokens), dim=0) + + draft_parents = torch.cat(parents_list, dim=0)[top_scores_index // top_k].long() + mask_index = torch.searchsorted(top_scores_index, draft_parents - 1, right=False) + # mask_index[(top_scores_index[mask_index]!=draft_parents - 1)]=-1 + mask_index[draft_parents == 0] = -1 + mask_index = mask_index + 1 + mask_index_list = mask_index.tolist() + # with Timer("mask"): + tree_mask = torch.eye(total_tokens + 1).bool() + tree_mask[:, 0] = True + for i in range(total_tokens): + tree_mask[i + 1].add_(tree_mask[mask_index_list[i]]) + + + tree_position_ids = torch.sum(tree_mask, dim=1) - 1 + + tree_mask = tree_mask.float()[None, None] + draft_tokens = draft_tokens[None] + + del parents_list, scores_list, ss_token, ss_token_list, draft_parents + + # with Timer("retrieve"): + + max_depth = torch.max(tree_position_ids) + 1 + noleaf_index = torch.unique(mask_index).tolist() + noleaf_num = len(noleaf_index) - 1 + leaf_num = total_tokens - noleaf_num + + retrieve_indices = torch.zeros(leaf_num, max_depth.item(), dtype=torch.long) - 1 + retrieve_indices = retrieve_indices.tolist() + + rid = 0 + position_ids_list = tree_position_ids.tolist() + + for i in range(total_tokens + 1): + if i not in noleaf_index: + cid = i + depth = position_ids_list[i] + for j in reversed(range(depth + 1)): + retrieve_indices[rid][j] = cid + cid = mask_index_list[cid - 1] + rid += 1 + + if logits_processor is not None: + maxitem = total_tokens + 5 + + def custom_sort(lst): + # sort_keys=[len(list)] + sort_keys = [] + for i in range(len(lst)): + sort_keys.append(lst[i] if lst[i] >= 0 else maxitem) + return sort_keys + + retrieve_indices = sorted(retrieve_indices, key=custom_sort) + + retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) + del mask_index, mask_index_list, noleaf_index, noleaf_num, leaf_num, max_depth, rid + tree_position_ids = tree_position_ids.to(hidden_states.device) + + return draft_tokens, retrieve_indices, tree_mask, tree_position_ids + + + + +import torch + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters()) + + +if __name__ == "__main__": + config = EConfig.from_pretrained('config.json') + model = Model(config, load_emb=False) + print(model) diff --git a/eagle/model/cnets1.py b/eagle/model/cnets1.py new file mode 100644 index 0000000000000000000000000000000000000000..a8beb32cb9d0dd7b96952ff12c2ad2b69dc830e4 --- /dev/null +++ b/eagle/model/cnets1.py @@ -0,0 +1,835 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" +import copy +import os +# os.environ["CUDA_VISIBLE_DEVICES"] = "5" +import math +from typing import List, Optional, Tuple, Union +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from huggingface_hub import hf_hub_download + + +try: + from .configs import EConfig + from .utils_c import * + from .choices import * +except: + from configs import EConfig + from utils_c import * + from choices import * + from utils import prepare_logits_processor + + + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + if hasattr(config, "qkv_bias"): + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.qkv_bias) + else: + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + if hasattr(self.config, "rope_theta"): + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta) + else: + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, + max_position_embeddings=self.max_position_embeddings) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config, index): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP(config) + self.index = index + if self.index != 0: + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + if self.index != 0: + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class I(nn.Module): + def __init__(self): + super().__init__() + self.dummy = nn.Parameter(torch.ones(1, dtype=torch.float32)) + + def forward(self, x): + return x + self.dummy - self.dummy # (also tried x+self.dummy) + + +def len_list(x, n): + return [i for i in x if len(i) <= n] + + +class Model(nn.Module): + def __init__(self, config, load_emb=False, path=None, bias=True, total_tokens=63, depth=5, top_k=8, threshold=1.0): + super().__init__() + + self.gradient_checkpointing = True + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if load_emb: + from safetensors import safe_open + import json + try: + index_json_path = os.path.join(path, "model.safetensors.index.json") + if not os.path.exists(index_json_path): + index_json_path = hf_hub_download(path, "model.safetensors.index.json") + with open(index_json_path, "r") as f: + index_json = json.loads(f.read()) + emb_path = index_json["weight_map"]["model.embed_tokens.weight"] + local_emb_path = os.path.join(path, emb_path) + if not os.path.exists(local_emb_path): + local_emb_path = hf_hub_download(path, emb_path) + with safe_open(local_emb_path, + framework="pt", + device="cpu") as f: + tensor_slice = f.get_slice("model.embed_tokens.weight") + vocab_size, hidden_dim = tensor_slice.get_shape() + tensor = tensor_slice[:, :hidden_dim].float() + except: + index_json_path = os.path.join(path, "pytorch_model.bin.index.json") + if not os.path.exists(index_json_path): + index_json_path = hf_hub_download(path, "pytorch_model.bin.index.json") + with open(index_json_path, "r") as f: + index_json = json.loads(f.read()) + emb_path = index_json["weight_map"]["model.embed_tokens.weight"] + local_emb_path = os.path.join(path, emb_path) + if not os.path.exists(local_emb_path): + local_emb_path = hf_hub_download(path, emb_path) + weights = torch.load(local_emb_path) + tensor = weights["model.embed_tokens.weight"].float() + self.embed_tokens.weight.data = tensor + + self.top_k = top_k + self.total_tokens = total_tokens - 1 + self.depth = depth + self.threshold = math.log(threshold) + # print("total_tokens",total_tokens) + # print("depth",depth) + # print("top_k",top_k) + # print("threshold",threshold) + + self.layers = nn.ModuleList([LlamaDecoderLayer(config, index) for index in range(config.num_hidden_layers)]) + self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=bias) + self.act = ACT2FN[config.hidden_act] + self.logsoftmax = nn.LogSoftmax(dim=-1) + for param in self.embed_tokens.parameters(): + param.requires_grad = False + + def init_tree(self): + self.tree_mask_init = torch.eye(self.top_k, device=self.embed_tokens.weight.device)[None, None] + self.position_ids = torch.zeros(self.top_k, device=self.embed_tokens.weight.device, dtype=torch.long) + self.tree_mask_init = self.tree_mask_init.to(self.embed_tokens.weight.device) + + def reset(self): + self.tree_mask = None + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + # inputs_embeds.dtype, + torch.float32, # [MODIFIED] force to cast to float32 + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, torch.float32, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + # [MODIFIED] add tree mask + if hasattr(self, "tree_mask") and self.tree_mask is not None: + tree_mask = self.tree_mask + _, _, tree_shape0, tree_shape1 = tree_mask.shape + combined_attention_mask[:, :, -tree_shape0:, -tree_shape1:][ + tree_mask == 0 + ] = torch.finfo(torch.float32).min + + return combined_attention_mask + + def forward( + self, + hidden_states, + input_ids, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + std=None + ): + batch_size, seq_length, _ = hidden_states.shape + seq_length_with_past = seq_length + past_key_values_length = 0 + + with torch.no_grad(): + inputs_embeds = self.embed_tokens(input_ids) + # inputs_embeds = inputs_embeds.detach() + + # if std is not None: + # noise = torch.randn(inputs_embeds.size(),device=inputs_embeds.device) * std + # inputs_embeds=inputs_embeds+noise + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if position_ids is None: + device = hidden_states.device if hidden_states is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + #position_ids=position_ids//4 + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + # if self.gradient_checkpointing and self.training: + # if use_cache: + # use_cache = False + + # hidden_states=self.act(self.fc(torch.cat((inputs_embeds,hidden_states),dim=-1))) + inputs_embeds = inputs_embeds.to(hidden_states.dtype) + hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1)) + + all_hidden_states = () if output_hidden_states else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if use_cache: + return hidden_states, next_decoder_cache + + return hidden_states + + def reset_kv(self): + self.stable_kv = None + + @torch.no_grad() + def topK_genrate(self, hidden_states, input_ids, head, logits_processor): + + input_ids = input_ids.to(hidden_states.device) + total_tokens = self.total_tokens + depth = self.depth + top_k = self.top_k + + sample_token = input_ids[:, -1] + + scores_list = [] + parents_list = [] + ss_token = [] + + input_ids = input_ids[:, 1:] + input_ids = input_ids.to(hidden_states.device) + + len_posi = input_ids.shape[1] + self.reset() + + # with Timer("draft many"): + if hasattr(self, "stable_kv") and self.stable_kv is not None: + kv_len = self.stable_kv[0][0].shape[2] + out_hidden, past_key_values = self(hidden_states, input_ids=input_ids[:, kv_len:], + past_key_values=self.stable_kv, use_cache=True) + else: + out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, use_cache=True) + self.stable_kv = past_key_values + last_hidden = out_hidden[:, -1] + + last_headout = head(last_hidden) + + last_p = self.logsoftmax(last_headout) + top = torch.topk(last_p, top_k, dim=-1) + topk_index, topk_p = top.indices, top.values + scores = topk_p[0] + scores_list.append(scores[None]) + parents_list.append(torch.zeros(1, dtype=torch.long, device=scores.device)) + ss_token.append(topk_index) + input_ids = topk_index + input_hidden = last_hidden[None].repeat(1, top_k, 1) + tree_mask = self.tree_mask_init + topk_cs_index = torch.arange(top_k, device=self.embed_tokens.weight.device) + + # 4 + for i in range(depth): + self.tree_mask = tree_mask + position_ids = len_posi + self.position_ids + # with Timer("draft one"): + out_hidden, past_key_values = self(input_hidden, input_ids=input_ids, past_key_values=past_key_values, + position_ids=position_ids, use_cache=True) + len_posi += 1 + + # with Timer("sort1"): + bias1 = top_k if i > 0 else 0 + bias2 = max(0, i - 1) + bias = 1 + top_k ** 2 * bias2 + bias1 + parents = (topk_cs_index + bias) + parents_list.append(parents) + + last_headout = head(out_hidden[0]) + last_p = self.logsoftmax(last_headout) + + top = torch.topk(last_p, top_k, dim=-1) + topk_index, topk_p = top.indices, top.values + + cu_scores = topk_p + scores[:, None] + + topk_cs = torch.topk(cu_scores.view(-1), top_k, dim=-1) + topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values + scores = topk_cs_p + + out_ids = topk_cs_index // top_k + input_hidden = out_hidden[:, out_ids] + + input_ids = topk_index.view(-1)[topk_cs_index][None] + + ss_token.append(topk_index) + scores_list.append(cu_scores) + tree_mask = torch.cat((tree_mask[:, :, out_ids], self.tree_mask_init), dim=3) + + + + scores_list = torch.cat(scores_list, dim=0).view(-1) + ss_token_list = torch.cat(ss_token, dim=0).view(-1) + top_scores = torch.topk(scores_list, total_tokens, dim=-1) + top_scores_index = top_scores.indices + top_scores_index = torch.sort(top_scores_index).values + + draft_tokens = ss_token_list[top_scores_index] + draft_tokens = torch.cat((sample_token, draft_tokens), dim=0) + + draft_parents = torch.cat(parents_list, dim=0)[top_scores_index // top_k].long() + mask_index = torch.searchsorted(top_scores_index, draft_parents - 1, right=False) + # mask_index[(top_scores_index[mask_index]!=draft_parents - 1)]=-1 + mask_index[draft_parents == 0] = -1 + mask_index = mask_index + 1 + mask_index_list = mask_index.tolist() + # with Timer("mask"): + tree_mask = torch.eye(total_tokens + 1).bool() + tree_mask[:, 0] = True + for i in range(total_tokens): + tree_mask[i + 1].add_(tree_mask[mask_index_list[i]]) + + + tree_position_ids = torch.sum(tree_mask, dim=1) - 1 + + tree_mask = tree_mask.float()[None, None] + draft_tokens = draft_tokens[None] + + del parents_list, scores_list, ss_token, ss_token_list, draft_parents + + # with Timer("retrieve"): + + max_depth = torch.max(tree_position_ids) + 1 + noleaf_index = torch.unique(mask_index).tolist() + noleaf_num = len(noleaf_index) - 1 + leaf_num = total_tokens - noleaf_num + + retrieve_indices = torch.zeros(leaf_num, max_depth.item(), dtype=torch.long) - 1 + retrieve_indices = retrieve_indices.tolist() + + rid = 0 + position_ids_list = tree_position_ids.tolist() + + for i in range(total_tokens + 1): + if i not in noleaf_index: + cid = i + depth = position_ids_list[i] + for j in reversed(range(depth + 1)): + retrieve_indices[rid][j] = cid + cid = mask_index_list[cid - 1] + rid += 1 + + if logits_processor is not None: + maxitem = total_tokens + 5 + + def custom_sort(lst): + # sort_keys=[len(list)] + sort_keys = [] + for i in range(len(lst)): + sort_keys.append(lst[i] if lst[i] >= 0 else maxitem) + return sort_keys + + retrieve_indices = sorted(retrieve_indices, key=custom_sort) + + retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) + del mask_index, mask_index_list, noleaf_index, noleaf_num, leaf_num, max_depth, rid + tree_position_ids = tree_position_ids.to(hidden_states.device) + + return draft_tokens, retrieve_indices, tree_mask, tree_position_ids + + + + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters()) + + +if __name__ == "__main__": + config = EConfig.from_pretrained('config.json') + model = Model(config, load_emb=False) + print(model) \ No newline at end of file diff --git a/eagle/model/configs.py b/eagle/model/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..0e11c92bf93260ba9b06dcff781c8b0af113e5d7 --- /dev/null +++ b/eagle/model/configs.py @@ -0,0 +1,147 @@ +from transformers.configuration_utils import PretrainedConfig +class EConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + pretraining_tp (`int`, *optional*, defaults to `1`): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format + is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + + Example: + + ```python + >>> from transformers import LlamaModel, LlamaConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LlamaConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = LlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_scaling=None, + rope_theta=10000, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + # self._rope_scaling_validation() + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + # def _rope_scaling_validation(self): + # """ + # Validate the `rope_scaling` configuration. + # """ + # if self.rope_scaling is None: + # return + + # if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + # raise ValueError( + # "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + # f"got {self.rope_scaling}" + # ) + # rope_scaling_type = self.rope_scaling.get("type", None) + # rope_scaling_factor = self.rope_scaling.get("factor", None) + # if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + # raise ValueError( + # f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + # ) + # if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + # raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") \ No newline at end of file diff --git a/eagle/model/configuration_minicpm.py b/eagle/model/configuration_minicpm.py new file mode 100644 index 0000000000000000000000000000000000000000..077e27ddbb862cb3008b25ba11dba8f91445adb0 --- /dev/null +++ b/eagle/model/configuration_minicpm.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2025 The OpenBMB Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MiniCPM model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +MINICPM_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class MiniCPMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the MiniCPM-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MiniCPMModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens, + MiniCPM 2 up to 4096, CodeMiniCPM up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import MiniCPMModel, MiniCPMConfig + >>> # Initializing a MiniCPM minicpm-7b style configuration + >>> configuration = MiniCPMConfig() + >>> # Initializing a model from the minicpm-7b style configuration + >>> model = MiniCPMModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = 'minicpm' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act='silu', + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=True, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + scale_emb=1, + dim_model_base=1, + scale_depth=1, + mup_denominator=32, + sparse_config=None, + **kwargs): + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.scale_emb = scale_emb + self.dim_model_base = dim_model_base + self.scale_depth = scale_depth + # only used for Eagle Head + self.mup_denominator = mup_denominator + + # sparse config + self.sparse_config = sparse_config + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + try: + import flash_attn + self._attn_implementation = 'flash_attention_2' + except: + pass + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, ' + f'got {self.rope_scaling}' + ) + rope_scaling_type = self.rope_scaling.get('type', None) + rope_scaling_factor = self.rope_scaling.get('factor', None) + if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") \ No newline at end of file diff --git a/eagle/model/ea_model.py b/eagle/model/ea_model.py new file mode 100644 index 0000000000000000000000000000000000000000..55b43ed46489a0b441754a7545e4e5a532c30ad5 --- /dev/null +++ b/eagle/model/ea_model.py @@ -0,0 +1,582 @@ +import copy +import json +import time + +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer +import os +from transformers import PreTrainedModel, PretrainedConfig, AutoConfig + +from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM +from .modeling_mixtral_kv import MixtralForCausalLM as KVMixtralForCausalLM +#from .modeling_qwen2_kv import LlamaForCausalLM as KVQwen2ForCausalLM +from .modeling_qwen2_kv import Qwen2ForCausalLM as KVQwen2ForCausalLM +from .utils import * +from .kv_cache import initialize_past_key_values + +from .cnets import Model +from .cnets1 import Model as Model1 +from .configs import EConfig + +""" Modified to support Eagle-3, marked by xxx """ +# from .modeling_minicpm_kv import HackConvertMiniCPMForCausalLM as KVMiniCPMForCausalLM # convert opensource impl to llama +from .modeling_minicpm_kv import MiniCPMForCausalLM as KVMiniCPMForCausalLM # use modified opensource impl + +class EaModel(nn.Module): + + def __init__( + self, + use_eagle3, + base_model, + base_model_name_or_path, + ea_model_path, + total_token, + depth, + top_k, + threshold, + ea_layer_state_dict, + ): + + super().__init__() + self.base_model = base_model + self.config = base_model.config + self.hidden_size = base_model.lm_head.weight.shape[-1] + self.vocab_size = base_model.lm_head.weight.shape[0] + self.base_model_name_or_path = base_model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path, use_fast=False) + self.use_eagle3 = use_eagle3 + config = EConfig.from_pretrained(ea_model_path) + with open(ea_model_path, "r") as f: + con = json.loads(f.read()) + try: + bias = con["bias"] + except: + bias = True + if use_eagle3: + self.ea_layer = Model(config, bias=bias, total_tokens=total_token, depth=depth, top_k=top_k, + threshold=threshold, path=base_model_name_or_path,load_emb=True) + else: + self.ea_layer = Model1(config, bias=bias, total_tokens=total_token, depth=depth, top_k=top_k, + threshold=threshold, path=base_model_name_or_path,load_emb=True) + + low_memory = False + + device = base_model.model.layers[-1].self_attn.q_proj.weight.device + if device != base_model.lm_head.weight.device: + self.ea_layer.diff_device = True + if not low_memory: + self.ea_layer.headweight = base_model.lm_head.weight.clone().to(device) + else: + self.ea_layer.layer_device = device + + else: + self.ea_layer.diff_device = False + if self.use_eagle3 and config.vocab_size==config.draft_vocab_size: + del self.ea_layer.d2t,self.ea_layer.t2d + load_=self.ea_layer.load_state_dict(ea_layer_state_dict, strict=False) + self.ea_layer.to(self.base_model.dtype).to(device) + self.ea_layer.init_tree() + + def get_tokenizer(self): + """Get the tokenizer of the base model. + + Returns: + Tokenizer: The tokenizer of the base model. + """ + return self.tokenizer + + @classmethod + def from_pretrained( + cls, + use_eagle3=True, + base_model_path=None, + ea_model_path=None, + total_token=60, + depth=7, + top_k=10, + threshold=1.0, + **kwargs, + ): + # assert Type=="LLaMA" or "Mixtral" + Type = AutoConfig.from_pretrained(base_model_path, trust_remote_code=True).architectures[0] + + if Type == 'LlamaForCausalLM': + base_model = KVLlamaForCausalLM.from_pretrained( + base_model_path, **kwargs + ) + elif Type == 'Qwen2ForCausalLM': + base_model = KVQwen2ForCausalLM.from_pretrained( + base_model_path, **kwargs + ) + elif Type == 'MiniCPMForCausalLM': # support MiniCPMForCausalLM + base_model = KVMiniCPMForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True, + ) + # + else: + base_model = KVMixtralForCausalLM.from_pretrained( + base_model_path, **kwargs + ) + + # + # configpath = os.path.join(ea_model_path, "config.json") + # if not os.path.exists(configpath): + # configpath = hf_hub_download(ea_model_path, "config.json") + + # try: + # load_model_path = os.path.join(ea_model_path, "pytorch_model.bin") + # if not os.path.exists(load_model_path): + # load_model_path = hf_hub_download(ea_model_path, "pytorch_model.bin") + # ea_layer_state_dict = torch.load(load_model_path, + # map_location=base_model.device) + # except: + # from safetensors.torch import load_file + # load_model_path = os.path.join(ea_model_path, "model.safetensors") + # if not os.path.exists(load_model_path): + # load_model_path = hf_hub_download(ea_model_path, "model.safetensors") + # ea_layer_state_dict = load_file(load_model_path) + # ------------------------------------------------- + # ### new loading logic to support subfolder on hf api + try: + configpath = os.path.join(ea_model_path, "config.json") + load_model_path = os.path.join(ea_model_path, "pytorch_model.bin") + if not os.path.exists(configpath): + configpath = hf_hub_download(ea_model_path, "config.json") + if not os.path.exists(load_model_path): + load_model_path = hf_hub_download(ea_model_path, "pytorch_model.bin") + except: + folder_names = ea_model_path.split("/") + repo = "/".join(folder_names[:-1]) + subfolder = folder_names[-1] + configpath = hf_hub_download(repo_id = repo, subfolder = subfolder, filename = "config.json") + load_model_path = hf_hub_download(repo_id = repo, subfolder = subfolder, filename = "pytorch_model.bin") + + ea_layer_state_dict = torch.load(load_model_path, map_location=base_model.device) + # + + model = cls( + use_eagle3, + base_model, + base_model_path, + configpath, + total_token, + depth, + top_k, + threshold, + ea_layer_state_dict + ) + + if total_token == -1: + device = model.base_model.model.layers[0].self_attn.q_proj.weight.device + cans = [40, 48, 50, 56, 60] + x = [1, 1.05, 1.07, 1.1, 1.13] + times = [] + + for i in range(len(cans)): + length = cans[i] + input_ids = torch.randint(0, model.config.vocab_size - 200, (1, length)).to(device) + torch.cuda.synchronize() + start_time = time.time() + for _ in range(20): + torch.cuda.synchronize() + with torch.no_grad(): + outputs = model.base_model(input_ids) + torch.cuda.synchronize() + torch.cuda.synchronize() + end_time = time.time() + times.append((end_time - start_time) / x[i]) + total_token = cans[times.index(min(times))] + model.ea_layer.total_tokens = total_token - 1 + + return model + + def forward( + self, + input_ids=None, + attention_mask=None, + past_key_values=None, + output_orig=False, + position_ids=None, + ): + + with torch.inference_mode(): + # Pass input through the base model + outputs = self.base_model.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + if output_orig: + orig = self.base_model.lm_head(outputs[0]) + hidden_states = outputs[0] + + if output_orig: + return outputs, orig, hidden_states + else: + return outputs, hidden_states + + @torch.no_grad() + def eagenerate( + self, + input_ids, + temperature=0.0, + top_p=0.0, + top_k=0.0, + max_new_tokens=512, + max_length=2048, + log=False, + is_llama3=False, + + ): + if is_llama3: + stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") + + + if temperature > 1e-5: + logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) + else: + logits_processor = None + # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" + # Avoid modifying the input_ids in-place + + padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) + input_ids = input_ids.clone() + self.ea_layer.reset_kv() + + # Initialize the past key and value states + if hasattr(self, "past_key_values"): + past_key_values = self.past_key_values + past_key_values_data = self.past_key_values_data + current_length_data = self.current_length_data + # Reset the past key and value states + current_length_data.zero_() + else: + ( + past_key_values, + past_key_values_data, + current_length_data, + ) = initialize_past_key_values(self.base_model,max_length=max_length) + self.past_key_values = past_key_values + self.past_key_values_data = past_key_values_data + self.current_length_data = current_length_data + + input_len = input_ids.shape[1] + reset_tree_mode(self) + # prefill + draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token = initialize_tree( + input_ids, self, past_key_values, logits_processor + ) + new_token = 0 + max_length = max_length - self.ea_layer.total_tokens - 10 + for idx in range(max_length): + # with Timer("all"): + self.base_model.model.tree_mask = tree_mask + + draft_tokens = draft_tokens.to(input_ids.device) + # Target model forward, get logits + logits, hidden_state_new, outputs = tree_decoding( + self, + draft_tokens, + past_key_values, + tree_position_ids, + input_ids, + retrieve_indices, + ) + # retrieve_indices=tree_buffers["retrieve_indices"] + # logits = logits[0, retrieve_indices] + draft_tokens = torch.cat((draft_tokens, padding), dim=1) + candidates = draft_tokens[0, retrieve_indices] + # verification + best_candidate, accept_length, sample_p = evaluate_posterior( + logits, candidates, logits_processor + ) + # print(accept_length) + # Adjusting the input sequence, draft model forward + input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs( + input_ids, + candidates, + best_candidate, + accept_length, + retrieve_indices, + logits_processor, + new_token, + past_key_values_data, + current_length_data, + self, + hidden_state_new, + sample_p + ) + + if is_llama3: + if stop_token_id in input_ids[0, input_len:].tolist(): + break + + if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): + break + if new_token > max_new_tokens: + break + if input_ids.shape[1] > max_length: + break + if not log: + return input_ids + else: + return input_ids, new_token, idx + + @torch.no_grad() + def naivegenerate( + self, + input_ids, + temperature=0.0, + top_p=0.0, + top_k=0.0, + max_new_tokens=512, + max_length=2048, + log=False, + is_llama3=False, + + ): + if is_llama3: + stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") + + + if temperature > 1e-5: + logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) + else: + logits_processor = None + # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" + # Avoid modifying the input_ids in-place + + padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) + input_ids = input_ids.clone() + self.ea_layer.reset_kv() + + # Initialize the past key and value states + if hasattr(self, "past_key_values"): + past_key_values = self.past_key_values + past_key_values_data = self.past_key_values_data + current_length_data = self.current_length_data + # Reset the past key and value states + current_length_data.zero_() + else: + ( + past_key_values, + past_key_values_data, + current_length_data, + ) = initialize_past_key_values(self.base_model,max_length=max_length) + self.past_key_values = past_key_values + self.past_key_values_data = past_key_values_data + self.current_length_data = current_length_data + + input_len = input_ids.shape[1] + reset_tree_mode(self) + outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True) + new_token = 0 + max_length = max_length - self.ea_layer.total_tokens - 10 + for idx in range(max_length): + if logits_processor is not None: + logits = outputs.logits[:, -1] + logits = logits_processor(None, logits) + probabilities = torch.nn.functional.softmax(logits, dim=-1) + input_id = torch.multinomial(probabilities, 1) + else: + input_id = outputs.logits[:, -1:].argmax(dim=-1) + outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values) + input_ids = torch.cat([input_ids, input_id], dim=-1) + new_token += 1 + + if is_llama3: + if stop_token_id in input_ids[0, input_len:].tolist(): + break + + if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): + break + if new_token > max_new_tokens: + break + if input_ids.shape[1] > max_length: + break + if not log: + return input_ids + else: + return input_ids, new_token, idx + + @torch.no_grad() + def ea_generate( + self, + input_ids, + temperature=0.0, + top_p=0.0, + top_k=0.0, + max_new_tokens=512, + max_length=2048, + log=False, + is_llama3=False, + + ): + if is_llama3: + stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") + + + if temperature > 1e-5: + logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) + else: + logits_processor = None + # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" + # Avoid modifying the input_ids in-place + + padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) + input_ids = input_ids.clone() + self.ea_layer.reset_kv() + + # Initialize the past key and value states + if hasattr(self, "past_key_values"): + past_key_values = self.past_key_values + past_key_values_data = self.past_key_values_data + current_length_data = self.current_length_data + # Reset the past key and value states + current_length_data.zero_() + else: + ( + past_key_values, + past_key_values_data, + current_length_data, + ) = initialize_past_key_values(self.base_model,max_length=max_length) + self.past_key_values = past_key_values + self.past_key_values_data = past_key_values_data + self.current_length_data = current_length_data + + input_len = input_ids.shape[1] + reset_tree_mode(self) + draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token = initialize_tree( + input_ids, self, past_key_values, logits_processor + ) + new_token = 0 + max_length = max_length - self.ea_layer.total_tokens - 10 + for idx in range(max_length): + # with Timer("all"): + self.base_model.model.tree_mask = tree_mask + + draft_tokens = draft_tokens.to(input_ids.device) + # with Timer("tree_decoding"): + logits, hidden_state_new, outputs = tree_decoding( + self, + draft_tokens, + past_key_values, + tree_position_ids, + input_ids, + retrieve_indices, + ) + # retrieve_indices=tree_buffers["retrieve_indices"] + # logits = logits[0, retrieve_indices] + draft_tokens = torch.cat((draft_tokens, padding), dim=1) + candidates = draft_tokens[0, retrieve_indices] + best_candidate, accept_length, sample_p = evaluate_posterior( + logits, candidates, logits_processor + ) + # print(accept_length) + # with Timer("update_inference_inputs"): + input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs( + input_ids, + candidates, + best_candidate, + accept_length, + retrieve_indices, + logits_processor, + new_token, + past_key_values_data, + current_length_data, + self, + hidden_state_new, + sample_p + ) + + yield input_ids + + if is_llama3: + if stop_token_id in input_ids[0, input_len:].tolist(): + break + + if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): + break + if new_token > max_new_tokens: + break + if input_ids.shape[1] > max_length: + break + + @torch.no_grad() + def naive_generate( + self, + input_ids, + temperature=0.0, + top_p=0.0, + top_k=0.0, + max_new_tokens=512, + max_length=2048, + log=False, + is_llama3=False, + + ): + if is_llama3: + stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") + + + if temperature > 1e-5: + logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) + else: + logits_processor = None + # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" + # Avoid modifying the input_ids in-place + + padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) + input_ids = input_ids.clone() + self.ea_layer.reset_kv() + + # Initialize the past key and value states + if hasattr(self, "past_key_values"): + past_key_values = self.past_key_values + past_key_values_data = self.past_key_values_data + current_length_data = self.current_length_data + # Reset the past key and value states + current_length_data.zero_() + else: + ( + past_key_values, + past_key_values_data, + current_length_data, + ) = initialize_past_key_values(self.base_model,max_length=max_length) + self.past_key_values = past_key_values + self.past_key_values_data = past_key_values_data + self.current_length_data = current_length_data + + input_len = input_ids.shape[1] + reset_tree_mode(self) + outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True) + new_token = 0 + max_length = max_length - self.ea_layer.total_tokens - 10 + for idx in range(max_length): + if logits_processor is not None: + logits = outputs.logits[:, -1] + logits = logits_processor(None, logits) + probabilities = torch.nn.functional.softmax(logits, dim=-1) + input_id = torch.multinomial(probabilities, 1) + else: + input_id = outputs.logits[:, -1:].argmax(dim=-1) + + outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values) + input_ids = torch.cat([input_ids, input_id], dim=-1) + new_token += 1 + + yield input_ids + + if is_llama3: + if stop_token_id in input_ids[0, input_len:].tolist(): + break + + if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): + break + if new_token > max_new_tokens: + break + if input_ids.shape[1] > max_length: + break diff --git a/eagle/model/kv_cache.py b/eagle/model/kv_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..ab89abe844794d6251865932c1d163dee71f11bd --- /dev/null +++ b/eagle/model/kv_cache.py @@ -0,0 +1,157 @@ +import torch + + +class KVCache: + """ + A key-value cache for the model. + + This class provides a mechanism to maintain a growing cache of keys and values, + particularly useful for models that benefit from caching previous states, + like transformers during autoregressive decoding. + + Attributes: + data (torch.Tensor): The tensor storing keys and values. + current_length (int): Current length of the data being stored. + """ + + def __init__(self, data, current_length): + """ + Initialize the KVCache. + + Args: + data (torch.Tensor): Initial tensor to store the keys and values. + current_length (int): Initial length of the data. + """ + self.data = data + self.current_length = current_length + + @property + def shape(self): + """Return the shape of the data tensor with updated length.""" + return ( + self.data.shape[0], + self.data.shape[1], + self.current_length.item(), + self.data.shape[3], + ) + + def copy(self, indices: torch.Tensor, prev_length: int, dim: int = 2): + """ + Copy values from the current data at specified indices to a new location. + + Args: + indices (torch.Tensor): Indices of the data tensor to be copied. + prev_length (int): Previous length before adding new data. + dim (int, optional): Dimension along which copying should be performed. Default is 2. + """ + tgt = self.data.index_select(dim, indices) + dst = self.data.narrow(dim, prev_length, tgt.shape[dim]) + dst.copy_(tgt, non_blocking=True) + self.current_length.fill_(prev_length + tgt.shape[dim]) + + def cat(self, tensor: torch.Tensor, dim: int = 2): + """ + Concatenate the given tensor with the current data. + + Args: + tensor (torch.Tensor): The tensor to be concatenated. + dim (int, optional): The dimension along which concatenation should be done. Default is 2. + + Returns: + torch.Tensor: The data tensor after concatenation up to the current length. + """ + dst = self.data.narrow(dim, self.current_length, tensor.shape[dim]) + dst.copy_(tensor) + self.current_length.add_(tensor.shape[dim]) + return torch.narrow(self.data, 2, 0, self.current_length) + + +def initialize_past_key_values(model,max_length=2200): + """ + Initialize past key and value states for a given transformer model. + + This function prepares key-value cache structures for the model, allowing it to store and reuse + past key and value states during autoregressive decoding, which can improve efficiency. + + Args: + model (nn.Module): The transformer model for which past key-value states need to be initialized. + + Returns: + tuple: + - past_key_values (list): A list of KVCache objects for each layer in the model. + - past_key_values_data (torch.Tensor): The tensor that will store all keys and values. + - current_length_data (torch.Tensor): A tensor tracking the current length of keys/values in the cache. + """ + # Extracting configuration from the model + config = model.config + # Initializing the batch size to 1, this can be modified if different batch sizes are required + batch_size = 1 + # Initializing a tensor to store past keys and values for all layers + + devices=[] + for i in range(config.num_hidden_layers): + try: + device = model.model.layers[i].self_attn.q_proj.weight.device + except: + device=model.layers[i].self_attn.q_proj.weight.device + devices.append(device) + past_key_values_data_list=[] + startnum=0 + startdevice=devices[0] + for id,i in enumerate(devices): + if startdevice!=i: + past_key_values_data = torch.zeros( + startnum * 2, + batch_size, + config.num_key_value_heads, + max_length, + config.hidden_size // config.num_attention_heads, + device=startdevice, + dtype=model.dtype, + ) + past_key_values_data_list.append(past_key_values_data) + startdevice = i + startnum=0 + startnum += 1 + past_key_values_data = torch.zeros( + startnum * 2, + batch_size, + config.num_key_value_heads, + max_length, + config.hidden_size // config.num_attention_heads, + device=startdevice, + dtype=model.dtype, + ) + past_key_values_data_list.append(past_key_values_data) + # Initialize tensor to store the current length of the cached data for all layers. + # [IMPORTANT] It needs to be kept on CPU for quick access and updates. + current_length_data = torch.zeros( + config.num_hidden_layers * 2, dtype=torch.long, device="cpu" + ) + # Creating a KVCache for each pair of key and value in all layers + past_key_values = [] * config.num_hidden_layers + + bias=0 + start_data_m=devices[0].index + for i in range(config.num_hidden_layers): + data_m=devices[i].index + if data_m!=start_data_m: + bias=0 + start_data_m=data_m + try: + past_key_values.append( + [ + KVCache(past_key_values_data_list[data_m-devices[0].index][2*bias + j], current_length_data[i * 2 + j]) + for j in range(2) + ] + ) + except: + past_key_values.append( + [ + KVCache(past_key_values_data_list[0][2 * bias + j], + current_length_data[i * 2 + j]) + for j in range(2) + ] + ) + bias+=1 + return past_key_values, past_key_values_data_list, current_length_data diff --git a/eagle/model/modeling_llama_kv.py b/eagle/model/modeling_llama_kv.py new file mode 100644 index 0000000000000000000000000000000000000000..05c1321882a7a8f94e5795fbc1c2a2d31fa85f81 --- /dev/null +++ b/eagle/model/modeling_llama_kv.py @@ -0,0 +1,1597 @@ +# Source: https://github.com/huggingface/transformers/blob/v4.31-release/src/transformers/models/llama/modeling_llama.py +# Modifications are denoted by the symbol: [MODIFIED] + + +""" PyTorch LLaMA model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +# [MODIFIED] Import from transformer library +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers import LlamaConfig + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Create a causal mask for bi-directional self-attention. + + Args: + input_ids_shape (torch.Size): The shape of input_ids tensor, typically (batch_size, tgt_len). + dtype (torch.dtype): The data type of the mask. + device (torch.device): The device on which the mask will be placed. + past_key_values_length (int, optional): The length of past key values. Default is 0. + + Returns: + torch.Tensor: The causal mask tensor. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expand attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + + Args: + mask (torch.Tensor): The attention mask tensor of shape `[bsz, seq_len]`. + dtype (torch.dtype): The data type of the mask. + tgt_len (Optional[int], optional): The target sequence length. If None, it defaults to the source sequence length. + + Returns: + torch.Tensor: The expanded mask tensor. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + + + +class LlamaRMSNorm(nn.Module): + """ + LlamaRMSNorm is equivalent to T5LayerNorm. + + Args: + hidden_size (int): The size of the hidden states. + eps (float, optional): A small value to prevent division by zero. Default is 1e-6. + """ + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + """ + Apply LlamaRMSNorm to the input hidden states. + + Args: + hidden_states (torch.Tensor): Input hidden states. + + Returns: + torch.Tensor: The normalized and scaled hidden states. + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class LlamaRotaryEmbedding(nn.Module): + """ + Llama Rotary Positional Embedding Module. + + Args: + dim (int): The dimension of the embedding. + max_position_embeddings (int, optional): The maximum position for embeddings. Default is 2048. + base (int, optional): The base value for rotational encoding. Default is 10000. + device (str, optional): The device on which the computation will be performed. Default is None. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + """ + Set the cosine and sine cache for positional embeddings. + + Args: + seq_len (int): The sequence length. + device (str): The device on which the cache tensors will be stored. + dtype: The data type of the cache tensors. + """ + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + def forward(self, x, seq_len=None): + """ + Forward pass of the LlamaRotaryEmbedding module. + + Args: + x (torch.Tensor): Input tensor of shape [bs, num_attention_heads, seq_len, head_size]. + seq_len (int): The sequence length. If greater than the cached length, the cache will be updated. + + Returns: + tuple: A tuple containing two tensors, the cosine and sine embeddings, both of shape [1, 1, seq_len, dim]. + """ + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class LlamaRotaryEmbedding_L31(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """ + LlamaRotaryEmbedding extended with linear scaling. + + This class adds linear scaling to LlamaRotaryEmbedding. Credits to the Reddit user /u/kaiokendev. + + Args: + dim (int): The dimension of the embedding. + max_position_embeddings (int, optional): The maximum number of position embeddings. Default is 2048. + base (int, optional): The base value for the rotational embeddings. Default is 10000. + device (str or torch.device, optional): The device where the embeddings should be stored. Default is None. + scaling_factor (float, optional): The scaling factor for the embeddings. Default is 1.0. + """ + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + """ + Set the cosine and sine cache for the rotary embeddings. + + Args: + seq_len (int): The sequence length. + device (str or torch.device): The device where the cache should be stored. + dtype: The data type for the cache. + """ + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """ + LlamaRotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla. + """ + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + """ + Initialize the LlamaDynamicNTKScalingRotaryEmbedding. + + Args: + dim (int): The dimensionality of the embedding. + max_position_embeddings (int, optional): Maximum number of position embeddings. Default is 2048. + base (int, optional): Base value for scaling calculations. Default is 10000. + device: The device to place tensors on. If None, uses the default device. + scaling_factor (float, optional): Scaling factor for NTK scaling. Default is 1.0. + """ + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + """ + Set the cached values for cosine and sine. + + Args: + seq_len (int): The sequence length. + device: The device to place tensors on. + dtype: The data type of tensors. + """ + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + +def rotate_half(x): + """ + Rotates half the hidden dimensions of the input. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Tensor with half of its hidden dimensions rotated. + """ + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + """ + Apply rotary position embeddings to query and key tensors. + + Args: + q (torch.Tensor): Query tensor. + k (torch.Tensor): Key tensor. + cos (torch.Tensor): Cosine values. + sin (torch.Tensor): Sine values. + position_ids (torch.Tensor): Position IDs. + + Returns: + torch.Tensor: Query and key tensors with rotary position embeddings applied. + """ + cos = cos.squeeze(1).squeeze(0) + sin = sin.squeeze(1).squeeze(0) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_rotary_pos_emb_L31(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + """ + LlamaMLP is a multi-layer perceptron module used in the Llama model. + + Args: + config: The configuration for the MLP. + + Attributes: + pretraining_tp (int): The pretraining time periods. + hidden_size (int): The size of the hidden layer. + intermediate_size (int): The size of the intermediate layer. + gate_proj (nn.Linear): The linear projection for gating. + up_proj (nn.Linear): The linear projection for the up projection. + down_proj (nn.Linear): The linear projection for the down projection. + act_fn: The activation function. + + """ + + def __init__(self, config): + super().__init__() + self.pretraining_tp = config.pretraining_tp + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + """ + Forward pass of the MLP. + + Args: + x: Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + if self.pretraining_tp > 1: + slice = self.intermediate_size // self.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], + dim=-1, + ) + up_proj = torch.cat( + [F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], + dim=-1, + ) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) + for i in range(self.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + Repeat key and value tensors n times along the specified dimension. + + Args: + hidden_states (torch.Tensor): Input tensor with shape (batch, num_key_value_heads, seqlen, head_dim). + n_rep (int): Number of times to repeat. + + Returns: + torch.Tensor: Repeated tensor with shape (batch, num_key_value_heads * n_rep, seqlen, head_dim). + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class HackMiniCPMLongRoPE(LlamaRotaryEmbedding): + """https://huggingface.co/openbmb/MiniCPM4.1-8B/blob/main/modeling_minicpm.py""" + """Extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, short_factor=None, long_factor=None, original_max_position_embeddings=None): + self.short_factor = short_factor + self.long_factor = long_factor + self.original_max_position_embeddings = original_max_position_embeddings + scale = (max_position_embeddings / self.original_max_position_embeddings) + self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device) + + freqs = torch.mul( + torch.outer(t, 1.0 / ext_factors).to(device=device), + self.inv_freq.to(device=device).to(dtype) + ) + # # Different from paper, but it uses a different permutation in order to obtain the same calculation + # emb = torch.cat((freqs, freqs), dim=-1) + # self.register_buffer('cos_cached', emb.cos().to(dtype) * self.scaling_factor, persistent=False) + # self.register_buffer('sin_cached', emb.sin().to(dtype) * self.scaling_factor, persistent=False) + + + # t = t / ext_factors + # # t = t / self.scaling_factor + + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # # Different from paper, but it uses a different permutation in order to obtain the same calculation + + # 250911 + # [DIFF] shape + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + # # 250914 prev modification forgot to add scaling factor + # # [DIFF] shape + # emb = torch.cat((freqs, freqs), dim=-1) + # self.register_buffer( + # "cos_cached", emb.cos()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False + # ) + # self.register_buffer( + # "sin_cached", emb.sin()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False + # ) + +class LlamaAttention(nn.Module): + """ + LlamaAttention is a multi-headed attention module based on the 'Attention Is All You Need' paper. + + Args: + config (LlamaConfig): Configuration for the attention module. + + Attributes: + config (LlamaConfig): Configuration for the attention module. + hidden_size (int): The size of the hidden layer. + num_heads (int): The number of attention heads. + head_dim (int): The dimension of each attention head. + num_key_value_heads (int): The number of key-value attention heads. + num_key_value_groups (int): The number of key-value groups. + pretraining_tp (int): The pretraining time periods. + max_position_embeddings (int): The maximum position embeddings. + + """ + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.pretraining_tp = config.pretraining_tp + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.config.rope_theta + ) + else: + + # add: Support MiniCPM4.1-8B | JQZ 250910 + try: + assert "rope_type" in self.config.rope_scaling.keys() + assert self.config.rope_scaling["rope_type"] == "longrope" + scaling_type = "longrope" + except: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + # /add + + + # scaling_type == "longrope": # add: Support MiniCPM4.1-8B | JQZ 250910 + self.rotary_emb = HackMiniCPMLongRoPE( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + short_factor=self.config.rope_scaling["short_factor"], + long_factor=self.config.rope_scaling["long_factor"], + original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"], + ) + + # try: + # scaling_type = self.config.rope_scaling["type"] + # scaling_factor = self.config.rope_scaling["factor"] + # if scaling_type == "linear": + # self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + # self.head_dim, + # max_position_embeddings=self.max_position_embeddings, + # scaling_factor=scaling_factor, + # base=self.config.rope_theta, + # ) + # elif scaling_type == "dynamic": + # self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + # self.head_dim, + # max_position_embeddings=self.max_position_embeddings, + # scaling_factor=scaling_factor, + # base=self.config.rope_theta, + # ) + # else: + # raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + # except: + # # print("For LLaMA 31") + # self.rotary_emb = LlamaRotaryEmbedding_L31(config=self.config) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.pretraining_tp > 1: + key_value_slicing = ( + self.num_key_value_heads * self.head_dim + ) // self.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) + for i in range(self.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) + for i in range(self.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) + for i in range(self.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + if isinstance(self.rotary_emb, LlamaRotaryEmbedding_L31): + cos, sin = self.rotary_emb(query_states,position_ids) + query_states, key_states = apply_rotary_pos_emb_L31(query_states, key_states, cos, sin) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + # [MODIFIED] Using KVCache mechanism for preallocated GPU memory optimization + # past_key_value is utilized to leverage previously computed key and value states. + # If past_key_value is available, reuse the states for k, v, and self_attention. + if past_key_value is not None: + key_states = past_key_value[0].cat(key_states, dim=2) + value_states = past_key_value[1].cat(value_states, dim=2) + # Reset past_key_value to avoid return past_key_value. + past_key_value = None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.pretraining_tp > 1: + attn_output = attn_output.split( + self.hidden_size // self.pretraining_tp, dim=2 + ) + o_proj_slices = self.o_proj.weight.split( + self.hidden_size // self.pretraining_tp, dim=1 + ) + attn_output = sum( + [ + F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.pretraining_tp) + ] + ) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaDecoderLayer(nn.Module): + """ + LlamaDecoderLayer represents a single layer of the Llama decoder. + + Args: + config (LlamaConfig): Configuration for the decoder layer. + + Attributes: + hidden_size (int): The size of the hidden layer. + self_attn (LlamaAttention): Multi-headed self-attention module. + mlp (LlamaMLP): Multi-layer perceptron module. + input_layernorm (LlamaRMSNorm): Layer normalization for input. + post_attention_layernorm (LlamaRMSNorm): Layer normalization after self-attention. + """ + + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Forward pass for the LlamaDecoderLayer. + + Args: + hidden_states (torch.FloatTensor): Input tensor of shape `(batch, seq_len, embed_dim)`. + attention_mask (torch.FloatTensor, optional): Attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (torch.LongTensor, optional): Positional IDs tensor. + past_key_value (Tuple[torch.FloatTensor], optional): Cached past key and value projection states. + output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. + use_cache (bool, optional): If set to `True`, `past_key_values` key-value states are returned and can be + used to speed up decoding. + + Returns: + Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: Tuple containing: + - hidden_states (torch.FloatTensor): Output tensor. + - self_attn_weights (Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]): Self-attention weights if + `output_attentions` is `True`. + - present_key_value (Optional[Tuple[torch.FloatTensor]]): Cached key and value projection states if + `use_cache` is `True`. + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + # inputs_embeds.dtype, + torch.float32, # [MODIFIED] force to cast to float32 + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + if hasattr(self, "tree_mask") and self.tree_mask is not None: + tree_mask = self.tree_mask + tree_len = tree_mask.size(-1) + combined_attention_mask[:, :, -tree_len:, -tree_len:][ + tree_mask == 0 + ] = combined_attention_mask.min() + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values=None, # [MODIFIED] past_key_value is KVCache class + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if 1 else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if idx==len(self.layers)-3 or idx==len(self.layers)//2 or idx==2: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # !!! + # all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values=None, # [MODIFIED] past_key_value is KVCache class + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split( + self.vocab_size // self.pretraining_tp, dim=0 + ) + logits = [ + F.linear(hidden_states, lm_head_slices[i]) + for i in range(self.pretraining_tp) + ] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/eagle/model/modeling_minicpm_kv.py b/eagle/model/modeling_minicpm_kv.py new file mode 100644 index 0000000000000000000000000000000000000000..8b883c710d1f6510632dcf6a0ea25006cf8cf017 --- /dev/null +++ b/eagle/model/modeling_minicpm_kv.py @@ -0,0 +1,2487 @@ +# coding=utf-8 +# Copyright 2025 The OpenBMB Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch MiniCPM model.""" +""" Modified to support Eagle-3, marked by xxx """ +import math +import re +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available + +from .configuration_minicpm import MiniCPMConfig + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from infllm_v2 import ( + infllmv2_attn_stage1, + infllmv2_attn_varlen_func, + infllmv2_attn_with_kvcache, + max_pooling_1d, + max_pooling_1d_varlen + ) +except: + pass + +from functools import lru_cache + +from .modeling_llama_kv import _make_causal_mask, _expand_mask # use eagle's impl + + +def compressed_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + block_size: int, + topk: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float = None, + init_blocks: int = 1, + local_blocks: int = 2, + cache_lens: torch.Tensor = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention. + + Args: + q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] + kernel_size (int): kernel size in compress_key_value + kernel_stride (int): stride of compress_key_value + block_size (int): key value block size for topk sparse attention. + topk (int): number of blocks for each query. + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. + max_seqlen_q (int): max q len of the batch. + max_seqlen_k (int): max k len of the batch. + sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). + init_blocks (int, optional): Number of init blocks for each query. Defaults to 1. + local_blocks (int, optional): Number of local blocks for each query. Defaults to 2. + cache_lens (torch.Tensor, optional): shape [batch_size], used to record the cache length of each query. Defaults to None. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention + """ + with torch.no_grad(): + batch_size = cu_seqlens_q.shape[0] - 1 + + # Check if it's prefilling stage + is_prefilling = cache_lens is None or (cache_lens == 0).all().item() + + # prefilling stage + if is_prefilling: + # Calculate q_idx for each query position in each batch + cache_lens = torch.zeros(batch_size, dtype=torch.int32, device=q.device) + q_idx = torch.cat([ + (torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) + + max_seqlen_q - (cu_seqlens_q[i + 1] - cu_seqlens_q[i])) // block_size + for i in range(batch_size) + ], dim=0) # shape: [total_q_len] + # decoding stage + else: + # Each batch has only one query (last position). Shape: [batch_size] = [total_q_len] in decoding + q_idx = cache_lens // block_size + + # compute attention score + score = infllmv2_attn_stage1( + q.contiguous(), + k.contiguous(), + v.contiguous(), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + causal=is_prefilling) + # Shape: [num_heads, total_q_len, num_blocks] + score = score[:, :q_idx.shape[0], :] + + # Shape: [num_heads, total_q_len, num_blocks] + block_score = max_pooling_1d_varlen( + score.contiguous(), + cu_seqlens_q, + cu_seqlens_k, + cache_lens, + max_seqlen_q, + max_seqlen_k, + local_blocks=local_blocks, + init_blocks=init_blocks, + block_size=block_size, + stride=kernel_stride) + + # get topk + topk = min(topk, block_score.shape[-1]) + topk_idx = block_score.topk(topk, dim=-1).indices.sort(-1).values + topk_idx[topk_idx > q_idx[None, :, None]] = -1 + topk_idx = topk_idx.to(torch.int32) + + return topk_idx + + +@lru_cache(maxsize=16) +def calc_chunks_with_stride(cu_seqlen, chunk_size, kernel_stride): + """ + Compute the chunks that require Sparse attention, with stride support. + + Args: + cu_seqlen (torch.Tensor): Cumulative sequence lengths for each sample. + chunk_size (int): Chunk size used for Sparse attention. + kernel_stride (int): Stride size when sliding over the sequence. + + Returns: + filtered_indices (torch.Tensor): Indices used to directly index into the key/value tensors. + cu_seqlens_compressed (torch.Tensor): Cumulative sequence lengths after compression. + """ + # 1. Compute the length of each sequence + batch_sizes = cu_seqlen[1:] - cu_seqlen[:-1] + + # 2. Compute the start positions of chunks for each sequence (with stride) + max_seq_len = torch.max(batch_sizes) + max_num_chunks_per_seq = (max_seq_len - chunk_size) // kernel_stride + 1 + chunk_start_offsets = torch.arange(0, max_num_chunks_per_seq * kernel_stride, kernel_stride, device=cu_seqlen.device) + seq_starts = cu_seqlen[:-1] + chunk_start_in_seq = seq_starts[:, None] + chunk_start_offsets[None, :] # [batch_size, max_num_chunks_per_seq] + + # 3. Filter out chunks that exceed sequence length or are smaller than the full chunk size + chunk_end_in_seq = chunk_start_in_seq + chunk_size + valid_chunk_mask = (chunk_end_in_seq <= (seq_starts[:, None] + batch_sizes[:, None])) + + # 4. Filter valid chunk start positions using the valid_chunk_mask + valid_chunk_starts = chunk_start_in_seq[valid_chunk_mask] # [num_valid_chunks] + del chunk_start_in_seq + # 5. Generate filtered_indices + chunk_indices = torch.arange( + 0, chunk_size, device=cu_seqlen.device + )[None, :] # [1, chunk_size] + filtered_indices = valid_chunk_starts[:, None] + chunk_indices # [num_valid_chunks, chunk_size] + filtered_indices = filtered_indices.view(-1) # Flatten to 1D indices + + # 6. Compute compressed cumulative sequence lengths + num_filtered_chunks_per_batch = valid_chunk_mask.sum(dim=1) # Number of valid chunks per batch + cu_seqlens_compressed = torch.zeros( + len(cu_seqlen), dtype=torch.int32, device=cu_seqlen.device + ) + cu_seqlens_compressed[1:] = num_filtered_chunks_per_batch.cumsum(dim=0) + del num_filtered_chunks_per_batch, chunk_start_offsets, seq_starts, chunk_end_in_seq, valid_chunk_mask, chunk_indices + return filtered_indices, cu_seqlens_compressed + + +class CompressK(torch.nn.Module): + def __init__(self, head_num_k, head_dim, kernel_size, kernel_stride=16): + """ + Module for compressing key (K) representations. + + Args: + head_num_k (int): Number of key attention heads. + head_dim (int): Dimension of each attention head. + kernel_size (int): Size of each chunk used for compression. + kernel_stride (int, optional): Stride used when dividing input into chunks. Default is 16. + """ + super().__init__() + self.kernel_size = kernel_size + self.head_num_k = head_num_k + self.head_dim = head_dim + self.kernel_stride = kernel_stride + + def forward(self, k: torch.Tensor, cu_seqlens): + """ + Forward pass for compressing the key (K) tensor. + + Args: + k (torch.Tensor): Input key tensor of shape (total_seq_len, num_heads, head_dim). + cu_seqlens (torch.Tensor): Cumulative sequence lengths for each sample in the batch, typically used for handling variable-length sequences. + + Returns: + compress_k (torch.Tensor): Compressed key tensor. + cu_seqlens_compressed (torch.Tensor): Updated cumulative sequence lengths after compression. + + """ + # Compute chunk-related metadata, with stride support + filtered_k_indices, cu_seqlens_compressed = calc_chunks_with_stride( + cu_seqlens, self.kernel_size, self.kernel_stride + ) + + # Extract filtered key vectors + filtered_k = k.index_select(0, filtered_k_indices.view(-1)) + + # split + filtered_k = filtered_k.view(filtered_k.shape[0] // self.kernel_size, self.kernel_size, self.head_num_k, self.head_dim) # [l, block_size,h,d] + + compressed_k = filtered_k.mean(dim=1) + return compressed_k, cu_seqlens_compressed + + + +class InfLLMv2CacheLayer(DynamicLayer): + def __init__(self): + super().__init__() + # Initialize any additional attributes specific to InfLLMv2CacheLayer + self.no_rope_keys = torch.tensor([], dtype=torch.float32) + self.compress_k_cache = [] + self.no_compress_k_cache = [] + self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32) + self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32) + + def update_no_rope_key(self, key_states): + if self.no_rope_keys.numel() == 0: + self.no_rope_keys = key_states + else: + self.no_rope_keys = torch.cat([self.no_rope_keys, key_states], dim=1) + return self.no_rope_keys + + def update_compress_k(self, key_states, cu_seqlens=None): + if len(self.compress_k_cache) == 0: + if cu_seqlens is not None: + self.cached_compressed_cu_seqlens = cu_seqlens.clone() + self.compress_k_cache_varlen = key_states + split_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + self.compress_k_cache = list(torch.split(key_states, split_sizes)) + else: + for index, k in enumerate(key_states): + if k is not None: + self.compress_k_cache[index] = torch.cat([self.compress_k_cache[index], k], dim=0) + new_seq_lens = torch.tensor([tensor.shape[0] for tensor in self.compress_k_cache], dtype=torch.int32) + new_cumsum = torch.cumsum(new_seq_lens, dim=0, dtype=torch.int32) + + self.compress_k_cache_varlen = torch.cat(self.compress_k_cache, dim=0) + self.cached_compressed_cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), new_cumsum]).to(self.compress_k_cache_varlen.device) + return self.compress_k_cache_varlen, self.cached_compressed_cu_seqlens + + def update_no_compress_k(self, key_states, kernel_size=32, kernel_stride=16): + k_chunk_list = [] + for index, k in enumerate(key_states): + if len(self.no_compress_k_cache) <= index: + self.no_compress_k_cache.append(k) + else: + self.no_compress_k_cache[index] = torch.cat([self.no_compress_k_cache[index], k], dim=0) + current_len = self.no_compress_k_cache[index].shape[0] + if current_len >= kernel_size: + k_chunk_list.append(self.no_compress_k_cache[index][:kernel_size]) + self.no_compress_k_cache[index] = self.no_compress_k_cache[index][kernel_stride:] + else: + k_chunk_list.append(None) + return k_chunk_list + +class InfLLMv2Cache(DynamicCache): + def __init__(self, + config,num_hidden_layers: Optional[int] = None) -> None: + super().__init__(config=config) + self.layers = [InfLLMv2CacheLayer() for _ in range(num_hidden_layers)] if num_hidden_layers else [] + self._seen_tokens = 0 + + def update(self, key_states, value_states, layer_idx, cache_kwargs=None): + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + + def update_no_rope_key(self, key_states, layer_idx, cache_kwargs=None): + return self.layers[layer_idx].update_no_rope_key(key_states) + + def update_compress_k(self, key_states, layer_idx, cu_seqlens=None, cache_kwargs=None): + return self.layers[layer_idx].update_compress_k(key_states, cu_seqlens) + + def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None): + return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride) + + def crop(self, max_length): + for layer in self.layers: + layer.crop(max_length) + + def batch_repeat_interleave(self, repeats): + for layer in self.layers: + layer.batch_repeat_interleave(repeats) + + def batch_select_indices(self, indices): + for layer in self.layers: + layer.batch_select_indices(indices) + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = 'MiniCPMConfig' + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + + + +# @torch.jit.script # type: ignore +def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float): + old_dtype = hidden.dtype + variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True) + hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype) + return hidden * weight + + +class MiniCPMRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MiniCPMRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + return rms_layernorm(hidden_states, self.weight, self.variance_epsilon) + + +ALL_LAYERNORM_LAYERS.append(MiniCPMRMSNorm) + + +class MiniCPMRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + # seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32 + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + + self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + # tensor shape diff + # return ( + # self.cos_cached[:seq_len].to(dtype=x.dtype), + # self.sin_cached[:seq_len].to(dtype=x.dtype), + # ) + # ------------------------------------------------- + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + # + + +class MiniCPMLongRoPE(MiniCPMRotaryEmbedding): + """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, short_factor=None, long_factor=None, original_max_position_embeddings=None): + self.short_factor = short_factor + self.long_factor = long_factor + self.original_max_position_embeddings = original_max_position_embeddings + scale = (max_position_embeddings / self.original_max_position_embeddings) + self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device) + + freqs = torch.mul( + torch.outer(t, 1.0 / ext_factors).to(device=device), + self.inv_freq.to(device=device).to(dtype) + ) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + + # tensor shape diff + # self.register_buffer('cos_cached', emb.cos().to(dtype) * self.scaling_factor, persistent=False) + # self.register_buffer('sin_cached', emb.sin().to(dtype) * self.scaling_factor, persistent=False) + # ------------------------------------------------- + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False + ) + # + + +class MiniCPMLinearScalingRotaryEmbedding(MiniCPMRotaryEmbedding): + """MiniCPMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + + # tensor shape diff + # self.register_buffer('cos_cached', emb.cos().to(dtype) * self.scaling_factor, persistent=False) + # self.register_buffer('sin_cached', emb.sin().to(dtype) * self.scaling_factor, persistent=False) + # ------------------------------------------------- + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False + ) + # + + +class MiniCPMDynamicNTKScalingRotaryEmbedding(MiniCPMRotaryEmbedding): + """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + + # tensor shape diff + # self.register_buffer('cos_cached', emb.cos().to(dtype) * self.scaling_factor, persistent=False) + # self.register_buffer('sin_cached', emb.sin().to(dtype) * self.scaling_factor, persistent=False) + # ------------------------------------------------- + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False + ) + # + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + # cos = cos[position_ids].unsqueeze(unsqueeze_dim) + # sin = sin[position_ids].unsqueeze(unsqueeze_dim) + # q_embed = (q * cos) + (rotate_half(q) * sin) + # k_embed = (k * cos) + (rotate_half(k) * sin) + cos = cos.squeeze(1).squeeze(0) + sin = sin.squeeze(1).squeeze(0) + orig_dtype = k.dtype + cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + q_fp32 = q.to(dtype=torch.float32, device=q.device) + k_fp32 = k.to(dtype=torch.float32, device=k.device) + q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin) + k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin) + return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype) + + +class MiniCPMMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + +def _unpad_one_tensor(hidden_states, attention_mask): + # Unpad the hidden states using the indices + indices, cu_seqlens, max_seqlen_in_batch = _get_unpad_data(attention_mask) + batch_size, seq_len = hidden_states.shape[:2] + + # Get the remaining dimensions + remaining_dims = hidden_states.shape[2:] + + # Reshape to (batch_size * seq_len, *remaining_dims) + reshaped_states = hidden_states.reshape(batch_size * seq_len, *remaining_dims) + + # Apply unpadding using indices + unpadded_states = index_first_axis(reshaped_states, indices) + + return unpadded_states, indices, cu_seqlens, max_seqlen_in_batch + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + The hidden states go from + (batch, num_key_value_heads, seqlen, head_dim) + to + (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MiniCPMAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MiniCPMConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f'Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will ' + 'to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` ' + 'when creating this class.' + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}' + f' and `num_heads`: {self.num_heads}).' + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = MiniCPMRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling['rope_type'] + scaling_factor = self.config.rope_scaling.get('factor', None) + if scaling_type == 'linear': + self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == 'dynamic': + self.rotary_emb = MiniCPMDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == 'longrope': + self.rotary_emb = MiniCPMLongRoPE( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + short_factor=self.config.rope_scaling['short_factor'], + long_factor=self.config.rope_scaling['long_factor'], + base=self.rope_theta, + original_max_position_embeddings=self.config.rope_scaling['original_max_position_embeddings'] + ) + else: + raise ValueError(f'Unknown RoPE scaling type {scaling_type}') + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`' + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = position_ids.max().item() + 1 + cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # + # if past_key_value is not None: + # cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # ------------------------------------------------- + # ### Copied from modeling_llama_kv.py, Line 709. class LlamaAttention, function forward(). + # ### [MODIFIED] Using KVCache mechanism for preallocated GPU memory optimization + # ### past_key_value is utilized to leverage previously computed key and value states. + # ### If past_key_value is available, reuse the states for k, v, and self_attention. + if past_key_value is not None: + key_states = past_key_value[0].cat(key_states, dim=2) + value_states = past_key_value[1].cat(value_states, dim=2) + # ### Reset past_key_value to avoid return past_key_value. + past_key_value = None + # + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + # raise ValueError( + # f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is' + # f' {attn_weights.size()}' + # ) + + if attention_mask is not None: + # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + # raise ValueError( + # f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' + # ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' + f' {attn_output.size()}' + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MiniCPMFlashAttention2(MiniCPMAttention): + """ + MiniCPM flash attention module. This module inherits from `MiniCPMAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # MiniCPMFlashAttention2 attention does not support output_attentions + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`' + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop('padding_mask') + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = position_ids.max().item() + 1 + cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # + # if past_key_value is not None: + # cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # ------------------------------------------------- + # ### Copied from modeling_llama_kv.py, Line 709. class LlamaAttention, function __init__(). + # ### [MODIFIED] Using KVCache mechanism for preallocated GPU memory optimization + # ### past_key_value is utilized to leverage previously computed key and value states. + # ### If past_key_value is available, reuse the states for k, v, and self_attention. + if past_key_value is not None: + key_states = past_key_value[0].cat(key_states, dim=2) + value_states = past_key_value[1].cat(value_states, dim=2) + # ### Reset past_key_value to avoid return past_key_value. + past_key_value = None + # + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (MiniCPMRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f'The input hidden states seems to be silently casted in float32, this might be related to' + f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in' + f' {target_dtype}.' + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class MiniCPMInfLLMv2Attention(MiniCPMAttention): + """ + MiniCPM flash attention module. This module inherits from `MiniCPMAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.config._attn_implementation == 'flash_attention_2', 'Only flash_attention_2 is supported for sparse attention' + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + # -------sparse------- + self.kernel_size = self.config.sparse_config.get('kernel_size', 32) + self.kernel_stride = self.config.sparse_config.get('kernel_stride', 16) + self.init_blocks = self.config.sparse_config.get('init_blocks', 1) + self.block_size = self.config.sparse_config.get('block_size', 64) + self.window_size = self.config.sparse_config.get('window_size', 2048) + self.dense_len = self.config.sparse_config.get('dense_len', 8192) + + self.local_blocks = self.window_size // self.block_size # local_blocks + self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size) + self.use_nope = self.config.sparse_config.get('use_nope', False) + self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # MiniCPMFlashAttention2 attention does not support output_attentions + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`' + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop('padding_mask') + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # !save no rope + if self.use_nope: + query_states_no_rope = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states_no_rope = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = position_ids.max().item() + 1 + cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # + # if past_key_value is not None: + # cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # ------------------------------------------------- + # ### Copied from modeling_llama_kv.py, Line 709. class LlamaAttention, function __init__(). + # ### [MODIFIED] Using KVCache mechanism for preallocated GPU memory optimization + # ### past_key_value is utilized to leverage previously computed key and value states. + # ### If past_key_value is available, reuse the states for k, v, and self_attention. + if past_key_value is not None: + key_states = past_key_value[0].cat(key_states, dim=2) + value_states = past_key_value[1].cat(value_states, dim=2) + # ### Reset past_key_value to avoid return past_key_value. + past_key_value = None + # + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + if self.use_nope: + key_states_no_rope = past_key_value.update_no_rope_key(key_states_no_rope, self.layer_idx) + no_rope_param = { + 'key_states_no_rope': key_states_no_rope, + 'query_states_no_rope': query_states_no_rope, + } + else: + no_rope_param = None + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (MiniCPMRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f'The input hidden states seems to be silently casted in float32, this might be related to' + f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in' + f' {target_dtype}.' + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + if kv_seq_len < self.dense_len: + attn_output = self._flash_attention_forward_dense( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate) + else: + attn_output = self._sparse_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, + no_rope_param=no_rope_param, # if past_key_value is not None else None, + past_key_value=past_key_value) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _sparse_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + no_rope_param=None, + past_key_value=None): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + # assert batch_size == 1, 'Only batch_size=1 is supported at the moment.' + if past_key_value!=None: + compressed_k, compressed_cu_seqlens = self.get_compress_k( + key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit; + attention_mask=attention_mask, + past_key_value=past_key_value) + + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + if no_rope_param != None: + if max_seqlen_in_batch_q == 1: + no_rope_param['query_states_no_rope'] = no_rope_param['query_states_no_rope'].squeeze(1) + else: + no_rope_param['query_states_no_rope'],_, _, _ = _unpad_one_tensor(no_rope_param['query_states_no_rope'],attention_mask=attention_mask) + if past_key_value==None: + # compress_k use varlen form + compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k) + + attn_output_unpad = self.sparse_forward( + query_states, + key_states, + value_states, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_in_batch_q, + max_seqlen_in_batch_k, + no_rope_param=no_rope_param, + compressed_k=compressed_k, + compressed_cu_seqlens=compressed_cu_seqlens) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + raise ValueError('Need attention mask') + + return attn_output + + def get_compress_k(self, key_states, attention_mask, past_key_value): + """ + Get compressed key states and corresponding cumulative sequence lengths. + + Args: + key_states: Key states tensor + cu_seqlens_k: Cumulative sequence lengths for keys + past_key_value: Past key-value cache + no_rope_param: Optional parameter containing key states without rope + + Returns: + Tuple of (compressed_k, compressed_cu_seqlens) + """ + # Check if this is prefilling or initial compression condition + is_prefilling = ( + key_states.shape[1] >= self.dense_len and + ( + not past_key_value.layers[self.layer_idx].compress_k_cache + ) + ) + + if is_prefilling: + unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask) + # Compress the keys + compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens) + + past_key_value.update_compress_k( + compressed_k, self.layer_idx, compressed_cu_seqlens) + + no_compress_k_list = [] + # Compute and update no_compress_k + for i in range(len(compressed_cu_seqlens)-1): + no_compress_k_start = (compressed_cu_seqlens[i+1]- compressed_cu_seqlens[i]) * self.kernel_stride + + no_compress_k_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k_start:cu_seqlens[i+1]].clone()) + + past_key_value.update_no_compress_k( + no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride, + kernel_size=self.kernel_size) + else: + # Decode case: incremental update + batch_size = key_states.shape[0] # key_states.shape = [batch_size, seq, k_head_num, head_dim] + key_states_split = list(torch.split( + key_states[:,-1:].squeeze(1), #[batch_size, seq, k_head_num, head_dim]->[batch_size, 1, k_head_num, head_dim]-> [batch_size, k_head_num, head_dim] + [1] * batch_size,dim=0, + )) + # Try to update no_compress_k buffer + no_compress_k_list = past_key_value.update_no_compress_k( + key_states_split, self.layer_idx, + kernel_stride=self.kernel_stride, + kernel_size=self.kernel_size) + new_compressed_k_list = [] + for no_compress_k in no_compress_k_list: + if no_compress_k is not None: + # We have enough tokens to compress + new_compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim] + new_compressed_k_list.append(new_compressed_k) + else: + new_compressed_k_list.append(None) + compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,) + + return compressed_k, compressed_cu_seqlens + + def sparse_forward(self, + query_layer, + key_layer, + value_layer, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_in_batch_q, + max_seqlen_in_batch_k, + no_rope_param=None, + compressed_k=None, + compressed_cu_seqlens=None): + compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1] + cache_lens = None + if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding + seq_lens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + cache_lens = seq_lens_k-1 + + topk_idx = compressed_attention( + query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'], + compressed_k, + compressed_k.clone(), + self.kernel_size, + self.kernel_stride, + self.block_size, + self.topk, + cu_seqlens_q, + compressed_cu_seqlens, + max_seqlen_in_batch_q, + compressed_seqlens.max().item(), + None, + init_blocks=self.init_blocks, + local_blocks=self.local_blocks, + cache_lens=cache_lens + ) + topk_attn_output = infllmv2_attn_varlen_func( + query_layer, + key_layer, + value_layer, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_in_batch_q, + max_seqlen_in_batch_k, + dropout_p=0.0, + deterministic=False, + softmax_scale=None, + causal=max_seqlen_in_batch_q != 1, + return_attn_probs=False, + # block_window_size=self.window_size // self.block_size, + topk_idx=topk_idx + ) + + return topk_attn_output + + def _flash_attention_forward_dense( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class MiniCPMSdpaAttention(MiniCPMAttention): + """ + MiniCPM attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MiniCPMAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MiniCPMAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + 'MiniCPMModel is using MiniCPMSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, ' + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = position_ids.max().item() + 1 + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # + # if past_key_value is not None: + # cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # ------------------------------------------------- + # ### Copied from modeling_llama_kv.py, Line 709. class LlamaAttention, function __init__(). + # ### [MODIFIED] Using KVCache mechanism for preallocated GPU memory optimization + # ### past_key_value is utilized to leverage previously computed key and value states. + # ### If past_key_value is available, reuse the states for k, v, and self_attention. + if past_key_value is not None: + key_states = past_key_value[0].cat(key_states, dim=2) + value_states = past_key_value[1].cat(value_states, dim=2) + # ### Reset past_key_value to avoid return past_key_value. + past_key_value = None + # + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # skip + # if attention_mask is not None: + # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + # raise ValueError( + # f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' + # ) + # + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == 'cuda' and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # + attention_mask = attention_mask.to(dtype=query_states.dtype) + # + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + +# dev: support sdpa only +# MINICPM_ATTENTION_CLASSES = { +# 'eager': MiniCPMAttention, +# 'flash_attention_2': MiniCPMFlashAttention2, +# 'sdpa': MiniCPMSdpaAttention, +# } +# ------------------------------------------------- +MINICPM_ATTENTION_CLASSES = { + 'eager': MiniCPMAttention, + 'flash_attention_2': MiniCPMAttention, + 'sdpa': MiniCPMAttention, +} +# + +class MiniCPMDecoderLayer(nn.Module): + def __init__(self, config: MiniCPMConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + if config.sparse_config is not None and torch.cuda.is_available(): + self.self_attn = MiniCPMInfLLMv2Attention(config=config, layer_idx=layer_idx) + else: + self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = MiniCPMMLP(config) + self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.scale_depth = config.scale_depth + self.num_hidden_layers = config.num_hidden_layers + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`' + ) + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers)) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers)) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +MINICPM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MiniCPMConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + 'The bare MiniCPM Model outputting raw hidden-states without any specific head on top.', + MINICPM_START_DOCSTRING, +) +class MiniCPMPreTrainedModel(PreTrainedModel): + config_class = MiniCPMConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['MiniCPMDecoderLayer'] + _skip_keys_device_placement = 'past_key_values' + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MINICPM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + 'The bare MiniCPM Model outputting raw hidden-states without any specific head on top.', + MINICPM_START_DOCSTRING, +) +class MiniCPMModel(MiniCPMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`] + + Args: + config: MiniCPMConfig + """ + + def __init__(self, config: MiniCPMConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + # dev: support sdpa only + # self._use_sdpa = config._attn_implementation == 'sdpa' + # self._use_flash_attention_2 = config._attn_implementation == 'flash_attention_2' + # ------------------------------------------------- + self._use_sdpa, self._use_flash_attention_2 = True, False + # + + self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # + # Copied from eagle/model/modeling_llama_kv.py, Line 1010, class LlamaModel, function _prepare_decoder_attention_mask(). + # ### Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + # inputs_embeds.dtype, + torch.float32, # [MODIFIED] force to cast to float32 + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + if hasattr(self, "tree_mask") and self.tree_mask is not None: + tree_mask = self.tree_mask + tree_len = tree_mask.size(-1) + combined_attention_mask[:, :, -tree_len:, -tree_len:][ + tree_mask == 0 + ] = combined_attention_mask.min() + + return combined_attention_mask + # + + + @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values=None, # [MODIFIED] past_key_value is KVCache class + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time') + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError('You have to specify either input_ids or inputs_embeds') + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + past_key_values_length = 0 + + # use eagle tree KVCache + # if use_cache: + # use_legacy_cache = not isinstance(past_key_values, Cache) + # if use_legacy_cache: + # raise ValueError( + # 'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.' + # ) + + # # Calculate the usable length of past key values + # past_key_values_length = past_key_values.get_seq_length() if isinstance(past_key_values, InfLLMv2Cache) else 0 + + # # Initialize InfLLMv2Cache if needed + # if self.config.sparse_config is not None and torch.cuda.is_available() and past_key_values_length == 0: + # past_key_values = InfLLMv2Cache(config = self.config, num_hidden_layers=self.config.num_hidden_layers) + # ------------------------------------------------- + # From modeling_llama_kv.py Line 1088:... + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + # + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + # + # position_ids = position_ids.unsqueeze(0) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + # + else: # + position_ids = position_ids.view(-1, seq_length).long() + # + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb + + # + # 暂时不支持flash attention, 使用 sdpa + # if self._use_flash_attention_2: + # # 2d mask is passed through the layers + # # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + # if attention_mask is None: + # raise ValueError( + # f'need attention_mask for flash attention, but got {attention_mask}.' + # ) + # elif self._use_sdpa and not output_attentions: + # # output_attentions=True can not be supported when using SDPA, and we fall back on + # # the manual implementation that requires a 4D causal mask in all cases. + # attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + # attention_mask, + # (batch_size, seq_length), + # inputs_embeds, + # past_key_values_length, + # ) + # else: + # # 4d mask is passed through the layers + # attention_mask = _prepare_4d_causal_attention_mask( + # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + # ) + # ------------------------------------------------- + if not self._use_sdpa: + raise NotImplementedError("JQZ 250917 | Currently support sdpa **ONLY**, further impl for flash attention or infllm attention not finished yet.") + # # below is copied from modeling_llama_kv.py, Line 1110 + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + # (batch_size, seq_length), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + # + + # embed positions + hidden_states = inputs_embeds + + # + # ### decoder layers + # all_hidden_states = () if output_hidden_states else None + # all_self_attns = () if output_attentions else None + # next_decoder_cache = None + # ------------------------------------------------- + # decoder layers + all_hidden_states = () + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + # + + # + # for decoder_layer in self.layers: + # if output_hidden_states: + # all_hidden_states += (hidden_states,) + + # if self.gradient_checkpointing and self.training: + # layer_outputs = self._gradient_checkpointing_func( + # decoder_layer.__call__, + # hidden_states, + # attention_mask, + # position_ids, + # past_key_values, + # output_attentions, + # use_cache, + # ) + # else: + # layer_outputs = decoder_layer( + # hidden_states, + # attention_mask=attention_mask, + # position_ids=position_ids, + # past_key_value=past_key_values, + # output_attentions=output_attentions, + # use_cache=use_cache, + # ) + + # hidden_states = layer_outputs[0] + + # if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + # if output_attentions: + # all_self_attns += (layer_outputs[1],) + # ------------------------------------------------- + # below is simplified based on modeling_llama_kv.py, Line 1137 + for idx, decoder_layer in enumerate(self.layers): + if idx==len(self.layers)-3 or idx==len(self.layers)//2 or idx==2: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + # + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + # + # next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = next_decoder_cache + # + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class MiniCPMForCausalLM(MiniCPMPreTrainedModel): + _tied_weights_keys = ['lm_head.weight'] + + def __init__(self, config): + super().__init__(config) + self.model = MiniCPMModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values=None, # [MODIFIED] past_key_value is KVCache class + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MiniCPMForCausalLM + + >>> model = MiniCPMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + hidden_states = hidden_states[:, slice_indices, :].contiguous() + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states / (self.config.hidden_size / self.config.dim_model_base)) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + # Use the new Cache class methods + cache_length = past_key_values.get_seq_length() + + if self.config.sparse_config is not None and torch.cuda.is_available() and cache_length == 0: + past_key_values = InfLLMv2Cache(config = self.config, num_hidden_layers=self.config.num_hidden_layers) + past_length = cache_length + max_cache_length = None + else: + raise ValueError( + 'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.' + ) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get('position_ids', None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1]:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update( + { + 'position_ids': position_ids, + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + } + ) + # Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + @torch.inference_mode() + def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = 'user', + max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None, + **kwargs): + if history is None: + history = [] + if logits_processor: + gen_kwargs = { + 'max_length': max_length, + 'num_beams': num_beams, + 'do_sample': do_sample, + 'top_p': top_p, + 'temperature': temperature, + 'logits_processor': logits_processor, + **kwargs + } + else: + gen_kwargs = { + 'max_length': max_length, + 'num_beams': num_beams, + 'do_sample': do_sample, + 'top_p': top_p, + 'temperature': temperature, + 'logits_processor': logits_processor, + **kwargs + } + + history.append({'role': role, 'content': query}) + history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False) + inputs = tokenizer(history_str, return_tensors='pt').to(self.device) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):-1] + response = tokenizer.decode(outputs) + pattern = re.compile(r'.*?(?=|<用户>)', re.DOTALL) + matches = pattern.findall(response) + if len(matches) > 0: + response = matches[0] + history.append({'role': 'assistant', 'content': response}) + return response, history + + +@add_start_docstrings( + """ + The MiniCPM Model transformer with a sequence classification head on top (linear layer). + + [`MiniCPMForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MINICPM_START_DOCSTRING, +) +class MiniCPMForSequenceClassification(MiniCPMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MiniCPMModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.') + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + + +# hack version +from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM + +class HackConvertMiniCPMForCausalLM: + def from_pretrained(model_path, **kwargs): + model = KVLlamaForCausalLM.from_pretrained(model_path, **kwargs) + + state_dict = model.state_dict() + scale_emb = 12 + dim_model_base = 256 + scale_depth = 1.4 + num_layers = 32 + hidden_size = 4096 + + print(state_dict["model.embed_tokens.weight"]) + embedding = state_dict["model.embed_tokens.weight"] + #model.embed_tokens.weight * scale_emb + new_emb = embedding.clone() * scale_emb + state_dict["model.embed_tokens.weight"] = new_emb + + #lm_head.weight / (hidden_size / dim_model_base) + new_emb = state_dict["lm_head.weight"].clone() / (hidden_size / dim_model_base) + state_dict["lm_head.weight"] = new_emb + + #model.layers.34.self_attn.o_proj.weight * (scale_depth / sqrt(num_layers)) + for i in range(num_layers): + attn_out_name = f"model.layers.{i}.self_attn.o_proj.weight" + new_weight = state_dict[attn_out_name] * (scale_depth / math.sqrt(num_layers)) + state_dict[attn_out_name] = new_weight + + ffn_down_proj_name = f"model.layers.{i}.mlp.down_proj.weight" + new_weight = state_dict[ffn_down_proj_name] * (scale_depth / math.sqrt(num_layers)) + state_dict[ffn_down_proj_name] = new_weight + + print(f"Converting: reload from converted state_dict.\nCheck sd:\n{model}") + + model.load_state_dict(state_dict) + print(f"Convert to llama: DONE.") + + return model diff --git a/eagle/model/modeling_mixtral_kv.py b/eagle/model/modeling_mixtral_kv.py new file mode 100644 index 0000000000000000000000000000000000000000..9081ee1657e901e690c8cfb760d1ae76f5c3b619 --- /dev/null +++ b/eagle/model/modeling_mixtral_kv.py @@ -0,0 +1,1199 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mixtral model.""" +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union +from .kv_cache import KVCache + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +# [MODIFIED] Import from transformer library +from transformers.activations import ACT2FN + +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers import MixtralConfig + + + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. + + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MixtralConfig" + + +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Create a causal mask for bi-directional self-attention. + + Args: + input_ids_shape (torch.Size): The shape of input_ids tensor, typically (batch_size, tgt_len). + dtype (torch.dtype): The data type of the mask. + device (torch.device): The device on which the mask will be placed. + past_key_values_length (int, optional): The length of past key values. Default is 0. + + Returns: + torch.Tensor: The causal mask tensor. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expand attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + + Args: + mask (torch.Tensor): The attention mask tensor of shape `[bsz, seq_len]`. + dtype (torch.dtype): The data type of the mask. + tgt_len (Optional[int], optional): The target sequence length. If None, it defaults to the source sequence length. + + Returns: + torch.Tensor: The expanded mask tensor. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_experts]. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None: + return 0 + + if isinstance(gate_logits, tuple): + # cat along the layers? + compute_device = gate_logits[0].device + gate_logits = torch.cat([gate.to(compute_device) for gate in gate_logits], dim=0) + + routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1) + routing_weights = routing_weights.softmax(dim=-1) + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if selected_experts.dtype != torch.int64: + selected_experts = selected_experts.to(torch.int64) + + if len(selected_experts.shape) == 2: + selected_experts = selected_experts.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1) + return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1)) * (num_experts**2) + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral +class MixtralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MixtralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral +class MixtralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral +class MixtralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = MixtralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[KVCache]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + key_states = past_key_value[0].cat(key_states, dim=2) + value_states = past_key_value[1].cat(value_states, dim=2) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral + + + +class MixtralBLockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states, routing_weights): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return routing_weights * current_hidden_states + + +MISTRAL_ATTENTION_CLASSES = { + "eager": MixtralAttention, +} + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class MixtralDecoderLayer(nn.Module): + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[KVCache]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +MIXTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MixtralConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral +class MixtralPreTrainedModel(PreTrainedModel): + config_class = MixtralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MixtralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MIXTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +class MixtralModel(MixtralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] + + Args: + config: MixtralConfig + """ + + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + # inputs_embeds.dtype, + torch.float32, # [MODIFIED] force to cast to float32 + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + + if hasattr(self, "tree_mask") and self.tree_mask is not None: + tree_mask = self.tree_mask + tree_len = tree_mask.size(-1) + combined_attention_mask[:, :, -tree_len:, -tree_len:][ + tree_mask == 0 + ] = combined_attention_mask.min() + + return combined_attention_mask + + # Ignore copy + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[Tuple[KVCache]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._use_flash_attention_2 and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + # if self._use_flash_attention_2: + # # 2d mask is passed through the layers + # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + # else: + # 4d mask is passed through the layers + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + + next_cache = next_decoder_cache if use_cache else None + # if use_cache: + # next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class MixtralForCausalLM(MixtralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MixtralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + + + + + +@add_start_docstrings( + """ + The Mixtral Model transformer with a sequence classification head on top (linear layer). + + [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL +class MixtralForSequenceClassification(MixtralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MixtralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/eagle/model/modeling_qwen2_kv.py b/eagle/model/modeling_qwen2_kv.py new file mode 100644 index 0000000000000000000000000000000000000000..9b8f69de8a07a234e81f2d4c370d6c9f46f43c81 --- /dev/null +++ b/eagle/model/modeling_qwen2_kv.py @@ -0,0 +1,1513 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta" +_CONFIG_FOR_DOC = "Qwen2Config" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 +class Qwen2RotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[Qwen2Config] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # key_states, value_states = past_key_value.cat(key_states, value_states, self.layer_idx) + past_key, past_value = past_key_value[self.layer_idx] + key_states = past_key.cat(key_states) + value_states = past_value.cat(value_states) + past_key_value = None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2FlashAttention2(Qwen2Attention): + """ + Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length[self.layer_idx][0].current_length.item() > 0 + kv_seq_len = key_states.shape[-2] + cache_position[0] + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2SdpaAttention(Qwen2Attention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # key_states, value_states = past_key_value.cat(key_states, value_states, self.layer_idx) + past_key, past_value = past_key_value[self.layer_idx] + key_states = past_key.cat(key_states) + value_states = past_value.cat(value_states) + past_key_value = None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_ATTENTION_CLASSES = { + "eager": Qwen2Attention, + "flash_attention_2": Qwen2FlashAttention2, + "sdpa": Qwen2SdpaAttention, +} + + +class Qwen2DecoderLayer(nn.Module): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + # self.self_attn = QWEN2_ATTENTION_CLASSES["eager"](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +QWEN2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2PreTrainedModel(PreTrainedModel): + config_class = Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +QWEN2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2Model(Qwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + # if use_cache and not isinstance(past_key_values, Cache): + # return_legacy_cache = True + # if past_key_values is None: + # past_key_values = DynamicCache() + # else: + # past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # logger.warning_once( + # "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + # "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + # "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + # ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values[0][0].current_length.item() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + self, + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if hasattr(self, "tree_mask") and self.tree_mask is not None: + tree_mask = self.tree_mask + tree_len = tree_mask.size(-1) + causal_mask[:, :, -tree_len:, -tree_len:][ + tree_mask == 0 + ] = causal_mask.min() + # causal_mask[:, :, -tree_len:, -tree_len:][ + # tree_mask == 1 + # ] = 0 + + return causal_mask + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + + past_seen_tokens = past_key_values[0][0].current_length.item() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0]:] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1]:] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a sequence classification head on top (linear layer). + + [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN2_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2 +class Qwen2ForTokenClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/eagle/model/utils.py b/eagle/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc64baec1c26f32b90382536e9e9835ee2d205d --- /dev/null +++ b/eagle/model/utils.py @@ -0,0 +1,481 @@ +import copy +import random + +# typing +from typing import List, Tuple +import time +import torch + +# TODO +# from transformers import LlamaTokenizer +# tokenizer=LlamaTokenizer.from_pretrained("/home/lyh/weights/hf/vicuna_v13/7B/") + +TOPK = 10 # topk for sparse tree + +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + + +class Timer: + def __init__(self,name): + self.name = name + def __enter__(self): + torch.cuda.synchronize() + self.start = time.perf_counter() + + + def __exit__(self, exc_type, exc_value, traceback): + torch.cuda.synchronize() + elapsed = time.perf_counter() - self.start + print(f'{self.name} took {elapsed} seconds') + + +def prepare_logits_processor( + temperature: float = 0.0, + repetition_penalty: float = 0.0, + top_p: float = 0.0, + top_k: int = 0 +) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + if temperature > 1e-5: + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list + + +# test_processor = prepare_logits_processor( +# 0.0, 0.0, -1, 1 +# ) + + +def pad_path(path: List[int], length: int, pad_value: int = -2) -> List[int]: + """ + Pad the given path list with a specific value up to a specified length. + + Parameters: + - path (list): The original list that needs padding. + - length (int): The desired length of the padded list. + - pad_value (optional, default=-2): The value to use for padding. + + Returns: + - list: A new list based on the original path but padded to the desired length. + + Example: + >>> pad_path([1,2,3], 5) + [1, 2, 3, -2, -2] + + Note: + If the given path is already longer than the specified length, + then no padding occurs, and the original path is returned. + """ + + # Calculate the number of padding values needed by subtracting the length + # of the path from the desired length. + # Append the padding values to the original path and return the new list. + return path + [pad_value] * (length - len(path)) + + +def generate_tree_buffers(tree_choices, device="cuda"): + def custom_sort(lst): + # sort_keys=[len(list)] + sort_keys = [] + for i in range(len(lst)): + sort_keys.append(lst[i] if lst[i] >= 0 else maxitem) + return sort_keys + with Timer("sort"): + + sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x)) + tree_len = len(sorted_tree_choices) + 1 + + # Initialize depth_counts to keep track of how many choices have a particular depth + depth_counts = [] + prev_depth = 0 + for path in sorted_tree_choices: + depth = len(path) + if depth != prev_depth: + depth_counts.append(0) + depth_counts[depth - 1] += 1 + prev_depth = depth + + tree_attn_mask = torch.eye(tree_len, tree_len) + tree_attn_mask[:, 0] = 1 + start = 0 + for i in range(len(depth_counts)): + for j in range(depth_counts[i]): + cur_tree_choice = sorted_tree_choices[start + j] + # retrieve ancestor position + if len(cur_tree_choice) == 1: + continue + ancestor_idx = [] + for c in range(len(cur_tree_choice) - 1): + ancestor_idx.append(sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) + tree_attn_mask[j + start + 1, ancestor_idx] = 1 + start += depth_counts[i] + + tree_indices = torch.zeros(tree_len, dtype=torch.long) + p_indices = [0 for _ in range(tree_len - 1)] + b_indices = [[] for _ in range(tree_len - 1)] + tree_indices[0] = 0 + start = 0 + bias = 0 + for i in range(len(depth_counts)): + inlayer_bias = 0 + b = [] + for j in range(depth_counts[i]): + cur_tree_choice = sorted_tree_choices[start + j] + cur_parent = cur_tree_choice[:-1] + if j != 0: + if cur_parent != parent: + bias += 1 + inlayer_bias += 1 + parent = cur_parent + b = [] + else: + parent = cur_parent + tree_indices[start + j + 1] = cur_tree_choice[-1] + TOPK * (i + bias) + 1 + p_indices[start + j] = inlayer_bias + if len(b) > 0: + b_indices[start + j] = copy.deepcopy(b) + else: + b_indices[start + j] = [] + b.append(cur_tree_choice[-1] + TOPK * (i + bias) + 1) + start += depth_counts[i] + + p_indices = [-1] + p_indices + tree_position_ids = torch.zeros(tree_len, dtype=torch.long) + start = 0 + for i in range(len(depth_counts)): + tree_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1 + start += depth_counts[i] + + retrieve_indices_nest = [] + retrieve_paths = [] + for i in range(len(sorted_tree_choices)): + cur_tree_choice = sorted_tree_choices[-i - 1] + retrieve_indice = [] + if cur_tree_choice in retrieve_paths: + continue + else: + for c in range(len(cur_tree_choice)): + retrieve_indice.append(sorted_tree_choices.index(cur_tree_choice[:c + 1])) + retrieve_paths.append(cur_tree_choice[:c + 1]) + retrieve_indices_nest.append(retrieve_indice) + max_length = max([len(x) for x in retrieve_indices_nest]) + retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest] + retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) + retrieve_indices = retrieve_indices + 1 + retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], + dim=1) + + maxitem = retrieve_indices.max().item() + 5 + + + + retrieve_indices = retrieve_indices.tolist() + retrieve_indices = sorted(retrieve_indices, key=custom_sort) + retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) + + + + # Aggregate the generated buffers into a dictionary + tree_buffers = { + "tree_attn_mask": tree_attn_mask.unsqueeze(0).unsqueeze(0), + "tree_indices": tree_indices, + "tree_position_ids": tree_position_ids, + "retrieve_indices": retrieve_indices, + } + + # Move the tensors in the dictionary to the specified device + tree_buffers = { + k: v.clone().to(device) + if isinstance(v, torch.Tensor) + else torch.tensor(v, device=device) + for k, v in tree_buffers.items() + } + + return tree_buffers + + +def initialize_tree0(input_ids, model, past_key_values, logits_processor): + draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, logits, hidden_state, sample_token = model( + input_ids, past_key_values=past_key_values, output_orig=True, logits_processor=logits_processor + ) + + # if logits_processor is not None: + # logits = orig[:, -1] + # logits = logits_processor(None, logits) + # probabilities = torch.nn.functional.softmax(logits, dim=1) + # token = torch.multinomial(probabilities, 1) + # else: + # token = torch.argmax(orig[:, -1]) + # token = token[None, None] + # input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) + # # Clone the output hidden states + # + # draft_tokens, retrieve_indices,tree_mask,tree_position_ids = self.ea_layer.topK_genrate(hidden_states, input_ids, self.base_model.lm_head) + # if output_orig: + # return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, orig, hidden_states, token + # return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, hidden_states, token + return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token + +def initialize_tree(input_ids, model, past_key_values, logits_processor): + outputs, orig, hidden_states = model( + input_ids, past_key_values=past_key_values, output_orig=True + ) + + if logits_processor is not None: + logits = orig[:, -1] + logits = logits_processor(None, logits) + probabilities = torch.nn.functional.softmax(logits, dim=1) + token = torch.multinomial(probabilities, 1) + else: + token = torch.argmax(orig[:, -1]) + token = token[None, None] + input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) + + # Clone the output hidden states + if model.use_eagle3: + ea_device = model.ea_layer.lm_head.weight.device + if outputs["hidden_states"][0].device != ea_device: + outputs["hidden_states"] = [x.to(ea_device) for x in outputs["hidden_states"]] + hidden_states=torch.cat(outputs["hidden_states"],dim=-1) + draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor) + return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token + + +def reset_tree_mode( + model, +): + model.base_model.model.tree_mask = None + model.base_model.model.tree_mode = None + + +def reset_past_key_values(passed_key_values: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Resets the current lengths in the passed key-values to zero. + + This function is designed to be used during the evaluation of a baseline model. + It iterates through each layer's key-values and sets their current lengths to zero, + effectively resetting their state. + + Args: + - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer. + + Returns: + - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths. + """ + for i in range(len(passed_key_values)): + for j in range(2): + passed_key_values[i][j].current_length.fill_(0) + return passed_key_values + + +def generate_candidates(tree_logits, tree_indices, retrieve_indices, sample_token, logits_processor): + sample_token = sample_token.to(tree_indices.device) + + candidates_logit = sample_token[0] + + candidates_tree_logits = tree_logits + + candidates = torch.cat([candidates_logit, candidates_tree_logits.view(-1)], dim=-1) + + tree_candidates = candidates[tree_indices] + + tree_candidates_ext = torch.cat( + [tree_candidates, torch.zeros((1), dtype=torch.long, device=tree_candidates.device) - 1], dim=0) + + cart_candidates = tree_candidates_ext[retrieve_indices] + + + # Unsqueeze the tree candidates for dimension consistency. + tree_candidates = tree_candidates.unsqueeze(0) + return cart_candidates, tree_candidates + + +def tree_decoding( + model, + tree_candidates, + past_key_values, + tree_position_ids, + input_ids, + retrieve_indices, +): + position_ids = tree_position_ids + input_ids.shape[1] + if position_ids is not None and position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + outputs, tree_logits, hidden_state = model( + tree_candidates, + output_orig=True, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + if model.use_eagle3: + ea_device = model.ea_layer.lm_head.weight.device + if outputs["hidden_states"][0].device != ea_device: + outputs["hidden_states"] = [x.to(ea_device) for x in outputs["hidden_states"]] + hidden_state = torch.cat(outputs["hidden_states"], dim=-1) + + logits = tree_logits[0, retrieve_indices] + return logits, hidden_state, outputs + + + + + +def evaluate_posterior( + logits: torch.Tensor, + candidates: torch.Tensor, + logits_processor, +): + """ + Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate. + + Depending on the temperature value, the function either uses greedy decoding or evaluates posterior + probabilities to select the best candidate. + + Args: + - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size). + - candidates (torch.Tensor): Candidate token sequences. + - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding. + - posterior_threshold (float): Threshold for posterior probability. + - posterior_alpha (float): Scaling factor for the threshold. + + Returns: + - best_candidate (torch.Tensor): Index of the chosen best candidate. + - accept_length (int): Length of the accepted candidate sequence. + """ + # Greedy decoding based on temperature value + if logits_processor is None: + # Find the tokens that match the maximum logits for each position in the sequence + posterior_mask = ( + candidates[:, 1:].to(logits.device) == torch.argmax(logits[:, :-1], dim=-1) + ).int() + candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) + accept_length = candidates_accept_length.max() + # Choose the best candidate + if accept_length == 0: + # Default to the first candidate if none are accepted + best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) + else: + best_candidate = torch.argmax(candidates_accept_length).to(torch.long) + return best_candidate, accept_length, logits[best_candidate, accept_length] + + else: + accept_length = 1 + accept_cand = candidates[0][:1] + best_candidate = 0 + for i in range(1, candidates.shape[1]): + if i != accept_length: + break + adjustflag = False + is_eq = (candidates[:, :accept_length] == accept_cand).all(dim=1) + fi = torch.nonzero(is_eq, as_tuple=True)[0][0] + gt_logits = logits[fi, i - 1][None] + gt_logits = logits_processor(None, gt_logits)[0] + gtp = torch.softmax(gt_logits, dim=0) + candidates_set = [] + for j in range(candidates.shape[0]): + if is_eq[j]: + x = candidates[j, i] + xi = x.item() + if xi in candidates_set or xi == -1: + continue + candidates_set.append(xi) + r = random.random() + px = gtp[xi] + qx = 1.0 + acp = px / qx + if r <= acp: + accept_cand = torch.cat((accept_cand, x[None]), dim=0) + accept_length += 1 + best_candidate = j + break + else: + gtp[xi] = 0 + gtp = gtp / gtp.sum() + adjustflag = True + if adjustflag and accept_length != candidates.shape[1]: + sample_p = gtp + else: + gt_logits = logits[best_candidate, accept_length - 1][None] + gt_logits = logits_processor(None, gt_logits)[0] + sample_p = torch.softmax(gt_logits, dim=0) + return torch.tensor(best_candidate), accept_length - 1, sample_p + + +@torch.no_grad() +def update_inference_inputs( + input_ids, + candidates, + best_candidate, + accept_length, + retrieve_indices, + logits_processor, + new_token, + past_key_values_data_list, + current_length_data, + model, + hidden_state_new, + sample_p +): + prev_input_len = input_ids.shape[1] + # Map the best candidate indices to the original indices in the sequence + select_indices = ( + retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len + ) + # Append the tokens from the best candidate to the input sequence + input_ids = torch.cat( + [input_ids, candidates[None, best_candidate, : accept_length + 1].to(input_ids.device)], dim=-1 + ) + # Update the past key values based on the selected tokens + # Source tensor that contains relevant past information based on the selected candidate + for past_key_values_data in past_key_values_data_list: + tgt = past_key_values_data[..., select_indices.to(past_key_values_data.device), :] + # Destination tensor where the relevant past information will be stored + dst = past_key_values_data[..., prev_input_len: prev_input_len + tgt.shape[-2], :] + # Copy relevant past information from the source to the destination + dst.copy_(tgt, non_blocking=True) + + # Update the current length tensor (currently only support batch size is 1) + current_length_data.fill_(prev_input_len + tgt.shape[-2]) + + retrieve_hidden_state_new = hidden_state_new[:, retrieve_indices] + accept_hidden_state_new = retrieve_hidden_state_new[:, best_candidate, : accept_length + 1] + # token=model.base_model.lm_head(accept_hidden_state_new[:,-1]).argmax() + # token=token[None,None] + prob = sample_p + if logits_processor is not None: + token = torch.multinomial(prob, 1) + token = token[None] + else: + token = torch.argmax(prob) + token = token[None, None] + # hidden_state = torch.cat((hidden_state, accept_hidden_state_new), dim=1) + draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(accept_hidden_state_new, + input_ids=torch.cat((input_ids, token.to(input_ids.device)), dim=1), + head=model.base_model.lm_head,logits_processor=logits_processor) + + + new_token += accept_length + 1 + + return input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, None, token + + +if __name__ == "__main__": + logits = torch.randn(1, 5) + tp = prepare_logits_processor(0.9, 0, 0.9, 0) + l = tp(None, logits) + if tp is None: + print(tp) diff --git a/eagle/model/utils_c.py b/eagle/model/utils_c.py new file mode 100644 index 0000000000000000000000000000000000000000..c837aa639ee45fbc068ee93fcd8ba5ff729ac03f --- /dev/null +++ b/eagle/model/utils_c.py @@ -0,0 +1,206 @@ +import torch + +# typing +from typing import List + +TOPK = 10 # topk for sparse tree + + +def pad_path(path: List[int], length: int, pad_value: int = -2) -> List[int]: + """ + Pad the given path list with a specific value up to a specified length. + + Parameters: + - path (list): The original list that needs padding. + - length (int): The desired length of the padded list. + - pad_value (optional, default=-2): The value to use for padding. + + Returns: + - list: A new list based on the original path but padded to the desired length. + + Example: + >>> pad_path([1,2,3], 5) + [1, 2, 3, -2, -2] + + Note: + If the given path is already longer than the specified length, + then no padding occurs, and the original path is returned. + """ + + # Calculate the number of padding values needed by subtracting the length + # of the path from the desired length. + # Append the padding values to the original path and return the new list. + return path + [pad_value] * (length - len(path)) + +class node: + def __init__(self,parent=None,value=None,dict_key=None): + self.parent=parent + self.value=value + if parent: + self.depth=parent.depth+1 + parent.children.append(self) + else: + self.depth=0 + self.children=[] + self.dict_key=dict_key + def is_leaf(self): + return len(self.children)==0 + + def all_index(self): + if not self.parent.parent: + return [self.index] + else: + return self.parent.all_index()+[self.index] + + + +class Tree: + def __init__(self,tree_list): + sorted_tree_list = sorted(tree_list, key=lambda x: (len(x), x)) + self.root=node() + self.node_dic={} + for tree_node in sorted_tree_list: + cur_value=tree_node[-1] + if len(tree_node)==1: + cur_node=node(parent=self.root,value=cur_value,dict_key=tuple(tree_node)) + else: + cur_parent=self.node_dic[tuple(tree_node[:-1])] + cur_node = node(parent=cur_parent, value=cur_value,dict_key=tuple(tree_node)) + self.node_dic[tuple(tree_node)] = cur_node + self.indexnode() + + def max_depth(self): + return max([item.depth for item in self.node_dic.values()]) + + def num_node_wchild(self): + num_c=0 + for item in self.node_dic.values(): + if not item.is_leaf(): + num_c+=1 + return num_c + + def get_node_wchild(self): + ns=[] + for item in self.node_dic.values(): + if not item.is_leaf(): + ns.append(item) + return ns + + def indexnode(self): + cur_index=0 + for key in self.node_dic: + cur_node=self.node_dic[key] + if not cur_node.is_leaf(): + cur_node.index=cur_index + cur_index+=1 + + + + +def generate_tree_buffers(tree_choices, device="cuda"): + tree=Tree(tree_choices) + sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x)) + tree_len = tree.num_node_wchild() + + + max_depth=tree.max_depth() + nodes_wc=tree.get_node_wchild() + + depth_counts=[0 for _ in range(max_depth-1)] + for x in nodes_wc: + depth_counts[x.depth-1]+=1 + depth_counts_sum = [sum(depth_counts[:i + 1]) for i in range(len(depth_counts))] + + + tree_attn_mask = torch.eye(tree_len, tree_len) + + for id,x in enumerate(nodes_wc): + tree_attn_mask[id,x.all_index()]=1 + + + + + tree_attn_mask_list0=[tree_attn_mask[:ml,:ml] for ml in depth_counts_sum] + tree_attn_mask_list=[] + for id,x in enumerate(tree_attn_mask_list0): + x=x[-depth_counts[id]:] + tree_attn_mask_list.append(x) + + + + tree_indices_list = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts] + repeat_nums=[[] for _ in depth_counts] + start = 0 + bias = 0 + for i in range(len(depth_counts)): + bias = 0 + repeat_j=0 + for j in range(depth_counts[i]): + cur_node = nodes_wc[start + j] + cur_parent = cur_node.parent + + if j != 0: + if cur_parent != parent: + bias += 1 + parent = cur_parent + repeat_nums[i].append(j-repeat_j) + repeat_j=j + else: + parent = cur_parent + tree_indices_list[i][j] = cur_node.value + TOPK * (bias) + repeat_nums[i].append(j - repeat_j+1) + start += depth_counts[i] + + position_ids = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts] + + # start = 0 + # for i in range(len(depth_counts)): + # position_ids[start: start + depth_counts[i]] = i + # start += depth_counts[i] + + tree_buffers = { + "attn_mask": [i.unsqueeze(0).unsqueeze(0) for i in tree_attn_mask_list], + "tree_indices": tree_indices_list, + "position_ids":position_ids, + "repeat_nums":repeat_nums + } + + # Move the tensors in the dictionary to the specified device + tree_buffers = { + k: [i.clone().to(device) for i in v] + if isinstance(v[0], torch.Tensor) + else ( + torch.tensor(v, device=device) + if isinstance(v, torch.Tensor) + else v + ) + for k, v in tree_buffers.items() + } + return tree_buffers + + +def reset_past_key_values(passed_key_values: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Resets the current lengths in the passed key-values to zero. + + This function is designed to be used during the evaluation of a baseline model. + It iterates through each layer's key-values and sets their current lengths to zero, + effectively resetting their state. + + Args: + - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer. + + Returns: + - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths. + """ + for i in range(len(passed_key_values)): + for j in range(2): + passed_key_values[i][j].current_length.fill_(0) + return passed_key_values + + + +if __name__=="__main__": + from choices import mc_sim_7b_63 + a=generate_tree_buffers(mc_sim_7b_63) + print(a) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1ca6d54f3fe0fedfaa5ab148aa1583086ecd5de0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +gradio +git+https://github.com/huggingface/transformers +torch +spaces +accelerate +tokenizers +numpy +Pillow +requests +sentencepiece +flash-attn \ No newline at end of file diff --git a/utils_chatbot.py b/utils_chatbot.py new file mode 100644 index 0000000000000000000000000000000000000000..6948d2515400a09521e8841ad83d227332a6ec58 --- /dev/null +++ b/utils_chatbot.py @@ -0,0 +1,63 @@ + +def organize_messages(message, history): + msg_ls = [dict( + role = "system", + content = "You are a helpful assistant.", + )] + for user, assistant in history: + msg_ls.append(dict( + role = "user", + content = user, + )) + if assistant: + msg_ls.append(dict( + role = "assistant", + content = assistant, + )) + msg_ls.append(dict( + role = "user", + content = message, + )) + return msg_ls + +def stream2display_text(stream_text, token_per_sec): + if stream_text.startswith("think>"): + stream_text = f"<{stream_text}" + + if not stream_text.startswith(""): + return stream_text + + if not "" in stream_text: + think_text, result_text = stream_text.replace("", ""), "" + else: + think_text, result_text = stream_text.split("") + think_text = think_text.replace("", "") + + result_text = result_text.replace("<|im_end|>", "") + + think_block = "\n".join(f"> {line}" if line else ">" for line in think_text.rstrip().splitlines()) + # display_text = f"{think_block}\n\n{result_text}" + + display_text_ls = [think_block] + if result_text: + display_text_ls.append(f"{result_text}") + display_text_ls.append(f"```{token_per_sec:.2f} token/s```") + + display_text = "\n\n".join(display_text_ls) + + return display_text + +def mtp_new_tokens(pred_ids, gen_tk_count, existing_tk_count, stop_token_ids): + output_ids = pred_ids[0][existing_tk_count:] + + if stop_token_ids: + stop_token_ids_index = [ + i + for i, id in enumerate(output_ids) + if id in stop_token_ids + ] + if len(stop_token_ids_index) > 0: + output_ids = output_ids[: stop_token_ids_index[0]] + new_tokens = output_ids[gen_tk_count:] + + return new_tokens, len(output_ids) \ No newline at end of file