Spaces:
Runtime error
Runtime error
| from game import AgentPlayer, SQLiteDB, Game | |
| import os | |
| import json | |
| import asyncio | |
| import argparse | |
| class Proctor: | |
| def __init__( | |
| self, | |
| article_list: list[tuple[str, str]], | |
| num_trials: int, | |
| num_workers: int, | |
| max_steps: int, | |
| agent_settings: dict, | |
| db_path: str, | |
| verbose: bool = True, | |
| output_dir: str = "./proctor_tmp", | |
| proctor_id: str = "proctor_1", | |
| starting_seed: int = 42, | |
| ): | |
| self.article_list = article_list | |
| self.num_trials = num_trials | |
| self.num_workers = num_workers | |
| self.max_steps = max_steps | |
| self.agent_settings = agent_settings | |
| self.db_path = db_path | |
| self.verbose = verbose | |
| self.output_dir = output_dir | |
| self.proctor_id = proctor_id | |
| self.db = SQLiteDB(self.db_path) | |
| self.starting_seed = starting_seed | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| self.runs = [] | |
| self.setup_runs() | |
| def setup_runs(self): | |
| for start in self.article_list: | |
| for destination in self.article_list: | |
| if start == destination: | |
| continue | |
| for n in range(self.num_trials): | |
| run_id = f"{self.proctor_id}_{start}_{destination}_{n}" | |
| self.runs.append( | |
| Run( | |
| start, | |
| destination, | |
| self.max_steps, | |
| self.agent_settings, | |
| self.db, | |
| self.output_dir, | |
| self.verbose, | |
| run_id, | |
| self.starting_seed + n, | |
| ) | |
| ) | |
| print(f"Setup run {run_id}") | |
| async def run(self): | |
| semaphore = asyncio.Semaphore(self.num_workers) | |
| tasks = [] | |
| async def run_with_semaphore(run_instance): | |
| async with semaphore: | |
| if self.verbose: | |
| print(f"Starting run {run_instance.id}") | |
| await run_instance.run() | |
| if self.verbose: | |
| print(f"Finished run {run_instance.id}") | |
| for run_instance in self.runs: | |
| tasks.append(asyncio.create_task(run_with_semaphore(run_instance))) | |
| await asyncio.gather(*tasks) | |
| self.analyze_runs() | |
| def analyze_runs(self): | |
| """We need to analze all the runs into a .json""" | |
| final_results = { | |
| "article_list": self.article_list, | |
| "num_trials": self.num_trials, | |
| "num_workers": self.num_workers, | |
| "max_steps": self.max_steps, | |
| "agent_settings": self.agent_settings, | |
| "runs": [], | |
| } | |
| win_count = 0 | |
| lose_count = 0 | |
| hops_distribution = [] | |
| for run in self.runs: | |
| with open(run.output_file, "r") as f: | |
| result = json.load(f) | |
| final_results["runs"].append(result) | |
| if result["result"] == "win": | |
| win_count += 1 | |
| hops_distribution.append(len(result["steps"]) - 1) | |
| else: | |
| lose_count += 1 | |
| final_results["hops_distribution"] = hops_distribution | |
| final_results["average_hops"] = sum(hops_distribution) / len(hops_distribution) | |
| final_results["win_rate"] = win_count / len(self.runs) | |
| final_results["lose_rate"] = lose_count / len(self.runs) | |
| with open(f"{self.output_dir}/{self.proctor_id}-final-results.json", "w") as f: | |
| json.dump(final_results, f, indent=4) | |
| class Run: | |
| def __init__( | |
| self, | |
| start_article: str, | |
| destination_article: str, | |
| max_steps: int, | |
| agent_settings: dict, | |
| db: SQLiteDB, | |
| output_dir: str, | |
| verbose: bool, | |
| id: str, | |
| seed: int, | |
| ): | |
| self.start_article = start_article | |
| self.destination_article = destination_article | |
| self.max_steps = max_steps | |
| self.agent_settings = agent_settings | |
| self.db = db | |
| self.output_dir = output_dir | |
| self.verbose = verbose | |
| self.id = id | |
| self.seed = seed | |
| self.output_file = f"{self.output_dir}/run_{self.id}.json" | |
| async def run(self): | |
| if os.path.exists(self.output_file): | |
| return | |
| player = AgentPlayer( | |
| model=self.agent_settings["model"], | |
| api_base=self.agent_settings["api_base"], | |
| max_links=self.agent_settings["max_links"], | |
| max_tries=self.agent_settings["max_tries"], | |
| verbose=False, | |
| seed=self.seed, | |
| ) | |
| game = Game( | |
| self.start_article, | |
| self.destination_article, | |
| self.db, | |
| self.max_steps, | |
| player, | |
| verbose=False, | |
| ) | |
| steps = await game.run() | |
| output = { | |
| "model": self.agent_settings["model"], | |
| "api_base": self.agent_settings["api_base"], | |
| "max_links": self.agent_settings["max_links"], | |
| "max_tries": self.agent_settings["max_tries"], | |
| "start_article": self.start_article, | |
| "destination_article": self.destination_article, | |
| "steps": steps, | |
| "seed": self.seed, | |
| "result": steps[-1]["type"], | |
| } | |
| with open(self.output_file, "w") as f: | |
| json.dump(output, f, indent=4) | |
| print(f"Run {self.id} completed in {len(steps)} steps") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Run parallel Wikispeedia evaluations") | |
| parser.add_argument("--model", type=str, default="gpt-4o", help="Model to use for agent") | |
| parser.add_argument("--api-base", type=str, default=None, help="API base URL for hosted models") | |
| parser.add_argument("--workers", type=int, default=20, help="Number of parallel workers") | |
| parser.add_argument("--trials", type=int, default=1, help="Number of trials per start-destination pair") | |
| parser.add_argument("--max-steps", type=int, default=20, help="Maximum steps per game") | |
| parser.add_argument("--max-links", type=int, default=200, help="Maximum links per page for agent") | |
| parser.add_argument("--max-tries", type=int, default=3, help="Maximum retries for agent") | |
| parser.add_argument("--db-path", type=str, default="wikihop.db", help="Path to the wikihop database") | |
| parser.add_argument("--output-dir", type=str, default="./proctor_tmp", help="Directory for output files") | |
| parser.add_argument("--proctor-id", type=str, default="proctor_1", help="Unique identifier for this proctor run") | |
| parser.add_argument("--seed", type=int, default=42, help="Starting random seed") | |
| parser.add_argument("--verbose", action="store_true", help="Enable verbose output") | |
| parser.add_argument("--article-list", type=str, default="supernodes.json", | |
| help="Path to JSON file with list of articles to test") | |
| args = parser.parse_args() | |
| # check if db exists | |
| if not os.path.exists(args.db_path): | |
| raise FileNotFoundError(f"Database file not found at {args.db_path}") | |
| # check if article list exists | |
| if not os.path.exists(args.article_list): | |
| raise FileNotFoundError(f"Article list file not found at {args.article_list}") | |
| # Read article list from file | |
| with open(args.article_list, "r") as f: | |
| article_list = json.load(f) | |
| agent_settings = { | |
| "model": args.model, | |
| "api_base": args.api_base, | |
| "max_links": args.max_links, | |
| "max_tries": args.max_tries, | |
| } | |
| proctor = Proctor( | |
| article_list=article_list, | |
| num_trials=args.trials, | |
| num_workers=args.workers, | |
| max_steps=args.max_steps, | |
| agent_settings=agent_settings, | |
| db_path=args.db_path, | |
| verbose=args.verbose, | |
| output_dir=args.output_dir, | |
| proctor_id=args.proctor_id, | |
| starting_seed=args.seed, | |
| ) | |
| asyncio.run(proctor.run()) | |