Apriel-Chat / log_chat.py
bradnow's picture
Remove invalid row count logging
ed565d2
import csv
import os
import time
from datetime import datetime
from queue import Queue
import threading
import pandas as pd
from gradio import ChatMessage
from huggingface_hub import HfApi, hf_hub_download
from timer import Timer
from utils import log_warning, log_info, log_debug, log_error
HF_TOKEN = os.environ.get("HF_TOKEN")
DATASET_REPO_ID = os.environ.get("APRIEL_PROMPT_DATASET")
CSV_FILENAME = "train.csv"
def log_chat(chat_id: str, session_id: str, model_name: str, prompt: str, history: list[str], info: dict) -> None:
log_info(f"log_chat() called for chat: {chat_id}, queue size: {log_chat_queue.qsize()}, model: {model_name}")
log_chat_queue.put((chat_id, session_id, model_name, prompt, history, info))
def _log_chat_worker():
while True:
chat_id, session_id, model_name, prompt, history, info = log_chat_queue.get()
try:
try:
_log_chat(chat_id, session_id, model_name, prompt, history, info)
except Exception as e:
log_error(f"Error logging chat: {e}")
finally:
log_chat_queue.task_done()
def _log_chat(chat_id: str, session_id: str, model_name: str, prompt: str, history: list[str], info: dict) -> bool:
log_info(f"_log_chat() storing chat {chat_id}")
if DATASET_REPO_ID is None:
log_warning("No dataset repo ID provided. Skipping logging of prompt.")
return False
if HF_TOKEN is None:
log_warning("No HF token provided. Skipping logging of prompt.")
return False
log_timer = Timer('log_chat')
log_timer.start()
# Initialize HF API
api = HfApi(token=HF_TOKEN)
# Check if the dataset repo exists, if not, create it
try:
repo_info = api.repo_info(repo_id=DATASET_REPO_ID, repo_type="dataset")
log_debug(f"log_chat() --> Dataset repo found: {repo_info.id} private={repo_info.private}")
except Exception: # Create new dataset if none exists
log_debug(f"log_chat() --> No dataset repo found, creating a new one...")
api.create_repo(repo_id=DATASET_REPO_ID, repo_type="dataset", private=True)
# Ensure messages are in the correct format
messages = [
{"role": item.role, "content": item.content,
"type": "thought" if item.metadata else "completion"} if isinstance(
item, ChatMessage) else item
for item in history
if isinstance(item, dict) and "role" in item and "content" in item or isinstance(item, ChatMessage)
]
if len(messages) != len(history):
log_warning("log_chat() --> Some messages in history are missing 'role' or 'content' keys.")
user_messages_count = sum(1 for item in messages if isinstance(item, dict) and item.get("role") == "user"
and isinstance(item.get("content"), str))
# These must match the keys in the new row
expected_headers = ["timestamp", "chat_id", "turns", "prompt", "messages", "model", "session_id", "info"]
# Prepare new data row
new_row = {
"timestamp": datetime.now().isoformat(),
"chat_id": chat_id,
"turns": user_messages_count,
"prompt": prompt,
"messages": messages,
"model": model_name,
"session_id": session_id,
"info": info,
}
log_timer.add_step("Prepared new data row")
# Try to download existing CSV with retry logic
max_retries = 3
retry_count = 0
file_exists = False
csv_path = None
row_count = 0
while retry_count < max_retries:
try:
csv_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
filename=CSV_FILENAME,
repo_type="dataset",
token=HF_TOKEN # Only needed if not already logged in
)
# Only read first row to check if file is valid and get row count efficiently
df_check = pd.read_csv(csv_path, nrows=1)
file_exists = True
break # Success, exit the loop
except Exception as e:
retry_count += 1
if retry_count < max_retries:
retry_delay = 2 * retry_count # Exponential backoff: 2s, 4s, 6s, 8s
log_warning(
f"log_chat() --> Download attempt {retry_count} failed: {e}. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
log_warning(f"log_chat() --> Failed to download CSV after {max_retries} attempts: {e}")
file_exists = False
log_timer.add_step(f"Downloaded existing CSV (attempts: {retry_count + 1})")
# Handle the case where the CSV file does not exist or is invalid
if file_exists:
# Check that the headers match our standard headers (only read first row)
existing_headers = pd.read_csv(csv_path, nrows=0).columns.tolist()
if set(existing_headers) != set(expected_headers):
log_warning(f"log_chat() --> CSV {csv_path} has unexpected headers: {existing_headers}. "
f"\nExpected {expected_headers} "
f"Will create a new one.")
dump_hub_csv()
file_exists = False
else:
log_debug(f"log_chat() --> CSV {csv_path} has expected headers: {existing_headers}")
# Write out the new row to the CSV file (append isn't working in HF container, so recreate each time)
log_debug(f"log_chat() --> Writing CSV file, file_exists={file_exists}")
try:
if file_exists:
# Append mode: copy existing file and append new row
# Use chunked reading to avoid loading entire file into memory
with open(CSV_FILENAME, "w", newline="\n") as f_out:
writer = csv.DictWriter(f_out, fieldnames=expected_headers)
writer.writeheader()
# Stream copy existing rows in chunks to minimize memory usage
chunk_size = 1000
for chunk in pd.read_csv(csv_path, chunksize=chunk_size):
for _, row in chunk.iterrows():
writer.writerow(row.to_dict())
# Append new row
writer.writerow(new_row)
else:
# Create new file with just the new row
with open(CSV_FILENAME, "w", newline="\n") as f:
writer = csv.DictWriter(f, fieldnames=expected_headers)
writer.writeheader()
writer.writerow(new_row)
log_debug(f"log_chat() --> Wrote out CSV with new row")
# dump_local_csv()
except Exception as e:
log_error(f"log_chat() --> Error writing to CSV: {e}")
return False
# Upload updated CSV
api.upload_file(
path_or_fileobj=CSV_FILENAME,
path_in_repo=CSV_FILENAME,
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message=f"Added new chat entry at {datetime.now().isoformat()}"
)
log_timer.add_step("Uploaded updated CSV")
log_timer.end()
log_debug("log_chat() --> Finished logging chat")
log_debug(log_timer.formatted_result())
return True
def dump_hub_csv():
# Verify the file contents by loading it from the hub and printing it out
try:
csv_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
filename=CSV_FILENAME,
repo_type="dataset",
token=HF_TOKEN # Only needed if not already logged in
)
df = pd.read_csv(csv_path)
log_info(df)
if (df.empty):
# show raw contents of downloaded csv file
log_info("Raw file contents:")
with open(csv_path, 'r') as f:
print(f.read())
except Exception as e:
log_error(f"Error loading CSV from hub: {e}")
def dump_local_csv():
# Verify the file contents by loading it from the local file and printing it out
try:
df = pd.read_csv(CSV_FILENAME)
log_info(df)
except Exception as e:
log_error(f"Error loading CSV from local file: {e}")
def test_log_chat():
# Example usage
chat_id = "12345"
session_id = "67890"
model_name = "Apriel-Model"
prompt = "Hello"
history = [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Hi there!"}]
prompt = "100 + 1"
history = [{'role': 'user', 'content': prompt}, ChatMessage(
content='Okay, that\'s a simple addition problem. , answer is 2.\n', role='assistant',
metadata={'title': '🧠 Thought'}, options=[]),
ChatMessage(content='\nThe result of adding 1 and 1 is:\n\n**2**\n', role='assistant', metadata={},
options=[])
]
info = {"additional_info": "Some extra data"}
log_debug("Starting test_log_chat()")
dump_hub_csv()
log_chat(chat_id, session_id, model_name, prompt, history, info)
log_debug("log_chat 1 returned")
log_chat(chat_id, session_id, model_name, prompt + " + 2", history, info)
log_debug("log_chat 2 returned")
log_chat(chat_id, session_id, model_name, prompt + " + 3", history, info)
log_debug("log_chat 3 returned")
log_chat(chat_id, session_id, model_name, prompt + " + 4", history, info)
log_debug("log_chat 4 returned")
sleep_seconds = 10
log_debug(f"Sleeping {sleep_seconds} seconds to let it finish and log the result.")
time.sleep(sleep_seconds)
log_debug("Finished sleeping.")
dump_hub_csv()
# Create a queue for logging chat messages
log_chat_queue = Queue()
# Start the worker thread
threading.Thread(target=_log_chat_worker, daemon=True).start()
if __name__ == "__main__":
test_log_chat()