import copy import json import time import torch import torch.nn as nn from huggingface_hub import hf_hub_download from transformers import AutoTokenizer import os from transformers import PreTrainedModel, PretrainedConfig, AutoConfig from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM from .modeling_mixtral_kv import MixtralForCausalLM as KVMixtralForCausalLM #from .modeling_qwen2_kv import LlamaForCausalLM as KVQwen2ForCausalLM from .modeling_qwen2_kv import Qwen2ForCausalLM as KVQwen2ForCausalLM from .utils import * from .kv_cache import initialize_past_key_values from .cnets import Model from .cnets1 import Model as Model1 from .configs import EConfig """ Modified to support Eagle-3, marked by xxx """ from .modeling_minicpm_kv import MiniCPMForCausalLM as KVMiniCPMForCausalLM # use modified opensource impl class EaModel(nn.Module): def __init__( self, use_eagle3, base_model, base_model_name_or_path, ea_model_path, total_token, depth, top_k, threshold, ea_layer_state_dict, ): super().__init__() self.base_model = base_model self.config = base_model.config self.hidden_size = base_model.lm_head.weight.shape[-1] self.vocab_size = base_model.lm_head.weight.shape[0] self.base_model_name_or_path = base_model_name_or_path self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path, use_fast=False) self.use_eagle3 = use_eagle3 config = EConfig.from_pretrained(ea_model_path) with open(ea_model_path, "r") as f: con = json.loads(f.read()) try: bias = con["bias"] except: bias = True if use_eagle3: self.ea_layer = Model(config, bias=bias, total_tokens=total_token, depth=depth, top_k=top_k, threshold=threshold, path=base_model_name_or_path,load_emb=True) else: self.ea_layer = Model1(config, bias=bias, total_tokens=total_token, depth=depth, top_k=top_k, threshold=threshold, path=base_model_name_or_path,load_emb=True) low_memory = False device = base_model.model.layers[-1].self_attn.q_proj.weight.device if device != base_model.lm_head.weight.device: self.ea_layer.diff_device = True if not low_memory: self.ea_layer.headweight = base_model.lm_head.weight.clone().to(device) else: self.ea_layer.layer_device = device else: self.ea_layer.diff_device = False if self.use_eagle3 and config.vocab_size==config.draft_vocab_size: del self.ea_layer.d2t,self.ea_layer.t2d load_=self.ea_layer.load_state_dict(ea_layer_state_dict, strict=False) self.ea_layer.to(self.base_model.dtype).to(device) self.ea_layer.init_tree() def get_tokenizer(self): """Get the tokenizer of the base model. Returns: Tokenizer: The tokenizer of the base model. """ return self.tokenizer @classmethod def from_pretrained( cls, use_eagle3=True, base_model_path=None, ea_model_path=None, total_token=60, depth=7, top_k=10, threshold=1.0, **kwargs, ): # assert Type=="LLaMA" or "Mixtral" Type = AutoConfig.from_pretrained(base_model_path, trust_remote_code=True).architectures[0] if Type == 'LlamaForCausalLM': base_model = KVLlamaForCausalLM.from_pretrained( base_model_path, **kwargs ) elif Type == 'Qwen2ForCausalLM': base_model = KVQwen2ForCausalLM.from_pretrained( base_model_path, **kwargs ) elif Type == 'MiniCPMForCausalLM': # support MiniCPMForCausalLM base_model = KVMiniCPMForCausalLM.from_pretrained( base_model_path, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True, ) # else: base_model = KVMixtralForCausalLM.from_pretrained( base_model_path, **kwargs ) # # configpath = os.path.join(ea_model_path, "config.json") # if not os.path.exists(configpath): # configpath = hf_hub_download(ea_model_path, "config.json") # try: # load_model_path = os.path.join(ea_model_path, "pytorch_model.bin") # if not os.path.exists(load_model_path): # load_model_path = hf_hub_download(ea_model_path, "pytorch_model.bin") # ea_layer_state_dict = torch.load(load_model_path, # map_location=base_model.device) # except: # from safetensors.torch import load_file # load_model_path = os.path.join(ea_model_path, "model.safetensors") # if not os.path.exists(load_model_path): # load_model_path = hf_hub_download(ea_model_path, "model.safetensors") # ea_layer_state_dict = load_file(load_model_path) # ------------------------------------------------- # ### new loading logic to support subfolder on hf api try: configpath = os.path.join(ea_model_path, "config.json") load_model_path = os.path.join(ea_model_path, "pytorch_model.bin") if not os.path.exists(configpath): configpath = hf_hub_download(ea_model_path, "config.json") if not os.path.exists(load_model_path): load_model_path = hf_hub_download(ea_model_path, "pytorch_model.bin") except: folder_names = ea_model_path.split("/") repo = "/".join(folder_names[:-1]) subfolder = folder_names[-1] configpath = hf_hub_download(repo_id = repo, subfolder = subfolder, filename = "config.json") load_model_path = hf_hub_download(repo_id = repo, subfolder = subfolder, filename = "pytorch_model.bin") ea_layer_state_dict = torch.load(load_model_path, map_location=base_model.device) # model = cls( use_eagle3, base_model, base_model_path, configpath, total_token, depth, top_k, threshold, ea_layer_state_dict ) if total_token == -1: device = model.base_model.model.layers[0].self_attn.q_proj.weight.device cans = [40, 48, 50, 56, 60] x = [1, 1.05, 1.07, 1.1, 1.13] times = [] for i in range(len(cans)): length = cans[i] input_ids = torch.randint(0, model.config.vocab_size - 200, (1, length)).to(device) torch.cuda.synchronize() start_time = time.time() for _ in range(20): torch.cuda.synchronize() with torch.no_grad(): outputs = model.base_model(input_ids) torch.cuda.synchronize() torch.cuda.synchronize() end_time = time.time() times.append((end_time - start_time) / x[i]) total_token = cans[times.index(min(times))] model.ea_layer.total_tokens = total_token - 1 return model def forward( self, input_ids=None, attention_mask=None, past_key_values=None, output_orig=False, position_ids=None, ): with torch.inference_mode(): # Pass input through the base model outputs = self.base_model.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, position_ids=position_ids, ) if output_orig: orig = self.base_model.lm_head(outputs[0]) hidden_states = outputs[0] if output_orig: return outputs, orig, hidden_states else: return outputs, hidden_states @torch.no_grad() def eagenerate( self, input_ids, temperature=0.0, top_p=0.0, top_k=0.0, max_new_tokens=512, max_length=2048, log=False, is_llama3=False, ): if is_llama3: stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") if temperature > 1e-5: logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) else: logits_processor = None # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" # Avoid modifying the input_ids in-place padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) input_ids = input_ids.clone() self.ea_layer.reset_kv() # Initialize the past key and value states if hasattr(self, "past_key_values"): past_key_values = self.past_key_values past_key_values_data = self.past_key_values_data current_length_data = self.current_length_data # Reset the past key and value states current_length_data.zero_() else: ( past_key_values, past_key_values_data, current_length_data, ) = initialize_past_key_values(self.base_model,max_length=max_length) self.past_key_values = past_key_values self.past_key_values_data = past_key_values_data self.current_length_data = current_length_data input_len = input_ids.shape[1] reset_tree_mode(self) # prefill draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token = initialize_tree( input_ids, self, past_key_values, logits_processor ) new_token = 0 max_length = max_length - self.ea_layer.total_tokens - 10 for idx in range(max_length): # with Timer("all"): self.base_model.model.tree_mask = tree_mask draft_tokens = draft_tokens.to(input_ids.device) # Target model forward, get logits logits, hidden_state_new, outputs = tree_decoding( self, draft_tokens, past_key_values, tree_position_ids, input_ids, retrieve_indices, ) # retrieve_indices=tree_buffers["retrieve_indices"] # logits = logits[0, retrieve_indices] draft_tokens = torch.cat((draft_tokens, padding), dim=1) candidates = draft_tokens[0, retrieve_indices] # verification best_candidate, accept_length, sample_p = evaluate_posterior( logits, candidates, logits_processor ) # print(accept_length) # Adjusting the input sequence, draft model forward input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs( input_ids, candidates, best_candidate, accept_length, retrieve_indices, logits_processor, new_token, past_key_values_data, current_length_data, self, hidden_state_new, sample_p ) if is_llama3: if stop_token_id in input_ids[0, input_len:].tolist(): break if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): break if new_token > max_new_tokens: break if input_ids.shape[1] > max_length: break if not log: return input_ids else: return input_ids, new_token, idx @torch.no_grad() def naivegenerate( self, input_ids, temperature=0.0, top_p=0.0, top_k=0.0, max_new_tokens=512, max_length=2048, log=False, is_llama3=False, ): if is_llama3: stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") if temperature > 1e-5: logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) else: logits_processor = None # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" # Avoid modifying the input_ids in-place padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) input_ids = input_ids.clone() self.ea_layer.reset_kv() # Initialize the past key and value states if hasattr(self, "past_key_values"): past_key_values = self.past_key_values past_key_values_data = self.past_key_values_data current_length_data = self.current_length_data # Reset the past key and value states current_length_data.zero_() else: ( past_key_values, past_key_values_data, current_length_data, ) = initialize_past_key_values(self.base_model,max_length=max_length) self.past_key_values = past_key_values self.past_key_values_data = past_key_values_data self.current_length_data = current_length_data input_len = input_ids.shape[1] reset_tree_mode(self) outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True) new_token = 0 max_length = max_length - self.ea_layer.total_tokens - 10 for idx in range(max_length): if logits_processor is not None: logits = outputs.logits[:, -1] logits = logits_processor(None, logits) probabilities = torch.nn.functional.softmax(logits, dim=-1) input_id = torch.multinomial(probabilities, 1) else: input_id = outputs.logits[:, -1:].argmax(dim=-1) outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values) input_ids = torch.cat([input_ids, input_id], dim=-1) new_token += 1 if is_llama3: if stop_token_id in input_ids[0, input_len:].tolist(): break if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): break if new_token > max_new_tokens: break if input_ids.shape[1] > max_length: break if not log: return input_ids else: return input_ids, new_token, idx @torch.no_grad() def ea_generate( self, input_ids, temperature=0.0, top_p=0.0, top_k=0.0, max_new_tokens=512, max_length=2048, log=False, is_llama3=False, ): if is_llama3: stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") if temperature > 1e-5: logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) else: logits_processor = None # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" # Avoid modifying the input_ids in-place padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) input_ids = input_ids.clone() self.ea_layer.reset_kv() # Initialize the past key and value states if hasattr(self, "past_key_values"): past_key_values = self.past_key_values past_key_values_data = self.past_key_values_data current_length_data = self.current_length_data # Reset the past key and value states current_length_data.zero_() else: ( past_key_values, past_key_values_data, current_length_data, ) = initialize_past_key_values(self.base_model,max_length=max_length) self.past_key_values = past_key_values self.past_key_values_data = past_key_values_data self.current_length_data = current_length_data input_len = input_ids.shape[1] reset_tree_mode(self) draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token = initialize_tree( input_ids, self, past_key_values, logits_processor ) new_token = 0 max_length = max_length - self.ea_layer.total_tokens - 10 for idx in range(max_length): # with Timer("all"): self.base_model.model.tree_mask = tree_mask draft_tokens = draft_tokens.to(input_ids.device) # with Timer("tree_decoding"): logits, hidden_state_new, outputs = tree_decoding( self, draft_tokens, past_key_values, tree_position_ids, input_ids, retrieve_indices, ) # retrieve_indices=tree_buffers["retrieve_indices"] # logits = logits[0, retrieve_indices] draft_tokens = torch.cat((draft_tokens, padding), dim=1) candidates = draft_tokens[0, retrieve_indices] best_candidate, accept_length, sample_p = evaluate_posterior( logits, candidates, logits_processor ) # print(accept_length) # with Timer("update_inference_inputs"): input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs( input_ids, candidates, best_candidate, accept_length, retrieve_indices, logits_processor, new_token, past_key_values_data, current_length_data, self, hidden_state_new, sample_p ) yield input_ids if is_llama3: if stop_token_id in input_ids[0, input_len:].tolist(): break if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): break if new_token > max_new_tokens: break if input_ids.shape[1] > max_length: break @torch.no_grad() def naive_generate( self, input_ids, temperature=0.0, top_p=0.0, top_k=0.0, max_new_tokens=512, max_length=2048, log=False, is_llama3=False, ): if is_llama3: stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") if temperature > 1e-5: logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) else: logits_processor = None # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" # Avoid modifying the input_ids in-place padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) input_ids = input_ids.clone() self.ea_layer.reset_kv() # Initialize the past key and value states if hasattr(self, "past_key_values"): past_key_values = self.past_key_values past_key_values_data = self.past_key_values_data current_length_data = self.current_length_data # Reset the past key and value states current_length_data.zero_() else: ( past_key_values, past_key_values_data, current_length_data, ) = initialize_past_key_values(self.base_model,max_length=max_length) self.past_key_values = past_key_values self.past_key_values_data = past_key_values_data self.current_length_data = current_length_data input_len = input_ids.shape[1] reset_tree_mode(self) outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True) new_token = 0 max_length = max_length - self.ea_layer.total_tokens - 10 for idx in range(max_length): if logits_processor is not None: logits = outputs.logits[:, -1] logits = logits_processor(None, logits) probabilities = torch.nn.functional.softmax(logits, dim=-1) input_id = torch.multinomial(probabilities, 1) else: input_id = outputs.logits[:, -1:].argmax(dim=-1) outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values) input_ids = torch.cat([input_ids, input_id], dim=-1) new_token += 1 yield input_ids if is_llama3: if stop_token_id in input_ids[0, input_len:].tolist(): break if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): break if new_token > max_new_tokens: break if input_ids.shape[1] > max_length: break