import copy
import json
import time
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
import os
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig
from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
from .modeling_mixtral_kv import MixtralForCausalLM as KVMixtralForCausalLM
#from .modeling_qwen2_kv import LlamaForCausalLM as KVQwen2ForCausalLM
from .modeling_qwen2_kv import Qwen2ForCausalLM as KVQwen2ForCausalLM
from .utils import *
from .kv_cache import initialize_past_key_values
from .cnets import Model
from .cnets1 import Model as Model1
from .configs import EConfig
""" Modified to support Eagle-3, marked by xxx """
from .modeling_minicpm_kv import MiniCPMForCausalLM as KVMiniCPMForCausalLM # use modified opensource impl
class EaModel(nn.Module):
def __init__(
self,
use_eagle3,
base_model,
base_model_name_or_path,
ea_model_path,
total_token,
depth,
top_k,
threshold,
ea_layer_state_dict,
):
super().__init__()
self.base_model = base_model
self.config = base_model.config
self.hidden_size = base_model.lm_head.weight.shape[-1]
self.vocab_size = base_model.lm_head.weight.shape[0]
self.base_model_name_or_path = base_model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path, use_fast=False)
self.use_eagle3 = use_eagle3
config = EConfig.from_pretrained(ea_model_path)
with open(ea_model_path, "r") as f:
con = json.loads(f.read())
try:
bias = con["bias"]
except:
bias = True
if use_eagle3:
self.ea_layer = Model(config, bias=bias, total_tokens=total_token, depth=depth, top_k=top_k,
threshold=threshold, path=base_model_name_or_path,load_emb=True)
else:
self.ea_layer = Model1(config, bias=bias, total_tokens=total_token, depth=depth, top_k=top_k,
threshold=threshold, path=base_model_name_or_path,load_emb=True)
low_memory = False
device = base_model.model.layers[-1].self_attn.q_proj.weight.device
if device != base_model.lm_head.weight.device:
self.ea_layer.diff_device = True
if not low_memory:
self.ea_layer.headweight = base_model.lm_head.weight.clone().to(device)
else:
self.ea_layer.layer_device = device
else:
self.ea_layer.diff_device = False
if self.use_eagle3 and config.vocab_size==config.draft_vocab_size:
del self.ea_layer.d2t,self.ea_layer.t2d
load_=self.ea_layer.load_state_dict(ea_layer_state_dict, strict=False)
self.ea_layer.to(self.base_model.dtype).to(device)
self.ea_layer.init_tree()
def get_tokenizer(self):
"""Get the tokenizer of the base model.
Returns:
Tokenizer: The tokenizer of the base model.
"""
return self.tokenizer
@classmethod
def from_pretrained(
cls,
use_eagle3=True,
base_model_path=None,
ea_model_path=None,
total_token=60,
depth=7,
top_k=10,
threshold=1.0,
**kwargs,
):
# assert Type=="LLaMA" or "Mixtral"
Type = AutoConfig.from_pretrained(base_model_path, trust_remote_code=True).architectures[0]
if Type == 'LlamaForCausalLM':
base_model = KVLlamaForCausalLM.from_pretrained(
base_model_path, **kwargs
)
elif Type == 'Qwen2ForCausalLM':
base_model = KVQwen2ForCausalLM.from_pretrained(
base_model_path, **kwargs
)
elif Type == 'MiniCPMForCausalLM': # support MiniCPMForCausalLM
base_model = KVMiniCPMForCausalLM.from_pretrained(
base_model_path, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True,
)
#
else:
base_model = KVMixtralForCausalLM.from_pretrained(
base_model_path, **kwargs
)
#
# configpath = os.path.join(ea_model_path, "config.json")
# if not os.path.exists(configpath):
# configpath = hf_hub_download(ea_model_path, "config.json")
# try:
# load_model_path = os.path.join(ea_model_path, "pytorch_model.bin")
# if not os.path.exists(load_model_path):
# load_model_path = hf_hub_download(ea_model_path, "pytorch_model.bin")
# ea_layer_state_dict = torch.load(load_model_path,
# map_location=base_model.device)
# except:
# from safetensors.torch import load_file
# load_model_path = os.path.join(ea_model_path, "model.safetensors")
# if not os.path.exists(load_model_path):
# load_model_path = hf_hub_download(ea_model_path, "model.safetensors")
# ea_layer_state_dict = load_file(load_model_path)
# -------------------------------------------------
# ### new loading logic to support subfolder on hf api
try:
configpath = os.path.join(ea_model_path, "config.json")
load_model_path = os.path.join(ea_model_path, "pytorch_model.bin")
if not os.path.exists(configpath):
configpath = hf_hub_download(ea_model_path, "config.json")
if not os.path.exists(load_model_path):
load_model_path = hf_hub_download(ea_model_path, "pytorch_model.bin")
except:
folder_names = ea_model_path.split("/")
repo = "/".join(folder_names[:-1])
subfolder = folder_names[-1]
configpath = hf_hub_download(repo_id = repo, subfolder = subfolder, filename = "config.json")
load_model_path = hf_hub_download(repo_id = repo, subfolder = subfolder, filename = "pytorch_model.bin")
ea_layer_state_dict = torch.load(load_model_path, map_location=base_model.device)
#
model = cls(
use_eagle3,
base_model,
base_model_path,
configpath,
total_token,
depth,
top_k,
threshold,
ea_layer_state_dict
)
if total_token == -1:
device = model.base_model.model.layers[0].self_attn.q_proj.weight.device
cans = [40, 48, 50, 56, 60]
x = [1, 1.05, 1.07, 1.1, 1.13]
times = []
for i in range(len(cans)):
length = cans[i]
input_ids = torch.randint(0, model.config.vocab_size - 200, (1, length)).to(device)
torch.cuda.synchronize()
start_time = time.time()
for _ in range(20):
torch.cuda.synchronize()
with torch.no_grad():
outputs = model.base_model(input_ids)
torch.cuda.synchronize()
torch.cuda.synchronize()
end_time = time.time()
times.append((end_time - start_time) / x[i])
total_token = cans[times.index(min(times))]
model.ea_layer.total_tokens = total_token - 1
return model
def forward(
self,
input_ids=None,
attention_mask=None,
past_key_values=None,
output_orig=False,
position_ids=None,
):
with torch.inference_mode():
# Pass input through the base model
outputs = self.base_model.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
if output_orig:
orig = self.base_model.lm_head(outputs[0])
hidden_states = outputs[0]
if output_orig:
return outputs, orig, hidden_states
else:
return outputs, hidden_states
@torch.no_grad()
def eagenerate(
self,
input_ids,
temperature=0.0,
top_p=0.0,
top_k=0.0,
max_new_tokens=512,
max_length=2048,
log=False,
is_llama3=False,
):
if is_llama3:
stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
if temperature > 1e-5:
logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
else:
logits_processor = None
# assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
# Avoid modifying the input_ids in-place
padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
input_ids = input_ids.clone()
self.ea_layer.reset_kv()
# Initialize the past key and value states
if hasattr(self, "past_key_values"):
past_key_values = self.past_key_values
past_key_values_data = self.past_key_values_data
current_length_data = self.current_length_data
# Reset the past key and value states
current_length_data.zero_()
else:
(
past_key_values,
past_key_values_data,
current_length_data,
) = initialize_past_key_values(self.base_model,max_length=max_length)
self.past_key_values = past_key_values
self.past_key_values_data = past_key_values_data
self.current_length_data = current_length_data
input_len = input_ids.shape[1]
reset_tree_mode(self)
# prefill
draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token = initialize_tree(
input_ids, self, past_key_values, logits_processor
)
new_token = 0
max_length = max_length - self.ea_layer.total_tokens - 10
for idx in range(max_length):
# with Timer("all"):
self.base_model.model.tree_mask = tree_mask
draft_tokens = draft_tokens.to(input_ids.device)
# Target model forward, get logits
logits, hidden_state_new, outputs = tree_decoding(
self,
draft_tokens,
past_key_values,
tree_position_ids,
input_ids,
retrieve_indices,
)
# retrieve_indices=tree_buffers["retrieve_indices"]
# logits = logits[0, retrieve_indices]
draft_tokens = torch.cat((draft_tokens, padding), dim=1)
candidates = draft_tokens[0, retrieve_indices]
# verification
best_candidate, accept_length, sample_p = evaluate_posterior(
logits, candidates, logits_processor
)
# print(accept_length)
# Adjusting the input sequence, draft model forward
input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs(
input_ids,
candidates,
best_candidate,
accept_length,
retrieve_indices,
logits_processor,
new_token,
past_key_values_data,
current_length_data,
self,
hidden_state_new,
sample_p
)
if is_llama3:
if stop_token_id in input_ids[0, input_len:].tolist():
break
if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
break
if new_token > max_new_tokens:
break
if input_ids.shape[1] > max_length:
break
if not log:
return input_ids
else:
return input_ids, new_token, idx
@torch.no_grad()
def naivegenerate(
self,
input_ids,
temperature=0.0,
top_p=0.0,
top_k=0.0,
max_new_tokens=512,
max_length=2048,
log=False,
is_llama3=False,
):
if is_llama3:
stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
if temperature > 1e-5:
logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
else:
logits_processor = None
# assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
# Avoid modifying the input_ids in-place
padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
input_ids = input_ids.clone()
self.ea_layer.reset_kv()
# Initialize the past key and value states
if hasattr(self, "past_key_values"):
past_key_values = self.past_key_values
past_key_values_data = self.past_key_values_data
current_length_data = self.current_length_data
# Reset the past key and value states
current_length_data.zero_()
else:
(
past_key_values,
past_key_values_data,
current_length_data,
) = initialize_past_key_values(self.base_model,max_length=max_length)
self.past_key_values = past_key_values
self.past_key_values_data = past_key_values_data
self.current_length_data = current_length_data
input_len = input_ids.shape[1]
reset_tree_mode(self)
outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True)
new_token = 0
max_length = max_length - self.ea_layer.total_tokens - 10
for idx in range(max_length):
if logits_processor is not None:
logits = outputs.logits[:, -1]
logits = logits_processor(None, logits)
probabilities = torch.nn.functional.softmax(logits, dim=-1)
input_id = torch.multinomial(probabilities, 1)
else:
input_id = outputs.logits[:, -1:].argmax(dim=-1)
outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values)
input_ids = torch.cat([input_ids, input_id], dim=-1)
new_token += 1
if is_llama3:
if stop_token_id in input_ids[0, input_len:].tolist():
break
if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
break
if new_token > max_new_tokens:
break
if input_ids.shape[1] > max_length:
break
if not log:
return input_ids
else:
return input_ids, new_token, idx
@torch.no_grad()
def ea_generate(
self,
input_ids,
temperature=0.0,
top_p=0.0,
top_k=0.0,
max_new_tokens=512,
max_length=2048,
log=False,
is_llama3=False,
):
if is_llama3:
stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
if temperature > 1e-5:
logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
else:
logits_processor = None
# assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
# Avoid modifying the input_ids in-place
padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
input_ids = input_ids.clone()
self.ea_layer.reset_kv()
# Initialize the past key and value states
if hasattr(self, "past_key_values"):
past_key_values = self.past_key_values
past_key_values_data = self.past_key_values_data
current_length_data = self.current_length_data
# Reset the past key and value states
current_length_data.zero_()
else:
(
past_key_values,
past_key_values_data,
current_length_data,
) = initialize_past_key_values(self.base_model,max_length=max_length)
self.past_key_values = past_key_values
self.past_key_values_data = past_key_values_data
self.current_length_data = current_length_data
input_len = input_ids.shape[1]
reset_tree_mode(self)
draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token = initialize_tree(
input_ids, self, past_key_values, logits_processor
)
new_token = 0
max_length = max_length - self.ea_layer.total_tokens - 10
for idx in range(max_length):
# with Timer("all"):
self.base_model.model.tree_mask = tree_mask
draft_tokens = draft_tokens.to(input_ids.device)
# with Timer("tree_decoding"):
logits, hidden_state_new, outputs = tree_decoding(
self,
draft_tokens,
past_key_values,
tree_position_ids,
input_ids,
retrieve_indices,
)
# retrieve_indices=tree_buffers["retrieve_indices"]
# logits = logits[0, retrieve_indices]
draft_tokens = torch.cat((draft_tokens, padding), dim=1)
candidates = draft_tokens[0, retrieve_indices]
best_candidate, accept_length, sample_p = evaluate_posterior(
logits, candidates, logits_processor
)
# print(accept_length)
# with Timer("update_inference_inputs"):
input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs(
input_ids,
candidates,
best_candidate,
accept_length,
retrieve_indices,
logits_processor,
new_token,
past_key_values_data,
current_length_data,
self,
hidden_state_new,
sample_p
)
yield input_ids
if is_llama3:
if stop_token_id in input_ids[0, input_len:].tolist():
break
if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
break
if new_token > max_new_tokens:
break
if input_ids.shape[1] > max_length:
break
@torch.no_grad()
def naive_generate(
self,
input_ids,
temperature=0.0,
top_p=0.0,
top_k=0.0,
max_new_tokens=512,
max_length=2048,
log=False,
is_llama3=False,
):
if is_llama3:
stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
if temperature > 1e-5:
logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
else:
logits_processor = None
# assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
# Avoid modifying the input_ids in-place
padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
input_ids = input_ids.clone()
self.ea_layer.reset_kv()
# Initialize the past key and value states
if hasattr(self, "past_key_values"):
past_key_values = self.past_key_values
past_key_values_data = self.past_key_values_data
current_length_data = self.current_length_data
# Reset the past key and value states
current_length_data.zero_()
else:
(
past_key_values,
past_key_values_data,
current_length_data,
) = initialize_past_key_values(self.base_model,max_length=max_length)
self.past_key_values = past_key_values
self.past_key_values_data = past_key_values_data
self.current_length_data = current_length_data
input_len = input_ids.shape[1]
reset_tree_mode(self)
outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True)
new_token = 0
max_length = max_length - self.ea_layer.total_tokens - 10
for idx in range(max_length):
if logits_processor is not None:
logits = outputs.logits[:, -1]
logits = logits_processor(None, logits)
probabilities = torch.nn.functional.softmax(logits, dim=-1)
input_id = torch.multinomial(probabilities, 1)
else:
input_id = outputs.logits[:, -1:].argmax(dim=-1)
outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values)
input_ids = torch.cat([input_ids, input_id], dim=-1)
new_token += 1
yield input_ids
if is_llama3:
if stop_token_id in input_ids[0, input_len:].tolist():
break
if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
break
if new_token > max_new_tokens:
break
if input_ids.shape[1] > max_length:
break