Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
|
@@ -25,95 +25,139 @@ step_threshold = 5
|
|
| 25 |
questions_for_api: List[Dict[str, Any]] = []
|
| 26 |
ground_truth_answers: Dict[str, str] = {}
|
| 27 |
filtered_dataset = None
|
|
|
|
|
|
|
|
|
|
| 28 |
# --- Define ErrorResponse if not already defined ---
|
| 29 |
class ErrorResponse(BaseModel):
|
| 30 |
detail: str
|
| 31 |
|
|
|
|
| 32 |
def load_questions():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
global filtered_dataset
|
| 34 |
global questions_for_api
|
| 35 |
global ground_truth_answers
|
| 36 |
-
global task_file_paths
|
|
|
|
| 37 |
tempo_filtered = []
|
| 38 |
-
# Clear existing data
|
| 39 |
questions_for_api.clear()
|
| 40 |
ground_truth_answers.clear()
|
| 41 |
-
task_file_paths.clear() # Clear the mapping
|
| 42 |
|
| 43 |
logger.info("Starting to load and filter GAIA dataset (validation split)...")
|
| 44 |
try:
|
|
|
|
| 45 |
dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True)
|
| 46 |
logger.info(f"GAIA dataset validation split loaded. Features: {dataset.features}")
|
| 47 |
except Exception as e:
|
| 48 |
logger.error(f"Failed to load GAIA dataset: {e}", exc_info=True)
|
|
|
|
| 49 |
raise RuntimeError("Could not load the primary GAIA dataset.") from e
|
| 50 |
|
| 51 |
-
# --- Filtering Logic
|
| 52 |
-
# [ ... Same filtering code as before ... ]
|
| 53 |
for item in dataset:
|
| 54 |
metadata = item.get('Annotator Metadata')
|
| 55 |
-
|
|
|
|
| 56 |
num_tools_str = metadata.get('Number of tools')
|
| 57 |
num_steps_str = metadata.get('Number of steps')
|
|
|
|
| 58 |
if num_tools_str is not None and num_steps_str is not None:
|
| 59 |
try:
|
| 60 |
num_tools = int(num_tools_str)
|
| 61 |
num_steps = int(num_steps_str)
|
|
|
|
| 62 |
if num_tools < tool_threshold and num_steps < step_threshold:
|
| 63 |
-
tempo_filtered.append(item)
|
| 64 |
except ValueError:
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
filtered_dataset = tempo_filtered
|
| 71 |
-
logger.info(f"Found {len(filtered_dataset)} questions matching the criteria.")
|
| 72 |
|
| 73 |
processed_count = 0
|
| 74 |
-
# ---
|
| 75 |
for item in filtered_dataset:
|
|
|
|
| 76 |
task_id = item.get('task_id')
|
| 77 |
original_question_text = item.get('Question')
|
| 78 |
final_answer = item.get('Final answer')
|
| 79 |
-
local_file_path = item.get('file_path') #
|
| 80 |
-
file_name = item.get('file_name')
|
| 81 |
|
| 82 |
-
# Validate essential fields
|
|
|
|
| 83 |
if task_id and original_question_text and final_answer is not None:
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
processed_item = {
|
| 86 |
"task_id": str(task_id),
|
| 87 |
-
"question": str(original_question_text),
|
|
|
|
| 88 |
"Level": item.get("Level"),
|
| 89 |
-
"file_name": file_name, # Include filename for info
|
| 90 |
}
|
| 91 |
-
#
|
| 92 |
processed_item = {k: v for k, v in processed_item.items() if v is not None}
|
| 93 |
|
| 94 |
questions_for_api.append(processed_item)
|
| 95 |
|
| 96 |
-
# Store ground truth
|
| 97 |
ground_truth_answers[str(task_id)] = str(final_answer)
|
| 98 |
|
| 99 |
-
#
|
| 100 |
-
if local_file_path and file_name:
|
| 101 |
-
#
|
| 102 |
-
if os.path.
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
else:
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
processed_count += 1
|
| 110 |
else:
|
| 111 |
-
|
|
|
|
| 112 |
|
|
|
|
| 113 |
logger.info(f"Successfully processed {processed_count} questions for the API.")
|
| 114 |
logger.info(f"Stored file path mappings for {len(task_file_paths)} tasks.")
|
|
|
|
| 115 |
if not questions_for_api:
|
| 116 |
-
logger.error("CRITICAL: No valid questions loaded after filtering
|
|
|
|
|
|
|
| 117 |
|
| 118 |
|
| 119 |
|
|
|
|
| 25 |
questions_for_api: List[Dict[str, Any]] = []
|
| 26 |
ground_truth_answers: Dict[str, str] = {}
|
| 27 |
filtered_dataset = None
|
| 28 |
+
|
| 29 |
+
ALLOWED_CACHE_BASE = os.path.abspath("/app/.cache")
|
| 30 |
+
|
| 31 |
# --- Define ErrorResponse if not already defined ---
|
| 32 |
class ErrorResponse(BaseModel):
|
| 33 |
detail: str
|
| 34 |
|
| 35 |
+
|
| 36 |
def load_questions():
|
| 37 |
+
"""
|
| 38 |
+
Loads the GAIA dataset, filters questions based on tool/step counts,
|
| 39 |
+
populates 'questions_for_api' with data for the API (excluding sensitive/internal fields),
|
| 40 |
+
stores ground truth answers, and maps task IDs to their local file paths on the server.
|
| 41 |
+
"""
|
| 42 |
global filtered_dataset
|
| 43 |
global questions_for_api
|
| 44 |
global ground_truth_answers
|
| 45 |
+
global task_file_paths # Declare modification of global
|
| 46 |
+
|
| 47 |
tempo_filtered = []
|
| 48 |
+
# Clear existing data from previous runs or restarts
|
| 49 |
questions_for_api.clear()
|
| 50 |
ground_truth_answers.clear()
|
| 51 |
+
task_file_paths.clear() # Clear the file path mapping
|
| 52 |
|
| 53 |
logger.info("Starting to load and filter GAIA dataset (validation split)...")
|
| 54 |
try:
|
| 55 |
+
# Load the specified split
|
| 56 |
dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True)
|
| 57 |
logger.info(f"GAIA dataset validation split loaded. Features: {dataset.features}")
|
| 58 |
except Exception as e:
|
| 59 |
logger.error(f"Failed to load GAIA dataset: {e}", exc_info=True)
|
| 60 |
+
# Depending on requirements, you might want to exit or raise a more specific error
|
| 61 |
raise RuntimeError("Could not load the primary GAIA dataset.") from e
|
| 62 |
|
| 63 |
+
# --- Filtering Logic based on Annotator Metadata ---
|
|
|
|
| 64 |
for item in dataset:
|
| 65 |
metadata = item.get('Annotator Metadata')
|
| 66 |
+
|
| 67 |
+
if metadata:
|
| 68 |
num_tools_str = metadata.get('Number of tools')
|
| 69 |
num_steps_str = metadata.get('Number of steps')
|
| 70 |
+
|
| 71 |
if num_tools_str is not None and num_steps_str is not None:
|
| 72 |
try:
|
| 73 |
num_tools = int(num_tools_str)
|
| 74 |
num_steps = int(num_steps_str)
|
| 75 |
+
# Apply filter conditions
|
| 76 |
if num_tools < tool_threshold and num_steps < step_threshold:
|
| 77 |
+
tempo_filtered.append(item) # Add the original item if it matches filter
|
| 78 |
except ValueError:
|
| 79 |
+
logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Could not convert tool/step count in metadata: tools='{num_tools_str}', steps='{num_steps_str}'.")
|
| 80 |
+
else:
|
| 81 |
+
logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - 'Number of tools' or 'Number of steps' missing in Metadata.")
|
| 82 |
+
else:
|
| 83 |
+
# If metadata is essential for filtering, you might want to skip items without it
|
| 84 |
+
logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Missing 'Annotator Metadata'.")
|
| 85 |
|
| 86 |
+
filtered_dataset = tempo_filtered # Store the list of filtered original dataset items
|
| 87 |
+
logger.info(f"Found {len(filtered_dataset)} questions matching the criteria (tools < {tool_threshold}, steps < {step_threshold}).")
|
| 88 |
|
| 89 |
processed_count = 0
|
| 90 |
+
# --- Process filtered items for API and File Mapping ---
|
| 91 |
for item in filtered_dataset:
|
| 92 |
+
# Extract data from the dataset item
|
| 93 |
task_id = item.get('task_id')
|
| 94 |
original_question_text = item.get('Question')
|
| 95 |
final_answer = item.get('Final answer')
|
| 96 |
+
local_file_path = item.get('file_path') # Server-local path from dataset
|
| 97 |
+
file_name = item.get('file_name') # Filename from dataset
|
| 98 |
|
| 99 |
+
# Validate essential fields needed for processing & ground truth
|
| 100 |
+
# Note: We proceed even if file path/name are missing, just won't map the file.
|
| 101 |
if task_id and original_question_text and final_answer is not None:
|
| 102 |
+
|
| 103 |
+
# 1. Create the dictionary to be exposed via the API
|
| 104 |
+
# (Includes 'file_name' for info, but excludes 'file_path')
|
| 105 |
processed_item = {
|
| 106 |
"task_id": str(task_id),
|
| 107 |
+
"question": str(original_question_text), # Rename 'Question' -> 'question'
|
| 108 |
+
# Include other desired fields, using .get() for safety
|
| 109 |
"Level": item.get("Level"),
|
| 110 |
+
"file_name": file_name, # Include filename for client info
|
| 111 |
}
|
| 112 |
+
# Optional: Remove keys with None values if you prefer cleaner JSON
|
| 113 |
processed_item = {k: v for k, v in processed_item.items() if v is not None}
|
| 114 |
|
| 115 |
questions_for_api.append(processed_item)
|
| 116 |
|
| 117 |
+
# 2. Store the ground truth answer separately
|
| 118 |
ground_truth_answers[str(task_id)] = str(final_answer)
|
| 119 |
|
| 120 |
+
# 3. Store the file path mapping if file details exist and are valid
|
| 121 |
+
if local_file_path and file_name:
|
| 122 |
+
# Log if the path from the dataset isn't absolute (might indicate issues)
|
| 123 |
+
if not os.path.isabs(local_file_path):
|
| 124 |
+
logger.warning(f"Task {task_id}: Path '{local_file_path}' from dataset is not absolute. This might cause issues finding the file on the server.")
|
| 125 |
+
# Depending on dataset guarantees, you might try making it absolute:
|
| 126 |
+
# Assuming WORKDIR is /app as per Dockerfile if paths are relative
|
| 127 |
+
# local_file_path = os.path.abspath(os.path.join("/app", local_file_path))
|
| 128 |
+
|
| 129 |
+
# Check if the file actually exists at the path ON THE SERVER
|
| 130 |
+
if os.path.exists(local_file_path) and os.path.isfile(local_file_path):
|
| 131 |
+
# Path exists, store the mapping
|
| 132 |
+
task_file_paths[str(task_id)] = local_file_path
|
| 133 |
+
logger.debug(f"Stored file path mapping for task_id {task_id}: {local_file_path}")
|
| 134 |
else:
|
| 135 |
+
# Path does *not* exist or is not a file on server filesystem
|
| 136 |
+
logger.warning(f"File path '{local_file_path}' for task_id {task_id} does NOT exist or is not a file on server. Mapping skipped.")
|
| 137 |
+
# Log if file info was missing in the first place
|
| 138 |
+
elif task_id: # Log only if we have a task_id to reference
|
| 139 |
+
# Check which specific part was missing for better debugging
|
| 140 |
+
if not local_file_path and not file_name:
|
| 141 |
+
logger.debug(f"Task {task_id}: No 'file_path' or 'file_name' found in dataset item. No file mapping stored.")
|
| 142 |
+
elif not local_file_path:
|
| 143 |
+
logger.debug(f"Task {task_id}: 'file_path' is missing in dataset item (file_name: '{file_name}'). No file mapping stored.")
|
| 144 |
+
else: # Not file_name
|
| 145 |
+
logger.debug(f"Task {task_id}: 'file_name' is missing in dataset item (file_path: '{local_file_path}'). No file mapping stored.")
|
| 146 |
|
| 147 |
|
| 148 |
processed_count += 1
|
| 149 |
else:
|
| 150 |
+
# Log skipping due to missing core fields (task_id, Question, Final answer)
|
| 151 |
+
logger.warning(f"Skipping item processing due to missing essential fields: task_id={task_id}, has_question={original_question_text is not None}, has_answer={final_answer is not None}")
|
| 152 |
|
| 153 |
+
# Final summary logging
|
| 154 |
logger.info(f"Successfully processed {processed_count} questions for the API.")
|
| 155 |
logger.info(f"Stored file path mappings for {len(task_file_paths)} tasks.")
|
| 156 |
+
|
| 157 |
if not questions_for_api:
|
| 158 |
+
logger.error("CRITICAL: No valid questions were loaded after filtering and processing. API endpoints like /questions will fail.")
|
| 159 |
+
# Consider raising an error if the application cannot function without questions
|
| 160 |
+
# raise RuntimeError("Failed to load mandatory question data after filtering.")
|
| 161 |
|
| 162 |
|
| 163 |
|