reab5555 commited on
Commit
6642d96
·
verified ·
1 Parent(s): b0f4511

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -58
app.py CHANGED
@@ -36,77 +36,76 @@ def create_plots(df, feature_columns, target_column):
36
  # Add target to each feature set
37
  features = features + [target_column]
38
 
39
- # Create pair plot (scatter)
 
40
  if is_numeric_target:
41
- pair_plot = sns.pairplot(df[features], kind='scatter',
42
- plot_kws={'alpha': 0.6})
43
  norm = plt.Normalize(df[target_column].min(), df[target_column].max())
44
- for i, ax in enumerate(pair_plot.axes.flatten()):
45
- if i < len(features) * (len(features) - 1): # Check if it's not a diagonal plot
46
- x_var = features[i // len(features)]
47
- y_var = features[i % len(features)]
48
- if x_var != y_var:
49
- scatter = ax.scatter(df[x_var], df[y_var], c=df[target_column], cmap='viridis', norm=norm)
50
- fig = ax.figure
51
- fig.colorbar(scatter, ax=ax, label=target_column)
52
  else:
53
- pair_plot = sns.pairplot(df[features], hue=target_column, kind='scatter')
54
 
55
- pair_plot.fig.suptitle(f'Features Distributions - Group {group} (Scatter)', y=1.02)
56
 
57
  # Adjust label size and spacing
58
- for ax in pair_plot.axes.flatten():
59
- ax.tick_params(labelsize=8)
60
- ax.set_xlabel(ax.get_xlabel(), fontsize=10)
61
- ax.set_ylabel(ax.get_ylabel(), fontsize=10)
62
 
63
- pair_plot.fig.tight_layout()
64
- plt.subplots_adjust(top=0.95, bottom=0.1, left=0.1, right=0.9, hspace=0.5, wspace=0.5)
65
 
66
  buf = io.BytesIO()
67
- pair_plot.savefig(buf, format='png', dpi=300)
68
  buf.seek(0)
69
  plots.append(buf)
70
- plt.close(pair_plot.fig)
71
 
72
- # Create pair plot (histogram)
 
73
  if is_numeric_target:
74
- pair_plot_hist = sns.pairplot(df[features], kind='hist',
75
- plot_kws={'alpha': 0.6})
76
- for i, ax in enumerate(pair_plot_hist.axes.flatten()):
77
- x_var = features[i // len(features)]
78
- y_var = features[i % len(features)]
79
- if x_var == y_var:
80
  ax.clear()
81
- sns.histplot(df[x_var], ax=ax, kde=True)
82
  else:
83
- scatter = ax.scatter(df[x_var], df[y_var], c=df[target_column], cmap='viridis', norm=norm)
84
- fig = ax.figure
85
- fig.colorbar(scatter, ax=ax, label=target_column)
 
 
86
  else:
87
- pair_plot_hist = sns.pairplot(df[features], kind='hist', hue=target_column)
88
 
89
- pair_plot_hist.fig.suptitle(f'Features Distributions - Group {group} (Histogram)', y=1.02)
90
 
91
  # Adjust label size and spacing
92
- for ax in pair_plot_hist.axes.flatten():
93
- ax.tick_params(labelsize=8)
94
- ax.set_xlabel(ax.get_xlabel(), fontsize=10)
95
- ax.set_ylabel(ax.get_ylabel(), fontsize=10)
96
 
97
- pair_plot_hist.fig.tight_layout()
98
- plt.subplots_adjust(top=0.95, bottom=0.1, left=0.1, right=0.9, hspace=0.5, wspace=0.5)
99
 
100
  buf = io.BytesIO()
101
- pair_plot_hist.savefig(buf, format='png', dpi=300)
102
  buf.seek(0)
103
  plots.append(buf)
104
- plt.close(pair_plot_hist.fig)
105
 
106
- # Create regplot grid
107
  n_features = len(features) - 1 # Exclude target column
108
- fig, axes = plt.subplots(n_features, n_features, figsize=(20, 20))
109
- fig.suptitle(f'Regression Plots - Group {group}', y=1.02)
110
 
111
  for i, feature1 in enumerate(features[:-1]):
112
  for j, feature2 in enumerate(features[:-1]):
@@ -121,23 +120,22 @@ def create_plots(df, feature_columns, target_column):
121
  plt.colorbar(scatter, ax=ax, label=target_column)
122
  else:
123
  sns.regplot(x=feature1, y=feature2, data=df, ax=ax,
124
- line_kws={"color": "black"})
125
  else:
126
  sns.histplot(df[feature1], ax=ax, kde=True)
127
 
128
- ax.set_xlabel(feature1 if i == n_features - 1 else '')
129
- ax.set_ylabel(feature2 if j == 0 else '')
130
  ax.tick_params(labelsize=8)
131
- ax.set_title(f'{feature1} vs {feature2}', fontsize=10)
132
 
133
- fig.tight_layout()
134
- plt.subplots_adjust(top=0.95, bottom=0.1, left=0.1, right=0.9, hspace=0.5, wspace=0.5)
135
 
136
  buf = io.BytesIO()
137
  plt.savefig(buf, format='png', dpi=300)
138
  buf.seek(0)
139
  plots.append(buf)
140
- plt.close(fig)
141
 
142
  # Calculate Pearson correlation values
143
  correlation_matrix = df[feature_columns + [target_column]].corr()
@@ -213,16 +211,16 @@ with gr.Blocks() as iface:
213
  analyze_btn = gr.Button("Analyze")
214
 
215
  with gr.Row():
216
- plot1 = gr.Image(label="Pair Plot (Scatter) - Group 1")
217
- plot4 = gr.Image(label="Pair Plot (Scatter) - Group 2")
218
 
219
  with gr.Row():
220
- plot2 = gr.Image(label="Pair Plot (Histogram) - Group 1")
221
- plot5 = gr.Image(label="Pair Plot (Histogram) - Group 2")
222
 
223
  with gr.Row():
224
- plot3 = gr.Image(label="Regression Plot - Group 1")
225
- plot6 = gr.Image(label="Regression Plot - Group 2")
226
 
227
  with gr.Row():
228
  heatmap = gr.Image(label="Pearson Correlation Heatmap")
 
36
  # Add target to each feature set
37
  features = features + [target_column]
38
 
39
+ # Create scatter plot
40
+ plt.figure(figsize=(12, 10))
41
  if is_numeric_target:
42
+ scatter_plot = sns.pairplot(df[features], kind='scatter',
43
+ plot_kws={'alpha': 0.6}, corner=True)
44
  norm = plt.Normalize(df[target_column].min(), df[target_column].max())
45
+ for ax in scatter_plot.axes.flatten():
46
+ if ax.get_xlabel() != ax.get_ylabel():
47
+ scatter = ax.collections[0]
48
+ scatter.set_cmap('viridis')
49
+ scatter.set_norm(norm)
50
+ scatter.set_array(df[target_column])
51
+ plt.colorbar(scatter, ax=ax, label=target_column)
 
52
  else:
53
+ scatter_plot = sns.pairplot(df[features], hue=target_column, kind='scatter', corner=True)
54
 
55
+ scatter_plot.fig.suptitle(f'Scatter Plots - Group {group}', y=1.02, fontsize=16)
56
 
57
  # Adjust label size and spacing
58
+ for ax in scatter_plot.axes.flatten():
59
+ ax.tick_params(labelsize=10)
60
+ ax.set_xlabel(ax.get_xlabel(), fontsize=12)
61
+ ax.set_ylabel(ax.get_ylabel(), fontsize=12)
62
 
63
+ plt.tight_layout()
 
64
 
65
  buf = io.BytesIO()
66
+ plt.savefig(buf, format='png', dpi=300)
67
  buf.seek(0)
68
  plots.append(buf)
69
+ plt.close()
70
 
71
+ # Create histogram plot
72
+ plt.figure(figsize=(12, 10))
73
  if is_numeric_target:
74
+ hist_plot = sns.pairplot(df[features], kind='hist',
75
+ plot_kws={'alpha': 0.6}, corner=True)
76
+ for ax in hist_plot.axes.flatten():
77
+ if ax.get_xlabel() == ax.get_ylabel():
 
 
78
  ax.clear()
79
+ sns.histplot(df[ax.get_xlabel()], ax=ax, kde=True)
80
  else:
81
+ scatter = ax.collections[0]
82
+ scatter.set_cmap('viridis')
83
+ scatter.set_norm(norm)
84
+ scatter.set_array(df[target_column])
85
+ plt.colorbar(scatter, ax=ax, label=target_column)
86
  else:
87
+ hist_plot = sns.pairplot(df[features], kind='hist', hue=target_column, corner=True)
88
 
89
+ hist_plot.fig.suptitle(f'Histogram Plots - Group {group}', y=1.02, fontsize=16)
90
 
91
  # Adjust label size and spacing
92
+ for ax in hist_plot.axes.flatten():
93
+ ax.tick_params(labelsize=10)
94
+ ax.set_xlabel(ax.get_xlabel(), fontsize=12)
95
+ ax.set_ylabel(ax.get_ylabel(), fontsize=12)
96
 
97
+ plt.tight_layout()
 
98
 
99
  buf = io.BytesIO()
100
+ plt.savefig(buf, format='png', dpi=300)
101
  buf.seek(0)
102
  plots.append(buf)
103
+ plt.close()
104
 
105
+ # Create regression plot
106
  n_features = len(features) - 1 # Exclude target column
107
+ fig, axes = plt.subplots(n_features, n_features, figsize=(16, 14))
108
+ fig.suptitle(f'Regression Plots - Group {group}', y=1.02, fontsize=16)
109
 
110
  for i, feature1 in enumerate(features[:-1]):
111
  for j, feature2 in enumerate(features[:-1]):
 
120
  plt.colorbar(scatter, ax=ax, label=target_column)
121
  else:
122
  sns.regplot(x=feature1, y=feature2, data=df, ax=ax,
123
+ scatter_kws={'alpha': 0.6}, line_kws={'color': 'red'})
124
  else:
125
  sns.histplot(df[feature1], ax=ax, kde=True)
126
 
127
+ ax.set_xlabel(feature1, fontsize=10)
128
+ ax.set_ylabel(feature2, fontsize=10)
129
  ax.tick_params(labelsize=8)
130
+ ax.set_title(f'{feature1} vs {feature2}', fontsize=12)
131
 
132
+ plt.tight_layout()
 
133
 
134
  buf = io.BytesIO()
135
  plt.savefig(buf, format='png', dpi=300)
136
  buf.seek(0)
137
  plots.append(buf)
138
+ plt.close()
139
 
140
  # Calculate Pearson correlation values
141
  correlation_matrix = df[feature_columns + [target_column]].corr()
 
211
  analyze_btn = gr.Button("Analyze")
212
 
213
  with gr.Row():
214
+ plot1 = gr.Image(label="Scatter Plots - Group 1")
215
+ plot4 = gr.Image(label="Scatter Plots - Group 2")
216
 
217
  with gr.Row():
218
+ plot2 = gr.Image(label="Histogram Plots - Group 1")
219
+ plot5 = gr.Image(label="Histogram Plots - Group 2")
220
 
221
  with gr.Row():
222
+ plot3 = gr.Image(label="Regression Plots - Group 1")
223
+ plot6 = gr.Image(label="Regression Plots - Group 2")
224
 
225
  with gr.Row():
226
  heatmap = gr.Image(label="Pearson Correlation Heatmap")