make clear distinction between ensemble agents and smart agents
Browse files- utils/monitoring_agents.py +135 -0
- utils/weight_management.py +107 -0
utils/monitoring_agents.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
import torch
|
| 4 |
+
import psutil # Ensure psutil is imported here as well
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
class EnsembleMonitorAgent:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.performance_metrics = {
|
| 11 |
+
"model_accuracy": {},
|
| 12 |
+
"response_times": {},
|
| 13 |
+
"confidence_distribution": {},
|
| 14 |
+
"consensus_rate": 0.0
|
| 15 |
+
}
|
| 16 |
+
self.alerts = []
|
| 17 |
+
|
| 18 |
+
def monitor_prediction(self, model_id, prediction, confidence, response_time):
|
| 19 |
+
"""Monitor individual model performance"""
|
| 20 |
+
if model_id not in self.performance_metrics["model_accuracy"]:
|
| 21 |
+
self.performance_metrics["model_accuracy"][model_id] = []
|
| 22 |
+
self.performance_metrics["response_times"][model_id] = []
|
| 23 |
+
self.performance_metrics["confidence_distribution"][model_id] = []
|
| 24 |
+
|
| 25 |
+
self.performance_metrics["response_times"][model_id].append(response_time)
|
| 26 |
+
self.performance_metrics["confidence_distribution"][model_id].append(confidence)
|
| 27 |
+
|
| 28 |
+
# Check for performance issues
|
| 29 |
+
self._check_performance_issues(model_id)
|
| 30 |
+
|
| 31 |
+
def _check_performance_issues(self, model_id):
|
| 32 |
+
"""Check for any performance anomalies"""
|
| 33 |
+
response_times = self.performance_metrics["response_times"][model_id]
|
| 34 |
+
if len(response_times) > 10:
|
| 35 |
+
avg_time = sum(response_times[-10:]) / 10
|
| 36 |
+
if avg_time > 2.0: # More than 2 seconds
|
| 37 |
+
self.alerts.append(f"High latency detected for {model_id}: {avg_time:.2f}s")
|
| 38 |
+
|
| 39 |
+
class WeightOptimizationAgent:
|
| 40 |
+
def __init__(self, weight_manager):
|
| 41 |
+
self.weight_manager = weight_manager
|
| 42 |
+
self.prediction_history = [] # Stores (ensemble_prediction_label, assumed_actual_label)
|
| 43 |
+
self.optimization_threshold = 0.05 # 5% change in accuracy triggers optimization
|
| 44 |
+
self.min_history_for_optimization = 20 # Minimum samples before optimizing
|
| 45 |
+
|
| 46 |
+
def analyze_performance(self, ensemble_prediction_label, actual_label=None):
|
| 47 |
+
"""Analyze ensemble performance and record for optimization"""
|
| 48 |
+
# If actual_label is not provided, assume ensemble is correct if not UNCERTAIN
|
| 49 |
+
assumed_actual_label = actual_label
|
| 50 |
+
if assumed_actual_label is None and ensemble_prediction_label != "UNCERTAIN":
|
| 51 |
+
assumed_actual_label = ensemble_prediction_label
|
| 52 |
+
|
| 53 |
+
self.prediction_history.append((ensemble_prediction_label, assumed_actual_label))
|
| 54 |
+
|
| 55 |
+
if len(self.prediction_history) >= self.min_history_for_optimization and self._should_optimize():
|
| 56 |
+
self._optimize_weights()
|
| 57 |
+
|
| 58 |
+
def _calculate_accuracy(self, history_subset):
|
| 59 |
+
"""Calculates accuracy based on history where actual_label is known."""
|
| 60 |
+
correct_predictions = 0
|
| 61 |
+
total_known = 0
|
| 62 |
+
for ensemble_pred, actual_label in history_subset:
|
| 63 |
+
if actual_label is not None:
|
| 64 |
+
total_known += 1
|
| 65 |
+
if ensemble_pred == actual_label:
|
| 66 |
+
correct_predictions += 1
|
| 67 |
+
return correct_predictions / total_known if total_known > 0 else 0.0
|
| 68 |
+
|
| 69 |
+
def _should_optimize(self):
|
| 70 |
+
"""Determine if weights should be optimized based on recent performance change."""
|
| 71 |
+
if len(self.prediction_history) < self.min_history_for_optimization * 2: # Need enough history for comparison
|
| 72 |
+
return False
|
| 73 |
+
|
| 74 |
+
# Compare accuracy of recent batch with previous batch
|
| 75 |
+
recent_batch = self.prediction_history[-self.min_history_for_optimization:]
|
| 76 |
+
previous_batch = self.prediction_history[-self.min_history_for_optimization*2:-self.min_history_for_optimization]
|
| 77 |
+
|
| 78 |
+
recent_accuracy = self._calculate_accuracy(recent_batch)
|
| 79 |
+
previous_accuracy = self._calculate_accuracy(previous_batch)
|
| 80 |
+
|
| 81 |
+
# Trigger optimization if there's a significant drop in accuracy
|
| 82 |
+
if previous_accuracy > 0 and (previous_accuracy - recent_accuracy) / previous_accuracy > self.optimization_threshold:
|
| 83 |
+
logger.warning(f"Performance degradation detected (from {previous_accuracy:.2f} to {recent_accuracy:.2f}). Triggering weight optimization.")
|
| 84 |
+
return True
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
def _optimize_weights(self):
|
| 88 |
+
"""Optimize model weights based on performance."""
|
| 89 |
+
logger.info("Optimizing model weights based on recent performance.")
|
| 90 |
+
# Placeholder for sophisticated optimization logic.
|
| 91 |
+
# This is where you would adjust self.weight_manager.base_weights
|
| 92 |
+
# based on which models contributed more to correct predictions or errors.
|
| 93 |
+
# For now, it's just a log message.
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SystemHealthAgent:
|
| 97 |
+
def __init__(self):
|
| 98 |
+
self.health_metrics = {
|
| 99 |
+
"memory_usage": [],
|
| 100 |
+
"gpu_utilization": [],
|
| 101 |
+
"model_load_times": {},
|
| 102 |
+
"error_rates": {}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
def monitor_system_health(self):
|
| 106 |
+
"""Monitor overall system health"""
|
| 107 |
+
self._check_memory_usage()
|
| 108 |
+
self._check_gpu_utilization()
|
| 109 |
+
# You might add _check_model_health() here later
|
| 110 |
+
|
| 111 |
+
def _check_memory_usage(self):
|
| 112 |
+
"""Monitor memory usage"""
|
| 113 |
+
try:
|
| 114 |
+
import psutil
|
| 115 |
+
memory = psutil.virtual_memory()
|
| 116 |
+
self.health_metrics["memory_usage"].append(memory.percent)
|
| 117 |
+
|
| 118 |
+
if memory.percent > 90:
|
| 119 |
+
logger.warning(f"High memory usage detected: {memory.percent}%")
|
| 120 |
+
except ImportError:
|
| 121 |
+
logger.warning("psutil not installed. Cannot monitor memory usage.")
|
| 122 |
+
|
| 123 |
+
def _check_gpu_utilization(self):
|
| 124 |
+
"""Monitor GPU utilization if available"""
|
| 125 |
+
if torch.cuda.is_available():
|
| 126 |
+
try:
|
| 127 |
+
gpu_util = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated()
|
| 128 |
+
self.health_metrics["gpu_utilization"].append(gpu_util)
|
| 129 |
+
|
| 130 |
+
if gpu_util > 0.9:
|
| 131 |
+
logger.warning(f"High GPU utilization detected: {gpu_util*100:.2f}%")
|
| 132 |
+
except Exception as e:
|
| 133 |
+
logger.warning(f"Error monitoring GPU utilization: {e}")
|
| 134 |
+
else:
|
| 135 |
+
logger.info("CUDA not available. Skipping GPU utilization monitoring.")
|
utils/weight_management.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
logger = logging.getLogger(__name__)
|
| 5 |
+
|
| 6 |
+
class ContextualWeightOverrideAgent:
|
| 7 |
+
def __init__(self):
|
| 8 |
+
self.context_overrides = {
|
| 9 |
+
# Example: when image is outdoor, model_X is penalized, model_Y is boosted
|
| 10 |
+
"outdoor": {
|
| 11 |
+
"model_1": 0.8, # Example: Reduce weight of model_1 by 20% for outdoor scenes
|
| 12 |
+
"model_5": 1.2, # Example: Boost weight of model_5 by 20% for outdoor scenes
|
| 13 |
+
},
|
| 14 |
+
"low_light": {
|
| 15 |
+
"model_2": 0.7,
|
| 16 |
+
"model_7": 1.3,
|
| 17 |
+
},
|
| 18 |
+
"sunny": {
|
| 19 |
+
"model_3": 0.9,
|
| 20 |
+
"model_4": 1.1,
|
| 21 |
+
}
|
| 22 |
+
# Add more contexts and their specific model weight adjustments here
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
def get_overrides(self, context_tags: list[str]) -> dict:
|
| 26 |
+
"""Returns combined weight overrides for given context tags."""
|
| 27 |
+
combined_overrides = {}
|
| 28 |
+
for tag in context_tags:
|
| 29 |
+
if tag in self.context_overrides:
|
| 30 |
+
for model_id, multiplier in self.context_overrides[tag].items():
|
| 31 |
+
# If a model appears in multiple contexts, we can decide how to combine (e.g., multiply, average, take max)
|
| 32 |
+
# For now, let's just take the last one if there are conflicts, or multiply for simple cumulative effect.
|
| 33 |
+
combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier
|
| 34 |
+
return combined_overrides
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ModelWeightManager:
|
| 38 |
+
def __init__(self):
|
| 39 |
+
self.base_weights = {
|
| 40 |
+
"model_1": 0.15, # SwinV2 Based
|
| 41 |
+
"model_2": 0.15, # ViT Based
|
| 42 |
+
"model_3": 0.15, # SDXL Dataset
|
| 43 |
+
"model_4": 0.15, # SDXL + FLUX
|
| 44 |
+
"model_5": 0.15, # ViT Based
|
| 45 |
+
"model_5b": 0.10, # ViT Based, Newer Dataset
|
| 46 |
+
"model_6": 0.10, # Swin, Midj + SDXL
|
| 47 |
+
"model_7": 0.05 # ViT
|
| 48 |
+
}
|
| 49 |
+
self.situation_weights = {
|
| 50 |
+
"high_confidence": 1.2, # Boost weights for high confidence predictions
|
| 51 |
+
"low_confidence": 0.8, # Reduce weights for low confidence
|
| 52 |
+
"conflict": 0.5, # Reduce weights when models disagree
|
| 53 |
+
"consensus": 1.5 # Boost weights when models agree
|
| 54 |
+
}
|
| 55 |
+
self.context_override_agent = ContextualWeightOverrideAgent()
|
| 56 |
+
|
| 57 |
+
def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None):
|
| 58 |
+
"""Dynamically adjust weights based on prediction patterns and optional context."""
|
| 59 |
+
adjusted_weights = self.base_weights.copy()
|
| 60 |
+
|
| 61 |
+
# 1. Apply contextual overrides first
|
| 62 |
+
if context_tags:
|
| 63 |
+
overrides = self.context_override_agent.get_overrides(context_tags)
|
| 64 |
+
for model_id, multiplier in overrides.items():
|
| 65 |
+
adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
|
| 66 |
+
|
| 67 |
+
# 2. Apply situation-based adjustments (consensus, conflict, confidence)
|
| 68 |
+
# Check for consensus
|
| 69 |
+
if self._has_consensus(predictions):
|
| 70 |
+
for model in adjusted_weights:
|
| 71 |
+
adjusted_weights[model] *= self.situation_weights["consensus"]
|
| 72 |
+
|
| 73 |
+
# Check for conflicts
|
| 74 |
+
if self._has_conflicts(predictions):
|
| 75 |
+
for model in adjusted_weights:
|
| 76 |
+
adjusted_weights[model] *= self.situation_weights["conflict"]
|
| 77 |
+
|
| 78 |
+
# Adjust based on confidence
|
| 79 |
+
for model, confidence in confidence_scores.items():
|
| 80 |
+
if confidence > 0.8:
|
| 81 |
+
adjusted_weights[model] *= self.situation_weights["high_confidence"]
|
| 82 |
+
elif confidence < 0.5:
|
| 83 |
+
adjusted_weights[model] *= self.situation_weights["low_confidence"]
|
| 84 |
+
|
| 85 |
+
return self._normalize_weights(adjusted_weights)
|
| 86 |
+
|
| 87 |
+
def _has_consensus(self, predictions):
|
| 88 |
+
"""Check if models agree on prediction"""
|
| 89 |
+
# Ensure all predictions are not None before checking for consensus
|
| 90 |
+
non_none_predictions = [p for p in predictions.values() if p is not None and p != "Error"]
|
| 91 |
+
return len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
|
| 92 |
+
|
| 93 |
+
def _has_conflicts(self, predictions):
|
| 94 |
+
"""Check if models have conflicting predictions"""
|
| 95 |
+
# Ensure all predictions are not None before checking for conflicts
|
| 96 |
+
non_none_predictions = [p for p in predictions.values() if p is not None and p != "Error"]
|
| 97 |
+
return len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
|
| 98 |
+
|
| 99 |
+
def _normalize_weights(self, weights):
|
| 100 |
+
"""Normalize weights to sum to 1"""
|
| 101 |
+
total = sum(weights.values())
|
| 102 |
+
if total == 0:
|
| 103 |
+
# Handle case where all weights became zero due to aggressive multipliers
|
| 104 |
+
# This could assign equal weights or revert to base weights
|
| 105 |
+
logger.warning("All weights became zero after adjustments. Reverting to base weights.")
|
| 106 |
+
return {k: 1.0/len(self.base_weights) for k in self.base_weights}
|
| 107 |
+
return {k: v/total for k, v in weights.items()}
|