Spaces:
Configuration error
Configuration error
| """ | |
| ToGMAL + ML Integration | |
| This module integrates the clustering-based anomaly detection models | |
| with the ToGMAL MCP server, enabling ML-enhanced safety detection. | |
| """ | |
| import os | |
| import pickle | |
| from typing import Dict, Any, Tuple, Optional, TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| import numpy as np | |
| try: | |
| import numpy as np | |
| except Exception as e: | |
| raise RuntimeError("Required ML dependencies missing. Please install: numpy, scikit-learn") from e | |
| # ============================================================================ | |
| # ML-ENHANCED DETECTION | |
| # ============================================================================ | |
| class MLEnhancedDetector: | |
| """ | |
| Wrapper for clustering models that can be used alongside heuristic detection. | |
| """ | |
| def __init__(self, models_dir: str = "./models"): | |
| self.models_dir = models_dir | |
| self.prompt_model = None | |
| self.response_model = None | |
| self.joint_model = None | |
| self._loaded = False | |
| def load_models(self): | |
| """Load all available trained models.""" | |
| try: | |
| # Load prompt clustering model | |
| prompt_path = os.path.join(self.models_dir, "prompt_clustering.pkl") | |
| if os.path.exists(prompt_path): | |
| with open(prompt_path, 'rb') as f: | |
| data = pickle.load(f) | |
| self.prompt_model = { | |
| 'model': data['model'], | |
| 'feature_extractor': data['feature_extractor'], | |
| 'dangerous_clusters': getattr(data.get('model'), 'dangerous_clusters_', []) | |
| } | |
| print(f"✓ Loaded prompt clustering model from {prompt_path}") | |
| # Load joint clustering model | |
| joint_path = os.path.join(self.models_dir, "joint_clustering.pkl") | |
| if os.path.exists(joint_path): | |
| with open(joint_path, 'rb') as f: | |
| data = pickle.load(f) | |
| self.joint_model = { | |
| 'model': data['model'], | |
| 'feature_extractor': data['feature_extractor'], | |
| 'dangerous_clusters': getattr(data.get('model'), 'dangerous_clusters_', []) | |
| } | |
| print(f"✓ Loaded joint clustering model from {joint_path}") | |
| self._loaded = True | |
| return True | |
| except Exception as e: | |
| print(f"✗ Failed to load models: {e}") | |
| return False | |
| def analyze_prompt_ml(self, prompt: str) -> Dict[str, Any]: | |
| """ | |
| Analyze a prompt using ML clustering model. | |
| Returns: | |
| dict with keys: | |
| - detected: bool | |
| - cluster_id: int | |
| - is_dangerous_cluster: bool | |
| - confidence: float | |
| - method: str = 'ml_clustering' | |
| """ | |
| if not self._loaded or self.prompt_model is None: | |
| return { | |
| 'detected': False, | |
| 'cluster_id': -1, | |
| 'is_dangerous_cluster': False, | |
| 'confidence': 0.0, | |
| 'method': 'ml_clustering_unavailable' | |
| } | |
| try: | |
| # Extract features | |
| feature_extractor = self.prompt_model['feature_extractor'] | |
| features = feature_extractor.transform_prompts([prompt]) | |
| # Predict cluster | |
| model = self.prompt_model['model'] | |
| cluster_id = model.predict(features)[0] | |
| # Check if dangerous | |
| # Note: We need to recover dangerous clusters from training | |
| # For now, use distance to cluster center as proxy | |
| if hasattr(model, 'cluster_centers_'): | |
| distances = np.linalg.norm( | |
| model.cluster_centers_ - features, axis=1 | |
| ) | |
| closest_dangerous = min( | |
| [d for i, d in enumerate(distances) if i in [1, 2]], # From training: clusters 1,2 are dangerous | |
| default=float('inf') | |
| ) | |
| is_dangerous = closest_dangerous < 1.0 # Threshold | |
| confidence = 1.0 - min(closest_dangerous / 2.0, 1.0) | |
| else: | |
| is_dangerous = False | |
| confidence = 0.0 | |
| return { | |
| 'detected': is_dangerous, | |
| 'cluster_id': int(cluster_id), | |
| 'is_dangerous_cluster': is_dangerous, | |
| 'confidence': float(confidence), | |
| 'method': 'ml_clustering' | |
| } | |
| except Exception as e: | |
| print(f"ML analysis error: {e}") | |
| return { | |
| 'detected': False, | |
| 'cluster_id': -1, | |
| 'is_dangerous_cluster': False, | |
| 'confidence': 0.0, | |
| 'method': 'ml_clustering_error', | |
| 'error': str(e) | |
| } | |
| def analyze_pair_ml(self, prompt: str, response: str) -> Dict[str, Any]: | |
| """ | |
| Analyze a prompt-response pair using ML clustering model. | |
| """ | |
| if not self._loaded or self.joint_model is None: | |
| return { | |
| 'detected': False, | |
| 'cluster_id': -1, | |
| 'is_dangerous_cluster': False, | |
| 'confidence': 0.0, | |
| 'method': 'ml_clustering_unavailable' | |
| } | |
| try: | |
| # Extract features from combined text | |
| combined = f"{prompt} [SEP] {response}" | |
| feature_extractor = self.joint_model['feature_extractor'] | |
| features = feature_extractor.prompt_vectorizer.transform([combined]).toarray() | |
| features = feature_extractor.scaler.transform(features) | |
| # Predict cluster | |
| model = self.joint_model['model'] | |
| cluster_id = model.predict(features)[0] | |
| # Check if dangerous (cluster 0 was dangerous in training) | |
| if hasattr(model, 'cluster_centers_'): | |
| distances = np.linalg.norm( | |
| model.cluster_centers_ - features, axis=1 | |
| ) | |
| # Cluster 0 is dangerous from training | |
| closest_dangerous = distances[0] | |
| is_dangerous = closest_dangerous < 1.0 | |
| confidence = 1.0 - min(closest_dangerous / 2.0, 1.0) | |
| else: | |
| is_dangerous = False | |
| confidence = 0.0 | |
| return { | |
| 'detected': is_dangerous, | |
| 'cluster_id': int(cluster_id), | |
| 'is_dangerous_cluster': is_dangerous, | |
| 'confidence': float(confidence), | |
| 'method': 'ml_clustering' | |
| } | |
| except Exception as e: | |
| print(f"ML analysis error: {e}") | |
| return { | |
| 'detected': False, | |
| 'cluster_id': -1, | |
| 'is_dangerous_cluster': False, | |
| 'confidence': 0.0, | |
| 'method': 'ml_clustering_error', | |
| 'error': str(e) | |
| } | |
| # ============================================================================ | |
| # HYBRID DETECTION (Heuristics + ML) | |
| # ============================================================================ | |
| def combine_detections( | |
| heuristic_results: Dict[str, Any], | |
| ml_results: Dict[str, Any], | |
| weight_heuristic: float = 0.7, | |
| weight_ml: float = 0.3 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Combine heuristic and ML detection results. | |
| Args: | |
| heuristic_results: Results from heuristic detection (ToGMAL) | |
| ml_results: Results from ML clustering | |
| weight_heuristic: Weight for heuristic confidence (0-1) | |
| weight_ml: Weight for ML confidence (0-1) | |
| Returns: | |
| Combined detection result with ensemble confidence | |
| """ | |
| # Normalize weights | |
| total_weight = weight_heuristic + weight_ml | |
| weight_heuristic /= total_weight | |
| weight_ml /= total_weight | |
| # Extract confidences | |
| heuristic_conf = heuristic_results.get('confidence', 0.0) | |
| ml_conf = ml_results.get('confidence', 0.0) | |
| # Combine confidences | |
| combined_confidence = ( | |
| weight_heuristic * heuristic_conf + | |
| weight_ml * ml_conf | |
| ) | |
| # Logical OR for detection (if either detects, flag it) | |
| combined_detected = ( | |
| heuristic_results.get('detected', False) or | |
| ml_results.get('detected', False) | |
| ) | |
| # Aggregate categories | |
| combined_categories = list(set( | |
| heuristic_results.get('categories', []) + | |
| ([ml_results.get('method', '')] if ml_results.get('detected') else []) | |
| )) | |
| return { | |
| 'detected': combined_detected, | |
| 'confidence': combined_confidence, | |
| 'categories': combined_categories, | |
| 'heuristic_confidence': heuristic_conf, | |
| 'ml_confidence': ml_conf, | |
| 'ml_cluster_id': ml_results.get('cluster_id', -1), | |
| 'method': 'hybrid_ensemble' | |
| } | |
| # ============================================================================ | |
| # INTEGRATION WITH ToGMAL | |
| # ============================================================================ | |
| # Global ML detector instance (lazy loaded) | |
| _ml_detector: Optional[MLEnhancedDetector] = None | |
| def get_ml_detector(models_dir: str = "./models") -> MLEnhancedDetector: | |
| """Get or create ML detector instance.""" | |
| global _ml_detector | |
| if _ml_detector is None: | |
| _ml_detector = MLEnhancedDetector(models_dir) | |
| _ml_detector.load_models() | |
| return _ml_detector | |