from transformers import pipeline from scraper import fetch_hazard_tweets from translate import translate_to_english from sentiment import classify_emotion_text from ner import extract_hazard_and_locations import json model_name = "cross-encoder/nli-deberta-v3-base" # Lazy loading - only load when needed classifier = None def get_classifier(): """Lazy load classifier to avoid startup delay""" global classifier if classifier is None: classifier = pipeline("zero-shot-classification", model=model_name, framework="pt") return classifier def classify_with_model(tweet_text): """ Classifies a tweet using DeBERTa-v3 cross-encoder for zero-shot classification. Returns 1 if hazardous, else 0. """ if not tweet_text or not tweet_text.strip(): return 0 candidate_labels = ["report of an ocean hazard", "not an ocean hazard"] classifier_instance = get_classifier() result = classifier_instance(tweet_text, candidate_labels) top_label = result['labels'][0] top_score = result['scores'][0] if top_label == "report of an ocean hazard" and top_score > 0.75: return 1 return 0 def classify_tweets(tweets): """ Accepts list of tweet dicts with 'text' field. Pipeline: translate -> classify hazard -> if hazardous, sentiment -> NER. Returns enriched dicts. """ classified = [] for t in tweets: text = t.get('text', '') item = dict(t) # Step 1: Translate ALL tweets first (more efficient) translated = translate_to_english(text) item['translated_text'] = translated # Step 2: Classify using translated text (more accurate) hazardous = classify_with_model(translated) item['hazardous'] = hazardous # Step 3: If hazardous, do additional analysis if hazardous == 1: sentiment = classify_emotion_text(translated) item['sentiment'] = sentiment ner_info = extract_hazard_and_locations(translated) item['ner'] = ner_info else: # For non-hazardous tweets, still extract basic info item['sentiment'] = {"label": "neutral", "score": 0.0} item['ner'] = {"hazards": [], "locations": []} classified.append(item) return classified if __name__ == "__main__": tweets = fetch_hazard_tweets(limit=20) classified = classify_tweets(tweets) print(json.dumps(classified, indent=2, ensure_ascii=False))