|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
import torch |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"econbert": "climatebert/econbert", |
|
|
"controversy-classification": "climatebert/ClimateControversyBERT_classification", |
|
|
"controversy-bert": "climatebert/ClimateControversyBert", |
|
|
"netzero-reduction": "climatebert/netzero-reduction", |
|
|
"transition-physical": "climatebert/transition-physical", |
|
|
"renewable": "climatebert/renewable", |
|
|
"climate-detector": "climatebert/distilroberta-base-climate-detector", |
|
|
"climate-commitment": "climatebert/distilroberta-base-climate-commitment", |
|
|
"climate-tcfd": "climatebert/distilroberta-base-climate-tcfd", |
|
|
"climate-s": "climatebert/distilroberta-base-climate-s", |
|
|
"climate-specificity": "climatebert/distilroberta-base-climate-specificity", |
|
|
"climate-sentiment": "climatebert/distilroberta-base-climate-sentiment", |
|
|
"environmental-claims": "climatebert/environmental-claims", |
|
|
"climate-f": "climatebert/distilroberta-base-climate-f", |
|
|
"climate-d-s": "climatebert/distilroberta-base-climate-d-s", |
|
|
"climate-d": "climatebert/distilroberta-base-climate-d" |
|
|
} |
|
|
|
|
|
|
|
|
LABEL_MAPS = { |
|
|
"climate-commitment": { |
|
|
"LABEL_0": "Not about climate commitments", |
|
|
"LABEL_1": "About climate commitments", |
|
|
}, |
|
|
"climate-detector": { |
|
|
"LABEL_0": "Not climate-related", |
|
|
"LABEL_1": "Climate-related", |
|
|
}, |
|
|
"climate-sentiment": { |
|
|
"LABEL_0": "Negative", |
|
|
"LABEL_1": "Neutral", |
|
|
"LABEL_2": "Positive", |
|
|
}, |
|
|
"climate-specificity": { |
|
|
"LABEL_0": "Low specificity", |
|
|
"LABEL_1": "Medium specificity", |
|
|
"LABEL_2": "High specificity", |
|
|
}, |
|
|
"netzero-reduction": { |
|
|
"LABEL_0": "No net-zero / reduction commitment", |
|
|
"LABEL_1": "Net-zero / reduction commitment", |
|
|
}, |
|
|
"transition-physical": { |
|
|
"LABEL_0": "Transition risk", |
|
|
"LABEL_1": "Physical risk", |
|
|
}, |
|
|
"renewable": { |
|
|
"LABEL_0": "Not about renewables", |
|
|
"LABEL_1": "About renewables", |
|
|
}, |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
pipelines = {} |
|
|
|
|
|
def load_model(model_key): |
|
|
"""Load pipeline for the selected model with truncation enabled.""" |
|
|
if model_key not in pipelines: |
|
|
repo_id = MODELS[model_key] |
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
pipelines[model_key] = pipeline( |
|
|
"text-classification", |
|
|
model=repo_id, |
|
|
device=device, |
|
|
torch_dtype=torch.float16 if device == 0 else None, |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
) |
|
|
return pipelines[model_key] |
|
|
|
|
|
def predict(model_key, text): |
|
|
"""Run inference on selected model with truncation and readable labels.""" |
|
|
if not text.strip(): |
|
|
return "Please enter some text." |
|
|
|
|
|
try: |
|
|
model = load_model(model_key) |
|
|
results = model(text) |
|
|
|
|
|
label_map = LABEL_MAPS.get(model_key, {}) |
|
|
formatted = "\n".join([ |
|
|
f"{label_map.get(r['label'], r['label'])}: {r['score']:.2f}" |
|
|
for r in results |
|
|
]) |
|
|
|
|
|
return f"Predictions for '{text[:50]}...':\n{formatted}" |
|
|
except Exception as e: |
|
|
return f"Error: {str(e)} (Check input length or model card for details)." |
|
|
|
|
|
|
|
|
with gr.Blocks(title="ClimateBERT Multi-Model Demo") as demo: |
|
|
gr.Markdown("# π ClimateBERT Models Demo\nSelect a model and input text for climate-related analysis (e.g., sentiment, classification).") |
|
|
|
|
|
with gr.Row(): |
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=list(MODELS.keys()), |
|
|
label="Select Model", |
|
|
value=list(MODELS.keys())[0] |
|
|
) |
|
|
text_input = gr.Textbox( |
|
|
label="Input Text", |
|
|
placeholder="E.g., 'Companies must reduce emissions to net zero by 2050.'", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
output = gr.Textbox(label="Output", lines=5) |
|
|
|
|
|
predict_btn = gr.Button("Run Inference") |
|
|
predict_btn.click(predict, inputs=[model_dropdown, text_input], outputs=output) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|