Spaces:
Running
Running
| ''' | |
| Chat State and Logging | |
| ''' | |
| import json | |
| import os | |
| from typing import Any, Literal, Optional | |
| from conversation import Conversation | |
| import datetime | |
| import uuid | |
| LOG_DIR = os.getenv("LOGDIR", "./logs") | |
| ''' | |
| The default output dir of log files | |
| ''' | |
| class ModelChatState: | |
| ''' | |
| The state of a chat with a model. | |
| ''' | |
| is_vision: bool | |
| ''' | |
| Whether the model is vision based. | |
| ''' | |
| conv: Conversation | |
| ''' | |
| The conversation | |
| ''' | |
| conv_id: str | |
| ''' | |
| Unique identifier for the model conversation. | |
| Unique per chat per model. | |
| ''' | |
| chat_session_id: str | |
| ''' | |
| Unique identifier for the chat session. | |
| Unique per chat. The two battle models share the same chat session id. | |
| ''' | |
| skip_next: bool | |
| ''' | |
| Flag to indicate skipping the next operation. | |
| ''' | |
| model_name: str | |
| ''' | |
| Name of the model being used. | |
| ''' | |
| oai_thread_id: Optional[str] | |
| ''' | |
| Identifier for the OpenAI thread. | |
| ''' | |
| has_csam_image: bool | |
| ''' | |
| Indicates if a CSAM image has been uploaded. | |
| ''' | |
| regen_support: bool | |
| ''' | |
| Indicates if regeneration is supported for the model. | |
| ''' | |
| chat_start_time: datetime.datetime | |
| ''' | |
| Chat start time. | |
| ''' | |
| chat_mode: Literal['battle_anony', 'battle_named', 'direct'] | |
| ''' | |
| Chat mode. | |
| ''' | |
| curr_response_type: Literal['chat_multi', 'chat_single', 'regenerate_multi', 'regenerate_single'] | None | |
| ''' | |
| Current response type. Used for logging. | |
| ''' | |
| def create_chat_session_id() -> str: | |
| ''' | |
| Create a new chat session id. | |
| ''' | |
| return uuid.uuid4().hex | |
| def create_battle_chat_states( | |
| model_name_1: str, model_name_2: str, | |
| chat_mode: Literal['battle_anony', 'battle_named'], | |
| is_vision: bool, | |
| ) -> tuple['ModelChatState', 'ModelChatState']: | |
| ''' | |
| Create two chat states for a battle. | |
| ''' | |
| chat_session_id = ModelChatState.create_chat_session_id() | |
| return ( | |
| ModelChatState(model_name_1, chat_mode, | |
| is_vision=is_vision, | |
| chat_session_id=chat_session_id), | |
| ModelChatState(model_name_2, chat_mode, | |
| is_vision=is_vision, | |
| chat_session_id=chat_session_id), | |
| ) | |
| def __init__(self, | |
| model_name: str, | |
| chat_mode: Literal['battle_anony', 'battle_named', 'direct'], | |
| is_vision: bool, | |
| chat_session_id: str | None = None, | |
| ): | |
| from fastchat.model.model_adapter import get_conversation_template | |
| self.conv = get_conversation_template(model_name) | |
| self.conv_id = uuid.uuid4().hex | |
| # if no chat session id is provided, use the conversation id | |
| self.chat_session_id = chat_session_id if chat_session_id else self.conv_id | |
| self.chat_start_time = datetime.datetime.now() | |
| self.chat_mode = chat_mode | |
| self.skip_next = False | |
| self.model_name = model_name | |
| self.oai_thread_id = None | |
| self.is_vision = is_vision | |
| # NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes. | |
| self.has_csam_image = False | |
| self.regen_support = True | |
| if "browsing" in model_name: | |
| self.regen_support = False | |
| self.init_system_prompt(self.conv, is_vision) | |
| def init_system_prompt(self, conv, is_vision): | |
| system_prompt = conv.get_system_message(is_vision) | |
| if len(system_prompt) == 0: | |
| return | |
| current_date = datetime.datetime.now().strftime("%Y-%m-%d") | |
| system_prompt = system_prompt.replace("{{currentDateTime}}", current_date) | |
| current_date_v2 = datetime.datetime.now().strftime("%d %b %Y") | |
| system_prompt = system_prompt.replace("{{currentDateTimev2}}", current_date_v2) | |
| current_date_v3 = datetime.datetime.now().strftime("%B %Y") | |
| system_prompt = system_prompt.replace("{{currentDateTimev3}}", current_date_v3) | |
| conv.set_system_message(system_prompt) | |
| def set_response_type( | |
| self, | |
| response_type: Literal['chat_multi', 'chat_single', 'regenerate_multi', 'regenerate_single'] | |
| ): | |
| ''' | |
| Set the response type for the chat state. | |
| ''' | |
| self.curr_response_type = response_type | |
| def to_gradio_chatbot(self): | |
| ''' | |
| Convert to a Gradio chatbot. | |
| ''' | |
| return self.conv.to_gradio_chatbot() | |
| def get_conv_log_filepath(self, path_prefix: str): | |
| ''' | |
| Get the filepath for the conversation log. | |
| Expected directory structure: | |
| softwarearenlog/ | |
| βββ YEAR_MONTH_DAY/ | |
| βββ conv_logs/ | |
| βββ sandbox_logs/ | |
| ''' | |
| date_str = self.chat_start_time.strftime('%Y_%m_%d') | |
| filepath = os.path.join( | |
| path_prefix, | |
| date_str, | |
| 'conv_logs', | |
| self.chat_mode, | |
| f"conv-log-{self.chat_session_id}.json" | |
| ) | |
| return filepath | |
| def to_dict(self): | |
| base = self.conv.to_dict() | |
| base.update( | |
| { | |
| "chat_session_id": self.chat_session_id, | |
| "conv_id": self.conv_id, | |
| "chat_mode": self.chat_mode, | |
| "chat_start_time": self.chat_start_time, | |
| "model_name": self.model_name, | |
| } | |
| ) | |
| if self.is_vision: | |
| base.update({"has_csam_image": self.has_csam_image}) | |
| return base | |
| def generate_vote_record( | |
| self, | |
| vote_type: str, | |
| ip: str | |
| ) -> dict[str, Any]: | |
| ''' | |
| Generate a vote record for telemertry. | |
| ''' | |
| data = { | |
| "tstamp": round(datetime.datetime.now().timestamp(), 4), | |
| "type": vote_type, | |
| "model": self.model_name, | |
| "state": self.to_dict(), | |
| "ip": ip, | |
| } | |
| return data | |
| def generate_response_record( | |
| self, | |
| gen_params: dict[str, Any], | |
| start_ts: float, | |
| end_ts: float, | |
| ip: str | |
| ) -> dict[str, Any]: | |
| ''' | |
| Generate a vote record for telemertry. | |
| ''' | |
| data = { | |
| "tstamp": round(datetime.datetime.now().timestamp(), 4), | |
| "type": self.curr_response_type, | |
| "model": self.model_name, | |
| "start_ts": round(start_ts, 4), | |
| "end_ts": round(end_ts, 4), | |
| "gen_params": gen_params, | |
| "state": self.to_dict(), | |
| "ip": ip, | |
| } | |
| return data | |
| def save_log_to_local( | |
| log_data: dict[str, Any], | |
| log_path: str, | |
| write_mode: Literal['overwrite', 'append'] = 'append' | |
| ): | |
| ''' | |
| Save the log locally. | |
| ''' | |
| log_json = json.dumps(log_data, default=str) | |
| os.makedirs(os.path.dirname(log_path), exist_ok=True) | |
| with open(log_path, "w" if write_mode == 'overwrite' else 'a') as fout: | |
| fout.write(log_json + "\n") | |