Spaces:
Running
Running
Upload 5 files
Browse files- data_loader.py +449 -0
- dockerfile +16 -0
- main.py +124 -0
- models.py +36 -0
- requirements.txt +11 -0
data_loader.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
from typing import List, Optional, Dict, Any
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from models import ArticleResponse, ArticleDetail, Argument, FiltersResponse
|
| 5 |
+
from collections import Counter
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
from whoosh import index
|
| 8 |
+
from whoosh.fields import Schema, TEXT, ID
|
| 9 |
+
from whoosh.qparser import QueryParser
|
| 10 |
+
from whoosh.filedb.filestore import RamStorage
|
| 11 |
+
from dateutil import parser as date_parser
|
| 12 |
+
import numpy as np
|
| 13 |
+
from sentence_transformers import SentenceTransformer
|
| 14 |
+
|
| 15 |
+
# Constants
|
| 16 |
+
SEARCH_CACHE_MAX_SIZE = 1000
|
| 17 |
+
LABOR_SCORE_WEIGHT = 0.1 # Weight for labor score in relevance calculation
|
| 18 |
+
DATE_RANGE_START = "2022-01-01"
|
| 19 |
+
DATE_RANGE_END = "2025-12-31"
|
| 20 |
+
|
| 21 |
+
class DataLoader:
|
| 22 |
+
"""
|
| 23 |
+
Handles loading, indexing, and searching of AI labor economy articles.
|
| 24 |
+
|
| 25 |
+
Uses Whoosh for full-text search and maintains in-memory data structures
|
| 26 |
+
for fast filtering and pagination.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
self.dataset = None
|
| 31 |
+
self.articles = []
|
| 32 |
+
self.articles_by_id = {} # ID -> article mapping
|
| 33 |
+
self.filter_options = None
|
| 34 |
+
|
| 35 |
+
# Initialize Whoosh search index for full-text search
|
| 36 |
+
self.search_schema = Schema(
|
| 37 |
+
id=ID(stored=True),
|
| 38 |
+
title=TEXT(stored=False),
|
| 39 |
+
summary=TEXT(stored=False),
|
| 40 |
+
content=TEXT(stored=False) # Combined title + summary for search
|
| 41 |
+
)
|
| 42 |
+
# Create in-memory index using RamStorage
|
| 43 |
+
storage = RamStorage()
|
| 44 |
+
self.search_index = storage.create_index(self.search_schema)
|
| 45 |
+
self.query_parser = QueryParser("content", self.search_schema)
|
| 46 |
+
|
| 47 |
+
# Dense retrieval components (lazy-loaded for efficiency)
|
| 48 |
+
self.embeddings = None # Article embeddings from dataset
|
| 49 |
+
self.embedding_model = None # SentenceTransformer model
|
| 50 |
+
self.model_path = "ibm-granite/granite-embedding-english-r2"
|
| 51 |
+
# Note: Using lru_cache for search caching instead of manual cache management
|
| 52 |
+
|
| 53 |
+
async def load_dataset(self):
|
| 54 |
+
"""Load and process the HuggingFace dataset"""
|
| 55 |
+
# Load dataset
|
| 56 |
+
self.dataset = load_dataset("yjernite/ai-economy-labor-articles-annotated-embed", split="train")
|
| 57 |
+
|
| 58 |
+
# Convert to list of dicts for easier processing
|
| 59 |
+
self.articles = []
|
| 60 |
+
|
| 61 |
+
# Prepare Whoosh index writer
|
| 62 |
+
writer = self.search_index.writer()
|
| 63 |
+
|
| 64 |
+
for i, row in enumerate(self.dataset):
|
| 65 |
+
# Parse date using dateutil (more flexible than pandas)
|
| 66 |
+
date = date_parser.parse(row['date']) if isinstance(row['date'], str) else row['date']
|
| 67 |
+
|
| 68 |
+
# Parse arguments
|
| 69 |
+
arguments = []
|
| 70 |
+
if row.get('arguments'):
|
| 71 |
+
for arg in row['arguments']:
|
| 72 |
+
arguments.append(Argument(
|
| 73 |
+
argument_quote=arg.get('argument_quote', []),
|
| 74 |
+
argument_summary=arg.get('argument_summary', ''),
|
| 75 |
+
argument_source=arg.get('argument_source', ''),
|
| 76 |
+
argument_type=arg.get('argument_type', ''),
|
| 77 |
+
))
|
| 78 |
+
|
| 79 |
+
article = {
|
| 80 |
+
'id': i,
|
| 81 |
+
'title': row.get('title', ''),
|
| 82 |
+
'source': row.get('source', ''),
|
| 83 |
+
'url': row.get('url', ''),
|
| 84 |
+
'date': date,
|
| 85 |
+
'summary': row.get('summary', ''),
|
| 86 |
+
'ai_labor_relevance': row.get('ai_labor_relevance', 0),
|
| 87 |
+
'document_type': row.get('document_type', ''),
|
| 88 |
+
'author_type': row.get('author_type', ''),
|
| 89 |
+
'document_topics': row.get('document_topics', []),
|
| 90 |
+
'text': row.get('text', ''),
|
| 91 |
+
'arguments': arguments,
|
| 92 |
+
}
|
| 93 |
+
self.articles.append(article)
|
| 94 |
+
self.articles_by_id[i] = article
|
| 95 |
+
|
| 96 |
+
# Add to search index
|
| 97 |
+
search_content = f"{article['title']} {article['summary']}"
|
| 98 |
+
writer.add_document(
|
| 99 |
+
id=str(i),
|
| 100 |
+
title=article['title'],
|
| 101 |
+
summary=article['summary'],
|
| 102 |
+
content=search_content
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Commit search index
|
| 106 |
+
writer.commit()
|
| 107 |
+
print(f"DEBUG: Search index populated with {len(self.articles)} articles")
|
| 108 |
+
|
| 109 |
+
# Load pre-computed embeddings for dense retrieval
|
| 110 |
+
print("DEBUG: Loading pre-computed embeddings...")
|
| 111 |
+
raw_embeddings = np.array(self.dataset['embeddings-granite'])
|
| 112 |
+
# Normalize embeddings for cosine similarity
|
| 113 |
+
self.embeddings = raw_embeddings / np.linalg.norm(raw_embeddings, axis=1, keepdims=True)
|
| 114 |
+
print(f"DEBUG: Loaded {len(self.embeddings)} article embeddings")
|
| 115 |
+
|
| 116 |
+
# Build filter options
|
| 117 |
+
self._build_filter_options()
|
| 118 |
+
|
| 119 |
+
def _build_filter_options(self):
|
| 120 |
+
"""Build available filter options from the dataset"""
|
| 121 |
+
document_types = sorted(set(article['document_type'] for article in self.articles if article['document_type']))
|
| 122 |
+
author_types = sorted(set(article['author_type'] for article in self.articles if article['author_type']))
|
| 123 |
+
|
| 124 |
+
# Flatten all topics
|
| 125 |
+
all_topics = []
|
| 126 |
+
for article in self.articles:
|
| 127 |
+
if article['document_topics']:
|
| 128 |
+
all_topics.extend(article['document_topics'])
|
| 129 |
+
topics = sorted(set(all_topics))
|
| 130 |
+
|
| 131 |
+
# Date range - fixed for research period
|
| 132 |
+
min_date = DATE_RANGE_START
|
| 133 |
+
max_date = DATE_RANGE_END
|
| 134 |
+
|
| 135 |
+
# Relevance range
|
| 136 |
+
relevances = [article['ai_labor_relevance'] for article in self.articles if article['ai_labor_relevance'] is not None]
|
| 137 |
+
min_relevance = min(relevances) if relevances else 0
|
| 138 |
+
max_relevance = max(relevances) if relevances else 10
|
| 139 |
+
|
| 140 |
+
self.filter_options = FiltersResponse(
|
| 141 |
+
document_types=document_types,
|
| 142 |
+
author_types=author_types,
|
| 143 |
+
topics=topics,
|
| 144 |
+
date_range={"min_date": min_date, "max_date": max_date},
|
| 145 |
+
relevance_range={"min_relevance": min_relevance, "max_relevance": max_relevance}
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def get_filter_options(self) -> FiltersResponse:
|
| 149 |
+
"""Get all available filter options"""
|
| 150 |
+
return self.filter_options
|
| 151 |
+
|
| 152 |
+
def _filter_articles(
|
| 153 |
+
self,
|
| 154 |
+
document_types: Optional[List[str]] = None,
|
| 155 |
+
author_types: Optional[List[str]] = None,
|
| 156 |
+
min_relevance: Optional[float] = None,
|
| 157 |
+
max_relevance: Optional[float] = None,
|
| 158 |
+
start_date: Optional[str] = None,
|
| 159 |
+
end_date: Optional[str] = None,
|
| 160 |
+
topics: Optional[List[str]] = None,
|
| 161 |
+
search_query: Optional[str] = None,
|
| 162 |
+
search_type: Optional[str] = None,
|
| 163 |
+
) -> List[Dict[str, Any]]:
|
| 164 |
+
"""Filter articles based on criteria"""
|
| 165 |
+
filtered = self.articles
|
| 166 |
+
|
| 167 |
+
if document_types:
|
| 168 |
+
filtered = [a for a in filtered if a['document_type'] in document_types]
|
| 169 |
+
|
| 170 |
+
if author_types:
|
| 171 |
+
filtered = [a for a in filtered if a['author_type'] in author_types]
|
| 172 |
+
|
| 173 |
+
if min_relevance is not None:
|
| 174 |
+
filtered = [a for a in filtered if a['ai_labor_relevance'] >= min_relevance]
|
| 175 |
+
|
| 176 |
+
if max_relevance is not None:
|
| 177 |
+
filtered = [a for a in filtered if a['ai_labor_relevance'] <= max_relevance]
|
| 178 |
+
|
| 179 |
+
if start_date:
|
| 180 |
+
start_dt = date_parser.parse(start_date)
|
| 181 |
+
filtered = [a for a in filtered if a['date'] >= start_dt]
|
| 182 |
+
|
| 183 |
+
if end_date:
|
| 184 |
+
end_dt = date_parser.parse(end_date)
|
| 185 |
+
filtered = [a for a in filtered if a['date'] <= end_dt]
|
| 186 |
+
|
| 187 |
+
if topics:
|
| 188 |
+
filtered = [a for a in filtered if any(topic in a['document_topics'] for topic in topics)]
|
| 189 |
+
|
| 190 |
+
if search_query:
|
| 191 |
+
print(f"DEBUG: Applying search filter for query: '{search_query}' with type: '{search_type}'")
|
| 192 |
+
|
| 193 |
+
if search_type == 'dense':
|
| 194 |
+
# For dense search, get similarity scores for all articles
|
| 195 |
+
dense_scores = self._dense_search_all_articles(search_query)
|
| 196 |
+
dense_score_dict = {idx: score for idx, score in dense_scores}
|
| 197 |
+
|
| 198 |
+
# Attach dense scores to filtered articles and filter by similarity threshold
|
| 199 |
+
filtered_with_scores = []
|
| 200 |
+
for article in filtered:
|
| 201 |
+
article_id = article['id']
|
| 202 |
+
if article_id in dense_score_dict:
|
| 203 |
+
# Create a copy to avoid modifying the original
|
| 204 |
+
article_copy = article.copy()
|
| 205 |
+
article_copy['dense_score'] = dense_score_dict[article_id]
|
| 206 |
+
# Only include articles with meaningful similarity (> 0.8)
|
| 207 |
+
if dense_score_dict[article_id] > 0.8:
|
| 208 |
+
filtered_with_scores.append(article_copy)
|
| 209 |
+
|
| 210 |
+
filtered = filtered_with_scores
|
| 211 |
+
print(f"DEBUG: After dense search filtering: {len(filtered)} articles remaining")
|
| 212 |
+
else:
|
| 213 |
+
# Existing exact search logic - inline the matching check
|
| 214 |
+
search_results = self._search_articles(search_query, search_type)
|
| 215 |
+
filtered = [a for a in filtered if a.get('id') in search_results]
|
| 216 |
+
print(f"DEBUG: After exact search filtering: {len(filtered)} articles remaining")
|
| 217 |
+
|
| 218 |
+
return filtered
|
| 219 |
+
|
| 220 |
+
def _search_articles(self, search_query: str, search_type: Optional[str] = None) -> Dict[int, float]:
|
| 221 |
+
"""Search articles using Whoosh and return article_id -> score mapping
|
| 222 |
+
|
| 223 |
+
Note: Dense search is handled separately in _filter_articles method.
|
| 224 |
+
This method only handles exact/Whoosh search.
|
| 225 |
+
"""
|
| 226 |
+
if not search_query:
|
| 227 |
+
return {}
|
| 228 |
+
|
| 229 |
+
# Use cached Whoosh search (lru_cache handles caching automatically)
|
| 230 |
+
return self._cached_whoosh_search(search_query)
|
| 231 |
+
|
| 232 |
+
@lru_cache(maxsize=SEARCH_CACHE_MAX_SIZE)
|
| 233 |
+
def _cached_whoosh_search(self, search_query: str) -> Dict[int, float]:
|
| 234 |
+
"""Cached version of Whoosh search using lru_cache"""
|
| 235 |
+
return self._whoosh_search(search_query)
|
| 236 |
+
|
| 237 |
+
def _whoosh_search(self, search_query: str) -> Dict[int, float]:
|
| 238 |
+
"""Perform search using Whoosh index"""
|
| 239 |
+
try:
|
| 240 |
+
with self.search_index.searcher() as searcher:
|
| 241 |
+
# Parse query - Whoosh handles tokenization automatically
|
| 242 |
+
query = self.query_parser.parse(search_query)
|
| 243 |
+
results = searcher.search(query, limit=None) # Get all results
|
| 244 |
+
|
| 245 |
+
print(f"DEBUG: Search query '{search_query}' found {len(results)} results")
|
| 246 |
+
|
| 247 |
+
# Return mapping of article_id -> normalized score
|
| 248 |
+
article_scores = {}
|
| 249 |
+
max_score = max((r.score for r in results), default=1.0)
|
| 250 |
+
|
| 251 |
+
for result in results:
|
| 252 |
+
article_id = int(result['id'])
|
| 253 |
+
# Normalize score to 0-1 range
|
| 254 |
+
normalized_score = result.score / max_score if max_score > 0 else 0.0
|
| 255 |
+
article_scores[article_id] = normalized_score
|
| 256 |
+
|
| 257 |
+
print(f"DEBUG: Returning {len(article_scores)} scored articles")
|
| 258 |
+
return article_scores
|
| 259 |
+
except Exception as e:
|
| 260 |
+
print(f"Search error: {e}")
|
| 261 |
+
return {}
|
| 262 |
+
|
| 263 |
+
def _initialize_embedding_model(self):
|
| 264 |
+
"""Lazy initialization of embedding model (CPU-only)"""
|
| 265 |
+
if self.embedding_model is None:
|
| 266 |
+
print("DEBUG: Initializing embedding model (CPU-only)...")
|
| 267 |
+
# Force CPU usage and disable problematic features
|
| 268 |
+
import os
|
| 269 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
| 270 |
+
|
| 271 |
+
# Initialize model with CPU device and specific config
|
| 272 |
+
self.embedding_model = SentenceTransformer(
|
| 273 |
+
self.model_path,
|
| 274 |
+
device='cpu',
|
| 275 |
+
model_kwargs={
|
| 276 |
+
'dtype': 'float32', # Fixed deprecation warning
|
| 277 |
+
'attn_implementation': 'eager' # Use eager attention instead of flash attention
|
| 278 |
+
}
|
| 279 |
+
)
|
| 280 |
+
print("DEBUG: Embedding model initialized")
|
| 281 |
+
|
| 282 |
+
@lru_cache(maxsize=100) # Cache encoded queries (smaller cache for this)
|
| 283 |
+
def _encode_query_cached(self, query: str) -> tuple:
|
| 284 |
+
"""Cache-friendly version of query encoding (returns tuple for hashing)"""
|
| 285 |
+
embedding = self._encode_query_internal(query)
|
| 286 |
+
return tuple(embedding.tolist()) # Convert to tuple for caching
|
| 287 |
+
|
| 288 |
+
def _encode_query(self, query: str) -> np.ndarray:
|
| 289 |
+
"""Encode a query string into an embedding vector"""
|
| 290 |
+
cached_result = self._encode_query_cached(query)
|
| 291 |
+
return np.array(cached_result) # Convert back to numpy array
|
| 292 |
+
|
| 293 |
+
def _encode_query_internal(self, query: str) -> np.ndarray:
|
| 294 |
+
"""Internal method that does the actual encoding"""
|
| 295 |
+
self._initialize_embedding_model()
|
| 296 |
+
query_embedding = self.embedding_model.encode([query])
|
| 297 |
+
# Normalize for cosine similarity
|
| 298 |
+
return query_embedding[0] / np.linalg.norm(query_embedding[0])
|
| 299 |
+
|
| 300 |
+
def _dense_search_all_articles(self, query: str, k: int = None) -> List[tuple]:
|
| 301 |
+
"""
|
| 302 |
+
Perform dense retrieval across ALL articles and return (index, score) pairs.
|
| 303 |
+
This computes all similarities upfront for maximum flexibility.
|
| 304 |
+
"""
|
| 305 |
+
if self.embeddings is None:
|
| 306 |
+
print("ERROR: Embeddings not loaded")
|
| 307 |
+
return []
|
| 308 |
+
|
| 309 |
+
print(f"DEBUG: Performing dense search for query: '{query}'")
|
| 310 |
+
|
| 311 |
+
# Encode query
|
| 312 |
+
query_embed = self._encode_query(query)
|
| 313 |
+
|
| 314 |
+
# Compute similarities with ALL articles
|
| 315 |
+
similarities = np.dot(self.embeddings, query_embed)
|
| 316 |
+
|
| 317 |
+
# Get all articles with their similarity scores
|
| 318 |
+
article_scores = [(i, float(similarities[i])) for i in range(len(similarities))]
|
| 319 |
+
|
| 320 |
+
# Sort by similarity (highest first)
|
| 321 |
+
article_scores.sort(key=lambda x: x[1], reverse=True)
|
| 322 |
+
|
| 323 |
+
# Apply k limit if specified
|
| 324 |
+
if k is not None:
|
| 325 |
+
article_scores = article_scores[:k]
|
| 326 |
+
|
| 327 |
+
print(f"DEBUG: Dense search found {len(article_scores)} scored articles")
|
| 328 |
+
return article_scores
|
| 329 |
+
|
| 330 |
+
def _calculate_query_score(self, article: Dict[str, Any], search_query: str, search_type: Optional[str] = None) -> float:
|
| 331 |
+
"""Calculate query relevance score based on search type"""
|
| 332 |
+
if not search_query:
|
| 333 |
+
return 0.0
|
| 334 |
+
|
| 335 |
+
if search_type == 'dense':
|
| 336 |
+
# For dense search, return the pre-computed similarity score
|
| 337 |
+
return article.get('dense_score', 0.0)
|
| 338 |
+
else:
|
| 339 |
+
# Existing exact search logic using Whoosh
|
| 340 |
+
search_results = self._search_articles(search_query, search_type)
|
| 341 |
+
article_id = article.get('id')
|
| 342 |
+
# Return Whoosh score or 0.0 if not found
|
| 343 |
+
return search_results.get(article_id, 0.0)
|
| 344 |
+
|
| 345 |
+
def _sort_by_relevance(self, articles: List[Dict[str, Any]], search_query: str, search_type: str = 'exact') -> List[Dict[str, Any]]:
|
| 346 |
+
"""Sort articles by relevance score (query score + labor score)"""
|
| 347 |
+
def relevance_key(article):
|
| 348 |
+
query_score = self._calculate_query_score(article, search_query, search_type)
|
| 349 |
+
labor_score = article.get('ai_labor_relevance', 0) / 10.0 # Normalize to 0-1
|
| 350 |
+
# Prioritize query score but include labor score as tiebreaker
|
| 351 |
+
return query_score + (labor_score * LABOR_SCORE_WEIGHT)
|
| 352 |
+
|
| 353 |
+
return sorted(articles, key=relevance_key, reverse=True)
|
| 354 |
+
|
| 355 |
+
def get_articles(
|
| 356 |
+
self,
|
| 357 |
+
page: int = 1,
|
| 358 |
+
limit: int = 20,
|
| 359 |
+
**filters
|
| 360 |
+
) -> List[ArticleResponse]:
|
| 361 |
+
"""Get filtered and paginated articles"""
|
| 362 |
+
# Extract sort_by, search_query, and search_type for special handling
|
| 363 |
+
sort_by = filters.pop('sort_by', 'date')
|
| 364 |
+
search_query = filters.get('search_query')
|
| 365 |
+
search_type = filters.get('search_type', 'exact')
|
| 366 |
+
|
| 367 |
+
filtered_articles = self._filter_articles(**filters)
|
| 368 |
+
|
| 369 |
+
# Apply sorting
|
| 370 |
+
if sort_by == 'score' and search_query:
|
| 371 |
+
# Sort by query relevance score descending, then by labor score
|
| 372 |
+
filtered_articles = self._sort_by_relevance(filtered_articles, search_query, search_type)
|
| 373 |
+
else:
|
| 374 |
+
# Sort by date (oldest first) - default
|
| 375 |
+
filtered_articles.sort(key=lambda x: x['date'], reverse=False)
|
| 376 |
+
|
| 377 |
+
# Paginate
|
| 378 |
+
start_idx = (page - 1) * limit
|
| 379 |
+
end_idx = start_idx + limit
|
| 380 |
+
page_articles = filtered_articles[start_idx:end_idx]
|
| 381 |
+
|
| 382 |
+
# Convert to response models - use the original ID from the sorted/filtered results
|
| 383 |
+
return [
|
| 384 |
+
ArticleResponse(
|
| 385 |
+
id=article['id'],
|
| 386 |
+
title=article['title'],
|
| 387 |
+
source=article['source'],
|
| 388 |
+
url=article['url'],
|
| 389 |
+
date=article['date'],
|
| 390 |
+
summary=article['summary'],
|
| 391 |
+
ai_labor_relevance=article['ai_labor_relevance'],
|
| 392 |
+
query_score=self._calculate_query_score(article, search_query or '', search_type),
|
| 393 |
+
document_type=article['document_type'],
|
| 394 |
+
author_type=article['author_type'],
|
| 395 |
+
document_topics=article['document_topics'],
|
| 396 |
+
)
|
| 397 |
+
for article in page_articles
|
| 398 |
+
]
|
| 399 |
+
|
| 400 |
+
def get_articles_count(self, **filters) -> int:
|
| 401 |
+
"""Get count of articles matching filters"""
|
| 402 |
+
filtered_articles = self._filter_articles(**filters)
|
| 403 |
+
return len(filtered_articles)
|
| 404 |
+
|
| 405 |
+
def get_filter_counts(self, filter_type: str, **filters) -> Dict[str, int]:
|
| 406 |
+
"""Get counts for each option in a specific filter type, given other filters"""
|
| 407 |
+
# Remove the current filter type from filters to avoid circular filtering
|
| 408 |
+
filters_copy = filters.copy()
|
| 409 |
+
filters_copy.pop(filter_type, None)
|
| 410 |
+
|
| 411 |
+
# Get base filtered articles (without the current filter type)
|
| 412 |
+
base_filtered = self._filter_articles(**filters_copy)
|
| 413 |
+
|
| 414 |
+
# Extract values based on filter type and count with Counter
|
| 415 |
+
if filter_type == 'document_types':
|
| 416 |
+
values = [article.get('document_type') for article in base_filtered
|
| 417 |
+
if article.get('document_type')]
|
| 418 |
+
elif filter_type == 'author_types':
|
| 419 |
+
values = [article.get('author_type') for article in base_filtered
|
| 420 |
+
if article.get('author_type')]
|
| 421 |
+
elif filter_type == 'topics':
|
| 422 |
+
values = [topic for article in base_filtered
|
| 423 |
+
for topic in article.get('document_topics', []) if topic]
|
| 424 |
+
else:
|
| 425 |
+
return {}
|
| 426 |
+
|
| 427 |
+
return dict(Counter(values))
|
| 428 |
+
|
| 429 |
+
def get_article_detail(self, article_id: int) -> ArticleDetail:
|
| 430 |
+
"""Get detailed article by ID"""
|
| 431 |
+
if article_id not in self.articles_by_id:
|
| 432 |
+
raise ValueError(f"Article ID {article_id} not found")
|
| 433 |
+
|
| 434 |
+
article = self.articles_by_id[article_id]
|
| 435 |
+
return ArticleDetail(
|
| 436 |
+
id=article['id'],
|
| 437 |
+
title=article['title'],
|
| 438 |
+
source=article['source'],
|
| 439 |
+
url=article['url'],
|
| 440 |
+
date=article['date'],
|
| 441 |
+
summary=article['summary'],
|
| 442 |
+
ai_labor_relevance=article['ai_labor_relevance'],
|
| 443 |
+
query_score=0.0, # Detail view doesn't have search context
|
| 444 |
+
document_type=article['document_type'],
|
| 445 |
+
author_type=article['author_type'],
|
| 446 |
+
document_topics=article['document_topics'],
|
| 447 |
+
text=article['text'],
|
| 448 |
+
arguments=article['arguments'],
|
| 449 |
+
)
|
dockerfile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Start from a standard Python 3.9 image
|
| 2 |
+
FROM python:3.9-slim
|
| 3 |
+
|
| 4 |
+
# Set the working directory inside the container
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Copy the requirements file and install dependencies
|
| 8 |
+
# This is done first to leverage Docker's layer caching
|
| 9 |
+
COPY requirements.txt .
|
| 10 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
+
|
| 12 |
+
# Copy the rest of your backend application code
|
| 13 |
+
COPY . .
|
| 14 |
+
|
| 15 |
+
# Command to run your FastAPI app on the port configured in your README.md
|
| 16 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
main.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException, Depends, Query
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
import uvicorn
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
+
from data_loader import DataLoader
|
| 7 |
+
from models import ArticleResponse, ArticleDetail, FiltersResponse
|
| 8 |
+
|
| 9 |
+
# Initialize data loader
|
| 10 |
+
data_loader = DataLoader()
|
| 11 |
+
|
| 12 |
+
# Dependency functions for API parameters
|
| 13 |
+
def get_filter_params(
|
| 14 |
+
document_type: Optional[List[str]] = Query(None, description="Filter by document types"),
|
| 15 |
+
author_type: Optional[List[str]] = Query(None, description="Filter by author types"),
|
| 16 |
+
min_relevance: Optional[float] = Query(None, ge=0, le=10, description="Minimum AI labor relevance score"),
|
| 17 |
+
max_relevance: Optional[float] = Query(None, ge=0, le=10, description="Maximum AI labor relevance score"),
|
| 18 |
+
start_date: Optional[str] = Query(None, description="Start date (YYYY-MM-DD)"),
|
| 19 |
+
end_date: Optional[str] = Query(None, description="End date (YYYY-MM-DD)"),
|
| 20 |
+
topic: Optional[List[str]] = Query(None, description="Filter by document topics"),
|
| 21 |
+
search_query: Optional[str] = Query(None, description="Search query for text matching"),
|
| 22 |
+
search_type: Optional[str] = Query("exact", description="Search type: 'exact' or 'dense'"),
|
| 23 |
+
) -> dict:
|
| 24 |
+
return {
|
| 25 |
+
'document_types': document_type,
|
| 26 |
+
'author_types': author_type,
|
| 27 |
+
'min_relevance': min_relevance,
|
| 28 |
+
'max_relevance': max_relevance,
|
| 29 |
+
'start_date': start_date,
|
| 30 |
+
'end_date': end_date,
|
| 31 |
+
'topics': topic,
|
| 32 |
+
'search_query': search_query,
|
| 33 |
+
'search_type': search_type,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def get_pagination_params(
|
| 37 |
+
page: int = Query(1, ge=1, description="Page number"),
|
| 38 |
+
limit: int = Query(20, ge=1, le=100, description="Items per page"),
|
| 39 |
+
sort_by: Optional[str] = Query("date", description="Sort by 'date' or 'score'"),
|
| 40 |
+
) -> dict:
|
| 41 |
+
return {
|
| 42 |
+
'page': page,
|
| 43 |
+
'limit': limit,
|
| 44 |
+
'sort_by': sort_by,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
@asynccontextmanager
|
| 48 |
+
async def lifespan(app: FastAPI):
|
| 49 |
+
# Startup
|
| 50 |
+
print("Loading dataset from HuggingFace...")
|
| 51 |
+
await data_loader.load_dataset()
|
| 52 |
+
print(f"Dataset loaded: {len(data_loader.articles)} articles")
|
| 53 |
+
yield
|
| 54 |
+
# Shutdown (nothing needed)
|
| 55 |
+
|
| 56 |
+
app = FastAPI(title="Archive Explorer API: AI, Labor and the Economy", version="1.0.0", lifespan=lifespan)
|
| 57 |
+
|
| 58 |
+
# Enable CORS for frontend
|
| 59 |
+
app.add_middleware(
|
| 60 |
+
CORSMiddleware,
|
| 61 |
+
allow_origins=[
|
| 62 |
+
"http://localhost:3000",
|
| 63 |
+
"http://localhost:5173",
|
| 64 |
+
"https://yjernite-labor-archive-backend.hf.space" # Add this line
|
| 65 |
+
],
|
| 66 |
+
allow_credentials=True,
|
| 67 |
+
allow_methods=["*"],
|
| 68 |
+
allow_headers=["*"],
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
@app.get("/")
|
| 72 |
+
async def root():
|
| 73 |
+
return {"message": "Archive Explorer API: AI, Labor and the Economy", "articles_count": len(data_loader.articles)}
|
| 74 |
+
|
| 75 |
+
@app.get("/filters", response_model=FiltersResponse)
|
| 76 |
+
async def get_filters():
|
| 77 |
+
"""Get all available filter options"""
|
| 78 |
+
return data_loader.get_filter_options()
|
| 79 |
+
|
| 80 |
+
@app.get("/articles", response_model=List[ArticleResponse])
|
| 81 |
+
async def get_articles(
|
| 82 |
+
pagination: dict = Depends(get_pagination_params),
|
| 83 |
+
filters: dict = Depends(get_filter_params),
|
| 84 |
+
):
|
| 85 |
+
"""Get filtered and paginated articles"""
|
| 86 |
+
return data_loader.get_articles(
|
| 87 |
+
**pagination,
|
| 88 |
+
**filters,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
@app.get("/articles/count")
|
| 92 |
+
async def get_articles_count(
|
| 93 |
+
filters: dict = Depends(get_filter_params),
|
| 94 |
+
):
|
| 95 |
+
"""Get count of articles matching filters"""
|
| 96 |
+
return {"count": data_loader.get_articles_count(**filters)}
|
| 97 |
+
|
| 98 |
+
@app.get("/filter-counts/{filter_type}")
|
| 99 |
+
async def get_filter_counts(
|
| 100 |
+
filter_type: str,
|
| 101 |
+
filters: dict = Depends(get_filter_params),
|
| 102 |
+
):
|
| 103 |
+
"""Get counts for each option in a specific filter type"""
|
| 104 |
+
if filter_type not in ['document_types', 'author_types', 'topics']:
|
| 105 |
+
raise HTTPException(status_code=400, detail="Invalid filter type")
|
| 106 |
+
|
| 107 |
+
counts = data_loader.get_filter_counts(
|
| 108 |
+
filter_type=filter_type,
|
| 109 |
+
**filters
|
| 110 |
+
)
|
| 111 |
+
return counts
|
| 112 |
+
|
| 113 |
+
@app.get("/articles/{article_id}", response_model=ArticleDetail)
|
| 114 |
+
async def get_article(article_id: int):
|
| 115 |
+
"""Get detailed article by ID"""
|
| 116 |
+
return data_loader.get_article_detail(article_id)
|
| 117 |
+
|
| 118 |
+
@app.get("/test-search")
|
| 119 |
+
async def test_search(q: str):
|
| 120 |
+
"""Test search functionality"""
|
| 121 |
+
return data_loader._search_articles(q, 'exact')
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
models.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List, Dict
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
class Argument(BaseModel):
|
| 6 |
+
argument_quote: List[str]
|
| 7 |
+
argument_summary: str
|
| 8 |
+
argument_source: str
|
| 9 |
+
argument_type: str
|
| 10 |
+
|
| 11 |
+
class ArticleResponse(BaseModel):
|
| 12 |
+
"""Response model for article list (minimal data for cards)"""
|
| 13 |
+
id: int
|
| 14 |
+
title: str
|
| 15 |
+
source: str
|
| 16 |
+
url: str
|
| 17 |
+
date: datetime
|
| 18 |
+
summary: str
|
| 19 |
+
ai_labor_relevance: float
|
| 20 |
+
query_score: float = 0.0
|
| 21 |
+
document_type: str
|
| 22 |
+
author_type: str
|
| 23 |
+
document_topics: List[str]
|
| 24 |
+
|
| 25 |
+
class ArticleDetail(ArticleResponse):
|
| 26 |
+
"""Response model for full article details (extends ArticleResponse)"""
|
| 27 |
+
text: str
|
| 28 |
+
arguments: List[Argument]
|
| 29 |
+
|
| 30 |
+
class FiltersResponse(BaseModel):
|
| 31 |
+
"""Available filter options"""
|
| 32 |
+
document_types: List[str]
|
| 33 |
+
author_types: List[str]
|
| 34 |
+
topics: List[str]
|
| 35 |
+
date_range: Dict[str, str] # min_date, max_date
|
| 36 |
+
relevance_range: Dict[str, float] # min_relevance, max_relevance
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
datasets
|
| 4 |
+
pydantic
|
| 5 |
+
python-multipart
|
| 6 |
+
pyarrow>=12.0.0
|
| 7 |
+
whoosh>=2.7.4
|
| 8 |
+
python-dateutil>=2.8.2
|
| 9 |
+
sentence-transformers>=2.2.2
|
| 10 |
+
torch>=2.0.0
|
| 11 |
+
numpy>=1.21.0
|