Spaces:
Build error
Build error
| import asyncio | |
| import json | |
| import os | |
| import pathlib | |
| import re | |
| import sqlite3 | |
| import subprocess | |
| import zipfile | |
| from typing import Any | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from func_timeout import FunctionTimedOut, func_timeout | |
| from tqdm import tqdm | |
| from evaluation.utils.shared import ( | |
| EvalMetadata, | |
| EvalOutput, | |
| compatibility_for_eval_history_pairs, | |
| get_default_sandbox_config_for_eval, | |
| make_metadata, | |
| prepare_dataset, | |
| reset_logger_for_multiprocessing, | |
| run_evaluation, | |
| ) | |
| from openhands.controller.state.state import State | |
| from openhands.core.config import ( | |
| OpenHandsConfig, | |
| get_llm_config_arg, | |
| parse_arguments, | |
| ) | |
| from openhands.core.logger import openhands_logger as logger | |
| from openhands.core.main import create_runtime, run_controller | |
| from openhands.events.action import CmdRunAction, MessageAction | |
| from openhands.events.observation import CmdOutputObservation | |
| from openhands.runtime.base import Runtime | |
| from openhands.utils.async_utils import call_async_from_sync | |
| def codeact_user_response(state: State) -> str: | |
| msg = ( | |
| 'Please continue working on the task on whatever approach you think is suitable.\n' | |
| 'If you think you have completed the SQL, please finish the interaction using the "finish" tool.\n' | |
| 'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP OR USE THE INTERNET TO SOLVE THIS TASK.\n' | |
| ) | |
| if state.history: | |
| # check if the agent has tried to talk to the user 3 times, if so, let the agent know it can give up | |
| user_msgs = [ | |
| event | |
| for event in state.history | |
| if isinstance(event, MessageAction) and event.source == 'user' | |
| ] | |
| if len(user_msgs) > 2: | |
| # let the agent know that it can give up when it has tried 3 times | |
| return ( | |
| msg | |
| + 'If you want to give up, use the "finish" tool to finish the interaction.\n' | |
| ) | |
| return msg | |
| AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = { | |
| 'CodeActAgent': codeact_user_response, | |
| } | |
| AGENT_CLS_TO_INST_SUFFIX = { | |
| 'CodeActAgent': 'When you think you have fixed the issue through code changes, please finish the interaction using the "finish" tool.\n' | |
| } | |
| def get_config( | |
| metadata: EvalMetadata, | |
| ) -> OpenHandsConfig: | |
| sandbox_config = get_default_sandbox_config_for_eval() | |
| sandbox_config.base_container_image = 'python:3.12-bookworm' | |
| config = OpenHandsConfig( | |
| default_agent=metadata.agent_class, | |
| run_as_openhands=False, | |
| runtime='docker', | |
| max_iterations=metadata.max_iterations, | |
| sandbox=sandbox_config, | |
| # do not mount workspace | |
| workspace_base=None, | |
| workspace_mount_path=None, | |
| ) | |
| config.set_llm_config(metadata.llm_config) | |
| agent_config = config.get_agent_config(metadata.agent_class) | |
| agent_config.enable_prompt_extensions = False | |
| return config | |
| def execute_sql(db_path, gen_sql, gold_sql): | |
| """Execute the generated SQL and the ground truth SQL and compare the results.""" | |
| with sqlite3.connect(db_path) as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(gen_sql) | |
| predicted_res = cursor.fetchall() | |
| cursor.execute(gold_sql) | |
| ground_truth_res = cursor.fetchall() | |
| res = 0 | |
| if set(predicted_res) == set(ground_truth_res): | |
| res = 1 | |
| return res | |
| LOCAL_DATASET_PATH = os.path.join(os.path.dirname(__file__), 'data') | |
| def load_bird(): | |
| """Main function to handle the flow of downloading, processing, and loading the bird dataset.""" | |
| def _download_bird(): | |
| """Downloads and extracts the bird dataset from a specified URL into a local directory.""" | |
| devset_path = os.path.join(LOCAL_DATASET_PATH, 'dev') | |
| if not os.path.exists(devset_path): | |
| logger.info( | |
| f'{LOCAL_DATASET_PATH} folder does not exist, starting download and extraction...' | |
| ) | |
| os.makedirs(LOCAL_DATASET_PATH, exist_ok=True) | |
| download_url = 'https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip' | |
| download_path = os.path.join(LOCAL_DATASET_PATH, 'dev.zip') | |
| if not os.path.exists(download_path): | |
| logger.info('Start Downloading...') | |
| subprocess.run(['wget', download_url, '-O', download_path]) | |
| logger.info('Download completed.') | |
| devset_path = os.path.join(LOCAL_DATASET_PATH, 'dev') | |
| if not os.path.exists(devset_path): | |
| logger.info('Start Extracting...') | |
| os.makedirs(devset_path, exist_ok=True) | |
| with zipfile.ZipFile(download_path, 'r') as zip_ref: | |
| zip_ref.extractall(devset_path) | |
| # move everything in 'dev_20240627' to the root folder | |
| for file in os.listdir(os.path.join(devset_path, 'dev_20240627')): | |
| os.rename( | |
| os.path.join(devset_path, 'dev_20240627', file), | |
| os.path.join(devset_path, file), | |
| ) | |
| os.rmdir(os.path.join(devset_path, 'dev_20240627')) | |
| logger.info('Extraction completed.') | |
| # extract databases | |
| database_path = os.path.join(devset_path, 'dev_databases.zip') | |
| assert os.path.exists(database_path) | |
| logger.info('Start Extracting...') | |
| with zipfile.ZipFile(database_path, 'r') as zip_ref: | |
| zip_ref.extractall(devset_path) | |
| logger.info('Extraction completed.') | |
| else: | |
| logger.info(f'{LOCAL_DATASET_PATH} folder already exists.') | |
| return devset_path | |
| def _extract_create_table_prompt(db_path, limit_value=0): | |
| """Generates a SQL prompt with CREATE TABLE statements and sample data from the database.""" | |
| table_query = "SELECT * FROM sqlite_master WHERE type='table';" | |
| tables = sqlite3.connect(db_path).cursor().execute(table_query).fetchall() | |
| prompt = '' | |
| for table in tables: | |
| table_name = table[1] | |
| create_table_statement = table[-1] | |
| table_info_query = f'PRAGMA table_info(`{table_name}`);' | |
| top_k_row_query = f'SELECT * FROM {table_name} LIMIT {limit_value};' | |
| try: | |
| headers = [ | |
| x[1] | |
| for x in sqlite3.connect(db_path) | |
| .cursor() | |
| .execute(table_info_query) | |
| .fetchall() | |
| ] | |
| except Exception: | |
| logger.error(f'Error Connection: {table_info_query}, {top_k_row_query}') | |
| exit(0) | |
| prompt += create_table_statement + ';\n' | |
| if limit_value > 0: | |
| top_k_rows = ( | |
| sqlite3.connect(db_path) | |
| .cursor() | |
| .execute(top_k_row_query) | |
| .fetchall() | |
| ) | |
| prompt += ( | |
| f'/*\n3 example rows:\n{top_k_row_query}\n{" ".join(headers)}\n' | |
| ) | |
| for row in top_k_rows: | |
| row = [str(x) for x in row] | |
| row = [x if x is not None else '' for x in row] | |
| prompt += ' '.join(row) + '\n' | |
| prompt += '*/\n' | |
| prompt += '\n' | |
| return prompt | |
| def _create_prompt(e, database_path): | |
| """Create a prompt for the given example""" | |
| db_id = e['db_id'] | |
| db_path = pathlib.Path(database_path) / db_id / f'{db_id}.sqlite' | |
| # Extract the CREATE TABLE statements and sample data from the database | |
| prompt = _extract_create_table_prompt(db_path) | |
| prompt += f'-- External Knowledge: {e["evidence"]}\n\n' | |
| prompt += '-- Using valid SQLite and understanding External Knowledge, answer the following questions for the tables provided above.\n\n' | |
| prompt += '-- Using valid SQLite, answer the following questions for the tables provided above.\n' | |
| prompt += f'Question: {e["question"]}\n' | |
| return prompt | |
| def _process_bird(dataset_path): | |
| """Processes the raw bird dataset into a structured format and saves it as JSON.""" | |
| processed_path = os.path.join(LOCAL_DATASET_PATH, 'dev', 'processed_dev.json') | |
| if not os.path.exists(processed_path): | |
| logger.info( | |
| f'{processed_path} folder does not exist, starting processing...' | |
| ) | |
| raw_data_path = os.path.join(LOCAL_DATASET_PATH, 'dev', 'dev.json') | |
| database_path = os.path.join(LOCAL_DATASET_PATH, 'dev', 'dev_databases') | |
| processed_data = [] | |
| with pathlib.Path(raw_data_path).open('r') as f: | |
| data = json.load(f) | |
| for e in tqdm(data): | |
| item = { | |
| 'instance_id': f'{len(processed_data)}', | |
| 'db_path': os.path.join( | |
| database_path, e['db_id'], f'{e["db_id"]}.sqlite' | |
| ), | |
| 'db_id': e['db_id'], | |
| 'instruction': _create_prompt(e, database_path), | |
| 'SQL': e['SQL'], | |
| } | |
| processed_data.append(item) | |
| with pathlib.Path(processed_path).open('w') as f: | |
| json.dump(processed_data, f, indent=2) | |
| logger.info(f'Processed data saved to {processed_path}') | |
| else: | |
| logger.info(f'{processed_path} folder already exists.') | |
| bird_dataset = load_dataset('json', data_files={'test': processed_path}) | |
| return bird_dataset | |
| raw_dataset_path = _download_bird() | |
| bird_dataset = _process_bird(raw_dataset_path) | |
| return bird_dataset | |
| def initialize_runtime( | |
| runtime: Runtime, | |
| instance: pd.Series, # this argument is not required | |
| ): | |
| """Initialize the runtime for the agent. | |
| This function is called before the runtime is used to run the agent. | |
| """ | |
| logger.info(f'{"-" * 50} BEGIN Runtime Initialization Fn {"-" * 50}') | |
| obs: CmdOutputObservation | |
| # Copy the database to the workspace | |
| db_file = os.path.join( | |
| LOCAL_DATASET_PATH, | |
| 'dev', | |
| 'dev_databases', | |
| instance.db_id, | |
| f'{instance.db_id}.sqlite', | |
| ) | |
| runtime.copy_to(db_file, '/workspace') | |
| # Check the database is copied | |
| action = CmdRunAction(command='cd /workspace && ls -l') | |
| obs = runtime.run_action(action) | |
| logger.info(obs, extra={'msg_type': 'OBSERVATION'}) | |
| assert obs.exit_code == 0 | |
| assert f'{instance.db_id}.sqlite' in obs.content | |
| logger.info(f'{"-" * 50} END Runtime Initialization Fn {"-" * 50}') | |
| def complete_runtime( | |
| runtime: Runtime, | |
| instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name | |
| ) -> dict[str, Any]: | |
| """Complete the runtime for the agent. | |
| This function is called before the runtime is used to run the agent. | |
| If you need to do something in the sandbox to get the correctness metric after | |
| the agent has run, modify this function. | |
| """ | |
| logger.info(f'{"-" * 50} BEGIN Runtime Completion Fn {"-" * 50}') | |
| obs: CmdOutputObservation | |
| timeout = 30 | |
| test_result = {'result': {}, 'metadata': {}} | |
| # Read the generated python file | |
| instance_id = instance.instance_id.replace('/', '__') | |
| path = os.path.join('/workspace', f'{instance_id}.py') | |
| action = CmdRunAction(command=f'cat {path}') | |
| obs = runtime.run_action(action) | |
| logger.info(obs, extra={'msg_type': 'OBSERVATION'}) | |
| if obs.exit_code != 0: | |
| test_result['result'] = {'passed': 0, 'status': 'error'} | |
| return test_result | |
| gen_file = obs.content.strip().replace('\r\n', '\n') | |
| # Extract the SQL from the python file | |
| gen_sql = '' | |
| pattern = r'sql\s*=\s*"([^"]+)"' | |
| match = re.search(pattern, gen_file) | |
| if match: | |
| gen_sql = match.group(1) | |
| else: | |
| print('No match found.') | |
| gold_sql = instance.SQL | |
| # Execute the SQL | |
| try: | |
| res = func_timeout( | |
| timeout, | |
| execute_sql, | |
| args=( | |
| instance.db_path, | |
| gen_sql, | |
| gold_sql, | |
| ), | |
| ) | |
| status = 'success' | |
| except FunctionTimedOut: | |
| res = 0 | |
| status = 'timeout' | |
| except Exception as e: | |
| res = 0 | |
| status = 'error' | |
| logger.error(f'Error: {e}') | |
| # Save the test result | |
| test_result['result'] = {'passed': res, 'status': status} | |
| test_result['metadata'] = { | |
| 'timeout': timeout, | |
| 'gen_sql': gen_sql, | |
| 'gold_sql': gold_sql, | |
| } | |
| logger.info(f'{"-" * 50} END Runtime Completion Fn {"-" * 50}') | |
| return test_result | |
| def process_instance( | |
| instance: pd.Series, | |
| metadata: EvalMetadata, | |
| reset_logger: bool = True, | |
| ) -> EvalOutput: | |
| config = get_config(metadata) | |
| # use session id for concurrent evaluation | |
| instance_id = instance.instance_id.replace('/', '__') | |
| # Set up the logger properly, so you can run multi-processing to parallelize the evaluation | |
| if reset_logger: | |
| log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs') | |
| reset_logger_for_multiprocessing(logger, instance_id, log_dir) | |
| else: | |
| logger.info(f'Starting evaluation for instance {instance_id}.') | |
| # Create file with BIRD instance | |
| database_path = os.path.join('/workspace', f'{instance.db_id}.sqlite') | |
| statements = f""" | |
| import sqlite3 | |
| def execute_sql(db_path, sql): | |
| with sqlite3.connect(db_path) as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(sql) | |
| result = cursor.fetchall() | |
| return result | |
| if __name__ == '__main__': | |
| sql = "" # fill in your SQL here | |
| db_path = "{database_path}" | |
| print(db_path) | |
| result = execute_sql(db_path, sql) | |
| print(result) | |
| """ | |
| instruction = ( | |
| f'You are a SQL expert and need to complete the following text-to-SQL tasks.' | |
| f'\n\n{instance.instruction}\n\n' | |
| 'Please write the SQL in one line without line breaks.' | |
| f'And write a new python file named {instance_id}.py to call the SQL you wrote.' | |
| 'You need to follow the code template below:' | |
| f'\n\n{statements}\n\n' | |
| 'Environment has been set up for you to start working.' | |
| 'You may assume all necessary tools are installed.\n\n' | |
| ) | |
| instruction += ( | |
| 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n' | |
| 'You SHOULD INCLUDE PROPER INDENTATION in your edit commands.\n' | |
| ) | |
| # NOTE: You can actually set slightly different instruction for different agents | |
| instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class] | |
| runtime = create_runtime(config) | |
| call_async_from_sync(runtime.connect) | |
| initialize_runtime(runtime, instance) | |
| # Here's how you can run the agent (similar to the `main` function) and get the final task state | |
| state: State | None = asyncio.run( | |
| run_controller( | |
| config=config, | |
| initial_user_action=MessageAction(content=instruction), | |
| fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[ | |
| metadata.agent_class | |
| ], | |
| runtime=runtime, | |
| ) | |
| ) | |
| # ======= Attempt to evaluate the agent's edits ======= | |
| test_result = complete_runtime(runtime, instance) | |
| # If you are working on some simpler benchmark that only evaluates the final model output (e.g., in a MessageAction) | |
| # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation. | |
| if state is None: | |
| raise ValueError('State should not be None.') | |
| metrics = state.metrics.get() if state.metrics else None | |
| # history is now available as a stream of events, rather than list of pairs of (Action, Observation) | |
| # for compatibility with the existing output format, we can remake the pairs here | |
| # remove when it becomes unnecessary | |
| histories = compatibility_for_eval_history_pairs(state.history) | |
| # Save the output | |
| output = EvalOutput( | |
| instance_id=instance.instance_id, | |
| instruction=instruction, | |
| metadata=metadata, | |
| history=histories, | |
| metrics=metrics, | |
| error=state.last_error if state and state.last_error else None, | |
| test_result=test_result, | |
| ) | |
| return output | |
| if __name__ == '__main__': | |
| args = parse_arguments() | |
| bird_dataset = load_bird() | |
| dataset = bird_dataset['test'].to_pandas() | |
| dataset.rename(columns={'task_id': 'instance_id'}, inplace=True) | |
| llm_config = None | |
| if args.llm_config: | |
| llm_config = get_llm_config_arg(args.llm_config) | |
| # modify_params must be False for evaluation purpose, for reproducibility and accurancy of results | |
| llm_config.modify_params = False | |
| if llm_config is None: | |
| raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}') | |
| metadata = make_metadata( | |
| llm_config, | |
| 'BIRD', | |
| args.agent_cls, | |
| args.max_iterations, | |
| args.eval_note, | |
| args.eval_output_dir, | |
| ) | |
| output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl') | |
| instances = prepare_dataset(dataset, output_file, args.eval_n_limit) | |
| run_evaluation( | |
| instances, metadata, output_file, args.eval_num_workers, process_instance | |
| ) | |