Spaces:
Build error
Build error
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from sklearn.ensemble import RandomForestClassifier, VotingClassifier | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.naive_bayes import GaussianNB | |
| def choose_model(model): | |
| if model == "Logistic Regression": | |
| return LogisticRegression(max_iter=1000, random_state=123) | |
| elif model == "Random Forest": | |
| return RandomForestClassifier(n_estimators=100, random_state=123) | |
| elif model == "Gaussian Naive Bayes": | |
| return GaussianNB() | |
| else: | |
| raise ValueError("Model is not supported.") | |
| def get_proba_plots( | |
| model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight | |
| ): | |
| clf1 = choose_model(model_1) | |
| clf2 = choose_model(model_2) | |
| clf3 = choose_model(model_3) | |
| X = np.array([[-1.0, -1.0], [-1.2, -1.4], [-3.4, -2.2], [1.1, 1.2]]) | |
| y = np.array([1, 1, 2, 2]) | |
| eclf = VotingClassifier( | |
| estimators=[("clf1", clf1), ("clf2", clf2), ("clf3", clf3)], | |
| voting="soft", | |
| weights=[model_1_weight, model_2_weight, model_3_weight], | |
| ) | |
| # predict class probabilities for all classifiers | |
| probas = [c.fit(X, y).predict_proba(X) for c in (clf1, clf2, clf3, eclf)] | |
| # get class probabilities for the first sample in the dataset | |
| class1_1 = [pr[0, 0] for pr in probas] | |
| class2_1 = [pr[0, 1] for pr in probas] | |
| # plotting | |
| N = 4 # number of groups | |
| ind = np.arange(N) # group positions | |
| width = 0.35 # bar width | |
| fig, ax = plt.subplots() | |
| # bars for classifier 1-3 | |
| p1 = ax.bar( | |
| ind, np.hstack(([class1_1[:-1], [0]])), width, color="green", edgecolor="k" | |
| ) | |
| p2 = ax.bar( | |
| ind + width, | |
| np.hstack(([class2_1[:-1], [0]])), | |
| width, | |
| color="lightgreen", | |
| edgecolor="k", | |
| ) | |
| # bars for VotingClassifier | |
| ax.bar(ind, [0, 0, 0, class1_1[-1]], width, color="blue", edgecolor="k") | |
| ax.bar( | |
| ind + width, [0, 0, 0, class2_1[-1]], width, color="steelblue", edgecolor="k" | |
| ) | |
| # plot annotations | |
| plt.axvline(2.8, color="k", linestyle="dashed") | |
| ax.set_xticks(ind + width) | |
| ax.set_xticklabels( | |
| [ | |
| f"{model_1}\nweight {model_1_weight}", | |
| f"{model_2}\nweight {model_2_weight}", | |
| f"{model_3}\nweight {model_3_weight}", | |
| "VotingClassifier\n(average probabilities)", | |
| ], | |
| rotation=40, | |
| ha="right", | |
| ) | |
| plt.ylim([0, 1]) | |
| plt.title("Class probabilities for sample 1 by different classifiers") | |
| plt.legend([p1[0], p2[0]], ["class 1", "class 2"], loc="upper left") | |
| plt.tight_layout() | |
| plt.show() | |
| return fig | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Class probabilities by the `VotingClassifier` | |
| This space shows the effect of the weight of different classifiers when using sklearn's [VotingClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.VotingClassifier.html#sklearn.ensemble.VotingClassifier). | |
| For example, suppose you set the weights as in the table below, and the models have the following predicted probabilities: | |
| | | Weights | Predicted Probabilities | | |
| |---------|:-------:|:----------------:| | |
| | Model 1 | 1 | 0.5 | | |
| | Model 2 | 2 | 0.8 | | |
| | Model 3 | 5 | 0.9 | | |
| The predicted probability by the `VotingClassifier` will be $(1*0.5 + 2*0.8 + 5*0.9) / (1 + 2 + 5)$ | |
| You can experiment with different model types and weights and see their effect on the VotingClassifier's prediction. | |
| This space is based on [sklearn’s original demo](https://scikit-learn.org/stable/auto_examples/ensemble/plot_voting_probas.html#sphx-glr-auto-examples-ensemble-plot-voting-probas-py). | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| model_1 = gr.Dropdown( | |
| [ | |
| "Logistic Regression", | |
| "Random Forest", | |
| "Gaussian Naive Bayes", | |
| ], | |
| label="Model 1", | |
| value="Logistic Regression", | |
| ) | |
| model_1_weight = gr.Slider( | |
| value=1, label="Model 1 Weight", minimum=0, maximum=10, step=1 | |
| ) | |
| with gr.Row(): | |
| model_2 = gr.Dropdown( | |
| [ | |
| "Logistic Regression", | |
| "Random Forest", | |
| "Gaussian Naive Bayes", | |
| ], | |
| label="Model 2", | |
| value="Random Forest", | |
| ) | |
| model_2_weight = gr.Slider( | |
| value=1, label="Model 2 Weight", minimum=0, maximum=10, step=1 | |
| ) | |
| with gr.Row(): | |
| model_3 = gr.Dropdown( | |
| [ | |
| "Logistic Regression", | |
| "Random Forest", | |
| "Gaussian Naive Bayes", | |
| ], | |
| label="Model 3", | |
| value="Gaussian Naive Bayes", | |
| ) | |
| model_3_weight = gr.Slider( | |
| value=5, label="Model 3 Weight", minimum=0, maximum=10, step=1 | |
| ) | |
| with gr.Column(scale=4): | |
| proba_plots = gr.Plot() | |
| model_1.change( | |
| get_proba_plots, | |
| [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], | |
| proba_plots, | |
| queue=False, | |
| ) | |
| model_2.change( | |
| get_proba_plots, | |
| [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], | |
| proba_plots, | |
| queue=False, | |
| ) | |
| model_3.change( | |
| get_proba_plots, | |
| [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], | |
| proba_plots, | |
| queue=False, | |
| ) | |
| model_1_weight.change( | |
| get_proba_plots, | |
| [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], | |
| proba_plots, | |
| queue=False, | |
| ) | |
| model_2_weight.change( | |
| get_proba_plots, | |
| [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], | |
| proba_plots, | |
| queue=False, | |
| ) | |
| model_3_weight.change( | |
| get_proba_plots, | |
| [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], | |
| proba_plots, | |
| queue=False, | |
| ) | |
| demo.load( | |
| get_proba_plots, | |
| [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], | |
| proba_plots, | |
| queue=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |