Spaces:
Runtime error
Runtime error
| import os | |
| import threading | |
| import toml | |
| from pathlib import Path | |
| import google.generativeai as palm_api | |
| from pingpong import PingPong | |
| from pingpong.pingpong import PPManager | |
| from pingpong.pingpong import PromptFmt | |
| from pingpong.pingpong import UIFmt | |
| from pingpong.gradio import GradioChatUIFmt | |
| from modules.llms import ( | |
| LLMFactory, | |
| PromptFmt, PromptManager, PPManager, UIPPManager, LLMService | |
| ) | |
| class PaLMFactory(LLMFactory): | |
| _palm_api_key = None | |
| def __init__(self, palm_api_key=None): | |
| if not PaLMFactory._palm_api_key: | |
| PaLMFactory.load_palm_api_key() | |
| assert PaLMFactory._palm_api_key, "PaLM API Key is missing." | |
| palm_api.configure(api_key=PaLMFactory._palm_api_key) | |
| def create_prompt_format(self): | |
| return PaLMChatPromptFmt() | |
| def create_prompt_manager(self, prompts_path: str=None): | |
| return PaLMPromptManager((prompts_path or Path('.') / 'prompts' / 'palm_prompts.toml')) | |
| def create_pp_manager(self): | |
| return PaLMChatPPManager() | |
| def create_ui_pp_manager(self): | |
| return GradioPaLMChatPPManager() | |
| def create_llm_service(self): | |
| return PaLMService() | |
| def load_palm_api_key(cls, palm_api_key: str=None): | |
| if palm_api_key: | |
| cls._palm_api_key = palm_api_key | |
| else: | |
| palm_api_key = os.getenv("PALM_API_KEY") | |
| if palm_api_key is None: | |
| with open('.palm_api_key.txt', 'r') as file: | |
| palm_api_key = file.read().strip() | |
| if not palm_api_key: | |
| raise ValueError("PaLM API Key is missing.") | |
| cls._palm_api_key = palm_api_key | |
| def palm_api_key(self): | |
| return PaLMFactory._palm_api_key | |
| def palm_api_key(self, palm_api_key: str): | |
| assert palm_api_key, "PaLM API Key is missing." | |
| PaLMFactory._palm_api_key = palm_api_key | |
| palm_api.configure(api_key=PaLMFactory._palm_api_key) | |
| class PaLMChatPromptFmt(PromptFmt): | |
| def ctx(cls, context): | |
| pass | |
| def prompt(cls, pingpong, truncate_size): | |
| ping = pingpong.ping[:truncate_size] | |
| pong = pingpong.pong | |
| if pong is None or pong.strip() == "": | |
| return [ | |
| { | |
| "author": "USER", | |
| "content": ping | |
| }, | |
| ] | |
| else: | |
| pong = pong[:truncate_size] | |
| return [ | |
| { | |
| "author": "USER", | |
| "content": ping | |
| }, | |
| { | |
| "author": "AI", | |
| "content": pong | |
| }, | |
| ] | |
| class PaLMPromptManager(PromptManager): | |
| _instance = None | |
| _lock = threading.Lock() | |
| _prompts = None | |
| def __new__(cls, prompts_path): | |
| if cls._instance is None: | |
| with cls._lock: | |
| if not cls._instance: | |
| cls._instance = super(PaLMPromptManager, cls).__new__(cls) | |
| cls._instance.load_prompts(prompts_path) | |
| return cls._instance | |
| def load_prompts(self, prompts_path): | |
| self._prompts_path = prompts_path | |
| self.reload_prompts() | |
| def reload_prompts(self): | |
| assert self.prompts_path, "Prompt path is missing." | |
| self._prompts = toml.load(self.prompts_path) | |
| def prompts_path(self): | |
| return self._prompts_path | |
| def prompts_path(self, prompts_path): | |
| self._prompts_path = prompts_path | |
| self.reload_prompts() | |
| def prompts(self): | |
| if self._prompts is None: | |
| self.load_prompts() | |
| return self._prompts | |
| class PaLMChatPPManager(PPManager): | |
| def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=None, truncate_size: int=None): | |
| if fmt is None: | |
| factory = PaLMFactory() | |
| fmt = factory.create_prompt_format() | |
| results = [] | |
| if to_idx == -1 or to_idx >= len(self.pingpongs): | |
| to_idx = len(self.pingpongs) | |
| for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]): | |
| results += fmt.prompt(pingpong, truncate_size=truncate_size) | |
| return results | |
| class GradioPaLMChatPPManager(UIPPManager, PaLMChatPPManager): | |
| def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt): | |
| if to_idx == -1 or to_idx >= len(self.pingpongs): | |
| to_idx = len(self.pingpongs) | |
| results = [] | |
| for pingpong in self.pingpongs[from_idx:to_idx]: | |
| results.append(fmt.ui(pingpong)) | |
| return results | |
| class PaLMService(LLMService): | |
| def __init__(self): | |
| self._default_parameters_text = { | |
| 'model': 'models/text-bison-001', | |
| 'temperature': 0.7, | |
| 'candidate_count': 1, | |
| 'top_k': 40, | |
| 'top_p': 0.95, | |
| 'max_output_tokens': 1024, | |
| 'stop_sequences': [], | |
| 'safety_settings': [{"category":"HARM_CATEGORY_DEROGATORY","threshold":1}, | |
| {"category":"HARM_CATEGORY_TOXICITY","threshold":1}, | |
| {"category":"HARM_CATEGORY_VIOLENCE","threshold":2}, | |
| {"category":"HARM_CATEGORY_SEXUAL","threshold":2}, | |
| {"category":"HARM_CATEGORY_MEDICAL","threshold":2}, | |
| {"category":"HARM_CATEGORY_DANGEROUS","threshold":2}], | |
| } | |
| self._default_parameters_chat = { | |
| 'model': 'models/chat-bison-001', | |
| 'temperature': 0.25, | |
| 'candidate_count': 1, | |
| 'top_k': 40, | |
| 'top_p': 0.95, | |
| } | |
| def make_params(self, mode="chat", | |
| temperature=None, | |
| candidate_count=None, | |
| top_k=None, | |
| top_p=None, | |
| max_output_tokens=None, | |
| use_filter=True): | |
| parameters = None | |
| if mode == "chat": | |
| parameters = self._default_parameters_chat.copy() | |
| elif mode == "text": | |
| parameters = self._default_parameters_text.copy() | |
| if temperature is not None: | |
| parameters['temperature'] = temperature | |
| if candidate_count is not None: | |
| parameters['candidate_count'] = candidate_count | |
| if top_k is not None: | |
| parameters['top_k'] = top_k | |
| if max_output_tokens is not None and mode == "text": | |
| parameters['max_output_tokens'] = max_output_tokens | |
| if not use_filter and mode == "text": | |
| for idx, _ in enumerate(parameters['safety_settings']): | |
| parameters['safety_settings'][idx]['threshold'] = 4 | |
| return parameters | |
| async def gen_text( | |
| self, | |
| prompt, | |
| mode="chat", #chat or text | |
| parameters=None, | |
| use_filter=True | |
| ): | |
| if parameters is None: | |
| temperature = 1.0 | |
| top_k = 40 | |
| top_p = 0.95 | |
| max_output_tokens = 1024 | |
| # default safety settings | |
| safety_settings = [{"category":"HARM_CATEGORY_DEROGATORY","threshold":1}, | |
| {"category":"HARM_CATEGORY_TOXICITY","threshold":1}, | |
| {"category":"HARM_CATEGORY_VIOLENCE","threshold":2}, | |
| {"category":"HARM_CATEGORY_SEXUAL","threshold":2}, | |
| {"category":"HARM_CATEGORY_MEDICAL","threshold":2}, | |
| {"category":"HARM_CATEGORY_DANGEROUS","threshold":2}] | |
| if not use_filter: | |
| for idx, _ in enumerate(safety_settings): | |
| safety_settings[idx]['threshold'] = 4 | |
| if mode == "chat": | |
| parameters = { | |
| 'model': 'models/chat-bison-001', | |
| 'candidate_count': 1, | |
| 'context': "", | |
| 'temperature': temperature, | |
| 'top_k': top_k, | |
| 'top_p': top_p, | |
| 'safety_settings': safety_settings, | |
| } | |
| else: | |
| parameters = { | |
| 'model': 'models/text-bison-001', | |
| 'candidate_count': 1, | |
| 'temperature': temperature, | |
| 'top_k': top_k, | |
| 'top_p': top_p, | |
| 'max_output_tokens': max_output_tokens, | |
| 'safety_settings': safety_settings, | |
| } | |
| try: | |
| if mode == "chat": | |
| response = await palm_api.chat_async(**parameters, messages=prompt) | |
| else: | |
| response = palm_api.generate_text(**parameters, prompt=prompt) | |
| except: | |
| raise EnvironmentError("PaLM API is not available.") | |
| if use_filter and len(response.filters) > 0: | |
| raise Exception("PaLM API has withheld a response due to content safety concerns.") | |
| else: | |
| if mode == "chat": | |
| response_txt = response.last | |
| else: | |
| response_txt = response.result | |
| return response, response_txt |