Spaces:
Build error
Build error
| # flake8: noqa: E501 | |
| import asyncio | |
| import dataclasses | |
| import json | |
| import os | |
| import pathlib | |
| import shutil | |
| import subprocess | |
| from argparse import Namespace | |
| from typing import Any | |
| from uuid import uuid4 | |
| from termcolor import colored | |
| import openhands | |
| from openhands.controller.state.state import State | |
| from openhands.core.config import AgentConfig, OpenHandsConfig, SandboxConfig | |
| from openhands.core.config.utils import load_openhands_config | |
| from openhands.core.logger import openhands_logger as logger | |
| from openhands.core.main import create_runtime, run_controller | |
| from openhands.events.action import CmdRunAction, MessageAction | |
| from openhands.events.event import Event | |
| from openhands.events.observation import ( | |
| CmdOutputObservation, | |
| ErrorObservation, | |
| Observation, | |
| ) | |
| from openhands.events.stream import EventStreamSubscriber | |
| from openhands.integrations.service_types import ProviderType | |
| from openhands.resolver.interfaces.issue import Issue | |
| from openhands.resolver.interfaces.issue_definitions import ( | |
| ServiceContextIssue, | |
| ServiceContextPR, | |
| ) | |
| from openhands.resolver.issue_handler_factory import IssueHandlerFactory | |
| from openhands.resolver.resolver_output import ResolverOutput | |
| from openhands.resolver.utils import ( | |
| codeact_user_response, | |
| get_unique_uid, | |
| identify_token, | |
| reset_logger_for_multiprocessing, | |
| ) | |
| from openhands.runtime.base import Runtime | |
| from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync | |
| # Don't make this configurable for now, unless we have other competitive agents | |
| AGENT_CLASS = 'CodeActAgent' | |
| class IssueResolver: | |
| GITLAB_CI = os.getenv('GITLAB_CI') == 'true' | |
| def __init__(self, args: Namespace) -> None: | |
| """Initialize the IssueResolver with the given parameters. | |
| Params initialized: | |
| owner: Owner of the repo. | |
| repo: Repository name. | |
| token: Token to access the repository. | |
| username: Username to access the repository. | |
| platform: Platform of the repository. | |
| runtime_container_image: Container image to use. | |
| max_iterations: Maximum number of iterations to run. | |
| output_dir: Output directory to write the results. | |
| llm_config: Configuration for the language model. | |
| prompt_template: Prompt template to use. | |
| issue_type: Type of issue to resolve (issue or pr). | |
| repo_instruction: Repository instruction to use. | |
| issue_number: Issue number to resolve. | |
| comment_id: Optional ID of a specific comment to focus on. | |
| base_domain: The base domain for the git server. | |
| """ | |
| parts = args.selected_repo.rsplit('/', 1) | |
| if len(parts) < 2: | |
| raise ValueError('Invalid repository format. Expected owner/repo') | |
| owner, repo = parts | |
| token = args.token or os.getenv('GITHUB_TOKEN') or os.getenv('GITLAB_TOKEN') | |
| username = args.username if args.username else os.getenv('GIT_USERNAME') | |
| if not username: | |
| raise ValueError('Username is required.') | |
| if not token: | |
| raise ValueError('Token is required.') | |
| platform = call_async_from_sync( | |
| identify_token, | |
| GENERAL_TIMEOUT, | |
| token, | |
| args.base_domain, | |
| ) | |
| repo_instruction = None | |
| if args.repo_instruction_file: | |
| with open(args.repo_instruction_file, 'r') as f: | |
| repo_instruction = f.read() | |
| issue_type = args.issue_type | |
| # Read the prompt template | |
| prompt_file = args.prompt_file | |
| if prompt_file is None: | |
| if issue_type == 'issue': | |
| prompt_file = os.path.join( | |
| os.path.dirname(__file__), 'prompts/resolve/basic-with-tests.jinja' | |
| ) | |
| else: | |
| prompt_file = os.path.join( | |
| os.path.dirname(__file__), 'prompts/resolve/basic-followup.jinja' | |
| ) | |
| with open(prompt_file, 'r') as f: | |
| user_instructions_prompt_template = f.read() | |
| with open( | |
| prompt_file.replace('.jinja', '-conversation-instructions.jinja') | |
| ) as f: | |
| conversation_instructions_prompt_template = f.read() | |
| base_domain = args.base_domain | |
| if base_domain is None: | |
| base_domain = ( | |
| 'github.com' if platform == ProviderType.GITHUB else 'gitlab.com' | |
| ) | |
| self.output_dir = args.output_dir | |
| self.issue_type = issue_type | |
| self.issue_number = args.issue_number | |
| self.workspace_base = self.build_workspace_base( | |
| self.output_dir, self.issue_type, self.issue_number | |
| ) | |
| self.max_iterations = args.max_iterations | |
| self.app_config = self.update_openhands_config( | |
| load_openhands_config(), | |
| self.max_iterations, | |
| self.workspace_base, | |
| args.base_container_image, | |
| args.runtime_container_image, | |
| args.is_experimental, | |
| ) | |
| self.owner = owner | |
| self.repo = repo | |
| self.platform = platform | |
| self.user_instructions_prompt_template = user_instructions_prompt_template | |
| self.conversation_instructions_prompt_template = ( | |
| conversation_instructions_prompt_template | |
| ) | |
| self.repo_instruction = repo_instruction | |
| self.comment_id = args.comment_id | |
| factory = IssueHandlerFactory( | |
| owner=self.owner, | |
| repo=self.repo, | |
| token=token, | |
| username=username, | |
| platform=self.platform, | |
| base_domain=base_domain, | |
| issue_type=self.issue_type, | |
| llm_config=self.app_config.get_llm_config(), | |
| ) | |
| self.issue_handler = factory.create() | |
| def update_openhands_config( | |
| cls, | |
| config: OpenHandsConfig, | |
| max_iterations: int, | |
| workspace_base: str, | |
| base_container_image: str | None, | |
| runtime_container_image: str | None, | |
| is_experimental: bool, | |
| ) -> OpenHandsConfig: | |
| config.default_agent = 'CodeActAgent' | |
| config.runtime = 'docker' | |
| config.max_budget_per_task = 4 | |
| config.max_iterations = max_iterations | |
| # do not mount workspace | |
| config.workspace_base = workspace_base | |
| config.workspace_mount_path = workspace_base | |
| config.agents = {'CodeActAgent': AgentConfig(disabled_microagents=['github'])} | |
| cls.update_sandbox_config( | |
| config, | |
| base_container_image, | |
| runtime_container_image, | |
| is_experimental, | |
| ) | |
| return config | |
| def update_sandbox_config( | |
| cls, | |
| openhands_config: OpenHandsConfig, | |
| base_container_image: str | None, | |
| runtime_container_image: str | None, | |
| is_experimental: bool, | |
| ) -> None: | |
| if runtime_container_image is not None and base_container_image is not None: | |
| raise ValueError('Cannot provide both runtime and base container images.') | |
| if ( | |
| runtime_container_image is None | |
| and base_container_image is None | |
| and not is_experimental | |
| ): | |
| runtime_container_image = ( | |
| f'ghcr.io/all-hands-ai/runtime:{openhands.__version__}-nikolaik' | |
| ) | |
| # Convert container image values to string or None | |
| container_base = ( | |
| str(base_container_image) if base_container_image is not None else None | |
| ) | |
| container_runtime = ( | |
| str(runtime_container_image) | |
| if runtime_container_image is not None | |
| else None | |
| ) | |
| sandbox_config = SandboxConfig( | |
| base_container_image=container_base, | |
| runtime_container_image=container_runtime, | |
| enable_auto_lint=False, | |
| use_host_network=False, | |
| timeout=300, | |
| ) | |
| # Configure sandbox for GitLab CI environment | |
| if cls.GITLAB_CI: | |
| sandbox_config.local_runtime_url = os.getenv( | |
| 'LOCAL_RUNTIME_URL', 'http://localhost' | |
| ) | |
| user_id = os.getuid() if hasattr(os, 'getuid') else 1000 | |
| if user_id == 0: | |
| sandbox_config.user_id = get_unique_uid() | |
| openhands_config.sandbox.base_container_image = ( | |
| sandbox_config.base_container_image | |
| ) | |
| openhands_config.sandbox.runtime_container_image = ( | |
| sandbox_config.runtime_container_image | |
| ) | |
| openhands_config.sandbox.enable_auto_lint = sandbox_config.enable_auto_lint | |
| openhands_config.sandbox.use_host_network = sandbox_config.use_host_network | |
| openhands_config.sandbox.timeout = sandbox_config.timeout | |
| openhands_config.sandbox.local_runtime_url = sandbox_config.local_runtime_url | |
| openhands_config.sandbox.user_id = sandbox_config.user_id | |
| def initialize_runtime( | |
| self, | |
| runtime: Runtime, | |
| ) -> None: | |
| """Initialize the runtime for the agent. | |
| This function is called before the runtime is used to run the agent. | |
| It sets up git configuration and runs the setup script if it exists. | |
| """ | |
| logger.info('-' * 30) | |
| logger.info('BEGIN Runtime Completion Fn') | |
| logger.info('-' * 30) | |
| obs: Observation | |
| action = CmdRunAction(command='cd /workspace') | |
| logger.info(action, extra={'msg_type': 'ACTION'}) | |
| obs = runtime.run_action(action) | |
| logger.info(obs, extra={'msg_type': 'OBSERVATION'}) | |
| if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0: | |
| raise RuntimeError(f'Failed to change directory to /workspace.\n{obs}') | |
| if self.platform == ProviderType.GITLAB and self.GITLAB_CI: | |
| action = CmdRunAction(command='sudo chown -R 1001:0 /workspace/*') | |
| logger.info(action, extra={'msg_type': 'ACTION'}) | |
| obs = runtime.run_action(action) | |
| logger.info(obs, extra={'msg_type': 'OBSERVATION'}) | |
| action = CmdRunAction(command='git config --global core.pager ""') | |
| logger.info(action, extra={'msg_type': 'ACTION'}) | |
| obs = runtime.run_action(action) | |
| logger.info(obs, extra={'msg_type': 'OBSERVATION'}) | |
| if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0: | |
| raise RuntimeError(f'Failed to set git config.\n{obs}') | |
| # Run setup script if it exists | |
| logger.info('Checking for .openhands/setup.sh script...') | |
| runtime.maybe_run_setup_script() | |
| # Setup git hooks if they exist | |
| logger.info('Checking for .openhands/pre-commit.sh script...') | |
| runtime.maybe_setup_git_hooks() | |
| async def complete_runtime( | |
| self, | |
| runtime: Runtime, | |
| base_commit: str, | |
| ) -> dict[str, Any]: | |
| """Complete the runtime for the agent. | |
| This function is called before the runtime is used to run the agent. | |
| If you need to do something in the sandbox to get the correctness metric after | |
| the agent has run, modify this function. | |
| """ | |
| logger.info('-' * 30) | |
| logger.info('BEGIN Runtime Completion Fn') | |
| logger.info('-' * 30) | |
| obs: Observation | |
| action = CmdRunAction(command='cd /workspace') | |
| logger.info(action, extra={'msg_type': 'ACTION'}) | |
| obs = runtime.run_action(action) | |
| logger.info(obs, extra={'msg_type': 'OBSERVATION'}) | |
| if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0: | |
| raise RuntimeError( | |
| f'Failed to change directory to /workspace. Observation: {obs}' | |
| ) | |
| action = CmdRunAction(command='git config --global core.pager ""') | |
| logger.info(action, extra={'msg_type': 'ACTION'}) | |
| obs = runtime.run_action(action) | |
| logger.info(obs, extra={'msg_type': 'OBSERVATION'}) | |
| if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0: | |
| raise RuntimeError(f'Failed to set git config. Observation: {obs}') | |
| action = CmdRunAction( | |
| command='git config --global --add safe.directory /workspace' | |
| ) | |
| logger.info(action, extra={'msg_type': 'ACTION'}) | |
| obs = runtime.run_action(action) | |
| logger.info(obs, extra={'msg_type': 'OBSERVATION'}) | |
| if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0: | |
| raise RuntimeError(f'Failed to set git config. Observation: {obs}') | |
| if self.platform == ProviderType.GITLAB and self.GITLAB_CI: | |
| action = CmdRunAction(command='sudo git add -A') | |
| else: | |
| action = CmdRunAction(command='git add -A') | |
| logger.info(action, extra={'msg_type': 'ACTION'}) | |
| obs = runtime.run_action(action) | |
| logger.info(obs, extra={'msg_type': 'OBSERVATION'}) | |
| if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0: | |
| raise RuntimeError(f'Failed to git add. Observation: {obs}') | |
| n_retries = 0 | |
| git_patch = None | |
| while n_retries < 5: | |
| action = CmdRunAction(command=f'git diff --no-color --cached {base_commit}') | |
| action.set_hard_timeout(600 + 100 * n_retries) | |
| logger.info(action, extra={'msg_type': 'ACTION'}) | |
| obs = runtime.run_action(action) | |
| logger.info(obs, extra={'msg_type': 'OBSERVATION'}) | |
| n_retries += 1 | |
| if isinstance(obs, CmdOutputObservation): | |
| if obs.exit_code == 0: | |
| git_patch = obs.content.strip() | |
| break | |
| else: | |
| logger.info('Failed to get git diff, retrying...') | |
| await asyncio.sleep(10) | |
| elif isinstance(obs, ErrorObservation): | |
| logger.error(f'Error occurred: {obs.content}. Retrying...') | |
| await asyncio.sleep(10) | |
| else: | |
| raise ValueError(f'Unexpected observation type: {type(obs)}') | |
| logger.info('-' * 30) | |
| logger.info('END Runtime Completion Fn') | |
| logger.info('-' * 30) | |
| return {'git_patch': git_patch} | |
| def build_workspace_base( | |
| output_dir: str, issue_type: str, issue_number: int | |
| ) -> str: | |
| workspace_base = os.path.join( | |
| output_dir, 'workspace', f'{issue_type}_{issue_number}' | |
| ) | |
| return os.path.abspath(workspace_base) | |
| async def process_issue( | |
| self, | |
| issue: Issue, | |
| base_commit: str, | |
| issue_handler: ServiceContextIssue | ServiceContextPR, | |
| reset_logger: bool = False, | |
| ) -> ResolverOutput: | |
| # Setup the logger properly, so you can run multi-processing to parallelize processing | |
| if reset_logger: | |
| log_dir = os.path.join(self.output_dir, 'infer_logs') | |
| reset_logger_for_multiprocessing(logger, str(issue.number), log_dir) | |
| else: | |
| logger.info(f'Starting fixing issue {issue.number}.') | |
| # write the repo to the workspace | |
| if os.path.exists(self.workspace_base): | |
| shutil.rmtree(self.workspace_base) | |
| shutil.copytree(os.path.join(self.output_dir, 'repo'), self.workspace_base) | |
| runtime = create_runtime(self.app_config) | |
| await runtime.connect() | |
| def on_event(evt: Event) -> None: | |
| logger.info(evt) | |
| runtime.event_stream.subscribe( | |
| EventStreamSubscriber.MAIN, on_event, str(uuid4()) | |
| ) | |
| self.initialize_runtime(runtime) | |
| instruction, conversation_instructions, images_urls = ( | |
| issue_handler.get_instruction( | |
| issue, | |
| self.user_instructions_prompt_template, | |
| self.conversation_instructions_prompt_template, | |
| self.repo_instruction, | |
| ) | |
| ) | |
| # Here's how you can run the agent (similar to the `main` function) and get the final task state | |
| action = MessageAction(content=instruction, image_urls=images_urls) | |
| try: | |
| state: State | None = await run_controller( | |
| config=self.app_config, | |
| initial_user_action=action, | |
| runtime=runtime, | |
| fake_user_response_fn=codeact_user_response, | |
| conversation_instructions=conversation_instructions, | |
| ) | |
| if state is None: | |
| raise RuntimeError('Failed to run the agent.') | |
| except (ValueError, RuntimeError) as e: | |
| error_msg = f'Agent failed with error: {str(e)}' | |
| logger.error(error_msg) | |
| state = None | |
| last_error: str | None = error_msg | |
| # Get git patch | |
| return_val = await self.complete_runtime(runtime, base_commit) | |
| git_patch = return_val['git_patch'] | |
| logger.info( | |
| f'Got git diff for instance {issue.number}:\n--------\n{git_patch}\n--------' | |
| ) | |
| # Serialize histories and set defaults for failed state | |
| if state is None: | |
| histories = [] | |
| metrics = None | |
| success = False | |
| comment_success = None | |
| result_explanation = 'Agent failed to run' | |
| last_error = 'Agent failed to run or crashed' | |
| else: | |
| histories = [dataclasses.asdict(event) for event in state.history] | |
| metrics = state.metrics.get() if state.metrics else None | |
| # determine success based on the history, issue description and git patch | |
| success, comment_success, result_explanation = issue_handler.guess_success( | |
| issue, state.history, git_patch | |
| ) | |
| if issue_handler.issue_type == 'pr' and comment_success: | |
| success_log = 'I have updated the PR and resolved some of the issues that were cited in the pull request review. Specifically, I identified the following revision requests, and all the ones that I think I successfully resolved are checked off. All the unchecked ones I was not able to resolve, so manual intervention may be required:\n' | |
| try: | |
| explanations = json.loads(result_explanation) | |
| except json.JSONDecodeError: | |
| logger.error( | |
| f'Failed to parse result_explanation as JSON: {result_explanation}' | |
| ) | |
| explanations = [ | |
| str(result_explanation) | |
| ] # Use raw string as fallback | |
| for success_indicator, explanation in zip( | |
| comment_success, explanations | |
| ): | |
| status = ( | |
| colored('[X]', 'red') | |
| if success_indicator | |
| else colored('[ ]', 'red') | |
| ) | |
| bullet_point = colored('-', 'yellow') | |
| success_log += f'\n{bullet_point} {status}: {explanation}' | |
| logger.info(success_log) | |
| last_error = state.last_error if state.last_error else None | |
| # Save the output | |
| output = ResolverOutput( | |
| issue=issue, | |
| issue_type=issue_handler.issue_type, | |
| instruction=instruction, | |
| base_commit=base_commit, | |
| git_patch=git_patch, | |
| history=histories, | |
| metrics=metrics, | |
| success=success, | |
| comment_success=comment_success, | |
| result_explanation=result_explanation, | |
| error=last_error, | |
| ) | |
| return output | |
| def extract_issue(self) -> Issue: | |
| # Load dataset | |
| issues: list[Issue] = self.issue_handler.get_converted_issues( | |
| issue_numbers=[self.issue_number], comment_id=self.comment_id | |
| ) | |
| if not issues: | |
| raise ValueError( | |
| f'No issues found for issue number {self.issue_number}. Please verify that:\n' | |
| f'1. The issue/PR #{self.issue_number} exists in the repository {self.owner}/{self.repo}\n' | |
| f'2. You have the correct permissions to access it\n' | |
| f'3. The repository name is spelled correctly' | |
| ) | |
| return issues[0] | |
| async def resolve_issue( | |
| self, | |
| reset_logger: bool = False, | |
| ) -> None: | |
| """Resolve a single issue. | |
| Args: | |
| reset_logger: Whether to reset the logger for multiprocessing. | |
| """ | |
| issue = self.extract_issue() | |
| if self.comment_id is not None: | |
| if ( | |
| self.issue_type == 'pr' | |
| and not issue.review_comments | |
| and not issue.review_threads | |
| and not issue.thread_comments | |
| ): | |
| raise ValueError( | |
| f'Comment ID {self.comment_id} did not have a match for issue {issue.number}' | |
| ) | |
| if self.issue_type == 'issue' and not issue.thread_comments: | |
| raise ValueError( | |
| f'Comment ID {self.comment_id} did not have a match for issue {issue.number}' | |
| ) | |
| # TEST METADATA | |
| model_name = self.app_config.get_llm_config().model.split('/')[-1] | |
| pathlib.Path(self.output_dir).mkdir(parents=True, exist_ok=True) | |
| pathlib.Path(os.path.join(self.output_dir, 'infer_logs')).mkdir( | |
| parents=True, exist_ok=True | |
| ) | |
| logger.info(f'Using output directory: {self.output_dir}') | |
| # checkout the repo | |
| repo_dir = os.path.join(self.output_dir, 'repo') | |
| if not os.path.exists(repo_dir): | |
| checkout_output = subprocess.check_output( # noqa: ASYNC101 | |
| [ | |
| 'git', | |
| 'clone', | |
| self.issue_handler.get_clone_url(), | |
| f'{self.output_dir}/repo', | |
| ] | |
| ).decode('utf-8') | |
| if 'fatal' in checkout_output: | |
| raise RuntimeError(f'Failed to clone repository: {checkout_output}') | |
| # get the commit id of current repo for reproducibility | |
| base_commit = ( | |
| subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=repo_dir) # noqa: ASYNC101 | |
| .decode('utf-8') | |
| .strip() | |
| ) | |
| logger.info(f'Base commit: {base_commit}') | |
| if self.repo_instruction is None: | |
| # Check for .openhands_instructions file in the workspace directory | |
| openhands_instructions_path = os.path.join( | |
| repo_dir, '.openhands_instructions' | |
| ) | |
| if os.path.exists(openhands_instructions_path): | |
| with open(openhands_instructions_path, 'r') as f: # noqa: ASYNC101 | |
| self.repo_instruction = f.read() | |
| # OUTPUT FILE | |
| output_file = os.path.join(self.output_dir, 'output.jsonl') | |
| logger.info(f'Writing output to {output_file}') | |
| # Check if this issue was already processed | |
| if os.path.exists(output_file): | |
| with open(output_file, 'r') as f: # noqa: ASYNC101 | |
| for line in f: | |
| data = ResolverOutput.model_validate_json(line) | |
| if data.issue.number == self.issue_number: | |
| logger.warning( | |
| f'Issue {self.issue_number} was already processed. Skipping.' | |
| ) | |
| return | |
| output_fp = open(output_file, 'a') # noqa: ASYNC101 | |
| logger.info( | |
| f'Resolving issue {self.issue_number} with Agent {AGENT_CLASS}, model {model_name}, max iterations {self.max_iterations}.' | |
| ) | |
| try: | |
| # checkout to pr branch if needed | |
| if self.issue_type == 'pr': | |
| branch_to_use = issue.head_branch | |
| logger.info( | |
| f'Checking out to PR branch {branch_to_use} for issue {issue.number}' | |
| ) | |
| if not branch_to_use: | |
| raise ValueError('Branch name cannot be None') | |
| # Fetch the branch first to ensure it exists locally | |
| fetch_cmd = ['git', 'fetch', 'origin', branch_to_use] | |
| subprocess.check_output( # noqa: ASYNC101 | |
| fetch_cmd, | |
| cwd=repo_dir, | |
| ) | |
| # Checkout the branch | |
| checkout_cmd = ['git', 'checkout', branch_to_use] | |
| subprocess.check_output( # noqa: ASYNC101 | |
| checkout_cmd, | |
| cwd=repo_dir, | |
| ) | |
| base_commit = ( | |
| subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=repo_dir) # noqa: ASYNC101 | |
| .decode('utf-8') | |
| .strip() | |
| ) | |
| output = await self.process_issue( | |
| issue, | |
| base_commit, | |
| self.issue_handler, | |
| reset_logger, | |
| ) | |
| output_fp.write(output.model_dump_json() + '\n') | |
| output_fp.flush() | |
| finally: | |
| output_fp.close() | |
| logger.info('Finished.') | |