yjernite's picture
yjernite HF Staff
Upload 5 files
1b21566 verified
raw
history blame
4.42 kB
from fastapi import FastAPI, HTTPException, Depends, Query
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Optional
import uvicorn
from contextlib import asynccontextmanager
from data_loader import DataLoader
from models import ArticleResponse, ArticleDetail, FiltersResponse
# Initialize data loader
data_loader = DataLoader()
# Dependency functions for API parameters
def get_filter_params(
document_type: Optional[List[str]] = Query(None, description="Filter by document types"),
author_type: Optional[List[str]] = Query(None, description="Filter by author types"),
min_relevance: Optional[float] = Query(None, ge=0, le=10, description="Minimum AI labor relevance score"),
max_relevance: Optional[float] = Query(None, ge=0, le=10, description="Maximum AI labor relevance score"),
start_date: Optional[str] = Query(None, description="Start date (YYYY-MM-DD)"),
end_date: Optional[str] = Query(None, description="End date (YYYY-MM-DD)"),
topic: Optional[List[str]] = Query(None, description="Filter by document topics"),
search_query: Optional[str] = Query(None, description="Search query for text matching"),
search_type: Optional[str] = Query("exact", description="Search type: 'exact' or 'dense'"),
) -> dict:
return {
'document_types': document_type,
'author_types': author_type,
'min_relevance': min_relevance,
'max_relevance': max_relevance,
'start_date': start_date,
'end_date': end_date,
'topics': topic,
'search_query': search_query,
'search_type': search_type,
}
def get_pagination_params(
page: int = Query(1, ge=1, description="Page number"),
limit: int = Query(20, ge=1, le=100, description="Items per page"),
sort_by: Optional[str] = Query("date", description="Sort by 'date' or 'score'"),
) -> dict:
return {
'page': page,
'limit': limit,
'sort_by': sort_by,
}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
print("Loading dataset from HuggingFace...")
await data_loader.load_dataset()
print(f"Dataset loaded: {len(data_loader.articles)} articles")
yield
# Shutdown (nothing needed)
app = FastAPI(title="Archive Explorer API: AI, Labor and the Economy", version="1.0.0", lifespan=lifespan)
# Enable CORS for frontend
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"http://localhost:5173",
"https://yjernite-labor-archive-backend.hf.space" # Add this line
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {"message": "Archive Explorer API: AI, Labor and the Economy", "articles_count": len(data_loader.articles)}
@app.get("/filters", response_model=FiltersResponse)
async def get_filters():
"""Get all available filter options"""
return data_loader.get_filter_options()
@app.get("/articles", response_model=List[ArticleResponse])
async def get_articles(
pagination: dict = Depends(get_pagination_params),
filters: dict = Depends(get_filter_params),
):
"""Get filtered and paginated articles"""
return data_loader.get_articles(
**pagination,
**filters,
)
@app.get("/articles/count")
async def get_articles_count(
filters: dict = Depends(get_filter_params),
):
"""Get count of articles matching filters"""
return {"count": data_loader.get_articles_count(**filters)}
@app.get("/filter-counts/{filter_type}")
async def get_filter_counts(
filter_type: str,
filters: dict = Depends(get_filter_params),
):
"""Get counts for each option in a specific filter type"""
if filter_type not in ['document_types', 'author_types', 'topics']:
raise HTTPException(status_code=400, detail="Invalid filter type")
counts = data_loader.get_filter_counts(
filter_type=filter_type,
**filters
)
return counts
@app.get("/articles/{article_id}", response_model=ArticleDetail)
async def get_article(article_id: int):
"""Get detailed article by ID"""
return data_loader.get_article_detail(article_id)
@app.get("/test-search")
async def test_search(q: str):
"""Test search functionality"""
return data_loader._search_articles(q, 'exact')
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)