reab5555 commited on
Commit
083effd
·
verified ·
1 Parent(s): fb63a56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -96
app.py CHANGED
@@ -38,96 +38,133 @@ def create_plots(df, feature_columns, target_column):
38
 
39
  # Create scatter plot
40
  plt.figure(figsize=(12, 10))
41
- if is_numeric_target:
42
- scatter_plot = sns.pairplot(df[features], kind='scatter',
43
- plot_kws={'alpha': 0.6}, corner=True)
44
- norm = plt.Normalize(df[target_column].min(), df[target_column].max())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  for ax in scatter_plot.axes.flatten():
46
- if ax.get_xlabel() != ax.get_ylabel():
47
- scatter = ax.collections[0]
48
- scatter.set_cmap('viridis')
49
- scatter.set_norm(norm)
50
- scatter.set_array(df[target_column])
51
- plt.colorbar(scatter, ax=ax, label=target_column)
52
- else:
53
- scatter_plot = sns.pairplot(df[features], hue=target_column, kind='scatter', corner=True)
54
-
55
- scatter_plot.fig.suptitle(f'Scatter Plots - Group {group}', y=1.02, fontsize=16)
56
-
57
- # Adjust label size and spacing
58
- for ax in scatter_plot.axes.flatten():
59
- ax.tick_params(labelsize=10)
60
- ax.set_xlabel(ax.get_xlabel(), fontsize=12)
61
- ax.set_ylabel(ax.get_ylabel(), fontsize=12)
62
-
63
- plt.tight_layout()
64
-
65
- buf = io.BytesIO()
66
- plt.savefig(buf, format='png', dpi=300)
67
- buf.seek(0)
68
- plots.append(buf)
69
- plt.close()
70
 
71
  # Create histogram plot
72
  plt.figure(figsize=(12, 10))
73
- if is_numeric_target:
74
- hist_plot = sns.pairplot(df[features], kind='hist',
75
- plot_kws={'alpha': 0.6}, corner=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  for ax in hist_plot.axes.flatten():
77
- if ax.get_xlabel() == ax.get_ylabel():
78
- ax.clear()
79
- sns.histplot(df[ax.get_xlabel()], ax=ax, kde=True)
80
- else:
81
- scatter = ax.collections[0]
82
- scatter.set_cmap('viridis')
83
- scatter.set_norm(norm)
84
- scatter.set_array(df[target_column])
85
- plt.colorbar(scatter, ax=ax, label=target_column)
86
- else:
87
- hist_plot = sns.pairplot(df[features], kind='hist', hue=target_column, corner=True)
88
-
89
- hist_plot.fig.suptitle(f'Histogram Plots - Group {group}', y=1.02, fontsize=16)
90
-
91
- # Adjust label size and spacing
92
- for ax in hist_plot.axes.flatten():
93
- ax.tick_params(labelsize=10)
94
- ax.set_xlabel(ax.get_xlabel(), fontsize=12)
95
- ax.set_ylabel(ax.get_ylabel(), fontsize=12)
96
-
97
- plt.tight_layout()
98
-
99
- buf = io.BytesIO()
100
- plt.savefig(buf, format='png', dpi=300)
101
- buf.seek(0)
102
- plots.append(buf)
103
- plt.close()
104
 
105
  # Create regression plot
106
  n_features = len(features) - 1 # Exclude target column
107
  fig, axes = plt.subplots(n_features, n_features, figsize=(16, 14))
108
  fig.suptitle(f'Regression Plots - Group {group}', y=1.02, fontsize=16)
109
 
110
- for i, feature1 in enumerate(features[:-1]):
111
- for j, feature2 in enumerate(features[:-1]):
112
- if n_features == 1:
113
- ax = axes
114
- else:
115
- ax = axes[i, j]
116
- if i != j:
117
- if is_numeric_target:
118
- scatter = ax.scatter(df[feature1], df[feature2], c=df[target_column],
119
- cmap='viridis', alpha=0.6)
120
- plt.colorbar(scatter, ax=ax, label=target_column)
 
 
 
 
121
  else:
122
- sns.regplot(x=feature1, y=feature2, data=df, ax=ax,
123
- scatter_kws={'alpha': 0.6}, line_kws={'color': 'red'})
124
- else:
125
- sns.histplot(df[feature1], ax=ax, kde=True)
126
-
127
- ax.set_xlabel(feature1, fontsize=10)
128
- ax.set_ylabel(feature2, fontsize=10)
129
- ax.tick_params(labelsize=8)
130
- ax.set_title(f'{feature1} vs {feature2}', fontsize=12)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  plt.tight_layout()
133
 
@@ -135,25 +172,10 @@ def create_plots(df, feature_columns, target_column):
135
  plt.savefig(buf, format='png', dpi=300)
136
  buf.seek(0)
137
  plots.append(buf)
 
 
 
138
  plt.close()
139
-
140
- # Calculate Pearson correlation values
141
- correlation_matrix = df[feature_columns + [target_column]].corr()
142
-
143
- # Create a heatmap of Pearson correlation values
144
- plt.figure(figsize=(12, 10))
145
- heatmap = sns.heatmap(correlation_matrix, annot=True, fmt='.2f', cmap='coolwarm', square=True, cbar_kws={'shrink': .8})
146
- heatmap.set_title('Pearson Correlation Heatmap', fontsize=16)
147
- plt.xticks(rotation=45, ha='right', fontsize=10)
148
- plt.yticks(fontsize=10)
149
-
150
- plt.tight_layout()
151
-
152
- buf = io.BytesIO()
153
- plt.savefig(buf, format='png', dpi=300)
154
- buf.seek(0)
155
- plots.append(buf)
156
- plt.close()
157
 
158
  except Exception as e:
159
  print(f"Error in create_plots: {str(e)}")
 
38
 
39
  # Create scatter plot
40
  plt.figure(figsize=(12, 10))
41
+ try:
42
+ if is_numeric_target:
43
+ scatter_plot = sns.pairplot(df[features], kind='scatter',
44
+ plot_kws={'alpha': 0.6}, corner=True)
45
+ norm = plt.Normalize(df[target_column].min(), df[target_column].max())
46
+ for ax in scatter_plot.axes.flatten():
47
+ if ax.get_xlabel() != ax.get_ylabel() and ax.get_xlabel() is not None:
48
+ if len(ax.collections) > 0:
49
+ scatter = ax.collections[0]
50
+ scatter.set_cmap('viridis')
51
+ scatter.set_norm(norm)
52
+ scatter.set_array(df[target_column])
53
+ plt.colorbar(scatter, ax=ax, label=target_column)
54
+ else:
55
+ scatter_plot = sns.pairplot(df[features], hue=target_column, kind='scatter', corner=True)
56
+
57
+ scatter_plot.fig.suptitle(f'Scatter Plots - Group {group}', y=1.02, fontsize=16)
58
+
59
+ # Adjust label size and spacing
60
  for ax in scatter_plot.axes.flatten():
61
+ ax.tick_params(labelsize=10)
62
+ if ax.get_xlabel():
63
+ ax.set_xlabel(ax.get_xlabel(), fontsize=12)
64
+ if ax.get_ylabel():
65
+ ax.set_ylabel(ax.get_ylabel(), fontsize=12)
66
+
67
+ plt.tight_layout()
68
+
69
+ buf = io.BytesIO()
70
+ plt.savefig(buf, format='png', dpi=300)
71
+ buf.seek(0)
72
+ plots.append(buf)
73
+ except Exception as e:
74
+ print(f"Error in scatter plot for group {group}: {str(e)}")
75
+ finally:
76
+ plt.close()
 
 
 
 
 
 
 
 
77
 
78
  # Create histogram plot
79
  plt.figure(figsize=(12, 10))
80
+ try:
81
+ if is_numeric_target:
82
+ hist_plot = sns.pairplot(df[features], kind='hist',
83
+ plot_kws={'alpha': 0.6}, corner=True)
84
+ for ax in hist_plot.axes.flatten():
85
+ if ax.get_xlabel() == ax.get_ylabel() and ax.get_xlabel() is not None:
86
+ ax.clear()
87
+ sns.histplot(df[ax.get_xlabel()], ax=ax, kde=True)
88
+ elif ax.get_xlabel() is not None and ax.get_ylabel() is not None:
89
+ if len(ax.collections) > 0:
90
+ scatter = ax.collections[0]
91
+ scatter.set_cmap('viridis')
92
+ scatter.set_norm(norm)
93
+ scatter.set_array(df[target_column])
94
+ plt.colorbar(scatter, ax=ax, label=target_column)
95
+ else:
96
+ hist_plot = sns.pairplot(df[features], kind='hist', hue=target_column, corner=True)
97
+
98
+ hist_plot.fig.suptitle(f'Histogram Plots - Group {group}', y=1.02, fontsize=16)
99
+
100
+ # Adjust label size and spacing
101
  for ax in hist_plot.axes.flatten():
102
+ ax.tick_params(labelsize=10)
103
+ if ax.get_xlabel():
104
+ ax.set_xlabel(ax.get_xlabel(), fontsize=12)
105
+ if ax.get_ylabel():
106
+ ax.set_ylabel(ax.get_ylabel(), fontsize=12)
107
+
108
+ plt.tight_layout()
109
+
110
+ buf = io.BytesIO()
111
+ plt.savefig(buf, format='png', dpi=300)
112
+ buf.seek(0)
113
+ plots.append(buf)
114
+ except Exception as e:
115
+ print(f"Error in histogram plot for group {group}: {str(e)}")
116
+ finally:
117
+ plt.close()
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # Create regression plot
120
  n_features = len(features) - 1 # Exclude target column
121
  fig, axes = plt.subplots(n_features, n_features, figsize=(16, 14))
122
  fig.suptitle(f'Regression Plots - Group {group}', y=1.02, fontsize=16)
123
 
124
+ try:
125
+ for i, feature1 in enumerate(features[:-1]):
126
+ for j, feature2 in enumerate(features[:-1]):
127
+ if n_features == 1:
128
+ ax = axes
129
+ else:
130
+ ax = axes[i, j]
131
+ if i != j:
132
+ if is_numeric_target:
133
+ scatter = ax.scatter(df[feature1], df[feature2], c=df[target_column],
134
+ cmap='viridis', alpha=0.6)
135
+ plt.colorbar(scatter, ax=ax, label=target_column)
136
+ else:
137
+ sns.regplot(x=feature1, y=feature2, data=df, ax=ax,
138
+ scatter_kws={'alpha': 0.6}, line_kws={'color': 'red'})
139
  else:
140
+ sns.histplot(df[feature1], ax=ax, kde=True)
141
+
142
+ ax.set_xlabel(feature1, fontsize=10)
143
+ ax.set_ylabel(feature2, fontsize=10)
144
+ ax.tick_params(labelsize=8)
145
+ ax.set_title(f'{feature1} vs {feature2}', fontsize=12)
146
+
147
+ plt.tight_layout()
148
+
149
+ buf = io.BytesIO()
150
+ plt.savefig(buf, format='png', dpi=300)
151
+ buf.seek(0)
152
+ plots.append(buf)
153
+ except Exception as e:
154
+ print(f"Error in regression plot for group {group}: {str(e)}")
155
+ finally:
156
+ plt.close()
157
+
158
+ # Calculate Pearson correlation values
159
+ correlation_matrix = df[feature_columns + [target_column]].corr()
160
+
161
+ # Create a heatmap of Pearson correlation values
162
+ plt.figure(figsize=(12, 10))
163
+ try:
164
+ heatmap = sns.heatmap(correlation_matrix, annot=True, fmt='.2f', cmap='coolwarm', square=True, cbar_kws={'shrink': .8})
165
+ heatmap.set_title('Pearson Correlation Heatmap', fontsize=16)
166
+ plt.xticks(rotation=45, ha='right', fontsize=10)
167
+ plt.yticks(fontsize=10)
168
 
169
  plt.tight_layout()
170
 
 
172
  plt.savefig(buf, format='png', dpi=300)
173
  buf.seek(0)
174
  plots.append(buf)
175
+ except Exception as e:
176
+ print(f"Error in correlation heatmap: {str(e)}")
177
+ finally:
178
  plt.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  except Exception as e:
181
  print(f"Error in create_plots: {str(e)}")