Spaces:
Runtime error
Runtime error
| from typing import List, Tuple, Dict, Optional | |
| import sqlite3 | |
| import json | |
| import litellm | |
| import re | |
| import asyncio | |
| import argparse | |
| from functools import lru_cache | |
| class SQLiteDB: | |
| def __init__(self, db_path: str): | |
| """Initialize the database with path to SQLite database""" | |
| self.db_path = db_path | |
| self.conn = sqlite3.connect(db_path) | |
| self.conn.row_factory = sqlite3.Row | |
| self.cursor = self.conn.cursor() | |
| self._article_count = self._get_article_count() | |
| print(f"Connected to SQLite database with {self._article_count} articles") | |
| def _get_article_count(self): | |
| self.cursor.execute("SELECT COUNT(*) FROM core_articles") | |
| return self.cursor.fetchone()[0] | |
| def get_article_with_links(self, article_title: str) -> Tuple[str, List[str]]: | |
| self.cursor.execute( | |
| "SELECT title, links_json FROM core_articles WHERE title = ?", | |
| (article_title,), | |
| ) | |
| article = self.cursor.fetchone() | |
| if not article: | |
| return None, [] | |
| links = json.loads(article["links_json"]) | |
| return article["title"], links | |
| class Player: | |
| def __init__(self, name: str): | |
| self.name = name | |
| async def get_move(self, game_state: List[Dict]) -> Tuple[str, Dict]: | |
| print("Link choices:") | |
| for i, link in enumerate(game_state[-1]["links"]): | |
| print(f"{i}: {link}") | |
| idx = int(input(f"Enter the index of the link you want to select: ")) | |
| return game_state[-1]["links"][idx], { | |
| "message": f"{self.name} selected link #{i}" | |
| } # select the first link | |
| class AgentPlayer(Player): | |
| def __init__( | |
| self, | |
| model: str, | |
| api_base: str, | |
| verbose: bool = True, | |
| max_links=None, | |
| max_tries=10, | |
| target_article = None, | |
| seed = None | |
| ): | |
| super().__init__(model) | |
| self.model = model | |
| self.api_base = api_base | |
| self.verbose = verbose | |
| self.max_links = max_links | |
| self.max_tries = max_tries | |
| self.target_article = target_article | |
| self.seed = seed | |
| async def get_move(self, game_state: List[Dict]) -> Tuple[str, Dict]: | |
| prompt = self.construct_prompt(game_state) | |
| conversation = [ | |
| {"role": "user", "content": prompt} | |
| ] | |
| for try_number in range(self.max_tries): | |
| response = await litellm.acompletion( | |
| model=self.model, | |
| api_base=self.api_base, | |
| messages=conversation, | |
| seed=self.seed | |
| ) | |
| response = response.choices[0].message.content | |
| conversation.append({"role": "assistant", "content": response}) | |
| answer, message = self._attempt_to_extract_answer(response, maximum_answer=len(game_state[-1]["links"])) | |
| # there was a problem with the answer so give the model another chance | |
| if answer == -1: | |
| conversation.append({"role": "user", "content": message}) | |
| continue | |
| assert answer >= 1 and answer <= len(game_state[-1]["links"]), f"Answer {answer} is out of range" | |
| # we found an answer so we can return it | |
| return game_state[-1]["links"][answer-1], {"tries": try_number, "conversation": conversation} | |
| # we tried the max number of times and still didn't find an answer | |
| return -1, {"tries": self.max_tries, "conversation": conversation} | |
| def construct_prompt(self, game_state: List[Dict]) -> str: | |
| current = game_state[-1]["article"] | |
| target = self.target_article | |
| available_links = game_state[-1]["links"] | |
| formatted_links = "\n".join([f"{i+1}. {link}" for i, link in enumerate(available_links)]) | |
| path_so_far = [step["article"] for step in game_state] | |
| try: | |
| formatted_path = ' -> '.join(path_so_far) | |
| except Exception as e: | |
| print(f"Error formatting path: {e}") | |
| print(game_state) | |
| print("Path so far: ", path_so_far) | |
| raise e | |
| return f"""You are playing WikiRun, trying to navigate from one Wikipedia article to another using only links. | |
| IMPORTANT: You MUST put your final answer in <answer>NUMBER</answer> tags, where NUMBER is the link number. | |
| For example, if you want to choose link 3, output <answer>3</answer>. | |
| Current article: {current} | |
| Target article: {target} | |
| Available links (numbered): | |
| {formatted_links} | |
| Your path so far: {formatted_path} | |
| Think about which link is most likely to lead you toward the target article. | |
| First, analyze each link briefly and how it connects to your goal, then select the most promising one. | |
| Remember to format your final answer by explicitly writing out the xml number tags like this: <answer>NUMBER</answer> | |
| """ | |
| def _attempt_to_extract_answer(self, response: str, maximum_answer: Optional[int] = None) -> Tuple[int, str]: | |
| 'returns -1 and a message if no answer is found' | |
| # Extract choice using format <answer>N</answer> | |
| choice_match = re.search(r"<answer>(\d+)</answer>", response) | |
| if choice_match is None: | |
| return -1, f"No answer found in response. Please respond with a number between 1 and {maximum_answer} in <answer>NUMBER</answer> tags." | |
| # check if there are multiple answers | |
| multiple_answers = re.findall(r"<answer>(\d+)</answer>", response) | |
| if len(multiple_answers) > 1: | |
| return -1, "Multiple answers found in response. Please respond with just one." | |
| answer = choice_match.group(1) | |
| # try to convert to int | |
| try: | |
| answer = int(answer) | |
| except ValueError: | |
| return -1, f"You answered with {answer} but it could not be converted to an integer. Please respond with a number between 1 and {maximum_answer}." | |
| # check if the answer is too high or too low | |
| if answer > maximum_answer or answer < 1: | |
| return -1, f"You answered with {answer} but you have to select a number between 1 and {maximum_answer}." | |
| return answer, "" # we found an answer so we don't need to return a message | |
| class Game: | |
| def __init__( | |
| self, | |
| start_article: str, | |
| target_article: str, | |
| db: SQLiteDB, | |
| max_allowed_steps: int, | |
| player: Player, | |
| verbose: bool = True, | |
| ): | |
| self.start_article = start_article | |
| self.target_article = target_article | |
| self.db = db | |
| self.max_allowed_steps = max_allowed_steps | |
| self.steps = [] | |
| self.steps_taken = 0 | |
| self.player = player | |
| self.verbose = verbose | |
| # Ensure the player knows the target article | |
| if isinstance(self.player, AgentPlayer): | |
| self.player.target_article = self.target_article | |
| async def run(self): | |
| if self.verbose: | |
| print(f"Starting game from {self.start_article} to {self.target_article}") | |
| # get the start article | |
| _, links = self.db.get_article_with_links(self.start_article) | |
| self.steps.append( | |
| { | |
| "type": "start", | |
| "article": self.start_article, | |
| "links": links, | |
| "metadata": {"message": "Game started"}, | |
| } | |
| ) | |
| # while the current article is not the target article and the number of steps taken is less than the max allowed steps | |
| while self.steps_taken < self.max_allowed_steps: | |
| self.steps_taken += 1 | |
| # Await the async player move | |
| player_move, metadata = await self.player.get_move(self.steps) | |
| # player couldn't select a valid link | |
| if player_move == -1: | |
| self.steps.append( | |
| {"type": "lose", "article": player_move, "metadata": metadata} | |
| ) | |
| break | |
| if self.verbose: | |
| print(f" -> Step {self.steps_taken}: {player_move}") | |
| # input("Press Enter to continue...") | |
| # if we found it its over | |
| if player_move == self.target_article: | |
| self.steps.append( | |
| {"type": "win", "article": player_move, "metadata": metadata} | |
| ) | |
| break | |
| # if not lets get the next article | |
| _, links = self.db.get_article_with_links(player_move) | |
| if len(links) == 0: | |
| self.steps.append( | |
| {"type": "lose", "article": player_move, "metadata": metadata} | |
| ) | |
| break | |
| self.steps.append( | |
| { | |
| "type": "move", | |
| "article": player_move, | |
| "links": links, | |
| "metadata": metadata, | |
| } | |
| ) | |
| return self.steps | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Play the WikiRun game") | |
| # Add mutual exclusion group for player type | |
| player_group = parser.add_mutually_exclusive_group(required=True) | |
| player_group.add_argument("--human", action="store_true", help="Play as a human") | |
| player_group.add_argument("--agent", action="store_true", help="Use an AI agent to play") | |
| # Game parameters | |
| parser.add_argument("--start", type=str, default="British Library", help="Starting article title") | |
| parser.add_argument("--end", type=str, default="Saint Lucia", help="Target article title") | |
| parser.add_argument("--db", type=str, required=True, help="Path to SQLite database") | |
| parser.add_argument("--max-steps", type=int, default=10, help="Maximum number of steps allowed (default: 10)") | |
| # Agent parameters (only used with --agent) | |
| parser.add_argument("--model", type=str, default="gpt-4o", help="Model to use for the agent (default: gpt-4o)") | |
| parser.add_argument("--api-base", type=str, default="https://api.openai.com/v1", | |
| help="API base URL (default: https://api.openai.com/v1)") | |
| parser.add_argument("--max-links", type=int, default=200, help="Maximum number of links to consider (default: 200)") | |
| parser.add_argument("--max-tries", type=int, default=3, help="Maximum number of tries for the agent (default: 3)") | |
| parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility") | |
| args = parser.parse_args() | |
| # Initialize the database | |
| db = SQLiteDB(args.db) | |
| # Initialize the player based on the argument | |
| if args.human: | |
| player = Player("Human") | |
| else: # args.agent is True | |
| player = AgentPlayer( | |
| model=args.model, | |
| api_base=args.api_base, | |
| verbose=True, | |
| max_links=args.max_links, | |
| max_tries=args.max_tries, | |
| target_article=args.end, | |
| seed=args.seed | |
| ) | |
| # Create and run the game | |
| game = Game( | |
| start_article=args.start, | |
| target_article=args.end, | |
| db=db, | |
| max_allowed_steps=args.max_steps, | |
| player=player, | |
| verbose=True | |
| ) | |
| steps = asyncio.run(game.run()) | |
| print(f"Game over in {len(steps)} steps") | |
| for i, step in enumerate(steps): | |
| print(f"Step {i}: {step['type']}") | |
| print(f" Article: {step['article']}") | |
| print(f" Links: {step.get('links', [])}") | |
| print(f" Metadata: {step.get('metadata', {})}") | |
| print("\n\n") | |