# ml_engine/strategies.py (Updated to use LearningHub for weights) import asyncio # (Import from internal modules) from .patterns import ChartPatternAnalyzer class PatternEnhancedStrategyEngine: # 🔴 --- START OF CHANGE --- 🔴 def __init__(self, data_manager, learning_hub): # (Changed from learning_engine) self.data_manager = data_manager self.learning_hub = learning_hub # (Changed from learning_engine) self.pattern_analyzer = ChartPatternAnalyzer() # 🔴 --- END OF CHANGE --- 🔴 async def enhance_strategy_with_patterns(self, strategy_scores, pattern_analysis, symbol): """(Unchanged logic)""" if not pattern_analysis or pattern_analysis.get('pattern_detected') in ['no_clear_pattern', 'insufficient_data']: return strategy_scores pattern_confidence = pattern_analysis.get('pattern_confidence', 0) pattern_name = pattern_analysis.get('pattern_detected', '') predicted_direction = pattern_analysis.get('predicted_direction', '') if pattern_confidence >= 0.6: enhancement_factor = self._calculate_pattern_enhancement(pattern_confidence, pattern_name) enhanced_strategies = self._get_pattern_appropriate_strategies(pattern_name, predicted_direction) # (Omitted print statements for brevity) for strategy in enhanced_strategies: if strategy in strategy_scores: original_score = strategy_scores[strategy] strategy_scores[strategy] = min(original_score * enhancement_factor, 1.0) return strategy_scores def _calculate_pattern_enhancement(self, pattern_confidence, pattern_name): """(Unchanged logic)""" base_enhancement = 1.0 + (pattern_confidence * 0.3) high_reliability_patterns = ['Double Top', 'Double Bottom', 'Head & Shoulders', 'Cup and Handle'] if pattern_name in high_reliability_patterns: base_enhancement *= 1.1 return min(base_enhancement, 1.5) def _get_pattern_appropriate_strategies(self, pattern_name, direction): """(Unchanged logic)""" reversal_patterns = ['Double Top', 'Double Bottom', 'Head & Shoulders', 'Triple Top', 'Triple Bottom'] continuation_patterns = ['Flags', 'Pennants', 'Triangles', 'Rectangles'] if pattern_name in reversal_patterns: if direction == 'down': return ['breakout_momentum', 'trend_following'] else: return ['mean_reversion', 'breakout_momentum'] elif pattern_name in continuation_patterns: return ['trend_following', 'breakout_momentum'] else: return ['breakout_momentum', 'hybrid_ai'] class MultiStrategyEngine: # 🔴 --- START OF CHANGE --- 🔴 def __init__(self, data_manager, learning_hub): # (Changed from learning_engine) self.data_manager = data_manager self.learning_hub = learning_hub # (Changed from learning_engine) # (Pass the hub to the enhancer) self.pattern_enhancer = PatternEnhancedStrategyEngine(data_manager, learning_hub) # 🔴 --- END OF CHANGE --- 🔴 self.strategies = { 'trend_following': self._trend_following_strategy, 'mean_reversion': self._mean_reversion_strategy, 'breakout_momentum': self._breakout_momentum_strategy, 'volume_spike': self._volume_spike_strategy, 'whale_tracking': self._whale_tracking_strategy, 'pattern_recognition': self._pattern_recognition_strategy, 'hybrid_ai': self._hybrid_ai_strategy } async def evaluate_all_strategies(self, symbol_data, market_context): """Evaluate all trading strategies""" try: # 🔴 --- START OF CHANGE --- 🔴 # (Get weights from the new Learning Hub) if self.learning_hub and self.learning_hub.initialized: try: market_condition = market_context.get('market_trend', 'sideways_market') # (Call the new hub function) optimized_weights = await self.learning_hub.get_optimized_weights(market_condition) except Exception as e: print(f"⚠️ Error getting optimized weights from hub: {e}. Using defaults.") optimized_weights = await self.get_default_weights() else: optimized_weights = await self.get_default_weights() # 🔴 --- END OF CHANGE --- 🔴 strategy_scores = {} base_scores = {} primary_strategies = [s for s in self.strategies.keys() if s != 'hybrid_ai'] for strategy_name in primary_strategies: strategy_function = self.strategies[strategy_name] try: base_score = await strategy_function(symbol_data, market_context) if base_score is None: continue base_scores[strategy_name] = base_score weight = optimized_weights.get(strategy_name, 0.1) weighted_score = base_score * weight strategy_scores[strategy_name] = min(weighted_score, 1.0) except Exception as error: print(f"❌ Error evaluating strategy {strategy_name}: {error}") continue try: hybrid_score = await self._hybrid_ai_strategy(symbol_data, market_context, base_scores) if hybrid_score is not None: base_scores['hybrid_ai'] = hybrid_score weight = optimized_weights.get('hybrid_ai', 0.1) strategy_scores['hybrid_ai'] = min(hybrid_score * weight, 1.0) except Exception as e: print(f"❌ Error in hybrid_ai strategy: {e}") # Pattern enhancement (Unchanged) pattern_analysis = symbol_data.get('pattern_analysis') if pattern_analysis: strategy_scores = await self.pattern_enhancer.enhance_strategy_with_patterns( strategy_scores, pattern_analysis, symbol_data.get('symbol') ) if base_scores: best_strategy = max(base_scores.items(), key=lambda x: x[1]) best_strategy_name = best_strategy[0] best_strategy_score = best_strategy[1] symbol_data['recommended_strategy'] = best_strategy_name symbol_data['strategy_confidence'] = best_strategy_score return strategy_scores, base_scores except Exception as error: print(f"❌ Error in evaluate_all_strategies: {error}") return {}, {} async def get_default_weights(self): """(Unchanged) Default weights""" return { 'trend_following': 0.15, 'mean_reversion': 0.12, 'breakout_momentum': 0.20, 'volume_spike': 0.13, 'whale_tracking': 0.20, 'pattern_recognition': 0.10, 'hybrid_ai': 0.10 } # # (All individual strategy functions remain unchanged) # (_trend_following_strategy, _mean_reversion_strategy, etc.) # (Omitted for brevity) # async def _trend_following_strategy(self, symbol_data, market_context): try: score = 0.0 indicators = symbol_data.get('advanced_indicators', {}) for timeframe in ['1h', '15m', '5m']: if timeframe in indicators: tf_indicators = indicators[timeframe] ema_21 = tf_indicators.get('ema_21') ema_50 = tf_indicators.get('ema_50') adx = tf_indicators.get('adx', 0) if ema_21 is not None and ema_50 is not None: if ema_21 > ema_50: score += 0.2 if adx > 20: score += 0.1 if symbol_data['current_price'] > ema_21: score += 0.05 return min(score, 1.0) except Exception: return None def _check_ema_alignment(self, indicators): required_emas = ['ema_9', 'ema_21', 'ema_50'] if all(ema in indicators for ema in required_emas): return (indicators['ema_9'] > indicators['ema_21'] > indicators['ema_50']) return False async def _mean_reversion_strategy(self, symbol_data, market_context): try: score = 0.0 current_price = symbol_data['current_price'] indicators = symbol_data.get('advanced_indicators', {}) for timeframe in ['1h', '15m']: if timeframe in indicators: tf_indicators = indicators[timeframe] rsi_value = tf_indicators.get('rsi', 50) bb_lower = tf_indicators.get('bb_lower') bb_upper = tf_indicators.get('bb_upper') if bb_lower is None or bb_upper is None: continue position_in_band = 0.5 if (bb_upper - bb_lower) > 0: position_in_band = (current_price - bb_lower) / (bb_upper - bb_lower) is_rsi_oversold = rsi_value < 25 is_bb_oversold = position_in_band < 0.1 if is_rsi_oversold or is_bb_oversold: score += 0.4 if is_rsi_oversold and is_bb_oversold: score += 0.2 return min(score, 1.0) except Exception: return None async def _breakout_momentum_strategy(self, symbol_data, market_context): try: score = 0.0 current_price = symbol_data['current_price'] indicators = symbol_data.get('advanced_indicators', {}) for timeframe in ['1h', '15m', '5m']: if timeframe in indicators: tf_indicators = indicators[timeframe] volume_ratio = tf_indicators.get('volume_ratio', 0) if volume_ratio < 1.5: continue score += 0.2 macd_hist = tf_indicators.get('macd_hist', 0) if macd_hist > 0: score += 0.1 atr_percent = tf_indicators.get('atr_percent', 0) if atr_percent > 1.5: score += 0.1 vwap = tf_indicators.get('vwap') if vwap and current_price > vwap: score += 0.05 return min(score, 1.0) except Exception: return None async def _volume_spike_strategy(self, symbol_data, market_context): try: score = 0.0 indicators = symbol_data.get('advanced_indicators', {}) for timeframe in ['1h', '15m', '5m']: if timeframe in indicators: volume_ratio = indicators[timeframe].get('volume_ratio', 0) if volume_ratio > 3.0: score += 0.45 elif volume_ratio > 2.0: score += 0.25 elif volume_ratio > 1.5: score += 0.15 return min(score, 1.0) except Exception: return None async def _whale_tracking_strategy(self, symbol_data, market_context): try: whale_data = symbol_data.get('whale_data', {}) if not whale_data.get('data_available', False): return None whale_signal = await self.data_manager.get_whale_trading_signal( symbol_data['symbol'], whale_data, market_context ) if whale_signal and whale_signal.get('action') != 'HOLD': confidence = whale_signal.get('confidence', 0) if whale_signal.get('action') in ['STRONG_BUY', 'BUY']: return min(confidence * 1.2, 1.0) return None except Exception: return None async def _pattern_recognition_strategy(self, symbol_data, market_context): try: score = 0.0 pattern_analysis = symbol_data.get('pattern_analysis') if pattern_analysis and pattern_analysis.get('pattern_confidence', 0) > 0.6: if pattern_analysis.get('predicted_direction') == 'up': score += pattern_analysis.get('pattern_confidence', 0) * 0.8 else: indicators = symbol_data.get('advanced_indicators', {}) if '1h' in indicators: tf_indicators = indicators['1h'] if (tf_indicators.get('rsi', 50) > 60 and tf_indicators.get('macd_hist', 0) > 0): score += 0.3 return min(score, 1.0) except Exception: return None async def _hybrid_ai_strategy(self, symbol_data, market_context, base_scores): try: score = 0.0 monte_carlo_prob = symbol_data.get('monte_carlo_probability') if monte_carlo_prob is not None: score += monte_carlo_prob * 0.4 breakout_score = base_scores.get('breakout_momentum', 0) volume_score = base_scores.get('volume_spike', 0) whale_score = base_scores.get('whale_tracking', 0) pattern_score = base_scores.get('pattern_recognition', 0) if breakout_score > 0.7 and volume_score > 0.6: score += 0.3 if breakout_score > 0.6 and whale_score > 0.7: score += 0.4 if pattern_score > 0.7 and volume_score > 0.5: score += 0.2 if breakout_score > 0.7 and whale_score > 0.7 and volume_score > 0.7: score = 1.0 return max(0.0, min(score, 1.0)) except Exception: return None print("✅ ML Module: Strategy Engine loaded (V3 - Integrated LearningHub for weights)")