Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| from collections import Counter, defaultdict | |
| import multiprocessing | |
| from datetime import datetime | |
| from concurrent.futures import ProcessPoolExecutor, as_completed | |
| from typing import Dict, List, Tuple | |
| import gc | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import RedirectResponse | |
| from api.code_execution import untrusted_check | |
| Result = Tuple[str, List[bool]] | |
| def create_app() -> FastAPI: | |
| level = os.environ.get("LOG_LEVEL", default=logging.INFO) | |
| logging.basicConfig(level=level) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| def root(): | |
| return RedirectResponse("/docs") | |
| def health(): | |
| return | |
| async def evaluate( | |
| samples: List[dict], | |
| calibrate: bool = True, | |
| parallel: int = -1, | |
| min_time_limit: float = 1, | |
| max_as_limit: int = 30 * 1024, | |
| max_data_limit: int = 30 * 1024, | |
| max_stack_limit: int = 10, | |
| no_gt: bool = True, | |
| ) -> dict: | |
| """ | |
| Evaluate the correctness of the solutions in the given samples data. | |
| """ | |
| if parallel < 1: | |
| n_workers = max(1, multiprocessing.cpu_count() // 2) | |
| else: | |
| n_workers = parallel | |
| if not no_gt: | |
| expected_time = get_groundtruth() | |
| else: | |
| expected_time = {} | |
| results = { | |
| "date": datetime.now().strftime("%Y-%m-%d %H:%M"), | |
| "eval": {}, | |
| } | |
| with ProcessPoolExecutor(max_workers=n_workers) as executor: | |
| futures = [] | |
| completion_id = Counter() | |
| n_samples = 0 | |
| eval_results = defaultdict(list) # task_id -> | |
| remainings = set() | |
| for i, sample in enumerate(samples): | |
| # TODO: investigate why HTTPException detail is not passed to client. | |
| for key in ["task_id", "res_id", "test", "solution", "entry_point"]: | |
| if key not in sample: | |
| raise HTTPException(status_code=400, detail=f"'{key}' not in sample {i}!") | |
| if not isinstance(sample["solution"], str): | |
| raise HTTPException(status_code=400, detail="Solution must be a string!") | |
| sample["_identifier"] = ( | |
| sample["task_id"] + f" (line {i+1} )" | |
| ) | |
| task_id = sample["task_id"] | |
| solution = sample["solution"] | |
| if calibrate: | |
| solution = sample["code_prompt"] + "\n pass\n" + solution | |
| remainings.add(sample["_identifier"]) | |
| args = ( | |
| completion_id[task_id], | |
| sample["res_id"], | |
| task_id, | |
| solution, | |
| sample["test"], | |
| sample["entry_point"], | |
| max_as_limit, | |
| max_data_limit, | |
| max_stack_limit, | |
| sample["_identifier"], | |
| min_time_limit, | |
| expected_time.get(task_id) if expected_time.get(task_id) else 20 | |
| ) | |
| futures.append(executor.submit(check_correctness, *args)) | |
| completion_id[task_id] += 1 | |
| n_samples += 1 | |
| assert n_samples == len(remainings), "Missing problems in unfinished" | |
| #assert len(completion_id) == len(problems), "Missing problems in samples" | |
| for future in as_completed(futures): | |
| result = future.result() | |
| remainings.remove(result["_identifier"]) | |
| eval_results[result["task_id"]].append(result) | |
| del future, result | |
| gc.collect() | |
| # sort the results for each problem by completion_id | |
| for task_id, task_results in eval_results.items(): | |
| task_results.sort(key=lambda x: x["completion_id"]) | |
| results["eval"][task_id] = [] | |
| for res in task_results: | |
| stat, details = res["base"] | |
| results["eval"][task_id].append( | |
| { | |
| "res_id": res["res_id"], | |
| "task_id": task_id, | |
| "solution": res["solution"], | |
| "status": stat, | |
| "details": details, | |
| } | |
| ) | |
| return results | |
| return app | |
| def check_correctness( | |
| completion_id: int, | |
| res_id: int, | |
| task_id: str, | |
| solution: str, | |
| test: str, | |
| entry_point: str, | |
| max_as_limit: float, | |
| max_data_limit: float, | |
| max_stack_limit: float, | |
| identifier=None, | |
| min_time_limit: float = 0.1, | |
| gt_time_limit: float = 2.0, | |
| ) -> Dict[str, Result]: | |
| ret = { | |
| "completion_id": completion_id, | |
| "res_id": res_id, | |
| "task_id": task_id, | |
| "_identifier": identifier, | |
| "solution": solution, | |
| } | |
| ret["base"] = untrusted_check( | |
| solution, | |
| test, | |
| entry_point, | |
| max_as_limit, | |
| max_data_limit, | |
| max_stack_limit, | |
| min_time_limit, | |
| gt_time_limit, | |
| ) | |
| return ret | |
| def get_groundtruth(): | |
| raise HTTPException(status_code=405, detail="Groundtruth execution is not implemented yet!") | |