Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Commit 
							
							·
						
						c94926c
	
1
								Parent(s):
							
							65e4811
								
Preserve scroll when re-generating plots
Browse files- app.py +46 -8
- css_html_js.py +13 -0
- logo.png +0 -0
- main.py +0 -6
- pyproject.toml +1 -0
    	
        app.py
    CHANGED
    
    | @@ -2,7 +2,7 @@ import json | |
| 2 | 
             
            import pandas as pd
         | 
| 3 | 
             
            import gradio as gr
         | 
| 4 | 
             
            from gradio_leaderboard import Leaderboard, ColumnFilter, SelectColumns
         | 
| 5 | 
            -
            from css_html_js import custom_css
         | 
| 6 | 
             
            from parse import read_json, read_data
         | 
| 7 | 
             
            from utils import model_hyperlink, filter_RTLRepo, filter_bench, handle_special_cases
         | 
| 8 | 
             
            from typing import Union
         | 
| @@ -45,6 +45,7 @@ def generate_scatter_plot(benchmark, metric): | |
| 45 | 
             
                scatter_data['y'] = scatter_data[metric]
         | 
| 46 | 
             
                scatter_data['size'] = (scatter_data['x'] ** 0.3) * 40
         | 
| 47 |  | 
|  | |
| 48 | 
             
                type_colors = {"General": "green", "Coding": "yellow", "RTL-Specific": "blue"}
         | 
| 49 | 
             
                scatter_data['color'] = scatter_data['Model Type'].map(type_colors).fillna('gray')
         | 
| 50 |  | 
| @@ -78,9 +79,26 @@ def generate_scatter_plot(benchmark, metric): | |
| 78 |  | 
| 79 | 
             
                return fig
         | 
| 80 |  | 
| 81 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 82 | 
             
                df, benchmarks, metrics, default_metric = read_data()
         | 
| 83 | 
            -
                gr.Markdown("""# TuRTLe 🐢 Model Leaderboard
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 84 | 
             
                Welcome to the TuRTLe Model Leaderboard! Use the filters below to explore different RTL benchmarks and models.
         | 
| 85 | 
             
                [GitHub Repository](https://github.com/https://github.com/HPAI-BSC) | [arXiv Preprint](https://arxiv.org/) | [How to submit](https://github.com/https://github.com/HPAI-BSC)<br/>
         | 
| 86 | 
             
                Contact us: hpai@bsc.es
         | 
| @@ -147,14 +165,34 @@ with gr.Blocks(css=custom_css) as app: | |
| 147 | 
             
                bubble_benchmark.change(
         | 
| 148 | 
             
                    fn=on_benchmark_change, 
         | 
| 149 | 
             
                    inputs=[bubble_benchmark, bubble_metric],
         | 
| 150 | 
            -
                    outputs=[bubble_metric, scatter_plot]
         | 
| 151 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 152 |  | 
| 153 | 
             
                bubble_metric.change(
         | 
| 154 | 
             
                    fn=on_metric_change,
         | 
| 155 | 
             
                    inputs=[bubble_benchmark, bubble_metric],
         | 
| 156 | 
            -
                    outputs=[bubble_benchmark, scatter_plot]
         | 
| 157 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 158 |  | 
| 159 |  | 
| 160 | 
            -
            app.launch()
         | 
|  | |
| 2 | 
             
            import pandas as pd
         | 
| 3 | 
             
            import gradio as gr
         | 
| 4 | 
             
            from gradio_leaderboard import Leaderboard, ColumnFilter, SelectColumns
         | 
| 5 | 
            +
            from css_html_js import custom_css, trigger_plot
         | 
| 6 | 
             
            from parse import read_json, read_data
         | 
| 7 | 
             
            from utils import model_hyperlink, filter_RTLRepo, filter_bench, handle_special_cases
         | 
| 8 | 
             
            from typing import Union
         | 
|  | |
| 45 | 
             
                scatter_data['y'] = scatter_data[metric]
         | 
| 46 | 
             
                scatter_data['size'] = (scatter_data['x'] ** 0.3) * 40
         | 
| 47 |  | 
| 48 | 
            +
                
         | 
| 49 | 
             
                type_colors = {"General": "green", "Coding": "yellow", "RTL-Specific": "blue"}
         | 
| 50 | 
             
                scatter_data['color'] = scatter_data['Model Type'].map(type_colors).fillna('gray')
         | 
| 51 |  | 
|  | |
| 79 |  | 
| 80 | 
             
                return fig
         | 
| 81 |  | 
| 82 | 
            +
            js_func = """
         | 
| 83 | 
            +
            function refresh() {
         | 
| 84 | 
            +
                const url = new URL(window.location);
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                if (url.searchParams.get('__theme') !== 'light') {
         | 
| 87 | 
            +
                    url.searchParams.set('__theme', 'light');
         | 
| 88 | 
            +
                    window.location.href = url.href;
         | 
| 89 | 
            +
                }
         | 
| 90 | 
            +
            }
         | 
| 91 | 
            +
            """
         | 
| 92 | 
            +
                    
         | 
| 93 | 
            +
            with gr.Blocks(css=custom_css, js=js_func) as app:
         | 
| 94 | 
             
                df, benchmarks, metrics, default_metric = read_data()
         | 
| 95 | 
            +
                gr.Markdown("""# TuRTLe 🐢 Model Leaderboard""")
         | 
| 96 | 
            +
                gr.HTML("""
         | 
| 97 | 
            +
                <p align="center">
         | 
| 98 | 
            +
                    <img src='/gradio_api/file=logo.png' alt='TuRTLe Logo' width='220'/> <br/>
         | 
| 99 | 
            +
                </p>
         | 
| 100 | 
            +
                """)
         | 
| 101 | 
            +
                gr.Markdown("""
         | 
| 102 | 
             
                Welcome to the TuRTLe Model Leaderboard! Use the filters below to explore different RTL benchmarks and models.
         | 
| 103 | 
             
                [GitHub Repository](https://github.com/https://github.com/HPAI-BSC) | [arXiv Preprint](https://arxiv.org/) | [How to submit](https://github.com/https://github.com/HPAI-BSC)<br/>
         | 
| 104 | 
             
                Contact us: hpai@bsc.es
         | 
|  | |
| 165 | 
             
                bubble_benchmark.change(
         | 
| 166 | 
             
                    fn=on_benchmark_change, 
         | 
| 167 | 
             
                    inputs=[bubble_benchmark, bubble_metric],
         | 
| 168 | 
            +
                    outputs=[bubble_metric, scatter_plot],
         | 
| 169 | 
            +
                    js=""" // this is to avoid resetting user scroll each time a plot is re-generated
         | 
| 170 | 
            +
                    (benchmark, metric) => {
         | 
| 171 | 
            +
                        let scrollY = window.scrollY;  
         | 
| 172 | 
            +
                        const observer = new MutationObserver(() => {
         | 
| 173 | 
            +
                            window.scrollTo(0, scrollY);
         | 
| 174 | 
            +
                            observer.disconnect();
         | 
| 175 | 
            +
                        });
         | 
| 176 | 
            +
                        observer.observe(document.getElementById('full-width-plot'), { childList: true });
         | 
| 177 | 
            +
                        return [benchmark, metric];  
         | 
| 178 | 
            +
                    }
         | 
| 179 | 
            +
                    """)
         | 
| 180 |  | 
| 181 | 
             
                bubble_metric.change(
         | 
| 182 | 
             
                    fn=on_metric_change,
         | 
| 183 | 
             
                    inputs=[bubble_benchmark, bubble_metric],
         | 
| 184 | 
            +
                    outputs=[bubble_benchmark, scatter_plot],
         | 
| 185 | 
            +
                    js=""" // this is to avoid resetting user scroll each time a plot is re-generated
         | 
| 186 | 
            +
                    (benchmark, metric) => {
         | 
| 187 | 
            +
                        let scrollY = window.scrollY;  
         | 
| 188 | 
            +
                        const observer = new MutationObserver(() => {
         | 
| 189 | 
            +
                            window.scrollTo(0, scrollY);
         | 
| 190 | 
            +
                            observer.disconnect();
         | 
| 191 | 
            +
                        });
         | 
| 192 | 
            +
                        observer.observe(document.getElementById('full-width-plot'), { childList: true });
         | 
| 193 | 
            +
                        return [benchmark, metric];  
         | 
| 194 | 
            +
                    }
         | 
| 195 | 
            +
                    """)
         | 
| 196 |  | 
| 197 |  | 
| 198 | 
            +
            app.launch(allowed_paths=["logo.png"])
         | 
    	
        css_html_js.py
    CHANGED
    
    | @@ -1,4 +1,7 @@ | |
| 1 | 
             
            custom_css = """
         | 
|  | |
|  | |
|  | |
| 2 | 
             
            #component-0 {
         | 
| 3 | 
             
                width: 75vw;
         | 
| 4 | 
             
                margin: 0 auto;
         | 
| @@ -108,3 +111,13 @@ get_window_url_params = """ | |
| 108 | 
             
                    return url_params;
         | 
| 109 | 
             
                }
         | 
| 110 | 
             
                """
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            custom_css = """
         | 
| 2 | 
            +
            #component-1 {
         | 
| 3 | 
            +
                text-align: center;
         | 
| 4 | 
            +
            }
         | 
| 5 | 
             
            #component-0 {
         | 
| 6 | 
             
                width: 75vw;
         | 
| 7 | 
             
                margin: 0 auto;
         | 
|  | |
| 111 | 
             
                    return url_params;
         | 
| 112 | 
             
                }
         | 
| 113 | 
             
                """
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            trigger_plot = """
         | 
| 116 | 
            +
            window.scrollY_before_update = window.scrollY; // Store scroll position
         | 
| 117 | 
            +
            console.log("Saved ScrollY:", window.scrollY_before_update);
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            setTimeout(function() {
         | 
| 120 | 
            +
                console.log("Restoring ScrollY:", window.scrollY_before_update);
         | 
| 121 | 
            +
                window.scrollTo(0, window.scrollY_before_update);
         | 
| 122 | 
            +
            }, 50);
         | 
| 123 | 
            +
            """
         | 
    	
        logo.png
    ADDED
    
    |   | 
    	
        main.py
    DELETED
    
    | @@ -1,6 +0,0 @@ | |
| 1 | 
            -
            def main():
         | 
| 2 | 
            -
                print("Hello from tortuga!")
         | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
            if __name__ == "__main__":
         | 
| 6 | 
            -
                main()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        pyproject.toml
    CHANGED
    
    | @@ -8,4 +8,5 @@ dependencies = [ | |
| 8 | 
             
                "gradio>=5.21.0",
         | 
| 9 | 
             
                "gradio-leaderboard>=0.0.13",
         | 
| 10 | 
             
                "pandas>=2.2.3",
         | 
|  | |
| 11 | 
             
            ]
         | 
|  | |
| 8 | 
             
                "gradio>=5.21.0",
         | 
| 9 | 
             
                "gradio-leaderboard>=0.0.13",
         | 
| 10 | 
             
                "pandas>=2.2.3",
         | 
| 11 | 
            +
                "plotly>=6.0.1",
         | 
| 12 | 
             
            ]
         | 

