reab5555 commited on
Commit
7ae6a8e
·
verified ·
1 Parent(s): c04659b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ from scipy.stats import kstest, norm, expon, lognorm, chi2, beta, gamma, uniform, pareto, cauchy, t, \
7
+ weibull_min, laplace, logistic, burr, invgamma, invgauss, gompertz, triang, loglaplace, levy, gumbel_r, gumbel_l, \
8
+ rayleigh, powerlaw
9
+ import warnings
10
+ import tempfile
11
+
12
+ # Suppress specific runtime warnings
13
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
14
+
15
+
16
+ # Function to check distribution type
17
+ def check_distribution(target_column):
18
+ data = target_column.dropna()
19
+
20
+ # Distribution dictionaries
21
+ distributions = {
22
+ 'Normal': norm,
23
+ 'Exponential': expon,
24
+ 'Lognormal': lognorm,
25
+ 'Chi-square': chi2,
26
+ 'Beta': beta,
27
+ 'Gamma': gamma,
28
+ 'Uniform': uniform,
29
+ 'Pareto': pareto,
30
+ 'Cauchy': cauchy,
31
+ 'Student\'s t': t,
32
+ 'Weibull': weibull_min,
33
+ 'Laplace': laplace,
34
+ 'Logistic': logistic,
35
+ 'Burr': burr,
36
+ 'Inverse Gamma': invgamma,
37
+ 'Inverse Gaussian': invgauss,
38
+ 'Gompertz': gompertz,
39
+ 'Triangular': triang,
40
+ 'Log-Laplace': loglaplace,
41
+ 'Levy': levy,
42
+ 'Gumbel Right': gumbel_r,
43
+ 'Gumbel Left': gumbel_l,
44
+ 'Rayleigh': rayleigh,
45
+ 'Powerlaw': powerlaw
46
+ }
47
+
48
+ # Colors for top 3 distributions and the normal distribution
49
+ distribution_colors = ['gold', 'silver', 'brown']
50
+ normal_color = 'red'
51
+ actual_data_color = 'black'
52
+
53
+ # List to store results
54
+ results = []
55
+
56
+ # Test each distribution with error handling
57
+ for name, distribution in distributions.items():
58
+ try:
59
+ params = distribution.fit(data)
60
+ D, p_value = kstest(data, distribution.cdf, args=params)
61
+ results.append((name, p_value, params))
62
+ except Exception as e:
63
+ print(f"Skipping {name} distribution due to an error: {e}")
64
+
65
+ # Sort by p-value and get top 3
66
+ results.sort(key=lambda x: x[1], reverse=True)
67
+ top_3_results = results[:3]
68
+
69
+ # Create a single plot for all distributions
70
+ plt.figure(figsize=(12, 8), dpi=400)
71
+
72
+ # Plot the original data distribution as a histogram
73
+ sns.histplot(data, kde=False, stat="density", bins=30, color=actual_data_color, label='Actual Data Distribution')
74
+
75
+ # Overlay the actual data KDE line
76
+ sns.kdeplot(data, color=actual_data_color, lw=2, label='Actual Data Distribution Line')
77
+
78
+ # Overlay the top 3 best fit distributions
79
+ for i, (name, p_value, params) in enumerate(top_3_results):
80
+ best_fit_data = np.linspace(min(data), max(data), 1000)
81
+ pdf = distributions[name].pdf(best_fit_data, *params)
82
+ p_value_text = "<0.0001" if p_value < 0.0001 else f"{p_value:.12f}"
83
+ plt.plot(best_fit_data, pdf, color=distribution_colors[i], lw=2, label=f'{name} Fit (p-value={p_value_text})')
84
+
85
+ plt.title("Top 3 Best Fit Distributions Overlaid")
86
+ plt.legend()
87
+
88
+ # Save plot to temporary file
89
+ temp_file_fit = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
90
+ plt.savefig(temp_file_fit.name)
91
+ plt.close()
92
+
93
+ best_fit_plot = temp_file_fit.name
94
+
95
+ # Prepare result text with top 3 distributions and p-values
96
+ result_text = "Top 3 best matched distributions:\n"
97
+ for i, (name, p_value, _) in enumerate(top_3_results):
98
+ p_value_text = "<0.0001" if p_value < 0.0001 else f"{p_value:.12f}"
99
+ result_text += f"Top {i + 1}: {name} with a p-value of {p_value_text}\n"
100
+
101
+ # Add disclaimer about p-value significance
102
+ result_text += "\nDisclaimer: A significant p-value (below 0.05) indicates that the distribution does not conform well to the actual data distribution."
103
+
104
+ # Normal distribution comparison
105
+ mean, std = norm.fit(data)
106
+ normal_best_fit_data = np.linspace(min(data), max(data), 1000)
107
+ normal_pdf = norm.pdf(normal_best_fit_data, mean, std)
108
+ ks_stat, normal_p_value = kstest(data, 'norm', args=(mean, std))
109
+
110
+ p_value_text = "<0.0001" if normal_p_value < 0.0001 else f"{normal_p_value:.12f}"
111
+
112
+ plt.figure(figsize=(12, 8), dpi=400)
113
+ sns.histplot(data, kde=False, stat="density", bins=30, color=actual_data_color, label='Actual Data Distribution')
114
+ sns.kdeplot(data, color=actual_data_color, lw=2, label='Actual Data Distribution Line')
115
+ plt.plot(normal_best_fit_data, normal_pdf, color=normal_color, lw=2, label=f'Normal Fit (p-value={p_value_text})')
116
+ plt.title("Comparison with Normal Distribution")
117
+ plt.legend()
118
+
119
+ # Save plot to temporary file
120
+ temp_file_normal = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
121
+ plt.savefig(temp_file_normal.name)
122
+ plt.close()
123
+
124
+ normal_comparison_plot = temp_file_normal.name
125
+
126
+ return result_text, best_fit_plot, normal_comparison_plot
127
+
128
+
129
+ # Function to load the CSV file and extract numeric column names
130
+ def load_file(file):
131
+ df = pd.read_csv(file.name)
132
+ numeric_columns = df.select_dtypes(include=[np.number]).columns.tolist()
133
+ return gr.update(choices=numeric_columns), df
134
+
135
+
136
+ # Function to analyze the selected column
137
+ def analyze_column(selected_column, df):
138
+ result_text, best_fit_plot, normal_comparison_plot = check_distribution(df[selected_column])
139
+ return result_text, best_fit_plot, normal_comparison_plot
140
+
141
+
142
+ # Define the Gradio app layout
143
+ with gr.Blocks() as demo:
144
+ gr.Markdown("# Distribution Fitting\n")
145
+
146
+ file_input = gr.File(label="Upload CSV File")
147
+ column_selector = gr.Dropdown(label="Select Target Column", choices=[])
148
+ analyze_button = gr.Button("Fit")
149
+ output_text = gr.Textbox(label="Results")
150
+ best_fit_plot_output = gr.Image(label="Best Fit Distributions")
151
+ normal_comparison_output = gr.Image(label="Comparison with Normal Distribution")
152
+
153
+ # State management
154
+ df_state = gr.State(None)
155
+
156
+ # Load the file and populate the dropdown
157
+ file_input.upload(load_file, inputs=file_input, outputs=[column_selector, df_state])
158
+
159
+ # Perform analysis on the selected column
160
+ analyze_button.click(analyze_column, inputs=[column_selector, df_state],
161
+ outputs=[output_text, best_fit_plot_output, normal_comparison_output])
162
+
163
+ demo.launch()