Spaces:
Running
Running
| import csv | |
| import os | |
| import time | |
| from datetime import datetime | |
| from queue import Queue | |
| import threading | |
| import pandas as pd | |
| from gradio import ChatMessage | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from timer import Timer | |
| from utils import log_warning, log_info, log_debug, log_error | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| DATASET_REPO_ID = os.environ.get("APRIEL_PROMPT_DATASET") | |
| CSV_FILENAME = "train.csv" | |
| def log_chat(chat_id: str, session_id: str, model_name: str, prompt: str, history: list[str], info: dict) -> None: | |
| log_info(f"log_chat() called for chat: {chat_id}, queue size: {log_chat_queue.qsize()}, model: {model_name}") | |
| log_chat_queue.put((chat_id, session_id, model_name, prompt, history, info)) | |
| def _log_chat_worker(): | |
| while True: | |
| chat_id, session_id, model_name, prompt, history, info = log_chat_queue.get() | |
| try: | |
| try: | |
| _log_chat(chat_id, session_id, model_name, prompt, history, info) | |
| except Exception as e: | |
| log_error(f"Error logging chat: {e}") | |
| finally: | |
| log_chat_queue.task_done() | |
| def _log_chat(chat_id: str, session_id: str, model_name: str, prompt: str, history: list[str], info: dict) -> bool: | |
| log_info(f"_log_chat() storing chat {chat_id}") | |
| if DATASET_REPO_ID is None: | |
| log_warning("No dataset repo ID provided. Skipping logging of prompt.") | |
| return False | |
| if HF_TOKEN is None: | |
| log_warning("No HF token provided. Skipping logging of prompt.") | |
| return False | |
| log_timer = Timer('log_chat') | |
| log_timer.start() | |
| # Initialize HF API | |
| api = HfApi(token=HF_TOKEN) | |
| # Check if the dataset repo exists, if not, create it | |
| try: | |
| repo_info = api.repo_info(repo_id=DATASET_REPO_ID, repo_type="dataset") | |
| log_debug(f"log_chat() --> Dataset repo found: {repo_info.id} private={repo_info.private}") | |
| except Exception: # Create new dataset if none exists | |
| log_debug(f"log_chat() --> No dataset repo found, creating a new one...") | |
| api.create_repo(repo_id=DATASET_REPO_ID, repo_type="dataset", private=True) | |
| # Ensure messages are in the correct format | |
| messages = [ | |
| {"role": item.role, "content": item.content, | |
| "type": "thought" if item.metadata else "completion"} if isinstance( | |
| item, ChatMessage) else item | |
| for item in history | |
| if isinstance(item, dict) and "role" in item and "content" in item or isinstance(item, ChatMessage) | |
| ] | |
| if len(messages) != len(history): | |
| log_warning("log_chat() --> Some messages in history are missing 'role' or 'content' keys.") | |
| user_messages_count = sum(1 for item in messages if isinstance(item, dict) and item.get("role") == "user") | |
| # These must match the keys in the new row | |
| expected_headers = ["timestamp", "chat_id", "turns", "prompt", "messages", "model", "session_id", "info"] | |
| # Prepare new data row | |
| new_row = { | |
| "timestamp": datetime.now().isoformat(), | |
| "chat_id": chat_id, | |
| "turns": user_messages_count, | |
| "prompt": prompt, | |
| "messages": messages, | |
| "model": model_name, | |
| "session_id": session_id, | |
| "info": info, | |
| } | |
| log_timer.add_step("Prepared new data row") | |
| # Try to download existing CSV with retry logic | |
| max_retries = 3 | |
| retry_count = 0 | |
| file_exists = False | |
| while retry_count < max_retries: | |
| try: | |
| csv_path = hf_hub_download( | |
| repo_id=DATASET_REPO_ID, | |
| filename=CSV_FILENAME, | |
| repo_type="dataset", | |
| token=HF_TOKEN # Only needed if not already logged in | |
| ) | |
| pd.read_csv(csv_path) | |
| file_exists = True | |
| log_debug(f"log_chat() --> Downloaded existing CSV with {len(pd.read_csv(csv_path))} rows") | |
| break # Success, exit the loop | |
| except Exception as e: | |
| retry_count += 1 | |
| if retry_count < max_retries: | |
| retry_delay = 2 * retry_count # Exponential backoff: 2s, 4s, 6s, 8s | |
| log_warning( | |
| f"log_chat() --> Download attempt {retry_count} failed: {e}. Retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay) | |
| else: | |
| log_warning(f"log_chat() --> Failed to download CSV after {max_retries} attempts: {e}") | |
| file_exists = False | |
| log_timer.add_step(f"Downloaded existing CSV (attempts: {retry_count + 1})") | |
| # Handle the case where the CSV file does not exist or is invalid | |
| if file_exists and len(pd.read_csv(csv_path)) == 0: | |
| log_warning(f"log_chat() --> CSV {csv_path} exists but is empty, will create a new one.") | |
| dump_hub_csv() | |
| file_exists = False | |
| elif file_exists: | |
| # Check that the headers match our standard headers of "timestamp", "chat_id", "turns", ... | |
| existing_headers = pd.read_csv(csv_path).columns.tolist() | |
| if set(existing_headers) != set(expected_headers): | |
| log_warning(f"log_chat() --> CSV {csv_path} has unexpected headers: {existing_headers}. " | |
| f"\nExpected {existing_headers} " | |
| f"Will create a new one.") | |
| dump_hub_csv() | |
| file_exists = False | |
| else: | |
| log_debug(f"log_chat() --> CSV {csv_path} has expected headers: {existing_headers}") | |
| # Write out the new row to the CSV file (append isn't working in HF container, so recreate each time) | |
| log_debug(f"log_chat() --> Writing CSV file, file_exists={file_exists}") | |
| try: | |
| with open(CSV_FILENAME, "w", newline="\n") as f: | |
| writer = csv.DictWriter(f, fieldnames=new_row.keys()) | |
| writer.writeheader() # Always write the header | |
| if file_exists: | |
| for _, row in pd.read_csv(csv_path).iterrows(): | |
| writer.writerow(row.to_dict()) # Write existing rows | |
| writer.writerow(new_row) # Write the new row | |
| log_debug("log_chat() --> Wrote out CSV with new row") | |
| # dump_local_csv() | |
| except Exception as e: | |
| log_error(f"log_chat() --> Error writing to CSV: {e}") | |
| return False | |
| # Upload updated CSV | |
| api.upload_file( | |
| path_or_fileobj=CSV_FILENAME, | |
| path_in_repo=CSV_FILENAME, | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| commit_message=f"Added new chat entry at {datetime.now().isoformat()}" | |
| ) | |
| log_timer.add_step("Uploaded updated CSV") | |
| log_timer.end() | |
| log_debug("log_chat() --> Finished logging chat") | |
| log_debug(log_timer.formatted_result()) | |
| return True | |
| def dump_hub_csv(): | |
| # Verify the file contents by loading it from the hub and printing it out | |
| try: | |
| csv_path = hf_hub_download( | |
| repo_id=DATASET_REPO_ID, | |
| filename=CSV_FILENAME, | |
| repo_type="dataset", | |
| token=HF_TOKEN # Only needed if not already logged in | |
| ) | |
| df = pd.read_csv(csv_path) | |
| log_info(df) | |
| if (df.empty): | |
| # show raw contents of downloaded csv file | |
| log_info("Raw file contents:") | |
| with open(csv_path, 'r') as f: | |
| print(f.read()) | |
| except Exception as e: | |
| log_error(f"Error loading CSV from hub: {e}") | |
| def dump_local_csv(): | |
| # Verify the file contents by loading it from the local file and printing it out | |
| try: | |
| df = pd.read_csv(CSV_FILENAME) | |
| log_info(df) | |
| except Exception as e: | |
| log_error(f"Error loading CSV from local file: {e}") | |
| def test_log_chat(): | |
| # Example usage | |
| chat_id = "12345" | |
| session_id = "67890" | |
| model_name = "Apriel-Model" | |
| prompt = "Hello" | |
| history = [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Hi there!"}] | |
| prompt = "100 + 1" | |
| history = [{'role': 'user', 'content': prompt}, ChatMessage( | |
| content='Okay, that\'s a simple addition problem. , answer is 2.\n', role='assistant', | |
| metadata={'title': '🧠 Thought'}, options=[]), | |
| ChatMessage(content='\nThe result of adding 1 and 1 is:\n\n**2**\n', role='assistant', metadata={}, | |
| options=[]) | |
| ] | |
| info = {"additional_info": "Some extra data"} | |
| log_debug("Starting test_log_chat()") | |
| dump_hub_csv() | |
| log_chat(chat_id, session_id, model_name, prompt, history, info) | |
| log_debug("log_chat 1 returned") | |
| log_chat(chat_id, session_id, model_name, prompt + " + 2", history, info) | |
| log_debug("log_chat 2 returned") | |
| log_chat(chat_id, session_id, model_name, prompt + " + 3", history, info) | |
| log_debug("log_chat 3 returned") | |
| log_chat(chat_id, session_id, model_name, prompt + " + 4", history, info) | |
| log_debug("log_chat 4 returned") | |
| sleep_seconds = 10 | |
| log_debug(f"Sleeping {sleep_seconds} seconds to let it finish and log the result.") | |
| time.sleep(sleep_seconds) | |
| log_debug("Finished sleeping.") | |
| dump_hub_csv() | |
| # Create a queue for logging chat messages | |
| log_chat_queue = Queue() | |
| # Start the worker thread | |
| threading.Thread(target=_log_chat_worker, daemon=True).start() | |
| if __name__ == "__main__": | |
| test_log_chat() | |