Spaces:
Running
Running
| import os | |
| from plotly import graph_objects as go | |
| import pandas as pd | |
| ## Evaluation Graphs | |
| # Load the data | |
| all_eval_results = {} | |
| for fname in os.listdir("data/txt360_eval"): | |
| if fname.endswith(".csv"): | |
| metric_name = fname.replace("CKPT Eval - ", "").replace(".csv", "") | |
| all_eval_results[metric_name] = {} | |
| # with open(os.path.join("data/txt360_eval", fname)) as f: | |
| df = pd.read_csv(os.path.join("data/txt360_eval", fname)) | |
| # slimpajama_res = df.iloc[2:, 2].astype(float).fillna(0.0) # slimpajama | |
| fineweb_res = df.iloc[2:, 1].astype(float).fillna(method="bfill") # fineweb | |
| txt360_base = df.iloc[2:, 2].astype(float).fillna(method="bfill") # txt360-dedup-only | |
| txt360_web_up = df.iloc[2:, 3].astype(float).fillna(method="bfill") # txt360-web-only-upsampled | |
| txt360_all_up_stack = df.iloc[2:, 4].astype(float).fillna(method="bfill") # txt360-all-upsampled + stackv2 | |
| # each row is 20B tokens. | |
| # all_eval_results[metric_name]["slimpajama"] = slimpajama_res | |
| all_eval_results[metric_name]["fineweb"] = fineweb_res | |
| all_eval_results[metric_name]["txt360-dedup-only"] = txt360_base | |
| all_eval_results[metric_name]["txt360-web-only-upsampled"] = txt360_web_up | |
| all_eval_results[metric_name]["txt360-all-upsampled + stackv2"] = txt360_all_up_stack | |
| all_eval_results[metric_name]["token"] = [20 * i for i in range(len(fineweb_res))] | |
| # Eval Result Plots | |
| all_eval_res_figs = {} | |
| for metric_name, res in all_eval_results.items(): | |
| fig_res = go.Figure() | |
| # Add lines | |
| fig_res.add_trace(go.Scatter( | |
| x=all_eval_results[metric_name]["token"], | |
| y=all_eval_results[metric_name]["fineweb"], | |
| mode='lines', name='FineWeb' | |
| )) | |
| fig_res.add_trace(go.Scatter( | |
| x=all_eval_results[metric_name]["token"], | |
| y=all_eval_results[metric_name]["txt360-web-only-upsampled"], | |
| mode='lines', name='TxT360 - CC Data Upsampled' | |
| )) | |
| fig_res.add_trace(go.Scatter( | |
| x=all_eval_results[metric_name]["token"], | |
| y=all_eval_results[metric_name]["txt360-dedup-only"], | |
| mode='lines', name='TxT360 - CC Data Dedup' | |
| )) | |
| fig_res.add_trace(go.Scatter( | |
| x=all_eval_results[metric_name]["token"], | |
| y=all_eval_results[metric_name]["txt360-all-upsampled + stackv2"], | |
| mode='lines', name='TxT360 - Full Upsampled + Stack V2' | |
| )) | |
| # Update layout | |
| fig_res.update_layout( | |
| title=f"{metric_name} Performance", | |
| title_x=0.5, # Centers the title | |
| xaxis_title="Billion Tokens", | |
| yaxis_title=metric_name, | |
| legend_title="Dataset", | |
| ) | |
| all_eval_res_figs[metric_name] = fig_res |