import torch import gradio as gr from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification from rapidfuzz import process # ------------------ Load models once ------------------ # Sentiment model sentiment_pipeline = pipeline( "sentiment-analysis", model="sreejith8100/indian_output", tokenizer="sreejith8100/indian_output", device=0 if torch.cuda.is_available() else -1 ) # NER model ner_tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicNER", use_fast=True) ner_model = AutoModelForTokenClassification.from_pretrained("ai4bharat/IndicNER") ner_pipeline = pipeline( "ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple", device=0 if torch.cuda.is_available() else -1 ) # Canonical entity list CANONICAL_ENTITIES = [ "V Abdurahiman / വി അബ്ദുറഹിമാൻ", "P A Mohamed Riyas / പി എ മുഹമ്മദ് റിയാസ്", "P Rajeev / പി രാജീവ്", "Saji Cherian / സജി ചെറിയാൻ", "Roshy Augustine / റോഷി ഓഗസ്റ്റിൻ", "R Bindu / ആർ ബിന്ദു", "A K Saseendran / എ കെ സസീന്ദ്രൻ", "O R Kelu / ഒ ആർ കെലു", "J Chinchurani / ജെ ചിഞ്ചുറാണി", "K N Balagopal / കെ എൻ ബാലഗോപാൽ", "K Krishnankutty / കെ കൃഷ്ണൻകുട്ടി", "Veena George / വീണാ ജോർജ്", "Antony Raju / ആന്റണി രാജു", "K Rajan / കെ രാജൻ", "M B Rajesh / എം ബി രാജേഷ്", "Chittayam Gopakumar / ചിറ്റയം ഗോപകുമാർ", "K Radhakrishnan / കെ രാധാകൃഷ്ണൻ", "Pinarayi Vijayan / പിണറായി വിജയൻ", "V Sivankutty / വി ശിവൻകുട്ടി", "K K Shailaja / കെ കെ ശൈലജ" ] def map_entity(entity_text, known_entities=CANONICAL_ENTITIES, threshold=70): match, score, _ = process.extractOne(entity_text, known_entities) if score >= threshold: return match return None # Map raw model labels to readable ones label_map = { "LABEL_0": "POSITIVE", "LABEL_1": "NEGATIVE", "LABEL_2": "NEUTRAL" } # ------------------ Prediction function ------------------ def predict(sentence): # Run sentiment sent_pred = sentiment_pipeline(sentence)[0] human_label = label_map.get(sent_pred["label"], sent_pred["label"]) # map it # Run NER + map entities = ner_pipeline(sentence) mapped_entities = [map_entity(ent["word"]) for ent in entities if map_entity(ent["word"])] return { "sentence": sentence, "prediction": human_label, # use mapped label "score": float(sent_pred["score"]), "mapped_entities": list(set(mapped_entities)) } # ------------------ Gradio Interface ------------------ demo = gr.Interface( fn=predict, inputs=gr.Textbox(label="Enter a sentence"), outputs=gr.JSON(label="Result"), title="Entity + Sentiment Analysis", description="Upload a sentence in Malayalam/English. The app detects entities and predicts sentiment." ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)