RickyGuoTheCrazish
commited on
Commit
·
33c14bd
1
Parent(s):
14eed51
update finbert_market_evaluation
Browse files- README.md +45 -4
- requirements.txt +23 -3
- run_app.py +37 -0
- src/__init__.py +1 -0
- src/evaluation.py +294 -0
- src/market_data.py +297 -0
- src/sentiment_analyzer.py +120 -0
- src/streamlit_app.py +369 -35
- src/visualizations.py +302 -0
README.md
CHANGED
|
@@ -7,14 +7,55 @@ sdk: docker
|
|
| 7 |
app_port: 8501
|
| 8 |
tags:
|
| 9 |
- streamlit
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
pinned: false
|
| 11 |
short_description: Evaluate FinBERT’s sentiment predictions against market data
|
| 12 |
license: mit
|
| 13 |
---
|
| 14 |
|
| 15 |
-
#
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
app_port: 8501
|
| 8 |
tags:
|
| 9 |
- streamlit
|
| 10 |
+
- finbert
|
| 11 |
+
- sentiment-analysis
|
| 12 |
+
- finance
|
| 13 |
+
- machine-learning
|
| 14 |
pinned: false
|
| 15 |
short_description: Evaluate FinBERT’s sentiment predictions against market data
|
| 16 |
license: mit
|
| 17 |
---
|
| 18 |
|
| 19 |
+
# 🚀 FinBERT Market Evaluation
|
| 20 |
|
| 21 |
+
Evaluate how well FinBERT's financial sentiment predictions match actual stock market movements.
|
| 22 |
|
| 23 |
+
## What It Does
|
| 24 |
+
|
| 25 |
+
Enter financial news → Get FinBERT sentiment → Compare with actual stock price movement → See if the prediction was right.
|
| 26 |
+
|
| 27 |
+
## How to Use
|
| 28 |
+
|
| 29 |
+
1. **Paste financial news** (e.g., "Apple reports record earnings")
|
| 30 |
+
2. **Enter stock ticker** (e.g., AAPL)
|
| 31 |
+
3. **Select news date** (when the news was published)
|
| 32 |
+
4. **Get results** - see if sentiment matched price movement
|
| 33 |
+
|
| 34 |
+
## Key Features
|
| 35 |
+
|
| 36 |
+
- **Smart thresholds** - Uses each stock's volatility (no rigid ±1% rules)
|
| 37 |
+
- **Same-day + 24h analysis** - Immediate reaction + follow-through
|
| 38 |
+
- **Graded scoring** - Not just right/wrong, but how right (0-1 score)
|
| 39 |
+
- **Market context** - Compares stock vs overall market performance
|
| 40 |
+
|
| 41 |
+
## Example
|
| 42 |
+
|
| 43 |
+
**News**: "Tesla announces new factory in Germany"
|
| 44 |
+
- **FinBERT says**: Positive sentiment (85% confidence)
|
| 45 |
+
- **Stock moved**: +4.2% same day
|
| 46 |
+
- **Evaluation**: ✅ Aligned (sentiment matched direction)
|
| 47 |
+
- **Score**: 0.91/1.0 (excellent alignment)
|
| 48 |
+
|
| 49 |
+
## Installation
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
pip install -r requirements.txt
|
| 53 |
+
streamlit run src/streamlit_app.py
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Limitations
|
| 57 |
+
|
| 58 |
+
- Research tool only (not for trading)
|
| 59 |
+
- 30-second rate limit between requests
|
| 60 |
+
- Needs 1+ day old news (requires market data)
|
| 61 |
+
- Uses Yahoo Finance (free but limited)
|
requirements.txt
CHANGED
|
@@ -1,3 +1,23 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core Streamlit and data processing
|
| 2 |
+
streamlit>=1.28.0
|
| 3 |
+
pandas>=1.5.0
|
| 4 |
+
numpy>=1.24.0
|
| 5 |
+
altair>=4.2.0
|
| 6 |
+
|
| 7 |
+
# Machine Learning and NLP
|
| 8 |
+
transformers>=4.30.0
|
| 9 |
+
torch>=2.0.0
|
| 10 |
+
tokenizers>=0.13.0
|
| 11 |
+
|
| 12 |
+
# Financial data
|
| 13 |
+
yfinance>=0.2.18
|
| 14 |
+
|
| 15 |
+
# Visualization and UI
|
| 16 |
+
plotly>=5.15.0
|
| 17 |
+
matplotlib>=3.7.0
|
| 18 |
+
seaborn>=0.12.0
|
| 19 |
+
|
| 20 |
+
# Utilities
|
| 21 |
+
requests>=2.31.0
|
| 22 |
+
python-dateutil>=2.8.0
|
| 23 |
+
pytz>=2023.3
|
run_app.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Simple launcher script for the FinBERT Market Evaluation Streamlit app.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
"""Launch the Streamlit application."""
|
| 12 |
+
print("🚀 Starting FinBERT Market Evaluation...")
|
| 13 |
+
print("=" * 50)
|
| 14 |
+
|
| 15 |
+
# Change to the correct directory
|
| 16 |
+
app_dir = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
os.chdir(app_dir)
|
| 18 |
+
|
| 19 |
+
# Launch Streamlit
|
| 20 |
+
try:
|
| 21 |
+
cmd = [sys.executable, "-m", "streamlit", "run", "src/streamlit_app.py"]
|
| 22 |
+
print(f"Running: {' '.join(cmd)}")
|
| 23 |
+
print("=" * 50)
|
| 24 |
+
|
| 25 |
+
subprocess.run(cmd, check=True)
|
| 26 |
+
|
| 27 |
+
except subprocess.CalledProcessError as e:
|
| 28 |
+
print(f"❌ Error launching Streamlit: {e}")
|
| 29 |
+
return 1
|
| 30 |
+
except KeyboardInterrupt:
|
| 31 |
+
print("\n👋 Application stopped by user")
|
| 32 |
+
return 0
|
| 33 |
+
|
| 34 |
+
return 0
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
sys.exit(main())
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# FinBERT Market Evaluation Package
|
src/evaluation.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core evaluation engine with heuristic algorithms
|
| 2 |
+
"""
|
| 3 |
+
This module implements the core evaluation logic including DAS calculation,
|
| 4 |
+
volatility-aware thresholds, WAT scoring, and macro-adjusted evaluation metrics.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Dict, List, Optional
|
| 9 |
+
import logging
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
|
| 12 |
+
# Configure logging
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
class EvaluationEngine:
|
| 17 |
+
"""
|
| 18 |
+
Core engine for evaluating FinBERT predictions against market movements.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, volatility_multiplier: float = 1.0, confidence_threshold: float = 0.7):
|
| 22 |
+
"""
|
| 23 |
+
Initialize the evaluation engine.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
volatility_multiplier: k factor for volatility thresholds (default: 1.0)
|
| 27 |
+
threshold = k * 14-day_volatility
|
| 28 |
+
confidence_threshold: Minimum confidence for high-confidence predictions (default: 0.7)
|
| 29 |
+
"""
|
| 30 |
+
self.volatility_multiplier = volatility_multiplier # k = 1.0 per framework
|
| 31 |
+
self.confidence_threshold = confidence_threshold
|
| 32 |
+
|
| 33 |
+
def calculate_das(self, sentiment_direction: int, price_return: float,
|
| 34 |
+
volatility: float) -> float:
|
| 35 |
+
"""
|
| 36 |
+
Calculate Directional Alignment Score (DAS).
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
sentiment_direction: 1 for positive, -1 for negative, 0 for neutral
|
| 40 |
+
price_return: Stock return percentage
|
| 41 |
+
volatility: Stock volatility percentage
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
DAS score between 0 and 1
|
| 45 |
+
"""
|
| 46 |
+
try:
|
| 47 |
+
# Handle neutral sentiment
|
| 48 |
+
if sentiment_direction == 0:
|
| 49 |
+
# For neutral sentiment, score based on how close to zero the return is
|
| 50 |
+
threshold = volatility * self.volatility_multiplier
|
| 51 |
+
if abs(price_return) <= threshold:
|
| 52 |
+
return 1.0 # Perfect neutral prediction
|
| 53 |
+
else:
|
| 54 |
+
# Decay score based on how far from neutral
|
| 55 |
+
excess = abs(price_return) - threshold
|
| 56 |
+
return max(0.0, 1.0 - (excess / (threshold * 2)))
|
| 57 |
+
|
| 58 |
+
# For positive/negative sentiment
|
| 59 |
+
expected_direction = sentiment_direction
|
| 60 |
+
actual_direction = 1 if price_return > 0 else -1 if price_return < 0 else 0
|
| 61 |
+
|
| 62 |
+
# Base alignment check
|
| 63 |
+
if expected_direction == actual_direction:
|
| 64 |
+
# Correct direction - score based on magnitude
|
| 65 |
+
magnitude_factor = min(abs(price_return) / (volatility * self.volatility_multiplier), 2.0)
|
| 66 |
+
return min(1.0, 0.7 + 0.3 * magnitude_factor)
|
| 67 |
+
else:
|
| 68 |
+
# Wrong direction - score based on how wrong
|
| 69 |
+
threshold = volatility * self.volatility_multiplier
|
| 70 |
+
if abs(price_return) <= threshold:
|
| 71 |
+
# Small move in wrong direction - partial credit
|
| 72 |
+
return 0.3
|
| 73 |
+
else:
|
| 74 |
+
# Large move in wrong direction - low score
|
| 75 |
+
return max(0.0, 0.3 - (abs(price_return) - threshold) / (threshold * 3))
|
| 76 |
+
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.error(f"Error calculating DAS: {str(e)}")
|
| 79 |
+
return 0.0
|
| 80 |
+
|
| 81 |
+
def calculate_wat_weight(self, confidence: float, impact: float,
|
| 82 |
+
days_ago: int = 0, decay_factor: float = 0.95) -> float:
|
| 83 |
+
"""
|
| 84 |
+
Calculate Weighted Accuracy over Time (WAT) weight.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
confidence: Model confidence score
|
| 88 |
+
impact: Impact magnitude (absolute return)
|
| 89 |
+
days_ago: Days since prediction (for decay)
|
| 90 |
+
decay_factor: Decay factor for time-based weighting
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
WAT weight for the prediction
|
| 94 |
+
"""
|
| 95 |
+
try:
|
| 96 |
+
# Base weight from confidence and impact
|
| 97 |
+
confidence_weight = confidence
|
| 98 |
+
impact_weight = min(impact / 5.0, 2.0) # Cap at 2x for very large moves
|
| 99 |
+
|
| 100 |
+
# Time decay (optional)
|
| 101 |
+
time_weight = decay_factor ** days_ago if days_ago > 0 else 1.0
|
| 102 |
+
|
| 103 |
+
# Combined weight
|
| 104 |
+
wat_weight = confidence_weight * impact_weight * time_weight
|
| 105 |
+
|
| 106 |
+
return float(wat_weight)
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.error(f"Error calculating WAT weight: {str(e)}")
|
| 110 |
+
return 1.0
|
| 111 |
+
|
| 112 |
+
def evaluate_prediction(self, sentiment_data: Dict, market_data: Dict,
|
| 113 |
+
news_date: datetime) -> Dict:
|
| 114 |
+
"""
|
| 115 |
+
Comprehensive evaluation of a single prediction.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
sentiment_data: Output from FinBERT analyzer
|
| 119 |
+
market_data: Output from market data service
|
| 120 |
+
news_date: Date when news was published
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Complete evaluation results
|
| 124 |
+
"""
|
| 125 |
+
try:
|
| 126 |
+
# Extract key values
|
| 127 |
+
sentiment = sentiment_data.get("sentiment", "neutral")
|
| 128 |
+
confidence = sentiment_data.get("confidence", 0.0)
|
| 129 |
+
return_24h = market_data.get("return_24h")
|
| 130 |
+
volatility_14d = market_data.get("volatility_14d")
|
| 131 |
+
alpha_adjusted = market_data.get("alpha_adjusted")
|
| 132 |
+
|
| 133 |
+
# Check for missing data
|
| 134 |
+
if return_24h is None or volatility_14d is None:
|
| 135 |
+
return {
|
| 136 |
+
"error": "Insufficient market data for evaluation",
|
| 137 |
+
"sentiment": sentiment,
|
| 138 |
+
"confidence": confidence
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
# Convert sentiment to direction
|
| 142 |
+
sentiment_direction = self._get_sentiment_direction(sentiment)
|
| 143 |
+
|
| 144 |
+
# Calculate volatility threshold
|
| 145 |
+
threshold = volatility_14d * self.volatility_multiplier
|
| 146 |
+
|
| 147 |
+
# Calculate DAS
|
| 148 |
+
das_score = self.calculate_das(sentiment_direction, return_24h, volatility_14d)
|
| 149 |
+
|
| 150 |
+
# Determine correctness
|
| 151 |
+
is_correct = self._is_prediction_correct(sentiment_direction, return_24h, threshold)
|
| 152 |
+
|
| 153 |
+
# Calculate WAT weight
|
| 154 |
+
impact = abs(return_24h)
|
| 155 |
+
wat_weight = self.calculate_wat_weight(confidence, impact)
|
| 156 |
+
|
| 157 |
+
# Prepare results
|
| 158 |
+
results = {
|
| 159 |
+
"ticker": market_data.get("ticker", "Unknown"),
|
| 160 |
+
"news_date": news_date.strftime("%Y-%m-%d"),
|
| 161 |
+
"sentiment": sentiment,
|
| 162 |
+
"confidence": confidence,
|
| 163 |
+
"return_24h": return_24h,
|
| 164 |
+
"volatility_14d": volatility_14d,
|
| 165 |
+
"threshold": threshold,
|
| 166 |
+
"das_score": das_score,
|
| 167 |
+
"is_correct": is_correct,
|
| 168 |
+
"wat_weight": wat_weight,
|
| 169 |
+
"impact": impact,
|
| 170 |
+
"alpha_adjusted": alpha_adjusted,
|
| 171 |
+
"sentiment_direction": sentiment_direction,
|
| 172 |
+
"evaluation_summary": self._generate_summary(
|
| 173 |
+
sentiment, confidence, return_24h, das_score, is_correct
|
| 174 |
+
)
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
logger.info(f"Evaluation completed - DAS: {das_score:.3f}, Correct: {is_correct}")
|
| 178 |
+
return results
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
logger.error(f"Error in prediction evaluation: {str(e)}")
|
| 182 |
+
return {"error": str(e)}
|
| 183 |
+
|
| 184 |
+
def _get_sentiment_direction(self, sentiment: str) -> int:
|
| 185 |
+
"""Convert sentiment to numerical direction."""
|
| 186 |
+
sentiment_map = {
|
| 187 |
+
"positive": 1,
|
| 188 |
+
"negative": -1,
|
| 189 |
+
"neutral": 0
|
| 190 |
+
}
|
| 191 |
+
return sentiment_map.get(sentiment.lower(), 0)
|
| 192 |
+
|
| 193 |
+
def _is_prediction_correct(self, sentiment_direction: int, price_return: float,
|
| 194 |
+
threshold: float) -> bool:
|
| 195 |
+
"""
|
| 196 |
+
Determine if prediction is correct based on volatility-aware thresholds.
|
| 197 |
+
"""
|
| 198 |
+
if sentiment_direction == 0: # Neutral
|
| 199 |
+
return abs(price_return) <= threshold
|
| 200 |
+
elif sentiment_direction == 1: # Positive
|
| 201 |
+
return price_return > threshold
|
| 202 |
+
elif sentiment_direction == -1: # Negative
|
| 203 |
+
return price_return < -threshold
|
| 204 |
+
else:
|
| 205 |
+
return False
|
| 206 |
+
|
| 207 |
+
def _generate_summary(self, sentiment: str, confidence: float,
|
| 208 |
+
return_24h: float, das_score: float, is_correct: bool) -> str:
|
| 209 |
+
"""Generate human-readable evaluation summary."""
|
| 210 |
+
direction = "📈" if return_24h > 0 else "📉" if return_24h < 0 else "➡️"
|
| 211 |
+
|
| 212 |
+
# More nuanced verdict based on DAS score
|
| 213 |
+
if is_correct:
|
| 214 |
+
verdict = "✅ Aligned"
|
| 215 |
+
else:
|
| 216 |
+
if das_score > 0.7:
|
| 217 |
+
verdict = "⚠️ Directionally Right, Magnitude Wrong" # Right direction, wrong magnitude
|
| 218 |
+
elif das_score > 0.3:
|
| 219 |
+
verdict = "🔄 Partially Aligned" # Some alignment
|
| 220 |
+
else:
|
| 221 |
+
verdict = "❌ Misaligned" # Completely wrong
|
| 222 |
+
|
| 223 |
+
confidence_level = "High" if confidence > 0.8 else "Medium" if confidence > 0.6 else "Low"
|
| 224 |
+
|
| 225 |
+
return (f"{verdict} | {sentiment.title()} sentiment ({confidence_level} conf: {confidence:.2f}) "
|
| 226 |
+
f"vs {direction} {return_24h:+.2f}% return | DAS: {das_score:.3f}")
|
| 227 |
+
|
| 228 |
+
def calculate_batch_metrics(self, evaluations: List[Dict]) -> Dict:
|
| 229 |
+
"""
|
| 230 |
+
Calculate aggregate metrics for a batch of evaluations.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
evaluations: List of evaluation results
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
Dictionary with aggregate metrics
|
| 237 |
+
"""
|
| 238 |
+
try:
|
| 239 |
+
if not evaluations:
|
| 240 |
+
return {"error": "No evaluations provided"}
|
| 241 |
+
|
| 242 |
+
# Filter out error results
|
| 243 |
+
valid_evals = [e for e in evaluations if "error" not in e]
|
| 244 |
+
|
| 245 |
+
if not valid_evals:
|
| 246 |
+
return {"error": "No valid evaluations found"}
|
| 247 |
+
|
| 248 |
+
# Calculate metrics
|
| 249 |
+
das_scores = [e["das_score"] for e in valid_evals]
|
| 250 |
+
correctness = [e["is_correct"] for e in valid_evals]
|
| 251 |
+
confidences = [e["confidence"] for e in valid_evals]
|
| 252 |
+
wat_weights = [e["wat_weight"] for e in valid_evals]
|
| 253 |
+
|
| 254 |
+
# Aggregate metrics
|
| 255 |
+
avg_das = float(np.mean(das_scores))
|
| 256 |
+
accuracy = float(np.mean(correctness))
|
| 257 |
+
avg_confidence = float(np.mean(confidences))
|
| 258 |
+
|
| 259 |
+
# Weighted accuracy
|
| 260 |
+
weighted_correctness = [float(c) * float(w) for c, w in zip(correctness, wat_weights)]
|
| 261 |
+
total_weight = sum(wat_weights)
|
| 262 |
+
weighted_accuracy = float(sum(weighted_correctness) / total_weight) if total_weight > 0 else 0.0
|
| 263 |
+
|
| 264 |
+
# Confidence-accuracy correlation (handle single evaluation case)
|
| 265 |
+
if len(confidences) > 1:
|
| 266 |
+
try:
|
| 267 |
+
corr_matrix = np.corrcoef(confidences, correctness)
|
| 268 |
+
confidence_correlation = float(corr_matrix[0, 1])
|
| 269 |
+
# Handle NaN case (when all values are the same)
|
| 270 |
+
if np.isnan(confidence_correlation):
|
| 271 |
+
confidence_correlation = 0.0
|
| 272 |
+
except:
|
| 273 |
+
confidence_correlation = 0.0
|
| 274 |
+
else:
|
| 275 |
+
confidence_correlation = 0.0 # Cannot calculate correlation with single point
|
| 276 |
+
|
| 277 |
+
# Count high/low confidence predictions
|
| 278 |
+
high_confidence_count = sum(1 for c in confidences if c > self.confidence_threshold)
|
| 279 |
+
low_confidence_count = sum(1 for c in confidences if c < 0.6)
|
| 280 |
+
|
| 281 |
+
return {
|
| 282 |
+
"total_evaluations": len(valid_evals),
|
| 283 |
+
"average_das": avg_das,
|
| 284 |
+
"accuracy": accuracy,
|
| 285 |
+
"weighted_accuracy": weighted_accuracy,
|
| 286 |
+
"average_confidence": avg_confidence,
|
| 287 |
+
"confidence_accuracy_correlation": confidence_correlation,
|
| 288 |
+
"high_confidence_predictions": high_confidence_count,
|
| 289 |
+
"low_confidence_predictions": low_confidence_count
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
except Exception as e:
|
| 293 |
+
logger.error(f"Error calculating batch metrics: {str(e)}")
|
| 294 |
+
return {"error": str(e)}
|
src/market_data.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Market data fetching service using yfinance
|
| 2 |
+
"""
|
| 3 |
+
This module handles fetching historical stock price data, calculating returns,
|
| 4 |
+
volatility, and market index comparisons for evaluation purposes.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import yfinance as yf
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from datetime import datetime, timedelta
|
| 11 |
+
from typing import Dict, Optional, Tuple
|
| 12 |
+
import logging
|
| 13 |
+
import streamlit as st
|
| 14 |
+
|
| 15 |
+
# Configure logging
|
| 16 |
+
logging.basicConfig(level=logging.INFO)
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
class MarketDataService:
|
| 20 |
+
"""
|
| 21 |
+
Service for fetching and processing market data for evaluation.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, market_index: str = "^GSPC"):
|
| 25 |
+
"""
|
| 26 |
+
Initialize the market data service.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
market_index: The market index ticker for macro adjustments (default: S&P 500)
|
| 30 |
+
"""
|
| 31 |
+
self.market_index = market_index
|
| 32 |
+
|
| 33 |
+
@st.cache_data(ttl=3600) # Cache for 1 hour
|
| 34 |
+
def fetch_stock_data(_self, ticker: str, start_date: datetime, end_date: datetime) -> Optional[pd.DataFrame]:
|
| 35 |
+
"""
|
| 36 |
+
Fetch historical stock data for a given ticker and date range.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
ticker: Stock ticker symbol
|
| 40 |
+
start_date: Start date for data fetch
|
| 41 |
+
end_date: End date for data fetch
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
DataFrame with stock price data or None if failed
|
| 45 |
+
"""
|
| 46 |
+
try:
|
| 47 |
+
logger.info(f"Fetching data for {ticker} from {start_date} to {end_date}")
|
| 48 |
+
stock = yf.Ticker(ticker)
|
| 49 |
+
data = stock.history(start=start_date, end=end_date)
|
| 50 |
+
|
| 51 |
+
if data.empty:
|
| 52 |
+
logger.warning(f"No data found for ticker {ticker}")
|
| 53 |
+
return None
|
| 54 |
+
|
| 55 |
+
return data
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"Error fetching data for {ticker}: {str(e)}")
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
def calculate_same_day_return(self, data: pd.DataFrame, news_date: datetime) -> Optional[float]:
|
| 62 |
+
"""
|
| 63 |
+
Calculate stock return on the same day the news was published (intraday).
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
data: Stock price DataFrame
|
| 67 |
+
news_date: Date when news was published
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Intraday return percentage or None if calculation fails
|
| 71 |
+
"""
|
| 72 |
+
try:
|
| 73 |
+
# Convert news_date to date only for comparison
|
| 74 |
+
news_date_only = news_date.date()
|
| 75 |
+
|
| 76 |
+
# Find the trading day that matches the news date
|
| 77 |
+
data_dates = data.index.date
|
| 78 |
+
matching_dates = [d for d in data_dates if d == news_date_only]
|
| 79 |
+
|
| 80 |
+
if not matching_dates:
|
| 81 |
+
# If no exact match, find the next trading day
|
| 82 |
+
future_dates = [d for d in data_dates if d > news_date_only]
|
| 83 |
+
if not future_dates:
|
| 84 |
+
logger.warning(f"No trading data available for or after {news_date_only}")
|
| 85 |
+
return None
|
| 86 |
+
trading_date = future_dates[0]
|
| 87 |
+
logger.info(f"News date {news_date_only} was not a trading day, using next trading day: {trading_date}")
|
| 88 |
+
else:
|
| 89 |
+
trading_date = matching_dates[0]
|
| 90 |
+
|
| 91 |
+
# Get the day's data
|
| 92 |
+
day_data = data[data.index.date == trading_date]
|
| 93 |
+
|
| 94 |
+
if len(day_data) == 0:
|
| 95 |
+
logger.warning(f"No trading data found for {trading_date}")
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
# Calculate intraday return: (Close - Open) / Open * 100
|
| 99 |
+
open_price = day_data['Open'].iloc[0]
|
| 100 |
+
close_price = day_data['Close'].iloc[-1]
|
| 101 |
+
|
| 102 |
+
return_pct = ((close_price - open_price) / open_price) * 100
|
| 103 |
+
|
| 104 |
+
logger.info(f"Calculated same-day return for {trading_date}: {return_pct:.2f}% (Open: {open_price:.2f}, Close: {close_price:.2f})")
|
| 105 |
+
return float(return_pct)
|
| 106 |
+
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(f"Error calculating same-day return: {str(e)}")
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
def calculate_next_24h_return(self, data: pd.DataFrame, news_date: datetime) -> Optional[float]:
|
| 112 |
+
"""
|
| 113 |
+
Calculate stock return over the next 24 hours after news publication.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
data: Stock price DataFrame
|
| 117 |
+
news_date: Date when news was published
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
24-hour return percentage or None if calculation fails
|
| 121 |
+
"""
|
| 122 |
+
try:
|
| 123 |
+
# Convert news_date to date only for comparison
|
| 124 |
+
news_date_only = news_date.date()
|
| 125 |
+
|
| 126 |
+
# Find the trading day that matches the news date
|
| 127 |
+
data_dates = data.index.date
|
| 128 |
+
matching_dates = [d for d in data_dates if d == news_date_only]
|
| 129 |
+
|
| 130 |
+
if not matching_dates:
|
| 131 |
+
# If no exact match, find the next trading day
|
| 132 |
+
future_dates = [d for d in data_dates if d > news_date_only]
|
| 133 |
+
if not future_dates:
|
| 134 |
+
logger.warning(f"No trading data available for or after {news_date_only}")
|
| 135 |
+
return None
|
| 136 |
+
start_trading_date = future_dates[0]
|
| 137 |
+
else:
|
| 138 |
+
start_trading_date = matching_dates[0]
|
| 139 |
+
|
| 140 |
+
# Find the next trading day for 24h comparison
|
| 141 |
+
future_dates = [d for d in data_dates if d > start_trading_date]
|
| 142 |
+
if not future_dates:
|
| 143 |
+
logger.warning(f"No next trading day available after {start_trading_date}")
|
| 144 |
+
return None
|
| 145 |
+
|
| 146 |
+
end_trading_date = future_dates[0]
|
| 147 |
+
|
| 148 |
+
# Get start and end prices
|
| 149 |
+
start_data = data[data.index.date == start_trading_date]
|
| 150 |
+
end_data = data[data.index.date == end_trading_date]
|
| 151 |
+
|
| 152 |
+
if len(start_data) == 0 or len(end_data) == 0:
|
| 153 |
+
logger.warning(f"Insufficient data for 24h return calculation")
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
# Use close of start day and close of next day
|
| 157 |
+
start_price = start_data['Close'].iloc[-1]
|
| 158 |
+
end_price = end_data['Close'].iloc[-1]
|
| 159 |
+
|
| 160 |
+
return_pct = ((end_price - start_price) / start_price) * 100
|
| 161 |
+
|
| 162 |
+
logger.info(f"Calculated 24h return from {start_trading_date} to {end_trading_date}: {return_pct:.2f}%")
|
| 163 |
+
return float(return_pct)
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
logger.error(f"Error calculating 24h return: {str(e)}")
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
def calculate_return(self, data: pd.DataFrame, news_date: datetime, hours: int = 24) -> Optional[float]:
|
| 170 |
+
"""
|
| 171 |
+
Legacy method - now returns same-day return for compatibility.
|
| 172 |
+
Use calculate_same_day_return() or calculate_next_24h_return() for specific needs.
|
| 173 |
+
"""
|
| 174 |
+
return self.calculate_same_day_return(data, news_date)
|
| 175 |
+
|
| 176 |
+
def calculate_volatility(self, data: pd.DataFrame, days: int = 14) -> Optional[float]:
|
| 177 |
+
"""
|
| 178 |
+
Calculate rolling volatility for the stock.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
data: Stock price DataFrame
|
| 182 |
+
days: Number of days for volatility calculation
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Volatility percentage or None if calculation fails
|
| 186 |
+
"""
|
| 187 |
+
try:
|
| 188 |
+
if len(data) < days:
|
| 189 |
+
logger.warning(f"Insufficient data for {days}-day volatility calculation")
|
| 190 |
+
return None
|
| 191 |
+
|
| 192 |
+
# Calculate daily returns
|
| 193 |
+
data['Daily_Return'] = data['Close'].pct_change()
|
| 194 |
+
|
| 195 |
+
# Calculate rolling volatility (annualized)
|
| 196 |
+
volatility = data['Daily_Return'].rolling(window=days).std() * np.sqrt(252) * 100
|
| 197 |
+
|
| 198 |
+
# Return the most recent volatility
|
| 199 |
+
recent_volatility = volatility.dropna().iloc[-1]
|
| 200 |
+
|
| 201 |
+
logger.info(f"Calculated {days}-day volatility: {recent_volatility:.2f}%")
|
| 202 |
+
return float(recent_volatility)
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"Error calculating volatility: {str(e)}")
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
+
def get_market_return(self, news_date: datetime, hours: int = 24) -> Optional[float]:
|
| 209 |
+
"""
|
| 210 |
+
Get market index return for the same day as news publication.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
news_date: Date when news was published
|
| 214 |
+
hours: Deprecated parameter (kept for compatibility)
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Market return percentage for the news day or None if calculation fails
|
| 218 |
+
"""
|
| 219 |
+
try:
|
| 220 |
+
# Fetch market data
|
| 221 |
+
start_date = news_date - timedelta(days=5) # Buffer for weekends
|
| 222 |
+
end_date = news_date + timedelta(days=5)
|
| 223 |
+
|
| 224 |
+
market_data = self.fetch_stock_data(self.market_index, start_date, end_date)
|
| 225 |
+
|
| 226 |
+
if market_data is None:
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
+
return self.calculate_return(market_data, news_date, hours)
|
| 230 |
+
|
| 231 |
+
except Exception as e:
|
| 232 |
+
logger.error(f"Error getting market return: {str(e)}")
|
| 233 |
+
return None
|
| 234 |
+
|
| 235 |
+
def get_stock_evaluation_data(self, ticker: str, news_date: datetime) -> Dict:
|
| 236 |
+
"""
|
| 237 |
+
Get comprehensive stock data for evaluation including both same-day and 24h returns.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
ticker: Stock ticker symbol
|
| 241 |
+
news_date: Date when news was published
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
Dictionary containing all relevant market data
|
| 245 |
+
"""
|
| 246 |
+
try:
|
| 247 |
+
# Define date range (get extra days for volatility calculation)
|
| 248 |
+
start_date = news_date - timedelta(days=30)
|
| 249 |
+
end_date = news_date + timedelta(days=5)
|
| 250 |
+
|
| 251 |
+
# Fetch stock data
|
| 252 |
+
stock_data = self.fetch_stock_data(ticker, start_date, end_date)
|
| 253 |
+
|
| 254 |
+
if stock_data is None:
|
| 255 |
+
return {"error": f"Could not fetch data for ticker {ticker}"}
|
| 256 |
+
|
| 257 |
+
# Calculate both same-day and 24h returns
|
| 258 |
+
same_day_return = self.calculate_same_day_return(stock_data, news_date)
|
| 259 |
+
next_24h_return = self.calculate_next_24h_return(stock_data, news_date)
|
| 260 |
+
volatility_14d = self.calculate_volatility(stock_data, 14)
|
| 261 |
+
|
| 262 |
+
# Get market returns for both periods
|
| 263 |
+
market_same_day = self.get_market_return(news_date, 0) # Same day
|
| 264 |
+
market_24h = self.get_market_return(news_date, 24) # 24h
|
| 265 |
+
|
| 266 |
+
# Calculate alpha-adjusted returns
|
| 267 |
+
alpha_same_day = None
|
| 268 |
+
alpha_24h = None
|
| 269 |
+
|
| 270 |
+
if same_day_return is not None and market_same_day is not None:
|
| 271 |
+
alpha_same_day = same_day_return - market_same_day
|
| 272 |
+
|
| 273 |
+
if next_24h_return is not None and market_24h is not None:
|
| 274 |
+
alpha_24h = next_24h_return - market_24h
|
| 275 |
+
|
| 276 |
+
return {
|
| 277 |
+
"ticker": ticker,
|
| 278 |
+
"return_same_day": same_day_return,
|
| 279 |
+
"return_next_24h": next_24h_return,
|
| 280 |
+
"return_24h": same_day_return, # Keep for compatibility with existing code
|
| 281 |
+
"volatility_14d": volatility_14d,
|
| 282 |
+
"market_return_same_day": market_same_day,
|
| 283 |
+
"market_return_24h": market_24h,
|
| 284 |
+
"market_return": market_same_day, # Keep for compatibility
|
| 285 |
+
"alpha_same_day": alpha_same_day,
|
| 286 |
+
"alpha_24h": alpha_24h,
|
| 287 |
+
"alpha_adjusted": alpha_same_day, # Keep for compatibility
|
| 288 |
+
"data_points": len(stock_data),
|
| 289 |
+
"date_range": {
|
| 290 |
+
"start": stock_data.index[0].strftime("%Y-%m-%d"),
|
| 291 |
+
"end": stock_data.index[-1].strftime("%Y-%m-%d")
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
except Exception as e:
|
| 296 |
+
logger.error(f"Error getting evaluation data: {str(e)}")
|
| 297 |
+
return {"error": str(e)}
|
src/sentiment_analyzer.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FinBERT sentiment analysis module for financial news
|
| 2 |
+
"""
|
| 3 |
+
This module handles loading the ProsusAI/finbert model and extracting
|
| 4 |
+
sentiment predictions with confidence scores from financial news text.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 9 |
+
import streamlit as st
|
| 10 |
+
from typing import Dict, Tuple
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
# Configure logging
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class FinBERTAnalyzer:
|
| 18 |
+
"""
|
| 19 |
+
A wrapper class for the ProsusAI/finbert model to analyze financial sentiment.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, model_name: str = "ProsusAI/finbert"):
|
| 23 |
+
"""
|
| 24 |
+
Initialize the FinBERT analyzer.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
model_name: The Hugging Face model identifier
|
| 28 |
+
"""
|
| 29 |
+
self.model_name = model_name
|
| 30 |
+
self.tokenizer = None
|
| 31 |
+
self.model = None
|
| 32 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 33 |
+
|
| 34 |
+
@st.cache_resource
|
| 35 |
+
def load_model(_self):
|
| 36 |
+
"""
|
| 37 |
+
Load the FinBERT model and tokenizer with caching.
|
| 38 |
+
Using _self to avoid hashing issues with streamlit cache.
|
| 39 |
+
"""
|
| 40 |
+
try:
|
| 41 |
+
logger.info(f"Loading FinBERT model: {_self.model_name}")
|
| 42 |
+
_self.tokenizer = AutoTokenizer.from_pretrained(_self.model_name)
|
| 43 |
+
_self.model = AutoModelForSequenceClassification.from_pretrained(_self.model_name)
|
| 44 |
+
_self.model.to(_self.device)
|
| 45 |
+
_self.model.eval()
|
| 46 |
+
logger.info("FinBERT model loaded successfully")
|
| 47 |
+
return True
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Error loading FinBERT model: {str(e)}")
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
def analyze_sentiment(self, text: str) -> Dict[str, float]:
|
| 53 |
+
"""
|
| 54 |
+
Analyze sentiment of financial news text.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
text: The financial news text to analyze
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Dictionary containing sentiment label, confidence, and raw scores
|
| 61 |
+
"""
|
| 62 |
+
if not self.model or not self.tokenizer:
|
| 63 |
+
if not self.load_model():
|
| 64 |
+
raise RuntimeError("Failed to load FinBERT model")
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
# Tokenize input
|
| 68 |
+
inputs = self.tokenizer(
|
| 69 |
+
text,
|
| 70 |
+
return_tensors="pt",
|
| 71 |
+
truncation=True,
|
| 72 |
+
padding=True,
|
| 73 |
+
max_length=512
|
| 74 |
+
)
|
| 75 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 76 |
+
|
| 77 |
+
# Get predictions
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
outputs = self.model(**inputs)
|
| 80 |
+
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 81 |
+
|
| 82 |
+
# Extract results
|
| 83 |
+
scores = predictions.cpu().numpy()[0]
|
| 84 |
+
labels = ["negative", "neutral", "positive"]
|
| 85 |
+
|
| 86 |
+
# Find the predicted sentiment and confidence
|
| 87 |
+
predicted_idx = scores.argmax()
|
| 88 |
+
predicted_sentiment = labels[predicted_idx]
|
| 89 |
+
confidence = float(scores[predicted_idx])
|
| 90 |
+
|
| 91 |
+
return {
|
| 92 |
+
"sentiment": predicted_sentiment,
|
| 93 |
+
"confidence": confidence,
|
| 94 |
+
"scores": {
|
| 95 |
+
"negative": float(scores[0]),
|
| 96 |
+
"neutral": float(scores[1]),
|
| 97 |
+
"positive": float(scores[2])
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"Error analyzing sentiment: {str(e)}")
|
| 103 |
+
raise RuntimeError(f"Sentiment analysis failed: {str(e)}")
|
| 104 |
+
|
| 105 |
+
def get_sentiment_direction(self, sentiment: str) -> int:
|
| 106 |
+
"""
|
| 107 |
+
Convert sentiment label to numerical direction for evaluation.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
sentiment: The sentiment label ("positive", "negative", "neutral")
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
1 for positive, -1 for negative, 0 for neutral
|
| 114 |
+
"""
|
| 115 |
+
sentiment_map = {
|
| 116 |
+
"positive": 1,
|
| 117 |
+
"negative": -1,
|
| 118 |
+
"neutral": 0
|
| 119 |
+
}
|
| 120 |
+
return sentiment_map.get(sentiment.lower(), 0)
|
src/streamlit_app.py
CHANGED
|
@@ -1,40 +1,374 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
| 4 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
| 1 |
+
# FinBERT Market Evaluation - Main Streamlit Application
|
| 2 |
+
"""
|
| 3 |
+
A confidence-aware, volatility-adjusted post-market evaluator for FinBERT sentiment
|
| 4 |
+
predictions against actual stock market movements.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
import streamlit as st
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
import plotly.graph_objects as go
|
| 11 |
+
import plotly.express as px
|
| 12 |
+
from datetime import datetime, timedelta, date
|
| 13 |
+
import time
|
| 14 |
+
import logging
|
| 15 |
|
| 16 |
+
# Import our custom modules
|
| 17 |
+
from sentiment_analyzer import FinBERTAnalyzer
|
| 18 |
+
from market_data import MarketDataService
|
| 19 |
+
from evaluation import EvaluationEngine
|
| 20 |
|
| 21 |
+
# Configure logging
|
| 22 |
+
logging.basicConfig(level=logging.INFO)
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
+
# Page configuration
|
| 26 |
+
st.set_page_config(
|
| 27 |
+
page_title="FinBERT Market Evaluation",
|
| 28 |
+
page_icon="🚀",
|
| 29 |
+
layout="wide",
|
| 30 |
+
initial_sidebar_state="expanded"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Initialize session state for rate limiting
|
| 34 |
+
if 'last_request_time' not in st.session_state:
|
| 35 |
+
st.session_state.last_request_time = 0
|
| 36 |
+
|
| 37 |
+
if 'evaluation_history' not in st.session_state:
|
| 38 |
+
st.session_state.evaluation_history = []
|
| 39 |
+
|
| 40 |
+
# Initialize services
|
| 41 |
+
@st.cache_resource
|
| 42 |
+
def initialize_services():
|
| 43 |
+
"""Initialize all services with caching."""
|
| 44 |
+
analyzer = FinBERTAnalyzer()
|
| 45 |
+
market_service = MarketDataService()
|
| 46 |
+
evaluation_engine = EvaluationEngine()
|
| 47 |
+
return analyzer, market_service, evaluation_engine
|
| 48 |
+
|
| 49 |
+
def check_rate_limit():
|
| 50 |
+
"""Check if rate limit allows new request (30 seconds)."""
|
| 51 |
+
current_time = time.time()
|
| 52 |
+
time_since_last = current_time - st.session_state.last_request_time
|
| 53 |
+
return time_since_last >= 30
|
| 54 |
+
|
| 55 |
+
def update_rate_limit():
|
| 56 |
+
"""Update the last request time."""
|
| 57 |
+
st.session_state.last_request_time = time.time()
|
| 58 |
+
|
| 59 |
+
def create_das_chart(das_score: float, confidence: float, impact: float):
|
| 60 |
+
"""Create horizontal bar chart for DAS, confidence, and impact."""
|
| 61 |
+
fig = go.Figure()
|
| 62 |
+
|
| 63 |
+
metrics = ['DAS Score', 'Confidence', 'Impact (scaled)']
|
| 64 |
+
values = [das_score, confidence, min(impact / 5.0, 1.0)] # Scale impact to 0-1
|
| 65 |
+
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
|
| 66 |
+
|
| 67 |
+
fig.add_trace(go.Bar(
|
| 68 |
+
y=metrics,
|
| 69 |
+
x=values,
|
| 70 |
+
orientation='h',
|
| 71 |
+
marker_color=colors,
|
| 72 |
+
text=[f'{v:.3f}' for v in values],
|
| 73 |
+
textposition='inside'
|
| 74 |
+
))
|
| 75 |
+
|
| 76 |
+
fig.update_layout(
|
| 77 |
+
title="Evaluation Metrics",
|
| 78 |
+
xaxis_title="Score",
|
| 79 |
+
height=200,
|
| 80 |
+
margin=dict(l=100, r=50, t=50, b=50)
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return fig
|
| 84 |
+
|
| 85 |
+
def display_evaluation_result(result: dict):
|
| 86 |
+
"""Display comprehensive evaluation results."""
|
| 87 |
+
if "error" in result:
|
| 88 |
+
st.error(f"Evaluation Error: {result['error']}")
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
# Prominent evaluation summary first
|
| 92 |
+
st.markdown(f"### {result['evaluation_summary']}")
|
| 93 |
+
|
| 94 |
+
# Key insights in a highlighted box
|
| 95 |
+
alignment_color = "green" if result['is_correct'] else "red"
|
| 96 |
+
volatility_note = "🔥 Extremely High" if result['volatility_14d'] > 100 else "📊 High" if result['volatility_14d'] > 50 else "📈 Normal"
|
| 97 |
+
|
| 98 |
+
# Calculate if movement was significant
|
| 99 |
+
movement_significant = result['impact'] > result['threshold']
|
| 100 |
+
significance_text = "exceeded" if movement_significant else "was below"
|
| 101 |
+
|
| 102 |
+
st.markdown(f"""
|
| 103 |
+
<div style="background-color: rgba(0,0,0,0.1); padding: 15px; border-radius: 10px; margin: 10px 0;">
|
| 104 |
+
<h4>📊 Volatility-Aware Analysis:</h4>
|
| 105 |
+
<ul>
|
| 106 |
+
<li><strong>Stock's 14-day volatility:</strong> {result['volatility_14d']:.1f}% ({volatility_note.lower()})</li>
|
| 107 |
+
<li><strong>Significance threshold:</strong> {result['threshold']:.1f}% (= 1.0 × volatility)</li>
|
| 108 |
+
<li><strong>Actual movement:</strong> {result['return_24h']:+.2f}% ({result['impact']:.2f}% magnitude)</li>
|
| 109 |
+
<li><strong>Movement significance:</strong> {significance_text} threshold → {'Significant' if movement_significant else 'Not significant'}</li>
|
| 110 |
+
<li><strong>Directional alignment:</strong> <span style="color: {alignment_color};">{'✅ Correct direction' if result['is_correct'] else '❌ Wrong direction or insufficient magnitude'}</span></li>
|
| 111 |
+
<li><strong>Model confidence:</strong> {'High' if result['confidence'] > 0.8 else 'Medium' if result['confidence'] > 0.6 else 'Low'} ({result['confidence']:.1%})</li>
|
| 112 |
+
</ul>
|
| 113 |
+
</div>
|
| 114 |
+
""", unsafe_allow_html=True)
|
| 115 |
+
|
| 116 |
+
# Main metrics in columns
|
| 117 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 118 |
+
|
| 119 |
+
with col1:
|
| 120 |
+
st.metric("DAS Score", f"{result['das_score']:.3f}", help="Directional Alignment Score (0-1, higher is better)")
|
| 121 |
+
|
| 122 |
+
with col2:
|
| 123 |
+
sentiment_emoji = {"positive": "📈", "negative": "📉", "neutral": "➡️"}
|
| 124 |
+
st.metric("Sentiment", f"{sentiment_emoji.get(result['sentiment'], '❓')} {result['sentiment'].title()}")
|
| 125 |
+
|
| 126 |
+
with col3:
|
| 127 |
+
st.metric("Confidence", f"{result['confidence']:.1%}")
|
| 128 |
+
|
| 129 |
+
with col4:
|
| 130 |
+
return_color = "normal" if abs(result['return_24h']) < result['threshold'] else "inverse"
|
| 131 |
+
st.metric("Same-Day Return", f"{result['return_24h']:+.2f}%", delta=f"vs {result['threshold']:.1f}% threshold")
|
| 132 |
+
|
| 133 |
+
# Additional metrics for 24h return if available
|
| 134 |
+
if result.get('return_next_24h') is not None:
|
| 135 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 136 |
+
with col1:
|
| 137 |
+
st.metric("Next 24h Return", f"{result['return_next_24h']:+.2f}%", help="Return from close of news day to close of next trading day")
|
| 138 |
+
with col2:
|
| 139 |
+
if result.get('alpha_24h') is not None:
|
| 140 |
+
st.metric("24h Alpha", f"{result['alpha_24h']:+.2f}%", help="24h return vs market performance")
|
| 141 |
+
with col3:
|
| 142 |
+
# Show combined impact
|
| 143 |
+
combined_impact = abs(result['return_24h']) + abs(result.get('return_next_24h', 0))
|
| 144 |
+
st.metric("Combined Impact", f"{combined_impact:.2f}%", help="Total magnitude of price movement")
|
| 145 |
+
with col4:
|
| 146 |
+
# Show follow-through consistency
|
| 147 |
+
same_direction = (result['return_24h'] * result.get('return_next_24h', 0)) > 0
|
| 148 |
+
consistency = "✅ Consistent" if same_direction else "🔄 Reversal"
|
| 149 |
+
st.metric("Follow-through", consistency, help="Whether 24h movement continued same direction")
|
| 150 |
+
|
| 151 |
+
# Visualization
|
| 152 |
+
chart = create_das_chart(result['das_score'], result['confidence'], result['impact'])
|
| 153 |
+
# Use session state to create unique chart counter
|
| 154 |
+
if 'chart_counter' not in st.session_state:
|
| 155 |
+
st.session_state.chart_counter = 0
|
| 156 |
+
st.session_state.chart_counter += 1
|
| 157 |
+
chart_key = f"chart_{st.session_state.chart_counter}"
|
| 158 |
+
st.plotly_chart(chart, use_container_width=True, key=chart_key)
|
| 159 |
+
|
| 160 |
+
# Technical metrics (always visible)
|
| 161 |
+
st.subheader("📊 Technical Metrics")
|
| 162 |
+
|
| 163 |
+
col1, col2, col3 = st.columns(3)
|
| 164 |
+
|
| 165 |
+
with col1:
|
| 166 |
+
st.metric("Ticker", result['ticker'])
|
| 167 |
+
st.metric("News Date", result['news_date'])
|
| 168 |
+
st.metric("14-day Volatility", f"{result['volatility_14d']:.2f}%")
|
| 169 |
+
st.metric("Significance Threshold", f"{result['threshold']:.2f}%")
|
| 170 |
+
|
| 171 |
+
with col2:
|
| 172 |
+
st.metric("Same-Day Impact", f"{result['impact']:.2f}%")
|
| 173 |
+
if result.get('return_next_24h') is not None:
|
| 174 |
+
st.metric("24h Impact", f"{abs(result['return_next_24h']):.2f}%")
|
| 175 |
+
st.metric("WAT Weight", f"{result['wat_weight']:.3f}")
|
| 176 |
+
alignment_text = "✅ Yes" if result['is_correct'] else "❌ No"
|
| 177 |
+
st.metric("Alignment", alignment_text)
|
| 178 |
+
|
| 179 |
+
with col3:
|
| 180 |
+
alpha_val = result.get('alpha_adjusted', 'N/A')
|
| 181 |
+
alpha_str = f"{alpha_val:+.2f}%" if isinstance(alpha_val, (int, float)) else str(alpha_val)
|
| 182 |
+
st.metric("Same-Day Alpha", alpha_str)
|
| 183 |
+
|
| 184 |
+
if result.get('alpha_24h') is not None:
|
| 185 |
+
st.metric("24h Alpha", f"{result['alpha_24h']:+.2f}%")
|
| 186 |
+
|
| 187 |
+
# Market context
|
| 188 |
+
market_same = result.get('market_return', 'N/A')
|
| 189 |
+
market_str = f"{market_same:+.2f}%" if isinstance(market_same, (int, float)) else str(market_same)
|
| 190 |
+
st.metric("Market Return", market_str)
|
| 191 |
+
|
| 192 |
+
def main():
|
| 193 |
+
"""Main application function."""
|
| 194 |
+
# Header
|
| 195 |
+
st.title("🚀 FinBERT Market Evaluation")
|
| 196 |
+
st.markdown("""
|
| 197 |
+
A confidence-aware, volatility-adjusted post-market evaluator for FinBERT sentiment predictions.
|
| 198 |
+
Evaluate how well FinBERT's financial news sentiment aligns with actual stock market movements.
|
| 199 |
+
""")
|
| 200 |
+
|
| 201 |
+
# Sidebar info (no user configuration needed)
|
| 202 |
+
st.sidebar.header("📊 Evaluation Framework")
|
| 203 |
+
st.sidebar.markdown("""
|
| 204 |
+
**Dual-Period Analysis:**
|
| 205 |
+
- **Same-Day**: Intraday return (Close - Open)
|
| 206 |
+
- **Next 24h**: Close-to-close follow-through
|
| 207 |
+
- **Combined**: Complete market reaction picture
|
| 208 |
+
|
| 209 |
+
**Volatility-Aware Evaluation:**
|
| 210 |
+
- Uses each stock's 14-day volatility
|
| 211 |
+
- Threshold = 1.0 × volatility (k=1.0)
|
| 212 |
+
- Adapts to stock movement patterns
|
| 213 |
+
|
| 214 |
+
**Directional Alignment Score:**
|
| 215 |
+
- Graded 0-1 score (not binary)
|
| 216 |
+
- Based on same-day return vs threshold
|
| 217 |
+
- Higher = better alignment
|
| 218 |
+
|
| 219 |
+
**Alpha Analysis:**
|
| 220 |
+
- Stock return vs market performance
|
| 221 |
+
- Isolates stock-specific impact
|
| 222 |
+
- Available for both time periods
|
| 223 |
+
""")
|
| 224 |
+
|
| 225 |
+
# Fixed research parameters (not user-configurable)
|
| 226 |
+
volatility_multiplier = 1.0 # k = 1.0 as per your framework
|
| 227 |
+
confidence_threshold = 0.7 # Reasonable default
|
| 228 |
+
|
| 229 |
+
# Initialize services
|
| 230 |
+
try:
|
| 231 |
+
analyzer, market_service, evaluation_engine = initialize_services()
|
| 232 |
+
evaluation_engine.volatility_multiplier = volatility_multiplier
|
| 233 |
+
evaluation_engine.confidence_threshold = confidence_threshold
|
| 234 |
+
except Exception as e:
|
| 235 |
+
st.error(f"Failed to initialize services: {str(e)}")
|
| 236 |
+
st.stop()
|
| 237 |
+
|
| 238 |
+
# Main input form
|
| 239 |
+
st.header("📰 News Analysis")
|
| 240 |
+
|
| 241 |
+
with st.form("evaluation_form"):
|
| 242 |
+
# News text input
|
| 243 |
+
news_text = st.text_area(
|
| 244 |
+
"Financial News Text",
|
| 245 |
+
height=150,
|
| 246 |
+
placeholder="Enter financial news headline or summary here...",
|
| 247 |
+
help="Paste the financial news text you want to analyze"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
col1, col2 = st.columns(2)
|
| 251 |
+
|
| 252 |
+
with col1:
|
| 253 |
+
ticker = st.text_input(
|
| 254 |
+
"Stock Ticker",
|
| 255 |
+
placeholder="e.g., TSLA, AAPL, MSFT",
|
| 256 |
+
help="Enter the stock ticker symbol"
|
| 257 |
+
).upper()
|
| 258 |
+
|
| 259 |
+
with col2:
|
| 260 |
+
news_date = st.date_input(
|
| 261 |
+
"News Publication Date",
|
| 262 |
+
value=date.today() - timedelta(days=1),
|
| 263 |
+
max_value=date.today() - timedelta(days=1),
|
| 264 |
+
help="Date when the news was published (must be at least 1 day ago)"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
submitted = st.form_submit_button("🔍 Evaluate Prediction")
|
| 268 |
+
|
| 269 |
+
# Process evaluation
|
| 270 |
+
if submitted:
|
| 271 |
+
if not news_text.strip():
|
| 272 |
+
st.error("Please enter some news text to analyze.")
|
| 273 |
+
return
|
| 274 |
+
|
| 275 |
+
if not ticker:
|
| 276 |
+
st.error("Please enter a stock ticker symbol.")
|
| 277 |
+
return
|
| 278 |
+
|
| 279 |
+
# Rate limiting check
|
| 280 |
+
if not check_rate_limit():
|
| 281 |
+
remaining_time = 30 - (time.time() - st.session_state.last_request_time)
|
| 282 |
+
st.warning(f"Rate limit: Please wait {remaining_time:.0f} more seconds before next request.")
|
| 283 |
+
return
|
| 284 |
+
|
| 285 |
+
# Update rate limit
|
| 286 |
+
update_rate_limit()
|
| 287 |
+
|
| 288 |
+
# Show progress
|
| 289 |
+
progress_bar = st.progress(0)
|
| 290 |
+
status_text = st.empty()
|
| 291 |
+
|
| 292 |
+
try:
|
| 293 |
+
# Step 1: Sentiment Analysis
|
| 294 |
+
status_text.text("🤖 Analyzing sentiment with FinBERT...")
|
| 295 |
+
progress_bar.progress(25)
|
| 296 |
+
|
| 297 |
+
sentiment_result = analyzer.analyze_sentiment(news_text)
|
| 298 |
+
|
| 299 |
+
# Step 2: Market Data
|
| 300 |
+
status_text.text("📊 Fetching market data...")
|
| 301 |
+
progress_bar.progress(50)
|
| 302 |
+
|
| 303 |
+
news_datetime = datetime.combine(news_date, datetime.min.time())
|
| 304 |
+
market_result = market_service.get_stock_evaluation_data(ticker, news_datetime)
|
| 305 |
+
|
| 306 |
+
# Step 3: Evaluation
|
| 307 |
+
status_text.text("⚖️ Evaluating prediction...")
|
| 308 |
+
progress_bar.progress(75)
|
| 309 |
+
|
| 310 |
+
evaluation_result = evaluation_engine.evaluate_prediction(
|
| 311 |
+
sentiment_result, market_result, news_datetime
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Step 4: Display Results
|
| 315 |
+
status_text.text("✅ Evaluation complete!")
|
| 316 |
+
progress_bar.progress(100)
|
| 317 |
+
|
| 318 |
+
# Clear progress indicators
|
| 319 |
+
time.sleep(0.5)
|
| 320 |
+
progress_bar.empty()
|
| 321 |
+
status_text.empty()
|
| 322 |
+
|
| 323 |
+
# Display results
|
| 324 |
+
st.header("📊 Evaluation Results")
|
| 325 |
+
display_evaluation_result(evaluation_result)
|
| 326 |
+
|
| 327 |
+
# Add to history
|
| 328 |
+
if "error" not in evaluation_result:
|
| 329 |
+
st.session_state.evaluation_history.append(evaluation_result)
|
| 330 |
+
|
| 331 |
+
except Exception as e:
|
| 332 |
+
progress_bar.empty()
|
| 333 |
+
status_text.empty()
|
| 334 |
+
st.error(f"Evaluation failed: {str(e)}")
|
| 335 |
+
logger.error(f"Evaluation error: {str(e)}")
|
| 336 |
+
|
| 337 |
+
# Evaluation History Section
|
| 338 |
+
if st.session_state.evaluation_history:
|
| 339 |
+
st.header("📋 Previous Evaluations")
|
| 340 |
+
|
| 341 |
+
# Show most recent evaluations first (reverse chronological)
|
| 342 |
+
recent_evaluations = list(reversed(st.session_state.evaluation_history))
|
| 343 |
+
|
| 344 |
+
# Show recent evaluations in expandable cards
|
| 345 |
+
for i, result in enumerate(recent_evaluations):
|
| 346 |
+
# Create a concise title for each evaluation
|
| 347 |
+
alignment_icon = "✅" if result['is_correct'] else "❌"
|
| 348 |
+
sentiment_icon = {"positive": "📈", "negative": "📉", "neutral": "➡️"}.get(result['sentiment'], "❓")
|
| 349 |
+
|
| 350 |
+
title = f"{alignment_icon} {result['ticker']} ({result['news_date']}) - {sentiment_icon} {result['sentiment'].title()} → {result['return_24h']:+.1f}% | DAS: {result['das_score']:.3f}"
|
| 351 |
+
|
| 352 |
+
with st.expander(title, expanded=(i==0)): # Expand the most recent one
|
| 353 |
+
display_evaluation_result(result)
|
| 354 |
+
|
| 355 |
+
# Simple action buttons
|
| 356 |
+
st.markdown("---")
|
| 357 |
+
|
| 358 |
+
# Simple action buttons
|
| 359 |
+
col1, col2 = st.columns([1, 3])
|
| 360 |
+
|
| 361 |
+
with col1:
|
| 362 |
+
if st.button("🗑️ Clear All History"):
|
| 363 |
+
st.session_state.evaluation_history = []
|
| 364 |
+
st.rerun()
|
| 365 |
+
|
| 366 |
+
with col2:
|
| 367 |
+
st.caption(f"📊 {len(st.session_state.evaluation_history)} evaluation(s) completed")
|
| 368 |
+
|
| 369 |
+
# Footer
|
| 370 |
+
st.markdown("---")
|
| 371 |
+
st.caption("🚀 **FinBERT Market Evaluation** | Rate limit: 30s | Model: ProsusAI/finbert | Data: Yahoo Finance")
|
| 372 |
|
| 373 |
+
if __name__ == "__main__":
|
| 374 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/visualizations.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Visualization components for FinBERT Market Evaluation
|
| 2 |
+
"""
|
| 3 |
+
This module provides additional visualization components including
|
| 4 |
+
calibration plots, correlation heatmaps, and performance charts.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import plotly.graph_objects as go
|
| 8 |
+
import plotly.express as px
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import List, Dict
|
| 12 |
+
import streamlit as st
|
| 13 |
+
|
| 14 |
+
def create_calibration_plot(evaluations: List[Dict]) -> go.Figure:
|
| 15 |
+
"""
|
| 16 |
+
Create a calibration plot showing confidence vs actual accuracy.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
evaluations: List of evaluation results
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Plotly figure for calibration plot
|
| 23 |
+
"""
|
| 24 |
+
if not evaluations:
|
| 25 |
+
return go.Figure()
|
| 26 |
+
|
| 27 |
+
# Extract confidence and correctness
|
| 28 |
+
confidences = [e['confidence'] for e in evaluations if 'confidence' in e]
|
| 29 |
+
correctness = [e['is_correct'] for e in evaluations if 'is_correct' in e]
|
| 30 |
+
|
| 31 |
+
if len(confidences) != len(correctness) or len(confidences) < 5:
|
| 32 |
+
return go.Figure()
|
| 33 |
+
|
| 34 |
+
# Create confidence bins
|
| 35 |
+
bins = np.linspace(0.5, 1.0, 6) # 5 bins from 0.5 to 1.0
|
| 36 |
+
bin_centers = (bins[:-1] + bins[1:]) / 2
|
| 37 |
+
|
| 38 |
+
# Calculate accuracy for each bin
|
| 39 |
+
bin_accuracies = []
|
| 40 |
+
bin_counts = []
|
| 41 |
+
|
| 42 |
+
for i in range(len(bins) - 1):
|
| 43 |
+
mask = (np.array(confidences) >= bins[i]) & (np.array(confidences) < bins[i + 1])
|
| 44 |
+
if i == len(bins) - 2: # Last bin includes upper bound
|
| 45 |
+
mask = (np.array(confidences) >= bins[i]) & (np.array(confidences) <= bins[i + 1])
|
| 46 |
+
|
| 47 |
+
bin_correct = np.array(correctness)[mask]
|
| 48 |
+
if len(bin_correct) > 0:
|
| 49 |
+
bin_accuracies.append(np.mean(bin_correct))
|
| 50 |
+
bin_counts.append(len(bin_correct))
|
| 51 |
+
else:
|
| 52 |
+
bin_accuracies.append(0)
|
| 53 |
+
bin_counts.append(0)
|
| 54 |
+
|
| 55 |
+
# Create figure
|
| 56 |
+
fig = go.Figure()
|
| 57 |
+
|
| 58 |
+
# Perfect calibration line
|
| 59 |
+
fig.add_trace(go.Scatter(
|
| 60 |
+
x=[0.5, 1.0],
|
| 61 |
+
y=[0.5, 1.0],
|
| 62 |
+
mode='lines',
|
| 63 |
+
name='Perfect Calibration',
|
| 64 |
+
line=dict(dash='dash', color='gray')
|
| 65 |
+
))
|
| 66 |
+
|
| 67 |
+
# Actual calibration
|
| 68 |
+
fig.add_trace(go.Scatter(
|
| 69 |
+
x=bin_centers,
|
| 70 |
+
y=bin_accuracies,
|
| 71 |
+
mode='markers+lines',
|
| 72 |
+
name='Actual Calibration',
|
| 73 |
+
marker=dict(size=[c/2 + 5 for c in bin_counts]), # Size by count
|
| 74 |
+
text=[f'Count: {c}' for c in bin_counts],
|
| 75 |
+
hovertemplate='Confidence: %{x:.2f}<br>Accuracy: %{y:.2f}<br>%{text}<extra></extra>'
|
| 76 |
+
))
|
| 77 |
+
|
| 78 |
+
fig.update_layout(
|
| 79 |
+
title='Calibration Plot: Confidence vs Accuracy',
|
| 80 |
+
xaxis_title='Predicted Confidence',
|
| 81 |
+
yaxis_title='Actual Accuracy',
|
| 82 |
+
xaxis=dict(range=[0.5, 1.0]),
|
| 83 |
+
yaxis=dict(range=[0.0, 1.0]),
|
| 84 |
+
height=400
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
return fig
|
| 88 |
+
|
| 89 |
+
def create_performance_over_time(evaluations: List[Dict]) -> go.Figure:
|
| 90 |
+
"""
|
| 91 |
+
Create a time series plot of performance metrics.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
evaluations: List of evaluation results
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Plotly figure for performance over time
|
| 98 |
+
"""
|
| 99 |
+
if not evaluations:
|
| 100 |
+
return go.Figure()
|
| 101 |
+
|
| 102 |
+
# Convert to DataFrame
|
| 103 |
+
df = pd.DataFrame(evaluations)
|
| 104 |
+
df['news_date'] = pd.to_datetime(df['news_date'])
|
| 105 |
+
df = df.sort_values('news_date')
|
| 106 |
+
|
| 107 |
+
# Calculate rolling metrics
|
| 108 |
+
window = min(5, len(df)) # 5-day rolling window or less
|
| 109 |
+
df['rolling_das'] = df['das_score'].rolling(window=window, min_periods=1).mean()
|
| 110 |
+
df['rolling_accuracy'] = df['is_correct'].rolling(window=window, min_periods=1).mean()
|
| 111 |
+
|
| 112 |
+
fig = go.Figure()
|
| 113 |
+
|
| 114 |
+
# DAS Score over time
|
| 115 |
+
fig.add_trace(go.Scatter(
|
| 116 |
+
x=df['news_date'],
|
| 117 |
+
y=df['rolling_das'],
|
| 118 |
+
mode='lines+markers',
|
| 119 |
+
name='Rolling DAS Score',
|
| 120 |
+
line=dict(color='blue'),
|
| 121 |
+
yaxis='y'
|
| 122 |
+
))
|
| 123 |
+
|
| 124 |
+
# Accuracy over time
|
| 125 |
+
fig.add_trace(go.Scatter(
|
| 126 |
+
x=df['news_date'],
|
| 127 |
+
y=df['rolling_accuracy'],
|
| 128 |
+
mode='lines+markers',
|
| 129 |
+
name='Rolling Accuracy',
|
| 130 |
+
line=dict(color='red'),
|
| 131 |
+
yaxis='y2'
|
| 132 |
+
))
|
| 133 |
+
|
| 134 |
+
fig.update_layout(
|
| 135 |
+
title=f'Performance Over Time (Rolling {window}-day average)',
|
| 136 |
+
xaxis_title='Date',
|
| 137 |
+
yaxis=dict(
|
| 138 |
+
title='DAS Score',
|
| 139 |
+
side='left',
|
| 140 |
+
range=[0, 1]
|
| 141 |
+
),
|
| 142 |
+
yaxis2=dict(
|
| 143 |
+
title='Accuracy',
|
| 144 |
+
side='right',
|
| 145 |
+
overlaying='y',
|
| 146 |
+
range=[0, 1]
|
| 147 |
+
),
|
| 148 |
+
height=400,
|
| 149 |
+
hovermode='x unified'
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return fig
|
| 153 |
+
|
| 154 |
+
def create_sentiment_distribution(evaluations: List[Dict]) -> go.Figure:
|
| 155 |
+
"""
|
| 156 |
+
Create a distribution plot of sentiments and their performance.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
evaluations: List of evaluation results
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Plotly figure for sentiment distribution
|
| 163 |
+
"""
|
| 164 |
+
if not evaluations:
|
| 165 |
+
return go.Figure()
|
| 166 |
+
|
| 167 |
+
df = pd.DataFrame(evaluations)
|
| 168 |
+
|
| 169 |
+
# Group by sentiment
|
| 170 |
+
sentiment_stats = df.groupby('sentiment').agg({
|
| 171 |
+
'das_score': ['mean', 'count'],
|
| 172 |
+
'is_correct': 'mean',
|
| 173 |
+
'confidence': 'mean'
|
| 174 |
+
}).round(3)
|
| 175 |
+
|
| 176 |
+
sentiment_stats.columns = ['avg_das', 'count', 'accuracy', 'avg_confidence']
|
| 177 |
+
sentiment_stats = sentiment_stats.reset_index()
|
| 178 |
+
|
| 179 |
+
# Create subplot
|
| 180 |
+
fig = go.Figure()
|
| 181 |
+
|
| 182 |
+
# Bar chart for counts
|
| 183 |
+
fig.add_trace(go.Bar(
|
| 184 |
+
x=sentiment_stats['sentiment'],
|
| 185 |
+
y=sentiment_stats['count'],
|
| 186 |
+
name='Count',
|
| 187 |
+
marker_color='lightblue',
|
| 188 |
+
yaxis='y',
|
| 189 |
+
text=sentiment_stats['count'],
|
| 190 |
+
textposition='auto'
|
| 191 |
+
))
|
| 192 |
+
|
| 193 |
+
# Line chart for accuracy
|
| 194 |
+
fig.add_trace(go.Scatter(
|
| 195 |
+
x=sentiment_stats['sentiment'],
|
| 196 |
+
y=sentiment_stats['accuracy'],
|
| 197 |
+
mode='lines+markers',
|
| 198 |
+
name='Accuracy',
|
| 199 |
+
line=dict(color='red'),
|
| 200 |
+
yaxis='y2',
|
| 201 |
+
marker=dict(size=10)
|
| 202 |
+
))
|
| 203 |
+
|
| 204 |
+
fig.update_layout(
|
| 205 |
+
title='Sentiment Distribution and Performance',
|
| 206 |
+
xaxis_title='Sentiment',
|
| 207 |
+
yaxis=dict(
|
| 208 |
+
title='Count',
|
| 209 |
+
side='left'
|
| 210 |
+
),
|
| 211 |
+
yaxis2=dict(
|
| 212 |
+
title='Accuracy',
|
| 213 |
+
side='right',
|
| 214 |
+
overlaying='y',
|
| 215 |
+
range=[0, 1]
|
| 216 |
+
),
|
| 217 |
+
height=400
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
return fig
|
| 221 |
+
|
| 222 |
+
def create_confidence_impact_scatter(evaluations: List[Dict]) -> go.Figure:
|
| 223 |
+
"""
|
| 224 |
+
Create a scatter plot of confidence vs impact with DAS score coloring.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
evaluations: List of evaluation results
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
Plotly figure for confidence-impact scatter
|
| 231 |
+
"""
|
| 232 |
+
if not evaluations:
|
| 233 |
+
return go.Figure()
|
| 234 |
+
|
| 235 |
+
df = pd.DataFrame(evaluations)
|
| 236 |
+
|
| 237 |
+
# Create scatter plot
|
| 238 |
+
fig = px.scatter(
|
| 239 |
+
df,
|
| 240 |
+
x='confidence',
|
| 241 |
+
y='impact',
|
| 242 |
+
color='das_score',
|
| 243 |
+
size='wat_weight',
|
| 244 |
+
hover_data=['ticker', 'sentiment', 'return_24h'],
|
| 245 |
+
color_continuous_scale='RdYlBu_r',
|
| 246 |
+
title='Confidence vs Impact (colored by DAS Score)'
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
fig.update_layout(
|
| 250 |
+
xaxis_title='Confidence',
|
| 251 |
+
yaxis_title='Impact (|Return %|)',
|
| 252 |
+
height=400
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
return fig
|
| 256 |
+
|
| 257 |
+
def display_advanced_visualizations(evaluations: List[Dict]):
|
| 258 |
+
"""
|
| 259 |
+
Display advanced visualization components in Streamlit.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
evaluations: List of evaluation results
|
| 263 |
+
"""
|
| 264 |
+
if len(evaluations) < 3:
|
| 265 |
+
st.info("Need at least 3 evaluations for advanced visualizations.")
|
| 266 |
+
return
|
| 267 |
+
|
| 268 |
+
st.subheader("📊 Advanced Analytics")
|
| 269 |
+
|
| 270 |
+
# Create tabs for different visualizations
|
| 271 |
+
tab1, tab2, tab3, tab4 = st.tabs([
|
| 272 |
+
"Calibration", "Performance Over Time",
|
| 273 |
+
"Sentiment Analysis", "Confidence vs Impact"
|
| 274 |
+
])
|
| 275 |
+
|
| 276 |
+
with tab1:
|
| 277 |
+
st.plotly_chart(
|
| 278 |
+
create_calibration_plot(evaluations),
|
| 279 |
+
use_container_width=True
|
| 280 |
+
)
|
| 281 |
+
st.caption("Shows how well confidence scores align with actual accuracy. Points closer to the diagonal line indicate better calibration.")
|
| 282 |
+
|
| 283 |
+
with tab2:
|
| 284 |
+
st.plotly_chart(
|
| 285 |
+
create_performance_over_time(evaluations),
|
| 286 |
+
use_container_width=True
|
| 287 |
+
)
|
| 288 |
+
st.caption("Rolling average of DAS scores and accuracy over time.")
|
| 289 |
+
|
| 290 |
+
with tab3:
|
| 291 |
+
st.plotly_chart(
|
| 292 |
+
create_sentiment_distribution(evaluations),
|
| 293 |
+
use_container_width=True
|
| 294 |
+
)
|
| 295 |
+
st.caption("Distribution of sentiment predictions and their respective performance.")
|
| 296 |
+
|
| 297 |
+
with tab4:
|
| 298 |
+
st.plotly_chart(
|
| 299 |
+
create_confidence_impact_scatter(evaluations),
|
| 300 |
+
use_container_width=True
|
| 301 |
+
)
|
| 302 |
+
st.caption("Relationship between model confidence and market impact, colored by DAS score.")
|