| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						Copied from https://github.com/lm-sys/FastChat. | 
					
					
						
						| 
							 | 
						Later we will contribute our changes into it. | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						import dataclasses | 
					
					
						
						| 
							 | 
						from enum import auto, IntEnum | 
					
					
						
						| 
							 | 
						from typing import List, Any, Dict | 
					
					
						
						| 
							 | 
						import math | 
					
					
						
						| 
							 | 
						from typing import List, Optional, Tuple, Union | 
					
					
						
						| 
							 | 
						import random | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.utils.checkpoint | 
					
					
						
						| 
							 | 
						from torch import nn | 
					
					
						
						| 
							 | 
						from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from transformers.activations import ACT2FN | 
					
					
						
						| 
							 | 
						from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast | 
					
					
						
						| 
							 | 
						from transformers.modeling_utils import PreTrainedModel | 
					
					
						
						| 
							 | 
						from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings | 
					
					
						
						| 
							 | 
						from transformers import ( | 
					
					
						
						| 
							 | 
						    LogitsProcessorList, | 
					
					
						
						| 
							 | 
						    MinLengthLogitsProcessor, | 
					
					
						
						| 
							 | 
						    TopKLogitsWarper, | 
					
					
						
						| 
							 | 
						    TemperatureLogitsWarper, | 
					
					
						
						| 
							 | 
						    TopPLogitsWarper, | 
					
					
						
						| 
							 | 
						    StoppingCriteriaList, | 
					
					
						
						| 
							 | 
						    MaxLengthCriteria, | 
					
					
						
						| 
							 | 
						    BitsAndBytesConfig, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class SeparatorStyle(IntEnum): | 
					
					
						
						| 
							 | 
						    """Separator styles.""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    ADD_COLON_SINGLE = auto() | 
					
					
						
						| 
							 | 
						    ADD_COLON_TWO = auto() | 
					
					
						
						| 
							 | 
						    ADD_COLON_SPACE_SINGLE = auto() | 
					
					
						
						| 
							 | 
						    NO_COLON_SINGLE = auto() | 
					
					
						
						| 
							 | 
						    NO_COLON_TWO = auto() | 
					
					
						
						| 
							 | 
						    ADD_NEW_LINE_SINGLE = auto() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@dataclasses.dataclass | 
					
					
						
						| 
							 | 
						class Conversation: | 
					
					
						
						| 
							 | 
						    """A class that manages prompt templates and keeps all conversation history.""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    name: str | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    system_template: str = "{system_message}" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    system_message: str = "" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    roles: List[str] = (("USER", "ASSISTANT"),) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    messages: List[List[str]] = () | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    offset: int = 0 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE | 
					
					
						
						| 
							 | 
						    sep: str = "\n" | 
					
					
						
						| 
							 | 
						    sep2: str = None | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    stop_str: str = None | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    stop_token_ids: List[int] = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_prompt(self) -> str: | 
					
					
						
						| 
							 | 
						        """Get the prompt for generation.""" | 
					
					
						
						| 
							 | 
						        system_prompt = self.system_template.format(system_message=self.system_message) | 
					
					
						
						| 
							 | 
						        if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: | 
					
					
						
						| 
							 | 
						            ret = system_prompt + self.sep | 
					
					
						
						| 
							 | 
						            for role, message in self.messages: | 
					
					
						
						| 
							 | 
						                if message: | 
					
					
						
						| 
							 | 
						                    ret += role + ": " + message + self.sep | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    ret += role + ":" | 
					
					
						
						| 
							 | 
						            return ret | 
					
					
						
						| 
							 | 
						        elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: | 
					
					
						
						| 
							 | 
						            seps = [self.sep, self.sep2] | 
					
					
						
						| 
							 | 
						            ret = system_prompt + seps[0] | 
					
					
						
						| 
							 | 
						            for i, (role, message) in enumerate(self.messages): | 
					
					
						
						| 
							 | 
						                if message: | 
					
					
						
						| 
							 | 
						                    ret += role + ": " + message + seps[i % 2] | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    ret += role + ":" | 
					
					
						
						| 
							 | 
						            return ret | 
					
					
						
						| 
							 | 
						        elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: | 
					
					
						
						| 
							 | 
						            ret = system_prompt + self.sep | 
					
					
						
						| 
							 | 
						            for role, message in self.messages: | 
					
					
						
						| 
							 | 
						                if message: | 
					
					
						
						| 
							 | 
						                    ret += role + ": " + message + self.sep | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    ret += role + ": "   | 
					
					
						
						| 
							 | 
						            return ret | 
					
					
						
						| 
							 | 
						        elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: | 
					
					
						
						| 
							 | 
						            ret = "" if system_prompt == "" else system_prompt + self.sep | 
					
					
						
						| 
							 | 
						            for role, message in self.messages: | 
					
					
						
						| 
							 | 
						                if message: | 
					
					
						
						| 
							 | 
						                    ret += role + "\n" + message + self.sep | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    ret += role + "\n" | 
					
					
						
						| 
							 | 
						            return ret | 
					
					
						
						| 
							 | 
						        elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: | 
					
					
						
						| 
							 | 
						            ret = system_prompt | 
					
					
						
						| 
							 | 
						            for role, message in self.messages: | 
					
					
						
						| 
							 | 
						                if message: | 
					
					
						
						| 
							 | 
						                    ret += role + message + self.sep | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    ret += role | 
					
					
						
						| 
							 | 
						            return ret | 
					
					
						
						| 
							 | 
						        elif self.sep_style == SeparatorStyle.NO_COLON_TWO: | 
					
					
						
						| 
							 | 
						            seps = [self.sep, self.sep2] | 
					
					
						
						| 
							 | 
						            ret = system_prompt | 
					
					
						
						| 
							 | 
						            for i, (role, message) in enumerate(self.messages): | 
					
					
						
						| 
							 | 
						                if message: | 
					
					
						
						| 
							 | 
						                    ret += role + message + seps[i % 2] | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    ret += role | 
					
					
						
						| 
							 | 
						            return ret | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def set_system_message(self, system_message: str): | 
					
					
						
						| 
							 | 
						        """Set the system message.""" | 
					
					
						
						| 
							 | 
						        self.system_message = system_message | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def append_message(self, role: str, message: str): | 
					
					
						
						| 
							 | 
						        """Append a new message.""" | 
					
					
						
						| 
							 | 
						        self.messages.append([role, message]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def update_last_message(self, message: str): | 
					
					
						
						| 
							 | 
						        """Update the last output. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        The last message is typically set to be None when constructing the prompt, | 
					
					
						
						| 
							 | 
						        so we need to update it in-place after getting the response from a model. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self.messages[-1][1] = message | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def copy(self): | 
					
					
						
						| 
							 | 
						        return Conversation( | 
					
					
						
						| 
							 | 
						            name=self.name, | 
					
					
						
						| 
							 | 
						            system_template=self.system_template, | 
					
					
						
						| 
							 | 
						            system_message=self.system_message, | 
					
					
						
						| 
							 | 
						            roles=self.roles, | 
					
					
						
						| 
							 | 
						            messages=[[x, y] for x, y in self.messages], | 
					
					
						
						| 
							 | 
						            offset=self.offset, | 
					
					
						
						| 
							 | 
						            sep_style=self.sep_style, | 
					
					
						
						| 
							 | 
						            sep=self.sep, | 
					
					
						
						| 
							 | 
						            sep2=self.sep2, | 
					
					
						
						| 
							 | 
						            stop_str=self.stop_str, | 
					
					
						
						| 
							 | 
						            stop_token_ids=self.stop_token_ids, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def dict(self): | 
					
					
						
						| 
							 | 
						        return { | 
					
					
						
						| 
							 | 
						            "template_name": self.name, | 
					
					
						
						| 
							 | 
						            "system_message": self.system_message, | 
					
					
						
						| 
							 | 
						            "roles": self.roles, | 
					
					
						
						| 
							 | 
						            "messages": self.messages, | 
					
					
						
						| 
							 | 
						            "offset": self.offset, | 
					
					
						
						| 
							 | 
						        } | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						conv_templates: Dict[str, Conversation] = {} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def register_conv_template(template: Conversation, override: bool = False): | 
					
					
						
						| 
							 | 
						    """Register a new conversation template.""" | 
					
					
						
						| 
							 | 
						    if not override: | 
					
					
						
						| 
							 | 
						        assert ( | 
					
					
						
						| 
							 | 
						            template.name not in conv_templates | 
					
					
						
						| 
							 | 
						        ), f"{template.name} has been registered." | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    conv_templates[template.name] = template | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_conv_template(name: str) -> Conversation: | 
					
					
						
						| 
							 | 
						    """Get a conversation template.""" | 
					
					
						
						| 
							 | 
						    return conv_templates[name].copy() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_conversation_template(model_path: str) -> Conversation: | 
					
					
						
						| 
							 | 
						    """Get the default conversation template.""" | 
					
					
						
						| 
							 | 
						    if "aquila-v1" in model_path: | 
					
					
						
						| 
							 | 
						        return get_conv_template("aquila-v1") | 
					
					
						
						| 
							 | 
						    elif "aquila-chat" in model_path: | 
					
					
						
						| 
							 | 
						        return get_conv_template("aquila-chat") | 
					
					
						
						| 
							 | 
						    elif "aquila-legacy" in model_path: | 
					
					
						
						| 
							 | 
						        return get_conv_template("aquila-legacy") | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        return get_conv_template("aquila") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						register_conv_template( | 
					
					
						
						| 
							 | 
						    Conversation( | 
					
					
						
						| 
							 | 
						        name="aquila-chat", | 
					
					
						
						| 
							 | 
						        system_message="A chat between a curious human and an artificial intelligence assistant. " | 
					
					
						
						| 
							 | 
						        "The assistant gives helpful, detailed, and polite answers to the human's questions.", | 
					
					
						
						| 
							 | 
						        roles=("Human", "Assistant", "System"), | 
					
					
						
						| 
							 | 
						        messages=(), | 
					
					
						
						| 
							 | 
						        offset=0, | 
					
					
						
						| 
							 | 
						        sep_style=SeparatorStyle.ADD_COLON_SINGLE, | 
					
					
						
						| 
							 | 
						        sep="###", | 
					
					
						
						| 
							 | 
						        sep2="", | 
					
					
						
						| 
							 | 
						        stop_str=["###", "</s>", "[UNK]"], | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						register_conv_template( | 
					
					
						
						| 
							 | 
						    Conversation( | 
					
					
						
						| 
							 | 
						        name="aquila-legacy", | 
					
					
						
						| 
							 | 
						        system_message="A chat between a curious human and an artificial intelligence assistant. " | 
					
					
						
						| 
							 | 
						        "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", | 
					
					
						
						| 
							 | 
						        roles=("### Human: ", "### Assistant: ", "System"), | 
					
					
						
						| 
							 | 
						        messages=(), | 
					
					
						
						| 
							 | 
						        offset=0, | 
					
					
						
						| 
							 | 
						        sep_style=SeparatorStyle.NO_COLON_TWO, | 
					
					
						
						| 
							 | 
						        sep="\n", | 
					
					
						
						| 
							 | 
						        sep2="</s>", | 
					
					
						
						| 
							 | 
						        stop_str=["</s>", "[UNK]"], | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						register_conv_template( | 
					
					
						
						| 
							 | 
						    Conversation( | 
					
					
						
						| 
							 | 
						        name="aquila", | 
					
					
						
						| 
							 | 
						        system_message="A chat between a curious human and an artificial intelligence assistant. " | 
					
					
						
						| 
							 | 
						        "The assistant gives helpful, detailed, and polite answers to the human's questions.", | 
					
					
						
						| 
							 | 
						        roles=("Human", "Assistant", "System"), | 
					
					
						
						| 
							 | 
						        messages=(), | 
					
					
						
						| 
							 | 
						        offset=0, | 
					
					
						
						| 
							 | 
						        sep_style=SeparatorStyle.ADD_COLON_TWO, | 
					
					
						
						| 
							 | 
						        sep="###", | 
					
					
						
						| 
							 | 
						        sep2="</s>", | 
					
					
						
						| 
							 | 
						        stop_str=["</s>", "[UNK]"], | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						register_conv_template( | 
					
					
						
						| 
							 | 
						    Conversation( | 
					
					
						
						| 
							 | 
						        name="aquila-v1", | 
					
					
						
						| 
							 | 
						        roles=("<|startofpiece|>", "<|endofpiece|>", ""), | 
					
					
						
						| 
							 | 
						        messages=(), | 
					
					
						
						| 
							 | 
						        offset=0, | 
					
					
						
						| 
							 | 
						        sep_style=SeparatorStyle.NO_COLON_TWO, | 
					
					
						
						| 
							 | 
						        sep="", | 
					
					
						
						| 
							 | 
						        sep2="</s>", | 
					
					
						
						| 
							 | 
						        stop_str=["</s>", "<|endoftext|>"], | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    print("aquila template:") | 
					
					
						
						| 
							 | 
						    conv = get_conv_template("aquila") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[0], "Hello!") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[1], "Hi!") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[0], "How are you?") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[1], None) | 
					
					
						
						| 
							 | 
						    print(conv.get_prompt()) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    print("\n") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    print("aquila-chat template:") | 
					
					
						
						| 
							 | 
						    conv = get_conv_template("aquila-chat") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[0], "Hello!") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[1], "Hi!") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[0], "How are you?") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[1], None) | 
					
					
						
						| 
							 | 
						    print(conv.get_prompt()) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    print("\n") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    print("aquila-v1 template:") | 
					
					
						
						| 
							 | 
						    conv = get_conv_template("aquila-v1") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[0], "Hello!") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[1], "Hi!") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[0], "How are you?") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[1], None) | 
					
					
						
						| 
							 | 
						    print(conv.get_prompt()) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    print("\n") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    print("aquila-legacy template:") | 
					
					
						
						| 
							 | 
						    conv = get_conv_template("aquila-legacy") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[0], "Hello!") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[1], "Hi!") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[0], "How are you?") | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[1], None) | 
					
					
						
						| 
							 | 
						    print(conv.get_prompt()) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    print("\n") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def set_random_seed(seed): | 
					
					
						
						| 
							 | 
						    """Set random seed for reproducability.""" | 
					
					
						
						| 
							 | 
						    if seed is not None and seed > 0: | 
					
					
						
						| 
							 | 
						        random.seed(seed) | 
					
					
						
						| 
							 | 
						        np.random.seed(seed) | 
					
					
						
						| 
							 | 
						        torch.manual_seed(seed) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token, convo_template="aquila-chat"): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    conv = get_conv_template(convo_template) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[1], None) | 
					
					
						
						| 
							 | 
						    conv.append_message(conv.roles[0], text) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids'] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if history is None or not isinstance(history, list): | 
					
					
						
						| 
							 | 
						      history = [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    while(len(history) > 0 and (len(example) < max_token)): | 
					
					
						
						| 
							 | 
						        tmp = history.pop() | 
					
					
						
						| 
							 | 
						        if tmp[0] == 'ASSISTANT': | 
					
					
						
						| 
							 | 
						            conv.append_message(conv.roles[1], tmp[1]) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            conv.append_message(conv.roles[0], tmp[1]) | 
					
					
						
						| 
							 | 
						        example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids'] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if len(example) >= max_token: | 
					
					
						
						| 
							 | 
						        conv.messages.pop() | 
					
					
						
						| 
							 | 
						    conv.messages = conv.messages[::-1] | 
					
					
						
						| 
							 | 
						    print('model in:', conv.get_prompt()) | 
					
					
						
						| 
							 | 
						    example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids'] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return example | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def predict(model, text, tokenizer=None, | 
					
					
						
						| 
							 | 
						            max_gen_len=200, top_p=0.95, | 
					
					
						
						| 
							 | 
						            seed=1234, topk=100, | 
					
					
						
						| 
							 | 
						            temperature=0.9,  | 
					
					
						
						| 
							 | 
						            sft=True, convo_template = "", | 
					
					
						
						| 
							 | 
						            device = "cuda", | 
					
					
						
						| 
							 | 
						            model_name="AquilaChat2-7B", | 
					
					
						
						| 
							 | 
						            history=None, | 
					
					
						
						| 
							 | 
						            **kwargs): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    vocab = tokenizer.get_vocab() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    id2word = {v:k for k, v in vocab.items()} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    template_map = {"AquilaChat2-7B": "aquila-v1", | 
					
					
						
						| 
							 | 
						                    "AquilaChat2-34B": "aquila-legacy", | 
					
					
						
						| 
							 | 
						                    "AquilaChat2-7B-16K": "aquila", | 
					
					
						
						| 
							 | 
						                    "AquilaChat2-34B-16K": "aquila"} | 
					
					
						
						| 
							 | 
						    if not convo_template: | 
					
					
						
						| 
							 | 
						        convo_template=template_map.get(model_name, "aquila-chat") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    set_random_seed(seed) | 
					
					
						
						| 
							 | 
						    if temperature == 0: | 
					
					
						
						| 
							 | 
						        topk = 1 | 
					
					
						
						| 
							 | 
						        temperature = 1.0 | 
					
					
						
						| 
							 | 
						    if sft: | 
					
					
						
						| 
							 | 
						        tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=2048, convo_template=convo_template) | 
					
					
						
						| 
							 | 
						        tokens = torch.tensor(tokens)[None,].to(device) | 
					
					
						
						| 
							 | 
						    else : | 
					
					
						
						| 
							 | 
						        tokens = tokenizer.encode_plus(text)["input_ids"] | 
					
					
						
						| 
							 | 
						        print(tokenizer.decode(tokens)) | 
					
					
						
						| 
							 | 
						        tokens = torch.tensor(tokens)[None,].to(device) | 
					
					
						
						| 
							 | 
						    input_length = len(tokens[0]) | 
					
					
						
						| 
							 | 
						    with torch.no_grad(): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        logits_processor = LogitsProcessorList( | 
					
					
						
						| 
							 | 
						            [ | 
					
					
						
						| 
							 | 
						                MinLengthLogitsProcessor(1, eos_token_id=100007), | 
					
					
						
						| 
							 | 
						            ] | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        logits_warper = LogitsProcessorList( | 
					
					
						
						| 
							 | 
						            [ | 
					
					
						
						| 
							 | 
						                TopPLogitsWarper(top_p), | 
					
					
						
						| 
							 | 
						                TopKLogitsWarper(topk), | 
					
					
						
						| 
							 | 
						                TemperatureLogitsWarper(temperature), | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						            ] | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=input_length + max_gen_len)]) | 
					
					
						
						| 
							 | 
						        out = model.sample( | 
					
					
						
						| 
							 | 
						                            tokens, | 
					
					
						
						| 
							 | 
						                            logits_processor=logits_processor, | 
					
					
						
						| 
							 | 
						                            logits_warper=logits_warper, | 
					
					
						
						| 
							 | 
						                            stopping_criteria=stopping_criteria, | 
					
					
						
						| 
							 | 
						                            return_dict_in_generate=True,  | 
					
					
						
						| 
							 | 
						                            output_scores=True, | 
					
					
						
						| 
							 | 
						                        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        out_ids = out["sequences"][0][input_length:].cpu().numpy() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        out_scores = out["scores"] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        out_scores = torch.cat(out_scores, dim=0) | 
					
					
						
						| 
							 | 
						        out_scores = torch.nn.functional.softmax(out_scores, dim=-1).cpu().numpy() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        probs = [] | 
					
					
						
						| 
							 | 
						        for i in range(len(out_ids)): | 
					
					
						
						| 
							 | 
						            probs.append(float(out_scores[i][out_ids[i]])) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        convert_tokens = [] | 
					
					
						
						| 
							 | 
						        for t in out_ids: | 
					
					
						
						| 
							 | 
						            if t == 100006: | 
					
					
						
						| 
							 | 
						                convert_tokens.append("[CLS]") | 
					
					
						
						| 
							 | 
						            else : | 
					
					
						
						| 
							 | 
						                convert_tokens.append(id2word.get(t, "[unkonwn_token]")) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        out_text = tokenizer.decode(out_ids.tolist()) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        out = out_text | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if "[UNK]" in out: | 
					
					
						
						| 
							 | 
						        special_index = out.index("[UNK]") | 
					
					
						
						| 
							 | 
						        out = out[:special_index] | 
					
					
						
						| 
							 | 
						        token_length = len(tokenizer.encode_plus(out)["input_ids"]) | 
					
					
						
						| 
							 | 
						        convert_tokens = convert_tokens[:token_length] | 
					
					
						
						| 
							 | 
						        probs = probs[:token_length] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if "</s>" in out: | 
					
					
						
						| 
							 | 
						        special_index = out.index("</s>") | 
					
					
						
						| 
							 | 
						        out = out[: special_index] | 
					
					
						
						| 
							 | 
						        token_length = len(tokenizer.encode_plus(out)["input_ids"]) | 
					
					
						
						| 
							 | 
						        convert_tokens = convert_tokens[:token_length] | 
					
					
						
						| 
							 | 
						        probs = probs[:token_length] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if len(out) > 0 and out[0] == " ": | 
					
					
						
						| 
							 | 
						        out = out[1:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        convert_tokens = convert_tokens[1:] | 
					
					
						
						| 
							 | 
						        probs = probs[1:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if isinstance(history, list): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        history.insert(0, ('ASSISTANT', out)) | 
					
					
						
						| 
							 | 
						        history.insert(0, ('USER', text)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return out  | 
					
					
						
						| 
							 | 
						
 |