File size: 4,224 Bytes
69212be
 
 
 
139f241
69212be
 
 
 
 
 
 
 
 
 
c3c791b
69212be
 
 
c3c791b
 
 
69212be
 
139f241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3c791b
69212be
 
 
c3c791b
69212be
 
c3c791b
69212be
 
 
 
c3c791b
139f241
 
69212be
 
 
 
139f241
69212be
 
 
 
 
 
139f241
 
 
 
 
 
 
 
69212be
c3c791b
69212be
139f241
69212be
139f241
69212be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139f241
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
from transformers import pipeline
import torch

# List of models
MODELS = {
    "econbert": "climatebert/econbert",
    "controversy-classification": "climatebert/ClimateControversyBERT_classification",
    "controversy-bert": "climatebert/ClimateControversyBert",
    "netzero-reduction": "climatebert/netzero-reduction",
    "transition-physical": "climatebert/transition-physical",
    "renewable": "climatebert/renewable",
    "climate-detector": "climatebert/distilroberta-base-climate-detector",
    "climate-commitment": "climatebert/distilroberta-base-climate-commitment",
    "climate-tcfd": "climatebert/distilroberta-base-climate-tcfd",
    "climate-s": "climatebert/distilroberta-base-climate-s",
    "climate-specificity": "climatebert/distilroberta-base-climate-specificity",
    "climate-sentiment": "climatebert/distilroberta-base-climate-sentiment",
    "environmental-claims": "climatebert/environmental-claims",
    "climate-f": "climatebert/distilroberta-base-climate-f",
    "climate-d-s": "climatebert/distilroberta-base-climate-d-s",
    "climate-d": "climatebert/distilroberta-base-climate-d"
}

# Human-readable label mappings
LABEL_MAPS = {
    "climate-commitment": {
        "LABEL_0": "Not about climate commitments",
        "LABEL_1": "About climate commitments",
    },
    "climate-detector": {
        "LABEL_0": "Not climate-related",
        "LABEL_1": "Climate-related",
    },
    "climate-sentiment": {
        "LABEL_0": "Negative",
        "LABEL_1": "Neutral",
        "LABEL_2": "Positive",
    },
    "climate-specificity": {
        "LABEL_0": "Low specificity",
        "LABEL_1": "Medium specificity",
        "LABEL_2": "High specificity",
    },
    "netzero-reduction": {
        "LABEL_0": "No net-zero / reduction commitment",
        "LABEL_1": "Net-zero / reduction commitment",
    },
    "transition-physical": {
        "LABEL_0": "Transition risk",
        "LABEL_1": "Physical risk",
    },
    "renewable": {
        "LABEL_0": "Not about renewables",
        "LABEL_1": "About renewables",
    },
    # You can expand mappings for other models after checking their model cards
}

# Cache for loaded pipelines
pipelines = {}

def load_model(model_key):
    """Load pipeline for the selected model with truncation enabled."""
    if model_key not in pipelines:
        repo_id = MODELS[model_key]
        device = 0 if torch.cuda.is_available() else -1
        pipelines[model_key] = pipeline(
            "text-classification",
            model=repo_id,
            device=device,
            torch_dtype=torch.float16 if device == 0 else None,
            truncation=True,
            max_length=512
        )
    return pipelines[model_key]

def predict(model_key, text):
    """Run inference on selected model with truncation and readable labels."""
    if not text.strip():
        return "Please enter some text."
    
    try:
        model = load_model(model_key)
        results = model(text)

        label_map = LABEL_MAPS.get(model_key, {})
        formatted = "\n".join([
            f"{label_map.get(r['label'], r['label'])}: {r['score']:.2f}"
            for r in results
        ])

        return f"Predictions for '{text[:50]}...':\n{formatted}"
    except Exception as e:
        return f"Error: {str(e)} (Check input length or model card for details)."

# Gradio interface
with gr.Blocks(title="ClimateBERT Multi-Model Demo") as demo:
    gr.Markdown("# 🌍 ClimateBERT Models Demo\nSelect a model and input text for climate-related analysis (e.g., sentiment, classification).")
    
    with gr.Row():
        model_dropdown = gr.Dropdown(
            choices=list(MODELS.keys()),
            label="Select Model",
            value=list(MODELS.keys())[0]
        )
        text_input = gr.Textbox(
            label="Input Text",
            placeholder="E.g., 'Companies must reduce emissions to net zero by 2050.'",
            lines=2
        )
    
    output = gr.Textbox(label="Output", lines=5)
    
    predict_btn = gr.Button("Run Inference")
    predict_btn.click(predict, inputs=[model_dropdown, text_input], outputs=output)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)