| from time import time | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import plotly.graph_objects as go | |
| from sklearn import manifold, datasets | |
| from sklearn.cluster import AgglomerativeClustering | |
| SEED = 0 | |
| digits = datasets.load_digits() | |
| X, y = digits.data, digits.target | |
| n_samples, n_features = X.shape | |
| np.random.seed(SEED) | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| def plot_clustering(linkage, dim): | |
| if dim == '3D': | |
| X_red = manifold.SpectralEmbedding(n_components=3).fit_transform(X) | |
| else: | |
| X_red = manifold.SpectralEmbedding(n_components=2).fit_transform(X) | |
| clustering = AgglomerativeClustering(linkage=linkage, n_clusters=10) | |
| t0 = time() | |
| clustering.fit(X_red) | |
| print("%s :\t%.2fs" % (linkage, time() - t0)) | |
| labels = clustering.labels_ | |
| x_min, x_max = np.min(X_red, axis=0), np.max(X_red, axis=0) | |
| X_red = (X_red - x_min) / (x_max - x_min) | |
| fig = go.Figure() | |
| for digit in digits.target_names: | |
| subset = X_red[y==digit] | |
| rgbas = plt.cm.nipy_spectral(labels[y == digit]/10) | |
| color = [f'rgba({rgba[0]}, {rgba[1]}, {rgba[2]}, 0.8)' for rgba in rgbas] | |
| if dim == '2D': | |
| fig.add_trace(go.Scatter(x=subset[:,0], y=subset[:,1], mode='text', text=str(digit), textfont={'size': 16, 'color': color})) | |
| elif dim == '3D': | |
| fig.add_trace(go.Scatter3d(x=subset[:,0], y=subset[:,1], z=subset[:,2], mode='text', text=str(digit), textfont={'size': 16, 'color': color})) | |
| fig.update_traces(showlegend=False) | |
| return fig | |
| title = '# Agglomerative Clustering on MNIST' | |
| description = """ | |
| An illustration of various linkage option for [agglomerative clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) on the digits dataset. | |
| The goal of this example is to show intuitively how the metrics behave, and not to find good clusters for the digits. | |
| What this example shows us is the behavior of "rich getting richer" in agglomerative clustering, which tends to create uneven cluster sizes. | |
| This behavior is pronounced for the average linkage strategy, which ends up with a couple of clusters having few data points. | |
| The case of single linkage is even more pathological, with a very large cluster covering most digits, an intermediate-sized (clean) cluster with mostly zero digits, and all other clusters being drawn from noise points around the fringes. | |
| The other linkage strategies lead to more evenly distributed clusters, which are therefore likely to be less sensitive to random resampling of the dataset. | |
| """ | |
| author = ''' | |
| Created by [@Hnabil](https://huggingface.co/Hnabil) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/cluster/plot_digits_linkage.html) | |
| ''' | |
| with gr.Blocks(analytics_enabled=False, title=title) as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| gr.Markdown(author) | |
| with gr.Row(): | |
| with gr.Column(): | |
| linkage = gr.Radio(["ward", "average", "complete", "single"], value="average", interactive=True, label="Linkage Method") | |
| dim = gr.Radio(['2D', '3D'], label='Embedding Dimensionality', value='2D') | |
| btn = gr.Button('Submit') | |
| with gr.Column(): | |
| plot = gr.Plot(label='MNIST Embeddings') | |
| btn.click(plot_clustering, inputs=[linkage, dim], outputs=[plot]) | |
| demo.load(plot_clustering, inputs=[linkage, dim], outputs=[plot]) | |
| demo.launch() |