File size: 6,416 Bytes
7ae6a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b04e9fe
7ae6a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b04e9fe
 
 
7ae6a8e
b04e9fe
7ae6a8e
 
 
 
 
 
d398f50
7ae6a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d398f50
7ae6a8e
 
 
458422b
7ae6a8e
 
 
 
 
 
 
d398f50
7ae6a8e
 
b04e9fe
 
7ae6a8e
 
458422b
7ae6a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b04e9fe
 
7ae6a8e
 
 
 
9046777
7ae6a8e
 
 
4333385
7ae6a8e
 
 
 
 
 
 
 
 
 
 
 
b04e9fe
7ae6a8e
 
b04e9fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import kstest, norm, expon, lognorm, chi2, beta, gamma, uniform, pareto, cauchy, t, \
    weibull_min, laplace, logistic, burr, invgamma, invgauss, gompertz, triang, loglaplace, levy, gumbel_r, gumbel_l, \
    rayleigh, powerlaw
import warnings
import tempfile

# Suppress specific runtime warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

# Function to check distribution type
def check_distribution(target_column, show_histogram):
    data = target_column.dropna()

    # Distribution dictionaries
    distributions = {
        'Normal': norm,
        'Exponential': expon,
        'Lognormal': lognorm,
        'Chi-square': chi2,
        'Beta': beta,
        'Gamma': gamma,
        'Uniform': uniform,
        'Pareto': pareto,
        'Cauchy': cauchy,
        'Student\'s t': t,
        'Weibull': weibull_min,
        'Laplace': laplace,
        'Logistic': logistic,
        'Burr': burr,
        'Inverse Gamma': invgamma,
        'Inverse Gaussian': invgauss,
        'Gompertz': gompertz,
        'Triangular': triang,
        'Log-Laplace': loglaplace,
        'Levy': levy,
        'Gumbel Right': gumbel_r,
        'Gumbel Left': gumbel_l,
        'Rayleigh': rayleigh,
        'Powerlaw': powerlaw
    }

    # Colors for top 3 distributions and the normal distribution
    distribution_colors = ['gold', 'silver', 'brown']
    normal_color = 'red'
    actual_data_color = 'black'

    # List to store results
    results = []

    # Test each distribution with error handling
    for name, distribution in distributions.items():
        try:
            params = distribution.fit(data)
            D, p_value = kstest(data, distribution.cdf, args=params)
            results.append((name, p_value, params))
        except Exception as e:
            print(f"Skipping {name} distribution due to an error: {e}")

    # Sort by p-value and get top 3
    results.sort(key=lambda x: x[1], reverse=True)
    top_3_results = results[:3]

    # Create a single plot for all distributions
    plt.figure(figsize=(12, 8), dpi=400)

    # Plot the original data distribution as a histogram if show_histogram is True
    if show_histogram:
        sns.histplot(data, kde=False, stat="density", bins=50, color=actual_data_color, label='Actual Data Distribution')

    # Always plot the KDE line
    sns.kdeplot(data, color=actual_data_color, lw=2, label='Actual Data Distribution Line')

    # Overlay the top 3 best fit distributions
    for i, (name, p_value, params) in enumerate(top_3_results):
        best_fit_data = np.linspace(min(data), max(data), 1000)
        pdf = distributions[name].pdf(best_fit_data, *params)
        p_value_text = "<0.001" if p_value < 0.001 else f"{p_value:.5f}"
        plt.plot(best_fit_data, pdf, color=distribution_colors[i], lw=2, label=f'{name} Fit (p-value={p_value_text})')

    plt.title("Top 3 Best Fit Distributions Overlaid")
    plt.legend()

    # Save plot to temporary file
    temp_file_fit = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
    plt.savefig(temp_file_fit.name)
    plt.close()

    best_fit_plot = temp_file_fit.name

    # Prepare result text with top 3 distributions and p-values
    result_text = "Top 3 best matched distributions:\n"
    for i, (name, p_value, _) in enumerate(top_3_results):
        p_value_text = "<0.001" if p_value < 0.001 else f"{p_value:.5f}"
        result_text += f"Top {i + 1}: {name} with a p-value of {p_value_text}\n"

    # Add disclaimer about p-value significance
    result_text += "\nKolmogorov-Smirnov Test - A significant p-value (below 0.05) indicates that the distribution does not conform well to the actual data distribution."

    # Normal distribution comparison
    mean, std = norm.fit(data)
    normal_best_fit_data = np.linspace(min(data), max(data), 1000)
    normal_pdf = norm.pdf(normal_best_fit_data, mean, std)
    ks_stat, normal_p_value = kstest(data, 'norm', args=(mean, std))

    p_value_text = "<0.001" if normal_p_value < 0.001 else f"{normal_p_value:.5f}"

    plt.figure(figsize=(12, 8), dpi=400)
    if show_histogram:
        sns.histplot(data, kde=False, stat="density", bins=50, color=actual_data_color, label='Actual Data Distribution')
    sns.kdeplot(data, color=actual_data_color, lw=2, label='Actual Data Distribution Line')
    plt.plot(normal_best_fit_data, normal_pdf, color=normal_color, lw=2, label=f'Normal Fit (p-value={p_value_text})')
    plt.title("Comparison with Normal Distribution (Kolmogorov-Smirnov)")
    plt.legend()

    # Save plot to temporary file
    temp_file_normal = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
    plt.savefig(temp_file_normal.name)
    plt.close()

    normal_comparison_plot = temp_file_normal.name

    return result_text, best_fit_plot, normal_comparison_plot

# Function to load the CSV file and extract numeric column names
def load_file(file):
    df = pd.read_csv(file.name)
    numeric_columns = df.select_dtypes(include=[np.number]).columns.tolist()
    return gr.update(choices=numeric_columns), df

# Function to analyze the selected column
def analyze_column(selected_column, df, show_histogram):
    result_text, best_fit_plot, normal_comparison_plot = check_distribution(df[selected_column], show_histogram)
    return result_text, best_fit_plot, normal_comparison_plot

# Define the Gradio app layout
with gr.Blocks() as demo:
    gr.Markdown("# Data Distribution Fit\n")

    file_input = gr.File(label="Upload CSV File")
    column_selector = gr.Dropdown(label="Select Target Column", choices=[])
    show_histogram = gr.Checkbox(label="Show Bars", value=True)
    analyze_button = gr.Button("Fit")
    output_text = gr.Textbox(label="Results")
    best_fit_plot_output = gr.Image(label="Best Fit Distributions")
    normal_comparison_output = gr.Image(label="Comparison with Normal Distribution")

    # State management
    df_state = gr.State(None)

    # Load the file and populate the dropdown
    file_input.upload(load_file, inputs=file_input, outputs=[column_selector, df_state])

    # Perform analysis on the selected column
    analyze_button.click(analyze_column, inputs=[column_selector, df_state, show_histogram],
                         outputs=[output_text, best_fit_plot_output, normal_comparison_output])

demo.launch()