Spaces:
Sleeping
Sleeping
| """ | |
| Async Natural Language to MCP Command Translator | |
| NON-BLOCKING version that never freezes the game loop | |
| Uses async model manager for instant response | |
| """ | |
| import json | |
| import re | |
| import time | |
| from typing import Dict, Optional, Tuple | |
| from pathlib import Path | |
| from model_manager import get_shared_model, RequestStatus | |
| class AsyncNLCommandTranslator: | |
| """Async translator that returns immediately and provides polling""" | |
| def __init__(self, model_path: str = "qwen2.5-coder-1.5b-instruct-q4_0.gguf"): | |
| self.model_path = model_path | |
| self.model_manager = get_shared_model() | |
| self.last_error = None | |
| # Track pending requests | |
| self._pending_requests = {} # command_text -> (request_id, submitted_at) | |
| self._current_request_id = None # Track current active request to cancel on new one | |
| # Language detection patterns | |
| self.lang_patterns = { | |
| 'zh': re.compile(r'[\u4e00-\u9fff]'), # Chinese characters | |
| 'fr': re.compile(r'[àâçèéêëîïôùûü]', re.IGNORECASE) # French accents | |
| } | |
| # System prompts (same as original) | |
| self.system_prompts = { | |
| "en": """You are an AI assistant for an RTS game. Convert user commands into JSON tool calls. | |
| Available tools: | |
| - get_game_state(): Get current game state | |
| - move_units(unit_ids: list, target_x: int, target_y: int): Move units to position | |
| - attack_unit(attacker_ids: list, target_id: str): Attack enemy unit | |
| - build_unit(unit_type: str): Build a unit (infantry, tank, helicopter, harvester) | |
| - build_building(building_type: str, x: int, y: int): Build a building (barracks, war_factory, power_plant, refinery, defense_turret) | |
| Respond ONLY with valid JSON containing "tool" and "params" fields. | |
| For parameterless functions, you may omit the params field. | |
| Example: {"tool": "move_units", "params": {"unit_ids": ["unit_1"], "target_x": 200, "target_y": 300}}""", | |
| "fr": """Tu es un assistant IA pour un jeu RTS. Convertis les commandes utilisateur en appels d'outils JSON. | |
| Outils disponibles : | |
| - get_game_state(): Obtenir l'état du jeu | |
| - move_units(unit_ids: list, target_x: int, target_y: int): Déplacer des unités | |
| - attack_unit(attacker_ids: list, target_id: str): Attaquer une unité ennemie | |
| - build_unit(unit_type: str): Construire une unité (infantry, tank, helicopter, harvester) | |
| - build_building(building_type: str, x: int, y: int): Construire un bâtiment (barracks, war_factory, power_plant, refinery, defense_turret) | |
| Réponds UNIQUEMENT avec du JSON valide contenant les champs "tool" et "params".""", | |
| "zh": """你是一个RTS游戏的AI助手。将用户命令转换为JSON工具调用。 | |
| 可用工具: | |
| - get_game_state(): 获取当前游戏状态 | |
| - move_units(unit_ids: list, target_x: int, target_y: int): 移动单位到位置 | |
| - attack_unit(attacker_ids: list, target_id: str): 攻击敌方单位 | |
| - build_unit(unit_type: str): 建造单位(infantry步兵, tank坦克, helicopter直升机, harvester采集车) | |
| - build_building(building_type: str, x: int, y: int): 建造建筑(barracks兵营, war_factory战争工厂, power_plant发电厂, refinery精炼厂, defense_turret防御塔) | |
| 仅响应包含"tool"和"params"字段的有效JSON。""" | |
| } | |
| def model_loaded(self) -> bool: | |
| """Check if model is loaded""" | |
| return self.model_manager.model_loaded | |
| def load_model(self) -> Tuple[bool, Optional[str]]: | |
| """Load the model (delegates to shared model manager)""" | |
| return self.model_manager.load_model(self.model_path) | |
| def detect_language(self, text: str) -> str: | |
| """Detect language from text (Chinese > French > English)""" | |
| if self.lang_patterns['zh'].search(text): | |
| return 'zh' | |
| elif self.lang_patterns['fr'].search(text): | |
| return 'fr' | |
| return 'en' | |
| def extract_json_from_response(self, text: str) -> Optional[Dict]: | |
| """Extract JSON object from LLM response""" | |
| try: | |
| # Try direct parsing | |
| if text.startswith('{'): | |
| return json.loads(text) | |
| # Find JSON in code blocks | |
| json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL) | |
| if json_match: | |
| return json.loads(json_match.group(1)) | |
| # Find any JSON object | |
| json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL) | |
| if json_match: | |
| return json.loads(json_match.group(0)) | |
| return None | |
| except json.JSONDecodeError: | |
| return None | |
| def submit_translation(self, nl_command: str, language: Optional[str] = None) -> str: | |
| """ | |
| Submit translation request (NON-BLOCKING - returns immediately) | |
| Cancels any previous translation request to ensure we showcase | |
| the latest command. No timeout - inference runs until completion. | |
| Args: | |
| nl_command: Natural language command | |
| language: Optional language override | |
| Returns: | |
| request_id: Use this to check result with check_translation() | |
| """ | |
| # Cancel previous request if any (one active translation at a time) | |
| if self._current_request_id is not None: | |
| self.model_manager.cancel_request(self._current_request_id) | |
| print(f"🔄 Cancelled previous translation request {self._current_request_id} (new command received)") | |
| # Ensure model is loaded | |
| if not self.model_loaded: | |
| success, error = self.load_model() | |
| if not success: | |
| raise RuntimeError(f"Model not loaded: {error}") | |
| # Detect language | |
| if language is None: | |
| language = self.detect_language(nl_command) | |
| # Get system prompt | |
| system_prompt = self.system_prompts.get(language, self.system_prompts["en"]) | |
| # Create messages | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": nl_command} | |
| ] | |
| # Submit async request | |
| request_id = self.model_manager.submit_async( | |
| messages=messages, | |
| max_tokens=64, # Reduced from 128 - JSON commands are short | |
| temperature=0.1 | |
| ) | |
| # Track request | |
| self._pending_requests[nl_command] = (request_id, time.time(), language) | |
| self._current_request_id = request_id # Track as current active request | |
| return request_id | |
| def check_translation(self, request_id: str) -> Dict: | |
| """ | |
| Check translation result (NON-BLOCKING - returns status immediately) | |
| Args: | |
| request_id: ID from submit_translation() | |
| Returns: | |
| Dict with status, result (if ready), or error | |
| """ | |
| status, result_text, error_message = self.model_manager.get_result(request_id, remove=False) | |
| # Not ready yet | |
| if status in [RequestStatus.PENDING, RequestStatus.PROCESSING]: | |
| return { | |
| "ready": False, | |
| "status": status.value, | |
| "message": "Translation in progress..." | |
| } | |
| # Failed | |
| if status == RequestStatus.FAILED or status == RequestStatus.CANCELLED: | |
| # Remove from manager | |
| self.model_manager.get_result(request_id, remove=True) | |
| return { | |
| "ready": True, | |
| "success": False, | |
| "error": error_message or "Translation failed", | |
| "status": status.value | |
| } | |
| # Completed - parse result | |
| if status == RequestStatus.COMPLETED and result_text: | |
| # Remove from manager | |
| self.model_manager.get_result(request_id, remove=True) | |
| # Clear current request if this was it | |
| if self._current_request_id == request_id: | |
| self._current_request_id = None | |
| # Extract JSON | |
| json_command = self.extract_json_from_response(result_text) | |
| if json_command and 'tool' in json_command: | |
| return { | |
| "ready": True, | |
| "success": True, | |
| "json_command": json_command, | |
| "raw_response": result_text, | |
| "language": "unknown" # We don't track language per request ID | |
| } | |
| else: | |
| return { | |
| "ready": True, | |
| "success": False, | |
| "error": "Could not extract valid JSON from response", | |
| "raw_response": result_text | |
| } | |
| # Unknown state | |
| return { | |
| "ready": True, | |
| "success": False, | |
| "error": "Unknown status", | |
| "status": status.value | |
| } | |
| def translate_blocking(self, nl_command: str, language: Optional[str] = None, max_wait: float = 300.0) -> Dict: | |
| """ | |
| Translate and wait for completion (for backward compatibility) | |
| NO TIMEOUT - waits for inference to complete (unless superseded). | |
| This showcases full LLM capability. max_wait is only a safety limit. | |
| """ | |
| try: | |
| # Submit (cancels any previous translation) | |
| request_id = self.submit_translation(nl_command, language) | |
| # Poll until complete (no timeout, let it finish) | |
| start_time = time.time() | |
| while time.time() - start_time < max_wait: # Safety limit only | |
| result = self.check_translation(request_id) | |
| if result["ready"]: | |
| return result | |
| # Wait a bit before checking again | |
| time.sleep(0.1) | |
| # Safety limit reached (extremely long inference) | |
| return { | |
| "success": False, | |
| "error": f"Translation exceeded safety limit ({max_wait}s) - model may be stuck", | |
| "timeout": True | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "error": f"Translation error: {str(e)}" | |
| } | |
| def cleanup_old_requests(self, max_age: float = 60.0): | |
| """Remove old pending requests""" | |
| now = time.time() | |
| to_remove = [] | |
| for cmd, (req_id, submitted_at, lang) in self._pending_requests.items(): | |
| if now - submitted_at > max_age: | |
| to_remove.append(cmd) | |
| for cmd in to_remove: | |
| req_id, _, _ = self._pending_requests.pop(cmd) | |
| self.model_manager.cancel_request(req_id) | |
| # Legacy API compatibility | |
| def translate(self, nl_command: str, language: Optional[str] = None) -> Dict: | |
| """Legacy blocking API - waits for completion (no timeout)""" | |
| return self.translate_blocking(nl_command, language) | |
| def translate_command(self, nl_command: str, language: Optional[str] = None) -> Dict: | |
| """Alias for translate() - for API compatibility""" | |
| return self.translate(nl_command, language) | |
| def get_example_commands(self, language: str = "en") -> list: | |
| """Get example commands for the given language""" | |
| examples = { | |
| "en": [ | |
| "Show me the game state", | |
| "Move my infantry to position 200, 300", | |
| "Build a tank", | |
| "Construct a power plant at 150, 150", | |
| "Attack the enemy base", | |
| ], | |
| "fr": [ | |
| "Montre-moi l'état du jeu", | |
| "Déplace mon infanterie vers 200, 300", | |
| "Construis un char", | |
| "Construit une centrale électrique à 150, 150", | |
| "Attaque la base ennemie", | |
| ], | |
| "zh": [ | |
| "显示游戏状态", | |
| "移动我的步兵到200, 300", | |
| "建造一个坦克", | |
| "在150, 150建造发电厂", | |
| "攻击敌人的基地", | |
| ] | |
| } | |
| return examples.get(language, examples["en"]) | |
| # Global instance | |
| _translator = None | |
| def get_nl_translator() -> AsyncNLCommandTranslator: | |
| """Get singleton translator instance""" | |
| global _translator | |
| if _translator is None: | |
| _translator = AsyncNLCommandTranslator() | |
| # Auto-load model | |
| if not _translator.model_loaded: | |
| print("🔄 Loading NL translator model...") | |
| success, error = _translator.load_model() | |
| if success: | |
| print("✅ NL translator model loaded successfully") | |
| else: | |
| print(f"❌ Failed to load NL translator model: {error}") | |
| return _translator | |