darisdzakwanhoesien's picture
More classification
139f241 verified
import gradio as gr
from transformers import pipeline
import torch
# List of models
MODELS = {
"econbert": "climatebert/econbert",
"controversy-classification": "climatebert/ClimateControversyBERT_classification",
"controversy-bert": "climatebert/ClimateControversyBert",
"netzero-reduction": "climatebert/netzero-reduction",
"transition-physical": "climatebert/transition-physical",
"renewable": "climatebert/renewable",
"climate-detector": "climatebert/distilroberta-base-climate-detector",
"climate-commitment": "climatebert/distilroberta-base-climate-commitment",
"climate-tcfd": "climatebert/distilroberta-base-climate-tcfd",
"climate-s": "climatebert/distilroberta-base-climate-s",
"climate-specificity": "climatebert/distilroberta-base-climate-specificity",
"climate-sentiment": "climatebert/distilroberta-base-climate-sentiment",
"environmental-claims": "climatebert/environmental-claims",
"climate-f": "climatebert/distilroberta-base-climate-f",
"climate-d-s": "climatebert/distilroberta-base-climate-d-s",
"climate-d": "climatebert/distilroberta-base-climate-d"
}
# Human-readable label mappings
LABEL_MAPS = {
"climate-commitment": {
"LABEL_0": "Not about climate commitments",
"LABEL_1": "About climate commitments",
},
"climate-detector": {
"LABEL_0": "Not climate-related",
"LABEL_1": "Climate-related",
},
"climate-sentiment": {
"LABEL_0": "Negative",
"LABEL_1": "Neutral",
"LABEL_2": "Positive",
},
"climate-specificity": {
"LABEL_0": "Low specificity",
"LABEL_1": "Medium specificity",
"LABEL_2": "High specificity",
},
"netzero-reduction": {
"LABEL_0": "No net-zero / reduction commitment",
"LABEL_1": "Net-zero / reduction commitment",
},
"transition-physical": {
"LABEL_0": "Transition risk",
"LABEL_1": "Physical risk",
},
"renewable": {
"LABEL_0": "Not about renewables",
"LABEL_1": "About renewables",
},
# You can expand mappings for other models after checking their model cards
}
# Cache for loaded pipelines
pipelines = {}
def load_model(model_key):
"""Load pipeline for the selected model with truncation enabled."""
if model_key not in pipelines:
repo_id = MODELS[model_key]
device = 0 if torch.cuda.is_available() else -1
pipelines[model_key] = pipeline(
"text-classification",
model=repo_id,
device=device,
torch_dtype=torch.float16 if device == 0 else None,
truncation=True,
max_length=512
)
return pipelines[model_key]
def predict(model_key, text):
"""Run inference on selected model with truncation and readable labels."""
if not text.strip():
return "Please enter some text."
try:
model = load_model(model_key)
results = model(text)
label_map = LABEL_MAPS.get(model_key, {})
formatted = "\n".join([
f"{label_map.get(r['label'], r['label'])}: {r['score']:.2f}"
for r in results
])
return f"Predictions for '{text[:50]}...':\n{formatted}"
except Exception as e:
return f"Error: {str(e)} (Check input length or model card for details)."
# Gradio interface
with gr.Blocks(title="ClimateBERT Multi-Model Demo") as demo:
gr.Markdown("# 🌍 ClimateBERT Models Demo\nSelect a model and input text for climate-related analysis (e.g., sentiment, classification).")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
label="Select Model",
value=list(MODELS.keys())[0]
)
text_input = gr.Textbox(
label="Input Text",
placeholder="E.g., 'Companies must reduce emissions to net zero by 2050.'",
lines=2
)
output = gr.Textbox(label="Output", lines=5)
predict_btn = gr.Button("Run Inference")
predict_btn.click(predict, inputs=[model_dropdown, text_input], outputs=output)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)