File size: 4,727 Bytes
99bdd87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python3
"""
Expand Vector Database with Comprehensive Data
==============================================

This script loads data from multiple sources to create a comprehensive
vector database with better domain coverage:

1. Full MMLU dataset (all domains, no sampling)
2. MMLU-Pro (harder questions)
3. GPQA Diamond (graduate-level questions)
4. MATH dataset (competition mathematics)

Target: 20,000+ questions across 20+ domains
"""

from pathlib import Path
from benchmark_vector_db import BenchmarkVectorDB
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def expand_database():
    """Build comprehensive vector database"""
    
    logger.info("=" * 60)
    logger.info("Expanding Vector Database with Comprehensive Data")
    logger.info("=" * 60)
    
    # Initialize new database
    db = BenchmarkVectorDB(
        db_path=Path("./data/benchmark_vector_db_expanded"),
        embedding_model="all-MiniLM-L6-v2"
    )
    
    # Build with significantly higher limits
    logger.info("\nPhase 1: Loading MMLU-Pro (harder subset)")
    logger.info("-" * 40)
    mmlu_pro_questions = db.load_mmlu_pro_dataset(max_samples=5000)
    logger.info(f"Loaded {len(mmlu_pro_questions)} MMLU-Pro questions")
    
    logger.info("\nPhase 2: Loading GPQA Diamond (graduate-level)")
    logger.info("-" * 40)
    gpqa_questions = db.load_gpqa_dataset(fetch_real_scores=False)
    logger.info(f"Loaded {len(gpqa_questions)} GPQA questions")
    
    logger.info("\nPhase 3: Loading MATH dataset (competition math)")
    logger.info("-" * 40)
    math_questions = db.load_math_dataset(max_samples=2000)
    logger.info(f"Loaded {len(math_questions)} MATH questions")
    
    # Combine all questions
    all_questions = mmlu_pro_questions + gpqa_questions + math_questions
    logger.info(f"\nTotal questions to index: {len(all_questions)}")
    
    # Index into vector database
    if all_questions:
        logger.info("\nIndexing questions into vector database...")
        logger.info("This may take several minutes...")
        db.index_questions(all_questions)
    
    # Get final statistics
    logger.info("\n" + "=" * 60)
    logger.info("Database Statistics")
    logger.info("=" * 60)
    
    stats = db.get_statistics()
    logger.info(f"\nTotal Questions: {stats['total_questions']}")
    logger.info(f"\nSources:")
    for source, count in stats.get('sources', {}).items():
        logger.info(f"  {source}: {count}")
    
    logger.info(f"\nDomains:")
    for domain, count in sorted(stats.get('domains', {}).items(), key=lambda x: x[1], reverse=True)[:20]:
        logger.info(f"  {domain}: {count}")
    
    logger.info(f"\nDifficulty Levels:")
    for level, count in stats.get('difficulty_levels', {}).items():
        logger.info(f"  {level}: {count}")
    
    logger.info("\n" + "=" * 60)
    logger.info("✅ Database expansion complete!")
    logger.info("=" * 60)
    
    return db, stats


def test_expanded_database(db):
    """Test the expanded database with example queries"""
    
    logger.info("\n" + "=" * 60)
    logger.info("Testing Expanded Database")
    logger.info("=" * 60)
    
    test_prompts = [
        # Hard prompts
        ("Graduate-level physics", "Calculate the quantum correction to the partition function for a 3D harmonic oscillator"),
        ("Abstract mathematics", "Prove that every field is also a ring"),
        ("Competition math", "Find all zeros of the polynomial x^3 + 2x + 2 in Z_7"),
        
        # Easy prompts
        ("Basic arithmetic", "What is 2 + 2?"),
        ("General knowledge", "What is the capital of France?"),
        
        # Domain-specific
        ("Medical reasoning", "Diagnose a patient with acute chest pain"),
        ("Legal knowledge", "Explain the doctrine of precedent in common law"),
        ("Computer science", "Implement a binary search tree"),
    ]
    
    for category, prompt in test_prompts:
        logger.info(f"\n{category}: '{prompt[:50]}...'")
        result = db.query_similar_questions(prompt, k=3)
        logger.info(f"  Risk Level: {result['risk_level']}")
        logger.info(f"  Success Rate: {result['weighted_success_rate']:.1%}")
        logger.info(f"  Recommendation: {result['recommendation']}")


if __name__ == "__main__":
    # Expand database
    db, stats = expand_database()
    
    # Test with example queries
    test_expanded_database(db)
    
    logger.info("\n🎉 All done! You can now use the expanded database.")
    logger.info("To switch to the expanded database, update your demo files:")
    logger.info("  db_path=Path('./data/benchmark_vector_db_expanded')")