Trad / LLM.py
Riy777's picture
Update LLM.py
e35582f
raw
history blame
16.5 kB
import os, traceback, asyncio, json
from datetime import datetime
from functools import wraps
from backoff import on_exception, expo
from openai import OpenAI, RateLimitError, APITimeoutError
import numpy as np
from sentiment_news import NewsFetcher
from helpers import parse_json_from_response, validate_required_fields, format_technical_indicators, format_strategy_scores, local_analyze_opportunity, local_re_analyze_trade
NVIDIA_API_KEY = os.getenv("NVIDIA_API_KEY")
PRIMARY_MODEL = "nvidia/llama-3.1-nemotron-ultra-253b-v1"
class PatternAnalysisEngine:
def __init__(self, llm_service):
self.llm = llm_service
def _format_chart_data_for_llm(self, ohlcv_data):
if not ohlcv_data or len(ohlcv_data) < 20: return "Insufficient chart data for pattern analysis"
try:
candles_to_analyze = ohlcv_data[-50:] if len(ohlcv_data) > 50 else ohlcv_data
chart_description = ["CANDLE DATA FOR PATTERN ANALYSIS:", f"Total candles available: {len(ohlcv_data)}", f"Candles used for analysis: {len(candles_to_analyze)}", ""]
if len(candles_to_analyze) >= 10:
recent_candles = candles_to_analyze[-10:]
chart_description.append("Recent 10 Candles (Latest First):")
for i, candle in enumerate(reversed(recent_candles)):
candle_idx = len(candles_to_analyze) - i
desc = f"Candle {candle_idx}: O:{candle[1]:.6f} H:{candle[2]:.6f} L:{candle[3]:.6f} C:{candle[4]:.6f} V:{candle[5]:.0f}"
chart_description.append(f" {desc}")
if len(candles_to_analyze) >= 2:
first_close = candles_to_analyze[0][4]
last_close = candles_to_analyze[-1][4]
price_change = ((last_close - first_close) / first_close) * 100
trend = "BULLISH" if price_change > 2 else "BEARISH" if price_change < -2 else "SIDEWAYS"
highs = [c[2] for c in candles_to_analyze]
lows = [c[3] for c in candles_to_analyze]
high_max = max(highs)
low_min = min(lows)
volatility = ((high_max - low_min) / low_min) * 100
chart_description.extend(["", "MARKET STRUCTURE ANALYSIS:", f"Trend Direction: {trend}", f"Price Change: {price_change:+.2f}%", f"Volatility Range: {volatility:.2f}%", f"Highest Price: {high_max:.6f}", f"Lowest Price: {low_min:.6f}"])
if len(candles_to_analyze) >= 5:
volumes = [c[5] for c in candles_to_analyze]
avg_volume = sum(volumes) / len(volumes)
current_volume = candles_to_analyze[-1][5]
volume_ratio = current_volume / avg_volume if avg_volume > 0 else 1
volume_signal = "HIGH" if volume_ratio > 2 else "NORMAL" if volume_ratio > 0.5 else "LOW"
chart_description.extend(["", "VOLUME ANALYSIS:", f"Current Volume: {current_volume:,.0f}", f"Volume Ratio: {volume_ratio:.2f}x average", f"Volume Signal: {volume_signal}"])
return "\n".join(chart_description)
except Exception as e: return f"Error formatting chart data: {str(e)}"
async def analyze_chart_patterns(self, symbol, ohlcv_data):
try:
if not ohlcv_data or len(ohlcv_data) < 20:
return {"pattern_detected": "insufficient_data", "pattern_confidence": 0.1, "pattern_analysis": "Insufficient candle data for pattern analysis"}
chart_text = self._format_chart_data_for_llm(ohlcv_data)
prompt = f"Analyze the following candle data for {symbol} and identify patterns.\n\nCANDLE DATA FOR ANALYSIS:\n{chart_text}\n\nOUTPUT FORMAT (JSON):\n{{\"pattern_detected\": \"pattern_name\",\"pattern_confidence\": 0.85,\"pattern_strength\": \"strong/medium/weak\",\"predicted_direction\": \"up/down/sideways\",\"predicted_movement_percent\": 5.50,\"timeframe_expectation\": \"15-25 minutes\",\"entry_suggestion\": 0.1234,\"target_suggestion\": 0.1357,\"stop_suggestion\": 0.1189,\"key_support\": 0.1200,\"key_resistance\": 0.1300,\"pattern_analysis\": \"Detailed explanation\"}}"
response = await self.llm._call_llm(prompt)
return self._parse_pattern_response(response)
except Exception as e:
print(f"Chart pattern analysis failed for {symbol}: {e}")
return None
def _parse_pattern_response(self, response_text):
try:
json_str = parse_json_from_response(response_text)
if not json_str: return {"pattern_detected": "parse_error", "pattern_confidence": 0.1, "pattern_analysis": "Could not parse pattern analysis response"}
pattern_data = json.loads(json_str)
required = ['pattern_detected', 'pattern_confidence', 'predicted_direction']
if not validate_required_fields(pattern_data, required): return {"pattern_detected": "incomplete_data", "pattern_confidence": 0.1, "pattern_analysis": "Incomplete pattern analysis data"}
return pattern_data
except Exception as e:
print(f"Error parsing pattern response: {e}")
return {"pattern_detected": "parse_error", "pattern_confidence": 0.1, "pattern_analysis": f"Error parsing pattern analysis: {str(e)}"}
class LLMService:
def __init__(self, api_key=NVIDIA_API_KEY, model_name=PRIMARY_MODEL, temperature=0.7):
self.api_key = api_key
self.model_name = model_name
self.temperature = temperature
self.client = OpenAI(base_url="https://integrate.api.nvidia.com/v1", api_key=self.api_key)
self.news_fetcher = NewsFetcher()
self.pattern_engine = PatternAnalysisEngine(self)
self.semaphore = asyncio.Semaphore(5)
def _rate_limit_nvidia_api(func):
@wraps(func)
@on_exception(expo, RateLimitError, max_tries=5)
async def wrapper(*args, **kwargs): return await func(*args, **kwargs)
return wrapper
async def get_trading_decision(self, data_payload: dict):
try:
symbol = data_payload.get('symbol', 'unknown')
target_strategy = data_payload.get('target_strategy', 'GENERIC')
news_text = await self.news_fetcher.get_news_for_symbol(symbol)
pattern_analysis = await self._get_pattern_analysis(data_payload)
prompt = self._create_enhanced_trading_prompt(data_payload, news_text, pattern_analysis)
async with self.semaphore: response = await self._call_llm(prompt)
decision_dict = self._parse_llm_response_enhanced(response, target_strategy, symbol)
if decision_dict:
decision_dict['model_source'] = self.model_name
decision_dict['pattern_analysis'] = pattern_analysis
return decision_dict
else: return local_analyze_opportunity(data_payload)
except Exception as e:
print(f"Error getting LLM decision for {data_payload.get('symbol', 'unknown')}: {e}")
return local_analyze_opportunity(data_payload)
def _parse_llm_response_enhanced(self, response_text: str, fallback_strategy: str, symbol: str) -> dict:
try:
json_str = parse_json_from_response(response_text)
if not json_str: return None
decision_data = json.loads(json_str)
required_fields = ['action', 'reasoning', 'risk_assessment', 'trade_type', 'stop_loss', 'take_profit', 'expected_target_minutes', 'confidence_level']
if not validate_required_fields(decision_data, required_fields): return None
strategy_value = decision_data.get('strategy')
if not strategy_value or strategy_value == 'unknown': decision_data['strategy'] = fallback_strategy
return decision_data
except Exception as e:
print(f"Error parsing LLM response for {symbol}: {e}")
return None
async def _get_pattern_analysis(self, data_payload):
try:
symbol = data_payload['symbol']
if 'raw_ohlcv' in data_payload and '1h' in data_payload['raw_ohlcv']:
ohlcv_data = data_payload['raw_ohlcv']['1h']
if ohlcv_data and len(ohlcv_data) >= 20: return await self.pattern_engine.analyze_chart_patterns(symbol, ohlcv_data)
if 'advanced_indicators' in data_payload and '1h' in data_payload['advanced_indicators']:
ohlcv_data = data_payload['advanced_indicators']['1h']
if ohlcv_data and len(ohlcv_data) >= 20: return await self.pattern_engine.analyze_chart_patterns(symbol, ohlcv_data)
return None
except Exception as e:
print(f"Pattern analysis failed for {data_payload.get('symbol')}: {e}")
return None
def _create_enhanced_trading_prompt(self, payload: dict, news_text: str, pattern_analysis: dict) -> str:
symbol = payload.get('symbol', 'N/A')
current_price = payload.get('current_price', 'N/A')
reasons = payload.get('reasons_for_candidacy', [])
sentiment_data = payload.get('sentiment_data', {})
advanced_indicators = payload.get('advanced_indicators', {})
strategy_scores = payload.get('strategy_scores', {})
recommended_strategy = payload.get('recommended_strategy', 'N/A')
target_strategy = payload.get('target_strategy', 'GENERIC')
final_score = payload.get('final_score', 'N/A')
enhanced_final_score = payload.get('enhanced_final_score', 'N/A')
whale_data = payload.get('whale_data', {})
final_score_display = f"{final_score:.2f}" if isinstance(final_score, (int, float)) else str(final_score)
enhanced_score_display = f"{enhanced_final_score:.2f}" if isinstance(enhanced_final_score, (int, float)) else str(enhanced_final_score)
indicators_summary = format_technical_indicators(advanced_indicators)
strategies_summary = format_strategy_scores(strategy_scores, recommended_strategy)
pattern_summary = self._format_pattern_analysis(pattern_analysis)
whale_analysis_section = self._format_whale_analysis(sentiment_data.get('general_whale_activity', {}), whale_data, symbol)
prompt = f"""
TRADING ANALYSIS FOR {symbol}
STRATEGY: {target_strategy}
Current Price: {current_price}
System Score: {final_score_display}
Enhanced Score: {enhanced_score_display}
CHART PATTERN ANALYSIS:
{pattern_summary}
TECHNICAL INDICATORS:
{indicators_summary}
STRATEGY ANALYSIS:
{strategies_summary}
MARKET CONTEXT:
- BTC Trend: {sentiment_data.get('btc_sentiment', 'N/A')}
- Fear & Greed: {sentiment_data.get('fear_and_greed_index', 'N/A')}
WHALE ANALYSIS:
{whale_analysis_section}
NEWS:
{news_text}
OUTPUT (JSON):
{{
"action": "BUY/SELL/HOLD",
"reasoning": "Detailed explanation",
"risk_assessment": "Risk analysis",
"trade_type": "LONG/SHORT",
"stop_loss": 0.0000,
"take_profit": 0.0000,
"expected_target_minutes": 15,
"confidence_level": 0.85,
"strategy": "{target_strategy}",
"pattern_influence": "Pattern influence description"
}}
"""
return prompt
def _format_pattern_analysis(self, pattern_analysis):
if not pattern_analysis: return "No clear patterns detected"
confidence = pattern_analysis.get('pattern_confidence', 0)
pattern_name = pattern_analysis.get('pattern_detected', 'unknown')
analysis_lines = [f"Pattern: {pattern_name}", f"Confidence: {confidence:.1%}", f"Predicted Move: {pattern_analysis.get('predicted_direction', 'N/A')}", f"Analysis: {pattern_analysis.get('pattern_analysis', 'No detailed analysis')}"]
return "\n".join(analysis_lines)
def _format_whale_analysis(self, general_whale_activity, symbol_whale_data, symbol):
from sentiment_news import SentimentAnalyzer
temp_analyzer = SentimentAnalyzer(None)
return temp_analyzer.format_whale_analysis(general_whale_activity, symbol_whale_data, symbol)
async def re_analyze_trade_async(self, trade_data: dict, processed_data: dict):
try:
symbol = trade_data['symbol']
original_strategy = trade_data.get('strategy', 'GENERIC')
news_text = await self.news_fetcher.get_news_for_symbol(symbol)
pattern_analysis = await self._get_pattern_analysis(processed_data)
prompt = self._create_re_analysis_prompt(trade_data, processed_data, news_text, pattern_analysis)
async with self.semaphore: response = await self._call_llm(prompt)
re_analysis_dict = self._parse_re_analysis_response(response, original_strategy, symbol)
if re_analysis_dict:
re_analysis_dict['model_source'] = self.model_name
return re_analysis_dict
else: return local_re_analyze_trade(trade_data, processed_data)
except Exception as e:
print(f"Error in LLM re-analysis: {e}")
return local_re_analyze_trade(trade_data, processed_data)
def _parse_re_analysis_response(self, response_text: str, fallback_strategy: str, symbol: str) -> dict:
try:
json_str = parse_json_from_response(response_text)
if not json_str: return None
decision_data = json.loads(json_str)
strategy_value = decision_data.get('strategy')
if not strategy_value or strategy_value == 'unknown': decision_data['strategy'] = fallback_strategy
return decision_data
except Exception as e:
print(f"Error parsing re-analysis response for {symbol}: {e}")
return None
def _create_re_analysis_prompt(self, trade_data: dict, processed_data: dict, news_text: str, pattern_analysis: dict) -> str:
symbol = trade_data.get('symbol', 'N/A')
entry_price = trade_data.get('entry_price', 'N/A')
current_price = processed_data.get('current_price', 'N/A')
strategy = trade_data.get('strategy', 'GENERIC')
try: price_change = ((current_price - entry_price) / entry_price) * 100; price_change_display = f"{price_change:+.2f}%"
except (TypeError, ZeroDivisionError): price_change_display = "N/A"
indicators_summary = format_technical_indicators(processed_data.get('advanced_indicators', {}))
pattern_summary = self._format_pattern_analysis(pattern_analysis)
whale_analysis_section = self._format_whale_analysis(processed_data.get('sentiment_data', {}).get('general_whale_activity', {}), processed_data.get('whale_data', {}), symbol)
prompt = f"""
TRADE RE-ANALYSIS FOR {symbol}
TRADE CONTEXT:
- Strategy: {strategy}
- Entry Price: {entry_price}
- Current Price: {current_price}
- Performance: {price_change_display}
UPDATED PATTERN ANALYSIS:
{pattern_summary}
UPDATED TECHNICALS:
{indicators_summary}
UPDATED WHALE DATA:
{whale_analysis_section}
LATEST NEWS:
{news_text}
OUTPUT (JSON):
{{
"action": "HOLD/CLOSE_TRADE/UPDATE_TRADE",
"reasoning": "Justification",
"new_stop_loss": 0.0000,
"new_take_profit": 0.0000,
"new_expected_minutes": 15,
"confidence_level": 0.85,
"strategy": "{strategy}",
"pattern_influence_reanalysis": "Pattern influence description"
}}
"""
return prompt
@_rate_limit_nvidia_api
async def _call_llm(self, prompt: str) -> str:
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
temperature=self.temperature,
seed=42
)
return response.choices[0].message.content
except (RateLimitError, APITimeoutError) as e:
print(f"LLM API Error: {e}. Retrying...")
raise
except Exception as e:
print(f"Unexpected LLM API error: {e}")
raise