Spaces:
Runtime error
Runtime error
| """ | |
| This script defines a PromptTemplate class that assists in generating | |
| conversation/prompt templates. The script facilitates formatting prompts | |
| for inference and training by combining various context elements and user inputs. | |
| """ | |
| import dataclasses | |
| from typing import Dict, List, Union | |
| class PromptTemplate: | |
| """A class that manages prompt templates""" | |
| # The name of this template | |
| name: str | |
| # The template of the system prompt | |
| system_template: str = "{system_message}" | |
| # The template for the system context | |
| context_template: str = "{user_context}\n{news_context}" | |
| # The template for the conversation history | |
| chat_history_template: str = "{chat_history}" | |
| # The template of the user question | |
| question_template: str = "{question}" | |
| # The template of the system answer | |
| answer_template: str = "{answer}" | |
| # The system message | |
| system_message: str = "" | |
| # Separator | |
| sep: str = "\n" | |
| eos: str = "</s>" | |
| def input_variables(self) -> List[str]: | |
| """Returns a list of input variables for the prompt template""" | |
| return ["user_context", "news_context", "chat_history", "question", "answer"] | |
| def train_raw_template(self): | |
| """Returns the training prompt template format""" | |
| system = self.system_template.format(system_message=self.system_message) | |
| context = f"{self.sep}{self.context_template}" | |
| chat_history = f"{self.sep}{self.chat_history_template}" | |
| question = f"{self.sep}{self.question_template}" | |
| answer = f"{self.sep}{self.answer_template}" | |
| return f"{system}{context}{chat_history}{question}{answer}{self.eos}" | |
| def infer_raw_template(self): | |
| """Returns the inference prompt template format""" | |
| system = self.system_template.format(system_message=self.system_message) | |
| context = f"{self.sep}{self.context_template}" | |
| chat_history = f"{self.sep}{self.chat_history_template}" | |
| question = f"{self.sep}{self.question_template}" | |
| return f"{system}{context}{chat_history}{question}{self.eos}" | |
| def format_train(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]: | |
| """Formats the data sample to a training sample""" | |
| prompt = self.train_raw_template.format( | |
| user_context=sample["user_context"], | |
| news_context=sample["news_context"], | |
| chat_history=sample.get("chat_history", ""), | |
| question=sample["question"], | |
| answer=sample["answer"], | |
| ) | |
| return {"prompt": prompt, "payload": sample} | |
| def format_infer(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]: | |
| """Formats the data sample to a testing sample""" | |
| prompt = self.infer_raw_template.format( | |
| user_context=sample["user_context"], | |
| news_context=sample["news_context"], | |
| chat_history=sample.get("chat_history", ""), | |
| question=sample["question"], | |
| ) | |
| return {"prompt": prompt, "payload": sample} | |
| # Global Templates registry | |
| templates: Dict[str, PromptTemplate] = {} | |
| def register_llm_template(template: PromptTemplate): | |
| """Register a new template to the global templates registry""" | |
| templates[template.name] = template | |
| def get_llm_template(name: str) -> PromptTemplate: | |
| """Returns the template assigned to the given name""" | |
| return templates[name] | |
| ##### Register Templates ##### | |
| # - Mistral 7B Instruct v0.2 Template | |
| register_llm_template( | |
| PromptTemplate( | |
| name="mistral", | |
| system_template="<s>{system_message}", | |
| system_message="You are a helpful assistant, with financial expertise, and you do not answer questions which contain illegal or harmful information.", | |
| context_template="{user_context}\n{news_context}", | |
| chat_history_template="Summary: {chat_history}", | |
| question_template="[INST] {question} [/INST]", | |
| answer_template="{answer}", | |
| sep="\n", | |
| eos=" </s>", | |
| ) | |
| ) | |
| # - FALCON (spec: https://huggingface.co/tiiuae/falcon-7b/blob/main/tokenizer.json) | |
| register_llm_template( | |
| PromptTemplate( | |
| name="falcon", | |
| system_template=">>INTRODUCTION<< {system_message}", | |
| system_message="You are a helpful assistant, with financial expertise.", | |
| context_template=">>DOMAIN<< {user_context}\n{news_context}", | |
| chat_history_template=">>SUMMARY<< {chat_history}", | |
| question_template=">>QUESTION<< {question}", | |
| answer_template=">>ANSWER<< {answer}", | |
| sep="\n", | |
| eos="<|endoftext|>", | |
| ) | |
| ) | |