Spaces:
Running
Running
Update src/vis_utils.py
Browse files- src/vis_utils.py +85 -44
src/vis_utils.py
CHANGED
|
@@ -17,9 +17,67 @@ from about import *
|
|
| 17 |
|
| 18 |
global data_component, filter_component
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def get_method_color(method):
|
| 21 |
return color_dict.get(method, 'black') # If method is not in color_dict, use black
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title):
|
| 24 |
df = pd.read_csv(CSV_RESULT_PATH)
|
| 25 |
# Filter the dataframe based on selected methods
|
|
@@ -64,50 +122,33 @@ def draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title):
|
|
| 64 |
|
| 65 |
return filename
|
| 66 |
|
| 67 |
-
def
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
return general_visualizer_plot(methods_selected, x_metric=x_metric, y_metric=y_metric)
|
| 71 |
-
elif benchmark_type == 'similarity':
|
| 72 |
-
title = f"{x_metric} vs {y_metric}"
|
| 73 |
-
return draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title)
|
| 74 |
-
elif benchmark_type == 'Benchmark 3':
|
| 75 |
-
return benchmark_3_plot(x_metric, y_metric)
|
| 76 |
-
elif benchmark_type == 'Benchmark 4':
|
| 77 |
-
return benchmark_4_plot(x_metric, y_metric)
|
| 78 |
-
else:
|
| 79 |
-
return "Invalid benchmark type selected."
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def get_baseline_df(selected_methods, selected_metrics):
|
| 83 |
-
df = pd.read_csv(CSV_RESULT_PATH)
|
| 84 |
-
present_columns = ["method_name"] + selected_metrics
|
| 85 |
-
df = df[df['method_name'].isin(selected_methods)][present_columns]
|
| 86 |
-
return df
|
| 87 |
-
|
| 88 |
-
def general_visualizer(methods_selected, x_metric, y_metric):
|
| 89 |
-
df = pd.read_csv(CSV_RESULT_PATH)
|
| 90 |
-
filtered_df = df[df['method_name'].isin(methods_selected)]
|
| 91 |
-
|
| 92 |
-
# Create a Seaborn lineplot with method as hue
|
| 93 |
-
plt.figure(figsize=(10, 8)) # Increase figure size
|
| 94 |
-
sns.lineplot(
|
| 95 |
-
data=filtered_df,
|
| 96 |
-
x=x_metric,
|
| 97 |
-
y=y_metric,
|
| 98 |
-
hue="method_name", # Different colors for different methods
|
| 99 |
-
marker="o", # Add markers to the line plot
|
| 100 |
-
)
|
| 101 |
|
| 102 |
-
#
|
| 103 |
-
|
| 104 |
-
plt.ylabel(y_metric)
|
| 105 |
-
plt.title(f'{y_metric} vs {x_metric} for selected methods')
|
| 106 |
-
plt.grid(True)
|
| 107 |
|
| 108 |
-
#
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
return
|
|
|
|
| 17 |
|
| 18 |
global data_component, filter_component
|
| 19 |
|
| 20 |
+
def get_baseline_df(selected_methods, selected_metrics):
|
| 21 |
+
df = pd.read_csv(CSV_RESULT_PATH)
|
| 22 |
+
present_columns = ["method_name"] + selected_metrics
|
| 23 |
+
df = df[df['method_name'].isin(selected_methods)][present_columns]
|
| 24 |
+
return df
|
| 25 |
+
|
| 26 |
def get_method_color(method):
|
| 27 |
return color_dict.get(method, 'black') # If method is not in color_dict, use black
|
| 28 |
|
| 29 |
+
def set_colors_and_marks_for_representation_groups(ax):
|
| 30 |
+
for label in ax.get_xticklabels():
|
| 31 |
+
text = label.get_text()
|
| 32 |
+
color = group_color_dict.get(text, 'black') # Default to black if label not in dict
|
| 33 |
+
label.set_color(color)
|
| 34 |
+
label.set_fontweight('bold')
|
| 35 |
+
|
| 36 |
+
# Add a caret symbol to specific labels
|
| 37 |
+
if text in {'MUT2VEC', 'PFAM', 'GENE2VEC', 'BERT-PFAM'}:
|
| 38 |
+
label.set_text(f"^ {text}")
|
| 39 |
+
|
| 40 |
+
def benchmark_plot(benchmark_type, methods_selected, x_metric, y_metric):
|
| 41 |
+
if benchmark_type == 'flexible':
|
| 42 |
+
# Use general visualizer logic
|
| 43 |
+
return general_visualizer_plot(methods_selected, x_metric=x_metric, y_metric=y_metric)
|
| 44 |
+
elif benchmark_type == 'similarity':
|
| 45 |
+
title = f"{x_metric} vs {y_metric}"
|
| 46 |
+
return draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title)
|
| 47 |
+
elif benchmark_type == 'Benchmark 3':
|
| 48 |
+
return benchmark_3_plot(x_metric, y_metric)
|
| 49 |
+
elif benchmark_type == 'Benchmark 4':
|
| 50 |
+
return benchmark_4_plot(x_metric, y_metric)
|
| 51 |
+
else:
|
| 52 |
+
return "Invalid benchmark type selected."
|
| 53 |
+
|
| 54 |
+
def general_visualizer(methods_selected, x_metric, y_metric):
|
| 55 |
+
df = pd.read_csv(CSV_RESULT_PATH)
|
| 56 |
+
filtered_df = df[df['method_name'].isin(methods_selected)]
|
| 57 |
+
|
| 58 |
+
# Create a Seaborn lineplot with method as hue
|
| 59 |
+
plt.figure(figsize=(10, 8)) # Increase figure size
|
| 60 |
+
sns.lineplot(
|
| 61 |
+
data=filtered_df,
|
| 62 |
+
x=x_metric,
|
| 63 |
+
y=y_metric,
|
| 64 |
+
hue="method_name", # Different colors for different methods
|
| 65 |
+
marker="o", # Add markers to the line plot
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Add labels and title
|
| 69 |
+
plt.xlabel(x_metric)
|
| 70 |
+
plt.ylabel(y_metric)
|
| 71 |
+
plt.title(f'{y_metric} vs {x_metric} for selected methods')
|
| 72 |
+
plt.grid(True)
|
| 73 |
+
|
| 74 |
+
# Save the plot to display it in Gradio
|
| 75 |
+
plot_path = "plot.png"
|
| 76 |
+
plt.savefig(plot_path)
|
| 77 |
+
plt.close()
|
| 78 |
+
|
| 79 |
+
return plot_path
|
| 80 |
+
|
| 81 |
def draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title):
|
| 82 |
df = pd.read_csv(CSV_RESULT_PATH)
|
| 83 |
# Filter the dataframe based on selected methods
|
|
|
|
| 122 |
|
| 123 |
return filename
|
| 124 |
|
| 125 |
+
def visualize_aspect_metric_clustermap(file_path, aspect, metric, method_names):
|
| 126 |
+
# Load data
|
| 127 |
+
df = pd.read_csv(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
+
# Filter for selected methods
|
| 130 |
+
df = df[df['Method'].isin(method_names)]
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
# Filter columns for specified aspect and metric
|
| 133 |
+
columns_to_plot = [col for col in df.columns if col.startswith(f"{aspect}_") and col.endswith(f"_{metric}")]
|
| 134 |
+
df = df[['Method'] + columns_to_plot]
|
| 135 |
+
df.set_index('Method', inplace=True)
|
| 136 |
+
|
| 137 |
+
# Create clustermap
|
| 138 |
+
g = sns.clustermap(df, annot=True, cmap="YlGnBu", row_cluster=False, col_cluster=False, figsize=(15, 15))
|
| 139 |
+
|
| 140 |
+
# Get heatmap axis and customize labels
|
| 141 |
+
ax = g.ax_heatmap
|
| 142 |
+
ax.set_xlabel("")
|
| 143 |
+
ax.set_ylabel("")
|
| 144 |
+
|
| 145 |
+
# Apply color and caret adjustments to x-axis labels
|
| 146 |
+
set_colors_and_marks_for_representation_groups(ax)
|
| 147 |
+
|
| 148 |
+
# Save the plot as an image
|
| 149 |
+
os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
|
| 150 |
+
filename = os.path.join(save_path, f"{aspect}_{metric}_heatmap.png")
|
| 151 |
+
plt.savefig(filename, dpi=400, bbox_inches='tight')
|
| 152 |
+
plt.close() # Close the plot to free memory
|
| 153 |
|
| 154 |
+
return filename
|