|  | """ | 
					
						
						|  | Copied from https://github.com/lm-sys/FastChat. | 
					
						
						|  | Later we will contribute our changes into it. | 
					
						
						|  | """ | 
					
						
						|  | import dataclasses | 
					
						
						|  | from enum import auto, IntEnum | 
					
						
						|  | from typing import List, Any, Dict | 
					
						
						|  | import math | 
					
						
						|  | from typing import List, Optional, Tuple, Union | 
					
						
						|  | import random | 
					
						
						|  | import numpy as np | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.utils.checkpoint | 
					
						
						|  | from torch import nn | 
					
						
						|  | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | 
					
						
						|  |  | 
					
						
						|  | from transformers.activations import ACT2FN | 
					
						
						|  | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast | 
					
						
						|  | from transformers.modeling_utils import PreTrainedModel | 
					
						
						|  | from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings | 
					
						
						|  | from transformers import ( | 
					
						
						|  | LogitsProcessorList, | 
					
						
						|  | MinLengthLogitsProcessor, | 
					
						
						|  | TopKLogitsWarper, | 
					
						
						|  | TemperatureLogitsWarper, | 
					
						
						|  | TopPLogitsWarper, | 
					
						
						|  | StoppingCriteriaList, | 
					
						
						|  | MaxLengthCriteria, | 
					
						
						|  | BitsAndBytesConfig, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SeparatorStyle(IntEnum): | 
					
						
						|  | """Separator styles.""" | 
					
						
						|  |  | 
					
						
						|  | ADD_COLON_SINGLE = auto() | 
					
						
						|  | ADD_COLON_TWO = auto() | 
					
						
						|  | ADD_COLON_SPACE_SINGLE = auto() | 
					
						
						|  | NO_COLON_SINGLE = auto() | 
					
						
						|  | NO_COLON_TWO = auto() | 
					
						
						|  | ADD_NEW_LINE_SINGLE = auto() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclasses.dataclass | 
					
						
						|  | class Conversation: | 
					
						
						|  | """A class that manages prompt templates and keeps all conversation history.""" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | name: str | 
					
						
						|  |  | 
					
						
						|  | system_template: str = "{system_message}" | 
					
						
						|  |  | 
					
						
						|  | system_message: str = "" | 
					
						
						|  |  | 
					
						
						|  | roles: List[str] = (("USER", "ASSISTANT"),) | 
					
						
						|  |  | 
					
						
						|  | messages: List[List[str]] = () | 
					
						
						|  |  | 
					
						
						|  | offset: int = 0 | 
					
						
						|  |  | 
					
						
						|  | sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE | 
					
						
						|  | sep: str = "\n" | 
					
						
						|  | sep2: str = None | 
					
						
						|  |  | 
					
						
						|  | stop_str: str = None | 
					
						
						|  |  | 
					
						
						|  | stop_token_ids: List[int] = None | 
					
						
						|  |  | 
					
						
						|  | def get_prompt(self) -> str: | 
					
						
						|  | """Get the prompt for generation.""" | 
					
						
						|  | system_prompt = self.system_template.format(system_message=self.system_message) | 
					
						
						|  | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: | 
					
						
						|  | ret = system_prompt + self.sep | 
					
						
						|  | for role, message in self.messages: | 
					
						
						|  | if message: | 
					
						
						|  | ret += role + ": " + message + self.sep | 
					
						
						|  | else: | 
					
						
						|  | ret += role + ":" | 
					
						
						|  | return ret | 
					
						
						|  | elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: | 
					
						
						|  | seps = [self.sep, self.sep2] | 
					
						
						|  | ret = system_prompt + seps[0] | 
					
						
						|  | for i, (role, message) in enumerate(self.messages): | 
					
						
						|  | if message: | 
					
						
						|  | ret += role + ": " + message + seps[i % 2] | 
					
						
						|  | else: | 
					
						
						|  | ret += role + ":" | 
					
						
						|  | return ret | 
					
						
						|  | elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: | 
					
						
						|  | ret = system_prompt + self.sep | 
					
						
						|  | for role, message in self.messages: | 
					
						
						|  | if message: | 
					
						
						|  | ret += role + ": " + message + self.sep | 
					
						
						|  | else: | 
					
						
						|  | ret += role + ": " | 
					
						
						|  | return ret | 
					
						
						|  | elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: | 
					
						
						|  | ret = "" if system_prompt == "" else system_prompt + self.sep | 
					
						
						|  | for role, message in self.messages: | 
					
						
						|  | if message: | 
					
						
						|  | ret += role + "\n" + message + self.sep | 
					
						
						|  | else: | 
					
						
						|  | ret += role + "\n" | 
					
						
						|  | return ret | 
					
						
						|  | elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: | 
					
						
						|  | ret = system_prompt | 
					
						
						|  | for role, message in self.messages: | 
					
						
						|  | if message: | 
					
						
						|  | ret += role + message + self.sep | 
					
						
						|  | else: | 
					
						
						|  | ret += role | 
					
						
						|  | return ret | 
					
						
						|  | elif self.sep_style == SeparatorStyle.NO_COLON_TWO: | 
					
						
						|  | seps = [self.sep, self.sep2] | 
					
						
						|  | ret = system_prompt | 
					
						
						|  | for i, (role, message) in enumerate(self.messages): | 
					
						
						|  | if message: | 
					
						
						|  | ret += role + message + seps[i % 2] | 
					
						
						|  | else: | 
					
						
						|  | ret += role | 
					
						
						|  | return ret | 
					
						
						|  |  | 
					
						
						|  | def set_system_message(self, system_message: str): | 
					
						
						|  | """Set the system message.""" | 
					
						
						|  | self.system_message = system_message | 
					
						
						|  |  | 
					
						
						|  | def append_message(self, role: str, message: str): | 
					
						
						|  | """Append a new message.""" | 
					
						
						|  | self.messages.append([role, message]) | 
					
						
						|  |  | 
					
						
						|  | def update_last_message(self, message: str): | 
					
						
						|  | """Update the last output. | 
					
						
						|  |  | 
					
						
						|  | The last message is typically set to be None when constructing the prompt, | 
					
						
						|  | so we need to update it in-place after getting the response from a model. | 
					
						
						|  | """ | 
					
						
						|  | self.messages[-1][1] = message | 
					
						
						|  |  | 
					
						
						|  | def copy(self): | 
					
						
						|  | return Conversation( | 
					
						
						|  | name=self.name, | 
					
						
						|  | system_template=self.system_template, | 
					
						
						|  | system_message=self.system_message, | 
					
						
						|  | roles=self.roles, | 
					
						
						|  | messages=[[x, y] for x, y in self.messages], | 
					
						
						|  | offset=self.offset, | 
					
						
						|  | sep_style=self.sep_style, | 
					
						
						|  | sep=self.sep, | 
					
						
						|  | sep2=self.sep2, | 
					
						
						|  | stop_str=self.stop_str, | 
					
						
						|  | stop_token_ids=self.stop_token_ids, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def dict(self): | 
					
						
						|  | return { | 
					
						
						|  | "template_name": self.name, | 
					
						
						|  | "system_message": self.system_message, | 
					
						
						|  | "roles": self.roles, | 
					
						
						|  | "messages": self.messages, | 
					
						
						|  | "offset": self.offset, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | conv_templates: Dict[str, Conversation] = {} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def register_conv_template(template: Conversation, override: bool = False): | 
					
						
						|  | """Register a new conversation template.""" | 
					
						
						|  | if not override: | 
					
						
						|  | assert ( | 
					
						
						|  | template.name not in conv_templates | 
					
						
						|  | ), f"{template.name} has been registered." | 
					
						
						|  |  | 
					
						
						|  | conv_templates[template.name] = template | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_conv_template(name: str) -> Conversation: | 
					
						
						|  | """Get a conversation template.""" | 
					
						
						|  | return conv_templates[name].copy() | 
					
						
						|  |  | 
					
						
						|  | def get_conversation_template(model_path: str) -> Conversation: | 
					
						
						|  | """Get the default conversation template.""" | 
					
						
						|  | if "aquila-v1" in model_path: | 
					
						
						|  | return get_conv_template("aquila-v1") | 
					
						
						|  | elif "aquila-v2" in model_path: | 
					
						
						|  | return get_conv_template("aquila-v2") | 
					
						
						|  | elif "aquila-chat" in model_path: | 
					
						
						|  | return get_conv_template("aquila-chat") | 
					
						
						|  | elif "aquila-legacy" in model_path: | 
					
						
						|  | return get_conv_template("aquila-legacy") | 
					
						
						|  | else: | 
					
						
						|  | return get_conv_template("aquila") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | register_conv_template( | 
					
						
						|  | Conversation( | 
					
						
						|  | name="aquila-chat", | 
					
						
						|  | system_message="A chat between a curious human and an artificial intelligence assistant. " | 
					
						
						|  | "The assistant gives helpful, detailed, and polite answers to the human's questions.", | 
					
						
						|  | roles=("Human", "Assistant", "System"), | 
					
						
						|  | messages=(), | 
					
						
						|  | offset=0, | 
					
						
						|  | sep_style=SeparatorStyle.ADD_COLON_SINGLE, | 
					
						
						|  | sep="###", | 
					
						
						|  | sep2="", | 
					
						
						|  | stop_str=["###", "</s>", "[UNK]"], | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | register_conv_template( | 
					
						
						|  | Conversation( | 
					
						
						|  | name="aquila-legacy", | 
					
						
						|  | system_message="A chat between a curious human and an artificial intelligence assistant. " | 
					
						
						|  | "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", | 
					
						
						|  | roles=("### Human: ", "### Assistant: ", "System"), | 
					
						
						|  | messages=(), | 
					
						
						|  | offset=0, | 
					
						
						|  | sep_style=SeparatorStyle.NO_COLON_TWO, | 
					
						
						|  | sep="\n", | 
					
						
						|  | sep2="</s>", | 
					
						
						|  | stop_str=["</s>", "[UNK]"], | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | register_conv_template( | 
					
						
						|  | Conversation( | 
					
						
						|  | name="aquila", | 
					
						
						|  | system_message="A chat between a curious human and an artificial intelligence assistant. " | 
					
						
						|  | "The assistant gives helpful, detailed, and polite answers to the human's questions.", | 
					
						
						|  | roles=("Human", "Assistant", "System"), | 
					
						
						|  | messages=(), | 
					
						
						|  | offset=0, | 
					
						
						|  | sep_style=SeparatorStyle.ADD_COLON_TWO, | 
					
						
						|  | sep="###", | 
					
						
						|  | sep2="</s>", | 
					
						
						|  | stop_str=["</s>", "[UNK]"], | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | register_conv_template( | 
					
						
						|  | Conversation( | 
					
						
						|  | name="aquila-v1", | 
					
						
						|  | roles=("<|startofpiece|>", "<|endofpiece|>", ""), | 
					
						
						|  | messages=(), | 
					
						
						|  | offset=0, | 
					
						
						|  | sep_style=SeparatorStyle.NO_COLON_TWO, | 
					
						
						|  | sep="", | 
					
						
						|  | sep2="</s>", | 
					
						
						|  | stop_str=["</s>", "<|endoftext|>"], | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | register_conv_template( | 
					
						
						|  | Conversation( | 
					
						
						|  | name="aquila-v2", | 
					
						
						|  | system_message="A chat between a curious human and an artificial intelligence assistant. " | 
					
						
						|  | "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", | 
					
						
						|  | roles=("<|startofpiece|>", "<|endofpiece|>", ""), | 
					
						
						|  | messages=(), | 
					
						
						|  | offset=0, | 
					
						
						|  | sep_style=SeparatorStyle.NO_COLON_TWO, | 
					
						
						|  | sep="", | 
					
						
						|  | sep2="</s>", | 
					
						
						|  | stop_str=["</s>", "<|endoftext|>", "<|startofpiece|>", "<|endofpiece|>"], | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | print("aquila template:") | 
					
						
						|  | conv = get_conv_template("aquila") | 
					
						
						|  | conv.append_message(conv.roles[0], "Hello!") | 
					
						
						|  | conv.append_message(conv.roles[1], "Hi!") | 
					
						
						|  | conv.append_message(conv.roles[0], "How are you?") | 
					
						
						|  | conv.append_message(conv.roles[1], None) | 
					
						
						|  | print(conv.get_prompt()) | 
					
						
						|  |  | 
					
						
						|  | print("\n") | 
					
						
						|  |  | 
					
						
						|  | print("aquila-chat template:") | 
					
						
						|  | conv = get_conv_template("aquila-chat") | 
					
						
						|  | conv.append_message(conv.roles[0], "Hello!") | 
					
						
						|  | conv.append_message(conv.roles[1], "Hi!") | 
					
						
						|  | conv.append_message(conv.roles[0], "How are you?") | 
					
						
						|  | conv.append_message(conv.roles[1], None) | 
					
						
						|  | print(conv.get_prompt()) | 
					
						
						|  |  | 
					
						
						|  | print("\n") | 
					
						
						|  |  | 
					
						
						|  | print("aquila-v1 template:") | 
					
						
						|  | conv = get_conv_template("aquila-v1") | 
					
						
						|  | conv.append_message(conv.roles[0], "Hello!") | 
					
						
						|  | conv.append_message(conv.roles[1], "Hi!") | 
					
						
						|  | conv.append_message(conv.roles[0], "How are you?") | 
					
						
						|  | conv.append_message(conv.roles[1], None) | 
					
						
						|  | print(conv.get_prompt()) | 
					
						
						|  |  | 
					
						
						|  | print("\n") | 
					
						
						|  |  | 
					
						
						|  | print("aquila-legacy template:") | 
					
						
						|  | conv = get_conv_template("aquila-legacy") | 
					
						
						|  | conv.append_message(conv.roles[0], "Hello!") | 
					
						
						|  | conv.append_message(conv.roles[1], "Hi!") | 
					
						
						|  | conv.append_message(conv.roles[0], "How are you?") | 
					
						
						|  | conv.append_message(conv.roles[1], None) | 
					
						
						|  | print(conv.get_prompt()) | 
					
						
						|  |  | 
					
						
						|  | print("\n") | 
					
						
						|  |  | 
					
						
						|  | print("aquila-v2 template:") | 
					
						
						|  | conv = get_conv_template("aquila-v2") | 
					
						
						|  | conv.append_message(conv.roles[0], "Hello!") | 
					
						
						|  | conv.append_message(conv.roles[1], "Hi!") | 
					
						
						|  | conv.append_message(conv.roles[0], "How are you?") | 
					
						
						|  | conv.append_message(conv.roles[1], None) | 
					
						
						|  | print(conv.get_prompt()) | 
					
						
						|  |  | 
					
						
						|  | print("\n") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def set_random_seed(seed): | 
					
						
						|  | """Set random seed for reproducability.""" | 
					
						
						|  | if seed is not None and seed > 0: | 
					
						
						|  | random.seed(seed) | 
					
						
						|  | np.random.seed(seed) | 
					
						
						|  | torch.manual_seed(seed) | 
					
						
						|  |  | 
					
						
						|  | def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token, convo_template="aquila-chat"): | 
					
						
						|  |  | 
					
						
						|  | conv = get_conv_template(convo_template) | 
					
						
						|  |  | 
					
						
						|  | conv.append_message(conv.roles[1], None) | 
					
						
						|  | conv.append_message(conv.roles[0], text) | 
					
						
						|  |  | 
					
						
						|  | example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids'] | 
					
						
						|  |  | 
					
						
						|  | if history is None or not isinstance(history, list): | 
					
						
						|  | history = [] | 
					
						
						|  |  | 
					
						
						|  | while(len(history) > 0 and (len(example) < max_token)): | 
					
						
						|  | tmp = history.pop() | 
					
						
						|  | if tmp[0] == 'ASSISTANT': | 
					
						
						|  | conv.append_message(conv.roles[1], tmp[1]) | 
					
						
						|  | else: | 
					
						
						|  | conv.append_message(conv.roles[0], tmp[1]) | 
					
						
						|  | example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids'] | 
					
						
						|  |  | 
					
						
						|  | if len(example) >= max_token: | 
					
						
						|  | conv.messages.pop() | 
					
						
						|  | conv.messages = conv.messages[::-1] | 
					
						
						|  | print('model in:', conv.get_prompt()) | 
					
						
						|  | example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids'] | 
					
						
						|  |  | 
					
						
						|  | return example | 
					
						
						|  |  | 
					
						
						|  | def predict(model, text, tokenizer=None, | 
					
						
						|  | max_gen_len=200, top_p=0.9, | 
					
						
						|  | seed=123, topk=15, | 
					
						
						|  | temperature=1.0, | 
					
						
						|  | sft=True, convo_template = "", | 
					
						
						|  | device = "cuda", | 
					
						
						|  | model_name="AquilaChat2-7B", | 
					
						
						|  | history=None, | 
					
						
						|  | **kwargs): | 
					
						
						|  |  | 
					
						
						|  | vocab = tokenizer.get_vocab() | 
					
						
						|  |  | 
					
						
						|  | id2word = {v:k for k, v in vocab.items()} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | template_map = {"AquilaChat2-7B": "aquila-v1", | 
					
						
						|  | "AquilaChat2-34B": "aquila-legacy", | 
					
						
						|  | "AquilaChat2-70B-Expr": "aquila-v2", | 
					
						
						|  | "AquilaChat2-7B-16K": "aquila", | 
					
						
						|  | "AquilaChat2-34B-16K": "aquila"} | 
					
						
						|  | if not convo_template: | 
					
						
						|  | convo_template=template_map.get(model_name, "aquila-chat") | 
					
						
						|  |  | 
					
						
						|  | set_random_seed(seed) | 
					
						
						|  | if temperature == 0: | 
					
						
						|  | topk = 1 | 
					
						
						|  | temperature = 1.0 | 
					
						
						|  | if sft: | 
					
						
						|  | tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=20480, convo_template=convo_template) | 
					
						
						|  | tokens = torch.tensor(tokens)[None,].to(device) | 
					
						
						|  | else : | 
					
						
						|  | tokens = tokenizer.encode_plus(text)["input_ids"] | 
					
						
						|  | print(tokenizer.decode(tokens)) | 
					
						
						|  | tokens = torch.tensor(tokens)[None,].to(device) | 
					
						
						|  | input_length = len(tokens[0]) | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logits_processor = LogitsProcessorList( | 
					
						
						|  | [ | 
					
						
						|  | MinLengthLogitsProcessor(1, eos_token_id=100007), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | logits_warper = LogitsProcessorList( | 
					
						
						|  | [ | 
					
						
						|  | TopPLogitsWarper(top_p), | 
					
						
						|  | TopKLogitsWarper(topk), | 
					
						
						|  | TemperatureLogitsWarper(temperature), | 
					
						
						|  |  | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=input_length + max_gen_len)]) | 
					
						
						|  | out = model.sample( | 
					
						
						|  | tokens, | 
					
						
						|  | logits_processor=logits_processor, | 
					
						
						|  | logits_warper=logits_warper, | 
					
						
						|  | stopping_criteria=stopping_criteria, | 
					
						
						|  | return_dict_in_generate=True, | 
					
						
						|  | output_scores=True, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | out_ids = out["sequences"][0][input_length:].cpu().numpy() | 
					
						
						|  |  | 
					
						
						|  | out_scores = out["scores"] | 
					
						
						|  |  | 
					
						
						|  | out_scores = torch.cat(out_scores, dim=0) | 
					
						
						|  | out_scores = torch.nn.functional.softmax(out_scores, dim=-1).cpu().numpy() | 
					
						
						|  |  | 
					
						
						|  | probs = [] | 
					
						
						|  | for i in range(len(out_ids)): | 
					
						
						|  | probs.append(float(out_scores[i][out_ids[i]])) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | convert_tokens = [] | 
					
						
						|  | for t in out_ids: | 
					
						
						|  | if t == 100006: | 
					
						
						|  | convert_tokens.append("[CLS]") | 
					
						
						|  | else : | 
					
						
						|  | convert_tokens.append(id2word.get(t, "[unkonwn_token]")) | 
					
						
						|  |  | 
					
						
						|  | out_text = tokenizer.decode(out_ids.tolist()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | out = out_text | 
					
						
						|  |  | 
					
						
						|  | if "[UNK]" in out: | 
					
						
						|  | special_index = out.index("[UNK]") | 
					
						
						|  | out = out[:special_index] | 
					
						
						|  | token_length = len(tokenizer.encode_plus(out)["input_ids"]) | 
					
						
						|  | convert_tokens = convert_tokens[:token_length] | 
					
						
						|  | probs = probs[:token_length] | 
					
						
						|  |  | 
					
						
						|  | if "</s>" in out: | 
					
						
						|  | special_index = out.index("</s>") | 
					
						
						|  | out = out[: special_index] | 
					
						
						|  | token_length = len(tokenizer.encode_plus(out)["input_ids"]) | 
					
						
						|  | convert_tokens = convert_tokens[:token_length] | 
					
						
						|  | probs = probs[:token_length] | 
					
						
						|  |  | 
					
						
						|  | if len(out) > 0 and out[0] == " ": | 
					
						
						|  | out = out[1:] | 
					
						
						|  |  | 
					
						
						|  | convert_tokens = convert_tokens[1:] | 
					
						
						|  | probs = probs[1:] | 
					
						
						|  |  | 
					
						
						|  | if isinstance(history, list): | 
					
						
						|  |  | 
					
						
						|  | history.insert(0, ('ASSISTANT', out)) | 
					
						
						|  | history.insert(0, ('USER', text)) | 
					
						
						|  |  | 
					
						
						|  | return out | 
					
						
						|  |  |