Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
from scipy.stats import kstest, norm, expon, lognorm, chi2, beta, gamma, uniform, pareto, cauchy, t, \
|
| 7 |
+
weibull_min, laplace, logistic, burr, invgamma, invgauss, gompertz, triang, loglaplace, levy, gumbel_r, gumbel_l, \
|
| 8 |
+
rayleigh, powerlaw
|
| 9 |
+
import warnings
|
| 10 |
+
import tempfile
|
| 11 |
+
|
| 12 |
+
# Suppress specific runtime warnings
|
| 13 |
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Function to check distribution type
|
| 17 |
+
def check_distribution(target_column):
|
| 18 |
+
data = target_column.dropna()
|
| 19 |
+
|
| 20 |
+
# Distribution dictionaries
|
| 21 |
+
distributions = {
|
| 22 |
+
'Normal': norm,
|
| 23 |
+
'Exponential': expon,
|
| 24 |
+
'Lognormal': lognorm,
|
| 25 |
+
'Chi-square': chi2,
|
| 26 |
+
'Beta': beta,
|
| 27 |
+
'Gamma': gamma,
|
| 28 |
+
'Uniform': uniform,
|
| 29 |
+
'Pareto': pareto,
|
| 30 |
+
'Cauchy': cauchy,
|
| 31 |
+
'Student\'s t': t,
|
| 32 |
+
'Weibull': weibull_min,
|
| 33 |
+
'Laplace': laplace,
|
| 34 |
+
'Logistic': logistic,
|
| 35 |
+
'Burr': burr,
|
| 36 |
+
'Inverse Gamma': invgamma,
|
| 37 |
+
'Inverse Gaussian': invgauss,
|
| 38 |
+
'Gompertz': gompertz,
|
| 39 |
+
'Triangular': triang,
|
| 40 |
+
'Log-Laplace': loglaplace,
|
| 41 |
+
'Levy': levy,
|
| 42 |
+
'Gumbel Right': gumbel_r,
|
| 43 |
+
'Gumbel Left': gumbel_l,
|
| 44 |
+
'Rayleigh': rayleigh,
|
| 45 |
+
'Powerlaw': powerlaw
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
# Colors for top 3 distributions and the normal distribution
|
| 49 |
+
distribution_colors = ['gold', 'silver', 'brown']
|
| 50 |
+
normal_color = 'red'
|
| 51 |
+
actual_data_color = 'black'
|
| 52 |
+
|
| 53 |
+
# List to store results
|
| 54 |
+
results = []
|
| 55 |
+
|
| 56 |
+
# Test each distribution with error handling
|
| 57 |
+
for name, distribution in distributions.items():
|
| 58 |
+
try:
|
| 59 |
+
params = distribution.fit(data)
|
| 60 |
+
D, p_value = kstest(data, distribution.cdf, args=params)
|
| 61 |
+
results.append((name, p_value, params))
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"Skipping {name} distribution due to an error: {e}")
|
| 64 |
+
|
| 65 |
+
# Sort by p-value and get top 3
|
| 66 |
+
results.sort(key=lambda x: x[1], reverse=True)
|
| 67 |
+
top_3_results = results[:3]
|
| 68 |
+
|
| 69 |
+
# Create a single plot for all distributions
|
| 70 |
+
plt.figure(figsize=(12, 8), dpi=400)
|
| 71 |
+
|
| 72 |
+
# Plot the original data distribution as a histogram
|
| 73 |
+
sns.histplot(data, kde=False, stat="density", bins=30, color=actual_data_color, label='Actual Data Distribution')
|
| 74 |
+
|
| 75 |
+
# Overlay the actual data KDE line
|
| 76 |
+
sns.kdeplot(data, color=actual_data_color, lw=2, label='Actual Data Distribution Line')
|
| 77 |
+
|
| 78 |
+
# Overlay the top 3 best fit distributions
|
| 79 |
+
for i, (name, p_value, params) in enumerate(top_3_results):
|
| 80 |
+
best_fit_data = np.linspace(min(data), max(data), 1000)
|
| 81 |
+
pdf = distributions[name].pdf(best_fit_data, *params)
|
| 82 |
+
p_value_text = "<0.0001" if p_value < 0.0001 else f"{p_value:.12f}"
|
| 83 |
+
plt.plot(best_fit_data, pdf, color=distribution_colors[i], lw=2, label=f'{name} Fit (p-value={p_value_text})')
|
| 84 |
+
|
| 85 |
+
plt.title("Top 3 Best Fit Distributions Overlaid")
|
| 86 |
+
plt.legend()
|
| 87 |
+
|
| 88 |
+
# Save plot to temporary file
|
| 89 |
+
temp_file_fit = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
|
| 90 |
+
plt.savefig(temp_file_fit.name)
|
| 91 |
+
plt.close()
|
| 92 |
+
|
| 93 |
+
best_fit_plot = temp_file_fit.name
|
| 94 |
+
|
| 95 |
+
# Prepare result text with top 3 distributions and p-values
|
| 96 |
+
result_text = "Top 3 best matched distributions:\n"
|
| 97 |
+
for i, (name, p_value, _) in enumerate(top_3_results):
|
| 98 |
+
p_value_text = "<0.0001" if p_value < 0.0001 else f"{p_value:.12f}"
|
| 99 |
+
result_text += f"Top {i + 1}: {name} with a p-value of {p_value_text}\n"
|
| 100 |
+
|
| 101 |
+
# Add disclaimer about p-value significance
|
| 102 |
+
result_text += "\nDisclaimer: A significant p-value (below 0.05) indicates that the distribution does not conform well to the actual data distribution."
|
| 103 |
+
|
| 104 |
+
# Normal distribution comparison
|
| 105 |
+
mean, std = norm.fit(data)
|
| 106 |
+
normal_best_fit_data = np.linspace(min(data), max(data), 1000)
|
| 107 |
+
normal_pdf = norm.pdf(normal_best_fit_data, mean, std)
|
| 108 |
+
ks_stat, normal_p_value = kstest(data, 'norm', args=(mean, std))
|
| 109 |
+
|
| 110 |
+
p_value_text = "<0.0001" if normal_p_value < 0.0001 else f"{normal_p_value:.12f}"
|
| 111 |
+
|
| 112 |
+
plt.figure(figsize=(12, 8), dpi=400)
|
| 113 |
+
sns.histplot(data, kde=False, stat="density", bins=30, color=actual_data_color, label='Actual Data Distribution')
|
| 114 |
+
sns.kdeplot(data, color=actual_data_color, lw=2, label='Actual Data Distribution Line')
|
| 115 |
+
plt.plot(normal_best_fit_data, normal_pdf, color=normal_color, lw=2, label=f'Normal Fit (p-value={p_value_text})')
|
| 116 |
+
plt.title("Comparison with Normal Distribution")
|
| 117 |
+
plt.legend()
|
| 118 |
+
|
| 119 |
+
# Save plot to temporary file
|
| 120 |
+
temp_file_normal = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
|
| 121 |
+
plt.savefig(temp_file_normal.name)
|
| 122 |
+
plt.close()
|
| 123 |
+
|
| 124 |
+
normal_comparison_plot = temp_file_normal.name
|
| 125 |
+
|
| 126 |
+
return result_text, best_fit_plot, normal_comparison_plot
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# Function to load the CSV file and extract numeric column names
|
| 130 |
+
def load_file(file):
|
| 131 |
+
df = pd.read_csv(file.name)
|
| 132 |
+
numeric_columns = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 133 |
+
return gr.update(choices=numeric_columns), df
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# Function to analyze the selected column
|
| 137 |
+
def analyze_column(selected_column, df):
|
| 138 |
+
result_text, best_fit_plot, normal_comparison_plot = check_distribution(df[selected_column])
|
| 139 |
+
return result_text, best_fit_plot, normal_comparison_plot
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# Define the Gradio app layout
|
| 143 |
+
with gr.Blocks() as demo:
|
| 144 |
+
gr.Markdown("# Distribution Fitting\n")
|
| 145 |
+
|
| 146 |
+
file_input = gr.File(label="Upload CSV File")
|
| 147 |
+
column_selector = gr.Dropdown(label="Select Target Column", choices=[])
|
| 148 |
+
analyze_button = gr.Button("Fit")
|
| 149 |
+
output_text = gr.Textbox(label="Results")
|
| 150 |
+
best_fit_plot_output = gr.Image(label="Best Fit Distributions")
|
| 151 |
+
normal_comparison_output = gr.Image(label="Comparison with Normal Distribution")
|
| 152 |
+
|
| 153 |
+
# State management
|
| 154 |
+
df_state = gr.State(None)
|
| 155 |
+
|
| 156 |
+
# Load the file and populate the dropdown
|
| 157 |
+
file_input.upload(load_file, inputs=file_input, outputs=[column_selector, df_state])
|
| 158 |
+
|
| 159 |
+
# Perform analysis on the selected column
|
| 160 |
+
analyze_button.click(analyze_column, inputs=[column_selector, df_state],
|
| 161 |
+
outputs=[output_text, best_fit_plot_output, normal_comparison_output])
|
| 162 |
+
|
| 163 |
+
demo.launch()
|