from fastapi import FastAPI, BackgroundTasks, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List, Optional import json import logging from datetime import datetime from email.utils import parsedate_to_datetime # Import our modules from scraper import fetch_hazard_tweets, fetch_custom_tweets, get_available_hazards, get_available_locations from classifier import classify_tweets from pg_db import init_db, upsert_hazardous_tweet # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Ocean Hazard Detection API", description="API for detecting ocean hazards from social media posts", version="1.0.0" ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Configure this properly for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize database try: init_db() logger.info("Database initialized successfully") except Exception as e: logger.warning(f"Database initialization failed: {e}. API will work without database persistence.") # Pydantic models class TweetAnalysisRequest(BaseModel): limit: int = 20 query: Optional[str] = None hazard_type: Optional[str] = None location: Optional[str] = None days_back: int = 1 class TweetAnalysisResponse(BaseModel): total_tweets: int hazardous_tweets: int results: List[dict] processing_time: float class HealthResponse(BaseModel): status: str message: str timestamp: str # Health check endpoint @app.get("/", response_model=HealthResponse) def health_check(): """Health check endpoint""" return HealthResponse( status="healthy", message="Ocean Hazard Detection API is running", timestamp=datetime.utcnow().isoformat() ) @app.get("/health", response_model=HealthResponse) def health(): """Alternative health check endpoint""" return health_check() @app.post("/warmup") async def warmup_models(): """Pre-load all models to reduce first request time""" try: logger.info("Starting model warmup...") # Pre-load all models from classifier import get_classifier from ner import get_ner_pipeline from sentiment import get_emotion_classifier from translate import get_translator classifier = get_classifier() ner = get_ner_pipeline() emotion_clf = get_emotion_classifier() translator = get_translator() # Test with sample data test_text = "Test tweet for model warmup" classifier(test_text, ["test", "not test"]) if ner: ner(test_text) emotion_clf(test_text) translator(test_text) logger.info("Model warmup completed successfully") return {"status": "success", "message": "All models loaded and ready"} except Exception as e: logger.error(f"Model warmup failed: {str(e)}") return {"status": "error", "message": str(e)} # Main analysis endpoint @app.post("/analyze", response_model=TweetAnalysisResponse) async def analyze_tweets(request: TweetAnalysisRequest): """ Analyze tweets for ocean hazards - **limit**: Number of tweets to analyze (1-50) - **query**: Custom search query (optional) """ start_time = datetime.utcnow() try: logger.info(f"Starting analysis with limit: {request.limit}") # Fetch tweets based on search type if request.query: # Use custom query if provided from scraper import search_tweets, extract_tweets result = search_tweets(request.query, limit=request.limit) tweets = extract_tweets(result) elif request.hazard_type or request.location: # Use keyword-based search tweets = fetch_custom_tweets( hazard_type=request.hazard_type, location=request.location, limit=request.limit, days_back=request.days_back ) else: # Use default hazard query tweets = fetch_hazard_tweets(limit=request.limit) logger.info(f"Fetched {len(tweets)} tweets") # Classify tweets results = classify_tweets(tweets) logger.info(f"Classified {len(results)} tweets") # Store hazardous tweets in database hazardous_count = 0 try: for r in results: if r.get('hazardous') == 1: hazardous_count += 1 hazards = (r.get('ner') or {}).get('hazards') or [] hazard_type = ", ".join(hazards) if hazards else "unknown" locs = (r.get('ner') or {}).get('locations') or [] if not locs and r.get('location'): locs = [r['location']] location = ", ".join(locs) if locs else "unknown" sentiment = r.get('sentiment') or {"label": "unknown", "score": 0.0} created_at = r.get('created_at') or "" tweet_date = "" tweet_time = "" if created_at: dt = None try: dt = parsedate_to_datetime(created_at) except Exception: dt = None if dt is None and 'T' in created_at: try: iso = created_at.replace('Z', '+00:00') dt = datetime.fromisoformat(iso) except Exception: dt = None if dt is not None: tweet_date = dt.date().isoformat() tweet_time = dt.time().strftime('%H:%M:%S') upsert_hazardous_tweet( tweet_url=r.get('tweet_url') or "", hazard_type=hazard_type, location=location, sentiment_label=sentiment.get('label', 'unknown'), sentiment_score=float(sentiment.get('score', 0.0)), tweet_date=tweet_date, tweet_time=tweet_time, ) logger.info(f"Stored {hazardous_count} hazardous tweets in database") except Exception as db_error: logger.warning(f"Database storage failed: {db_error}. Results will not be persisted.") # Calculate processing time processing_time = (datetime.utcnow() - start_time).total_seconds() return TweetAnalysisResponse( total_tweets=len(results), hazardous_tweets=hazardous_count, results=results, processing_time=processing_time ) except Exception as e: logger.error(f"Analysis failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # Get stored hazardous tweets @app.get("/hazardous-tweets") async def get_hazardous_tweets(limit: int = 100, offset: int = 0): """ Get stored hazardous tweets from database - **limit**: Maximum number of tweets to return (default: 100) - **offset**: Number of tweets to skip (default: 0) """ try: from pg_db import get_conn with get_conn() as conn: with conn.cursor() as cur: cur.execute(""" SELECT tweet_url, hazard_type, location, sentiment_label, sentiment_score, tweet_date, tweet_time, inserted_at FROM hazardous_tweets ORDER BY inserted_at DESC LIMIT %s OFFSET %s """, (limit, offset)) columns = [desc[0] for desc in cur.description] results = [dict(zip(columns, row)) for row in cur.fetchall()] return { "tweets": results, "count": len(results), "limit": limit, "offset": offset } except Exception as e: logger.error(f"Failed to fetch hazardous tweets: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # Get available keywords @app.get("/keywords/hazards") async def get_hazard_keywords(): """Get available hazard types for keyword search""" return { "hazards": get_available_hazards(), "count": len(get_available_hazards()) } @app.get("/keywords/locations") async def get_location_keywords(): """Get available locations for keyword search""" return { "locations": get_available_locations(), "count": len(get_available_locations()) } # Get statistics @app.get("/stats") async def get_stats(): """Get analysis statistics""" try: from pg_db import get_conn with get_conn() as conn: with conn.cursor() as cur: # Total hazardous tweets cur.execute("SELECT COUNT(*) FROM hazardous_tweets") total_hazardous = cur.fetchone()[0] # By hazard type cur.execute(""" SELECT hazard_type, COUNT(*) as count FROM hazardous_tweets GROUP BY hazard_type ORDER BY count DESC """) hazard_types = [{"type": row[0], "count": row[1]} for row in cur.fetchall()] # By location cur.execute(""" SELECT location, COUNT(*) as count FROM hazardous_tweets WHERE location != 'unknown' GROUP BY location ORDER BY count DESC LIMIT 10 """) locations = [{"location": row[0], "count": row[1]} for row in cur.fetchall()] # By sentiment cur.execute(""" SELECT sentiment_label, COUNT(*) as count FROM hazardous_tweets GROUP BY sentiment_label ORDER BY count DESC """) sentiments = [{"sentiment": row[0], "count": row[1]} for row in cur.fetchall()] return { "total_hazardous_tweets": total_hazardous, "hazard_types": hazard_types, "top_locations": locations, "sentiment_distribution": sentiments } except Exception as e: logger.error(f"Failed to fetch statistics: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)