Spaces:
Runtime error
Runtime error
| from simpletransformers.conv_ai import ConvAIModel | |
| from simpletransformers.conv_ai.conv_ai_utils import get_dataset | |
| import torch | |
| import random | |
| import os | |
| import copy | |
| import json | |
| class ConvAIModelExtended(ConvAIModel): | |
| PERSONACHAT_URL = "https://cloud.uncool.ai/index.php/s/Pazx3rifmFpwNNm/download/id_personachat.json" | |
| dataset_path = "data/id_personachat.json" | |
| persona_list_path = "data/persona_list.json" | |
| dialogs = {} | |
| dialogs_counter = 0 | |
| def __init__(self, model_type, model_name, args=None, use_cuda=True, **kwargs): | |
| super(ConvAIModelExtended, self).__init__(model_type, model_name, | |
| args, use_cuda, **kwargs) | |
| os.makedirs(self.args.cache_dir, exist_ok=True) | |
| self.dataset = get_dataset( | |
| self.tokenizer, | |
| dataset_path=ConvAIModelExtended.dataset_path, | |
| dataset_cache=self.args.cache_dir, | |
| process_count=self.args.process_count, | |
| proxies=self.__dict__.get("proxies", None), | |
| interact=False, | |
| args=self.args, | |
| ) | |
| self.personalities = [ | |
| dialog["personality"] | |
| for dataset in self.dataset.values() | |
| for dialog in dataset | |
| ] | |
| with open(ConvAIModelExtended.persona_list_path, "r") as f: | |
| self.persona_list = json.load(f) | |
| def new_dialog(self): | |
| tokenizer = self.tokenizer | |
| ConvAIModelExtended.dialogs_counter += 1 | |
| dialog_id = ConvAIModelExtended.dialogs_counter | |
| persona_list = copy.deepcopy(self.persona_list) | |
| for persona in persona_list: | |
| persona["history"] = [] | |
| persona["personality"] = [tokenizer.encode(s.lower()) for s in persona["personality"]] | |
| persona_ids = {persona["id"]: persona for persona in persona_list} | |
| ConvAIModelExtended.dialogs[dialog_id] = { | |
| "persona_list": persona_list, | |
| "persona_ids": persona_ids, | |
| "args": copy.deepcopy(self.args) # each dialog has its own independent copy of args | |
| } | |
| return dialog_id | |
| def delete_dialog(dialog_id): | |
| del ConvAIModelExtended.dialogs[dialog_id] | |
| def get_persona_list(self, dialog_id: int): | |
| tokenizer = self.tokenizer | |
| persona_list = copy.deepcopy(ConvAIModelExtended.dialogs[dialog_id]["persona_list"]) | |
| for persona in persona_list: | |
| persona["personality"] = [tokenizer.decode(tokens) for tokens in persona["personality"]] | |
| return persona_list | |
| def set_personality(self, dialog_id: int, persona_id: str, personality: list): | |
| tokenizer = self.tokenizer | |
| personality = [tokenizer.encode(s.lower()) for s in personality] | |
| for i in range(3, len(ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["personality"])): | |
| ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["personality"][i] = personality[i-3] | |
| def get_persona_name(dialog_id: int, persona_id: int): | |
| name = ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["name"] | |
| return name | |
| def talk(self, dialog_id: int, persona_id:int, utterance: str, | |
| do_sample: bool = True, min_length: int = 1, max_length: int = 20, | |
| temperature: float = 0.7, top_k: int = 0, top_p: float = 0.9): | |
| model = self.model | |
| args = ConvAIModelExtended.dialogs[dialog_id]["args"] | |
| args.do_sample = do_sample | |
| args.min_length = min_length | |
| args.max_length = max_length | |
| args.temperature = temperature | |
| args.top_k = top_k | |
| args.top_p = top_p | |
| tokenizer = self.tokenizer | |
| ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"].append( | |
| tokenizer.encode(utterance) | |
| ) | |
| with torch.no_grad(): | |
| out_ids = self.sample_sequence( | |
| ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["personality"], | |
| ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"], | |
| tokenizer, model, args | |
| ) | |
| if len(out_ids) == 0: | |
| return "Ma'af, saya tidak mengerti. Coba tanya yang lain" | |
| ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"].append(out_ids) | |
| ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"] = \ | |
| ConvAIModelExtended.dialogs[dialog_id]["persona_ids"][persona_id]["history"][-(2 * args.max_history + 1):] | |
| out_text = tokenizer.decode( | |
| out_ids, skip_special_tokens=args.skip_special_tokens | |
| ) | |
| return out_text | |