Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import plotly.express as px | |
| CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results" | |
| CITATION_BUTTON_TEXT = r"""@misc{aienergyscore-leaderboard, | |
| author = {Sasha Luccioni and Boris Gamazaychikov and Emma Strubell and Sara Hooker and Yacine Jernite and Carole-Jean Wu and Margaret Mitchell}, | |
| title = {AI Energy Score Leaderboard - February 2025}, | |
| year = {2025}, | |
| publisher = {Hugging Face}, | |
| howpublished = "\url{https://huggingface.co/spaces/AIEnergyScore/Leaderboard}", | |
| }""" | |
| # List of tasks (CSV filenames) | |
| tasks = [ | |
| 'asr.csv', | |
| 'object_detection.csv', | |
| 'text_classification.csv', | |
| 'image_captioning.csv', | |
| 'question_answering.csv', | |
| 'text_generation.csv', | |
| 'image_classification.csv', | |
| 'sentence_similarity.csv', | |
| 'image_generation.csv', | |
| 'summarization.csv' | |
| ] | |
| def format_stars(score): | |
| try: | |
| score_int = int(score) | |
| except Exception: | |
| score_int = 0 | |
| # Render stars in black with a slightly larger font | |
| return f'<span style="color: black !important; font-size:1.5em !important;">{"★" * score_int}</span>' | |
| def make_link(mname): | |
| parts = str(mname).split('/') | |
| display_name = parts[1] if len(parts) > 1 else mname | |
| return f'[{display_name}](https://huggingface.co/{mname})' | |
| # --- Plot Functions (Bar Chart) --- | |
| def get_plots(task): | |
| df = pd.read_csv('data/energy/' + task) | |
| if df.columns[0].startswith("Unnamed:"): | |
| df = df.iloc[:, 1:] | |
| # Use the raw numeric value from the CSV for GPU Energy | |
| df['total_gpu_energy'] = pd.to_numeric(df['total_gpu_energy'], errors='raise') | |
| df['energy_score'] = df['energy_score'].astype(int).astype(str) | |
| # Create a display model column for labeling | |
| df['Display Model'] = df['model'].apply(lambda m: m.split('/')[-1]) | |
| # Use the energy score to control color | |
| color_map = {"1": "red", "2": "orange", "3": "yellow", "4": "lightgreen", "5": "green"} | |
| # Now plot as a bar chart | |
| fig = px.bar( | |
| df, | |
| x="Display Model", | |
| y="total_gpu_energy", | |
| color="energy_score", | |
| custom_data=['energy_score'], | |
| height=500, | |
| width=800, | |
| color_discrete_map=color_map | |
| ) | |
| # Update hover text to show the model and GPU Energy (with 4 decimals) | |
| fig.update_traces( | |
| hovertemplate="<br>".join([ | |
| "Model: %{x}", | |
| "GPU Energy (Wh): %{y:.4f}", | |
| "Energy Score: %{customdata[0]}" | |
| ]) | |
| ) | |
| fig.update_layout( | |
| xaxis_title="Model", | |
| yaxis_title="GPU Energy (Wh)", | |
| yaxis_tickformat=".4f", # Add this line to format y-axis ticks - might not be needed for bar chart | |
| yaxis = dict( | |
| tickformat=".4f" # Ensure tickformat is set within yaxis dict as well - might not be needed for bar chart | |
| ) | |
| ) | |
| return fig | |
| def get_all_plots(): | |
| all_df = pd.DataFrame() | |
| for task in tasks: | |
| df = pd.read_csv('data/energy/' + task) | |
| if df.columns[0].startswith("Unnamed:"): | |
| df = df.iloc[:, 1:] | |
| df['total_gpu_energy'] = pd.to_numeric(df['total_gpu_energy'], errors='raise') | |
| df['energy_score'] = df['energy_score'].astype(int).astype(str) | |
| df['Display Model'] = df['model'].apply(lambda m: m.split('/')[-1]) | |
| all_df = pd.concat([all_df, df], ignore_index=True) | |
| all_df = all_df.drop_duplicates(subset=['model']) | |
| color_map = {"1": "red", "2": "orange", "3": "yellow", "4": "lightgreen", "5": "green"} | |
| fig = px.bar( | |
| all_df, | |
| x="Display Model", | |
| y="total_gpu_energy", | |
| color="energy_score", | |
| custom_data=['energy_score'], | |
| height=500, | |
| width=800, | |
| color_discrete_map=color_map | |
| ) | |
| fig.update_traces( | |
| hovertemplate="<br>".join([ | |
| "Model: %{x}", | |
| "GPU Energy (Wh): %{y:.4f}", | |
| "Energy Score: %{customdata[0]}" | |
| ]) | |
| ) | |
| fig.update_layout( | |
| xaxis_title="Model", | |
| yaxis_title="GPU Energy (Wh)", | |
| yaxis_tickformat=".4f", # Add this line to format y-axis ticks - might not be needed for bar chart | |
| yaxis = dict( | |
| tickformat=".4f" # Ensure tickformat is set within yaxis dict as well - might not be needed for bar chart | |
| ) | |
| ) | |
| return fig | |
| # --- New functions for Text Generation filtering by model class (with Bar Chart) --- | |
| def get_text_generation_plots(model_class): | |
| df = pd.read_csv('data/energy/text_generation.csv') | |
| if df.columns[0].startswith("Unnamed:"): | |
| df = df.iloc[:, 1:] | |
| # Filter by the selected model class if the "class" column exists | |
| if 'class' in df.columns: | |
| df = df[df['class'] == model_class] | |
| df['total_gpu_energy'] = pd.to_numeric(df['total_gpu_energy'], errors='raise') | |
| df['energy_score'] = df['energy_score'].astype(int).astype(str) | |
| df['Display Model'] = df['model'].apply(lambda m: m.split('/')[-1]) | |
| color_map = {"1": "red", "2": "orange", "3": "yellow", "4": "lightgreen", "5": "green"} | |
| fig = px.bar( | |
| df, | |
| x="Display Model", | |
| y="total_gpu_energy", | |
| color="energy_score", | |
| custom_data=['energy_score'], | |
| height=500, | |
| width=800, | |
| color_discrete_map=color_map | |
| ) | |
| fig.update_traces( | |
| hovertemplate="<br>".join([ | |
| "Model: %{x}", | |
| "GPU Energy (Wh): %{y:.4f}", | |
| "Energy Score: %{customdata[0]}" | |
| ]) | |
| ) | |
| fig.update_layout( | |
| xaxis_title="Model", | |
| yaxis_title="GPU Energy (Wh)", | |
| yaxis_tickformat=".4f", # Add this line to format y-axis ticks - might not be needed for bar chart | |
| yaxis = dict( | |
| tickformat=".4f" # Ensure tickformat is set within yaxis dict as well - might not be needed for bar chart | |
| ) | |
| ) | |
| return fig | |
| # --- Leaderboard Table Functions and Gradio Interface are unchanged --- | |
| # (Keep the rest of the code same as previous response) | |
| def get_model_names(task): | |
| df = pd.read_csv('data/energy/' + task) | |
| if df.columns[0].startswith("Unnamed:"): | |
| df = df.iloc[:, 1:] | |
| df['energy_score'] = df['energy_score'].astype(int) | |
| # For leaderboard display, format GPU Energy to 4 decimals | |
| df['GPU Energy (Wh)'] = pd.to_numeric(df['total_gpu_energy'], errors='raise').apply(lambda x: f"{x:.4f}") | |
| df['Model'] = df['model'].apply(make_link) | |
| df['Score'] = df['energy_score'].apply(format_stars) | |
| # Remove any Class column if it exists | |
| df = df[['Model', 'GPU Energy (Wh)', 'Score']] | |
| df = df.sort_values(by='GPU Energy (Wh)') | |
| return df | |
| def get_all_model_names(): | |
| all_df = pd.DataFrame() | |
| for task in tasks: | |
| df = pd.read_csv('data/energy/' + task) | |
| df['energy_score'] = df['energy_score'].astype(int) | |
| df['GPU Energy (Wh)'] = pd.to_numeric(df['total_gpu_energy'], errors='raise').apply(lambda x: f"{x:.4f}") | |
| df['Model'] = df['model'].apply(make_link) | |
| df['Score'] = df['energy_score'].apply(format_stars) | |
| all_df = pd.concat([all_df, df], ignore_index=True) | |
| all_df = all_df.drop_duplicates(subset=['model']) | |
| all_df = all_df.sort_values(by='GPU Energy (Wh)') | |
| return all_df[['Model', 'GPU Energy (Wh)', 'Score']] | |
| def get_text_generation_model_names(model_class): | |
| df = pd.read_csv('data/energy/text_generation.csv') | |
| if df.columns[0].startswith("Unnamed:"): | |
| df = df.iloc[:, 1:] | |
| if 'class' in df.columns: | |
| df = df[df['class'] == model_class] | |
| df['energy_score'] = df['energy_score'].astype(int) | |
| df['GPU Energy (Wh)'] = pd.to_numeric(df['total_gpu_energy'], errors='raise').apply(lambda x: f"{x:.4f}") | |
| df['Model'] = df['model'].apply(make_link) | |
| df['Score'] = df['energy_score'].apply(format_stars) | |
| # Remove the Class column if it exists | |
| df = df[['Model', 'GPU Energy (Wh)', 'Score']] | |
| df = df.sort_values(by='GPU Energy (Wh)') | |
| return df | |
| def update_text_generation(model_class): | |
| plot = get_text_generation_plots(model_class) | |
| table = get_text_generation_model_names(model_class) | |
| return plot, table | |
| # --- Build the Gradio Interface --- | |
| demo = gr.Blocks(css=""" | |
| .gr-dataframe table { | |
| table-layout: fixed; | |
| width: 100%; | |
| } | |
| .gr-dataframe th, .gr-dataframe td { | |
| max-width: 150px; | |
| white-space: nowrap; | |
| overflow: hidden; | |
| text-overflow: ellipsis; | |
| } | |
| """) | |
| with demo: | |
| gr.Markdown( | |
| """# AI Energy Score Leaderboard | |
| ### Welcome to the leaderboard for the [AI Energy Score Project!](https://huggingface.co/AIEnergyScore) | |
| Select different tasks to see scored models. Submit open models for testing and learn about testing proprietary models via the [submission portal](https://huggingface.co/spaces/AIEnergyScore/submission_portal)""" | |
| ) | |
| with gr.Tabs(): | |
| # --- Text Generation Tab with Dropdown for Model Class --- | |
| with gr.TabItem("Text Generation 💬"): | |
| # Dropdown moved above the plot and leaderboard | |
| model_class_dropdown = gr.Dropdown(choices=["A", "B", "C"], | |
| label="Select Model Class", | |
| value="A") | |
| with gr.Row(): | |
| with gr.Column(scale=1.3): | |
| tg_plot = gr.Plot(get_text_generation_plots("A")) | |
| with gr.Column(scale=1): | |
| tg_table = gr.Dataframe(get_text_generation_model_names("A"), datatype="markdown") | |
| # Update plot and table when the dropdown value changes | |
| model_class_dropdown.change(fn=update_text_generation, | |
| inputs=model_class_dropdown, | |
| outputs=[tg_plot, tg_table]) | |
| with gr.TabItem("Image Generation 📷"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| plot = gr.Plot(get_plots('image_generation.csv')) | |
| with gr.Column(): | |
| table = gr.Dataframe(get_model_names('image_generation.csv'), datatype="markdown") | |
| with gr.TabItem("Text Classification 🎭"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| plot = gr.Plot(get_plots('text_classification.csv')) | |
| with gr.Column(): | |
| table = gr.Dataframe(get_model_names('text_classification.csv'), datatype="markdown") | |
| with gr.TabItem("Image Classification 🖼️"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| plot = gr.Plot(get_plots('image_classification.csv')) | |
| with gr.Column(): | |
| table = gr.Dataframe(get_model_names('image_classification.csv'), datatype="markdown") | |
| with gr.TabItem("Image Captioning 📝"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| plot = gr.Plot(get_plots('image_captioning.csv')) | |
| with gr.Column(): | |
| table = gr.Dataframe(get_model_names('image_captioning.csv'), datatype="markdown") | |
| with gr.TabItem("Summarization 📃"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| plot = gr.Plot(get_plots('summarization.csv')) | |
| with gr.Column(): | |
| table = gr.Dataframe(get_model_names('summarization.csv'), datatype="markdown") | |
| with gr.TabItem("Automatic Speech Recognition 💬"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| plot = gr.Plot(get_plots('asr.csv')) | |
| with gr.Column(): | |
| table = gr.Dataframe(get_model_names('asr.csv'), datatype="markdown") | |
| with gr.TabItem("Object Detection 🚘"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| plot = gr.Plot(get_plots('object_detection.csv')) | |
| with gr.Column(): | |
| table = gr.Dataframe(get_model_names('object_detection.csv'), datatype="markdown") | |
| with gr.TabItem("Sentence Similarity 📚"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| plot = gr.Plot(get_plots('sentence_similarity.csv')) | |
| with gr.Column(): | |
| table = gr.Dataframe(get_model_names('sentence_similarity.csv'), datatype="markdown") | |
| with gr.TabItem("Extractive QA ❔"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| plot = gr.Plot(get_plots('question_answering.csv')) | |
| with gr.Column(): | |
| table = gr.Dataframe(get_model_names('question_answering.csv'), datatype="markdown") | |
| with gr.TabItem("All Tasks 💡"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| plot = gr.Plot(get_all_plots()) | |
| with gr.Column(): | |
| table = gr.Dataframe(get_all_model_names(), datatype="markdown") | |
| with gr.Accordion("📙 Citation", open=False): | |
| citation_button = gr.Textbox( | |
| value=CITATION_BUTTON_TEXT, | |
| label=CITATION_BUTTON_LABEL, | |
| elem_id="citation-button", | |
| lines=10, | |
| show_copy_button=True, | |
| ) | |
| gr.Markdown( | |
| """Last updated: February 2025""" | |
| ) | |
| demo.launch() |