Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Commit 
							
							·
						
						201434d
	
1
								Parent(s):
							
							6a00f6c
								
add app.py
Browse files- app.py +133 -0
- requirements.txt +2 -0
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,133 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import pandas as pd
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            MODEL_MAPPINGS = {
         | 
| 5 | 
            +
                "gpt-4o-2024-05-13": "GPT-4o",
         | 
| 6 | 
            +
                "gpt-4-0613": "GPT-4",
         | 
| 7 | 
            +
                "gpt-4-turbo-2024-04-09": "GPT-4 Turbo",
         | 
| 8 | 
            +
                "gpt-4-0125-preview": "GPT-4 Preview",
         | 
| 9 | 
            +
                "gpt-3.5": "GPT-3.5",
         | 
| 10 | 
            +
                "gpt-3.5-turbo-0125": "GPT-3.5 Turbo",
         | 
| 11 | 
            +
                "claude-3-opus-20240229": "Claude-3 O",
         | 
| 12 | 
            +
                "claude-3-sonnet-20240229": "Claude-3 S",
         | 
| 13 | 
            +
                "claude-3-haiku-20240307": "Claude-3 H",
         | 
| 14 | 
            +
                "claude-3-5-sonnet-20240620": "Claude-3.5 S",
         | 
| 15 | 
            +
                "llama-2-70b-chat": "Llama-2 70b",
         | 
| 16 | 
            +
                "llama-2-13b-chat": "Llama-2 13b",
         | 
| 17 | 
            +
                "llama-2-7b-chat": "Llama-2 7b",
         | 
| 18 | 
            +
                "llama-3-8b-chat": "Llama-3 8b",
         | 
| 19 | 
            +
                "llama-3-70b-chat": "Llama-3 70b",
         | 
| 20 | 
            +
                "codellama-70b-instruct": "Codellama 70b",
         | 
| 21 | 
            +
                "mistral-large-2402": "Mistral Large",
         | 
| 22 | 
            +
                "mistral-medium-2312": "Mistral Medium",
         | 
| 23 | 
            +
                "open-mixtral-8x22b-instruct-v0.1": "Mixtral 8x22b",
         | 
| 24 | 
            +
                "open-mixtral-8x7b-instruct": "Mixtral 8x7b",
         | 
| 25 | 
            +
                "open-mistral-7b-instruct": "Mistral 7b",
         | 
| 26 | 
            +
                "open-mistral-7b": "Mistral 7b",
         | 
| 27 | 
            +
                "open-mixtral-8x22b": "Mixtral 8x22b",
         | 
| 28 | 
            +
                "open-mixtral-8x7b": "Mixtral 8x7b",
         | 
| 29 | 
            +
                "open-mistral-7b-instruct-v0.1": "Mistral 7b",
         | 
| 30 | 
            +
                "dbrx-instruct": "DBRX",
         | 
| 31 | 
            +
                "command-r-plus": "Command R Plus",
         | 
| 32 | 
            +
                "gemma-7b-it": "Gemma 7b",
         | 
| 33 | 
            +
                "gemma-2b-it": "Gemma 2b",
         | 
| 34 | 
            +
                "gemini-1.5-pro-latest": "Gemini 1.5",
         | 
| 35 | 
            +
                "gemini-pro": "Gemini 1.0",
         | 
| 36 | 
            +
                "qwen1.5-7b-chat": "Qwen 1.5 7b",
         | 
| 37 | 
            +
                "qwen1.5-14b-chat": "Qwen 1.5 14b",
         | 
| 38 | 
            +
                "qwen1.5-32b-chat": "Qwen 1.5 32b",
         | 
| 39 | 
            +
                "qwen1.5-72b-chat": "Qwen 1.5 72b",
         | 
| 40 | 
            +
                "qwen1.5-0.5b-chat": "Qwen 1.5 0.5b",
         | 
| 41 | 
            +
                "qwen1.5-1.8b-chat": "Qwen 1.5 1.8b",
         | 
| 42 | 
            +
                "qwen2-72b-instruct": "Qwen 2 72b",
         | 
| 43 | 
            +
                "codestral-2405": "Codestral"
         | 
| 44 | 
            +
            }
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            df = pd.read_json('responses_data/responses.jsonl', lines=True)
         | 
| 47 | 
            +
            df['model'] = df['model'].map(MODEL_MAPPINGS)
         | 
| 48 | 
            +
            df['prompt'] = df[['prompt', 'prompt_id']].apply(lambda x: f"{x['prompt']} [{x['prompt_id']}]", axis=1)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            model_list = df['model'].unique()
         | 
| 52 | 
            +
            prompt_id_list = list(df['prompt'].unique())
         | 
| 53 | 
            +
            prompt_id_list = sorted(prompt_id_list, key=lambda x: int(x.split('[')[1].split(']')[0]))
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def response(num_responses, model, correct, prompt_ids):
         | 
| 57 | 
            +
                responses = df
         | 
| 58 | 
            +
                if model:
         | 
| 59 | 
            +
                    responses = responses[responses['model'].isin(model)]
         | 
| 60 | 
            +
                if correct:
         | 
| 61 | 
            +
                    responses = responses[responses['correct'].isin(correct)]
         | 
| 62 | 
            +
                if prompt_ids:
         | 
| 63 | 
            +
                    responses = responses[responses['prompt'].isin(prompt_ids)]
         | 
| 64 | 
            +
                if num_responses > len(responses):
         | 
| 65 | 
            +
                    num_responses = len(responses)
         | 
| 66 | 
            +
                return responses.sample(num_responses)[['model', 'prompt', 'model_response', 'correct']]
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def barplot_for_prompt_id(prompt_ids, models):
         | 
| 70 | 
            +
                responses = df
         | 
| 71 | 
            +
                if prompt_ids:
         | 
| 72 | 
            +
                    responses = responses[responses['prompt'].isin(prompt_ids)]
         | 
| 73 | 
            +
                if models:
         | 
| 74 | 
            +
                    responses = responses[responses['model'].isin(models)]
         | 
| 75 | 
            +
                means = responses.groupby(['model', 'prompt_id'])['correct'].mean()
         | 
| 76 | 
            +
                means = means.reset_index()
         | 
| 77 | 
            +
                means['prompt_id'] = means['prompt_id'].astype(str)
         | 
| 78 | 
            +
                prompt_ids = list(set([p for p in means['prompt_id']]))
         | 
| 79 | 
            +
                prompt_ids_str = ', '.join(prompt_ids)
         | 
| 80 | 
            +
                return gr.BarPlot(
         | 
| 81 | 
            +
                    means,
         | 
| 82 | 
            +
                    x='prompt_id',
         | 
| 83 | 
            +
                    y='correct',
         | 
| 84 | 
            +
                    group='model',
         | 
| 85 | 
            +
                    color='prompt_id',
         | 
| 86 | 
            +
                    group_title="",
         | 
| 87 | 
            +
                    title=f'Correctness for Prompt IDs: {prompt_ids_str}',
         | 
| 88 | 
            +
                    x_title="",
         | 
| 89 | 
            +
                )
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            title= "🎩🐇 Alice in Wonderland: Simple Tasks Showing Complete Reasoning Breakdown in State-Of-the-Art Large Language Models"
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            with gr.Blocks() as demo:
         | 
| 94 | 
            +
                with gr.Row(elem_id="header-row"):
         | 
| 95 | 
            +
                    gr.HTML(
         | 
| 96 | 
            +
                        f"""<h1 style='font-size: 30px; font-weight: bold; text-align: center;'>{title}</h1>
         | 
| 97 | 
            +
                        <h4 align="center"><a href="https://marianna13.github.io/aiw/" target="_blank">🌐Homepage</a> | <a href="https://arxiv.org/pdf/2406.02061" target="_blank"> 📝Paper</a> | <a href="https://github.com/LAION-AI/AIW"target="_blank">🛠️Code</a></h4>
         | 
| 98 | 
            +
                        <p style='color: #000000; font-size: 20px; text-align: center;'>This demo shows the responses of different models to a set of prompts. The responses are categorized as correct or incorrect. You can choose the number of responses, the model, the correctness of the responses, and the prompt IDs to see the responses.</p>
         | 
| 99 | 
            +
                        <p style='color: #000000; font-size: 20px; text-align: center;'>You can also see the correctness of the responses for different prompt IDs using the robustness plot tab.</p>
         | 
| 100 | 
            +
                        """
         | 
| 101 | 
            +
                    )
         | 
| 102 | 
            +
                with gr.Tab("Responses"):
         | 
| 103 | 
            +
                    gr.Interface(
         | 
| 104 | 
            +
                        response,
         | 
| 105 | 
            +
                        [
         | 
| 106 | 
            +
                            gr.Slider(2, 20, value=4, label="Number of responses", info="Choose between 2 and 20"),
         | 
| 107 | 
            +
                            gr.Dropdown(
         | 
| 108 | 
            +
                                list(model_list), label="Model", info="Choose to see responses", multiselect=True
         | 
| 109 | 
            +
                            ),
         | 
| 110 | 
            +
                            gr.CheckboxGroup([("Correct", True), ("Incorrect", False)], label="Correct or not", info="Choose to see correct or incorrect responses"),
         | 
| 111 | 
            +
                            gr.Dropdown(
         | 
| 112 | 
            +
                                prompt_id_list, multiselect=True, label="Prompt IDs", info="Choose to see responses for a specific prompt ID(s)"
         | 
| 113 | 
            +
                            ),
         | 
| 114 | 
            +
                        ],
         | 
| 115 | 
            +
                        gr.DataFrame(type="pandas", wrap=True, label="Responses"),
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    )
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                with gr.Tab("Robustness plot"):
         | 
| 120 | 
            +
                    gr.Interface(
         | 
| 121 | 
            +
                        barplot_for_prompt_id,
         | 
| 122 | 
            +
                       [
         | 
| 123 | 
            +
                           gr.Dropdown(
         | 
| 124 | 
            +
                            prompt_id_list, multiselect=True, label="Prompt IDs", info="Choose to see responses for a specific prompt ID(s)"
         | 
| 125 | 
            +
                        ),
         | 
| 126 | 
            +
                        gr.Dropdown(
         | 
| 127 | 
            +
                            list(model_list), label="Model", info="Choose to see responses", multiselect=True
         | 
| 128 | 
            +
                        )],
         | 
| 129 | 
            +
                        gr.BarPlot( title="Correctness for Prompt IDs"),
         | 
| 130 | 
            +
                    )
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            if __name__ == "__main__":
         | 
| 133 | 
            +
                demo.launch()
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            gradio
         | 
| 2 | 
            +
            pandas
         |