Variational Autoencoder Toolkit

Clustering Notebook

Imports

import os
import wget
import h5py
import gdown

import numpy as np
from scipy.ndimage import zoom
import skimage
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
import pickle
from IPython.display import display, HTML
import ipywidgets
from ipywidgets import interact

Load Data

# filedir = "data/"
# zenodo_url = "https://zenodo.org/record/4555979/files/"

# model_files = [
#     'Sm_0_1_HAADF.h5',
#     'Sm_0_1_UCParameterization.h5',
#     'Sm_7_0_HAADF.h5',
#     'Sm_7_0_UCParameterization.h5',
#     'SM_10_0_HAADF.h5',
#     'Sm_10_0_UCParameterization.h5',
#     'Sm_13_0_HAADF.h5',
#     'Sm_13_0_UCParameterization.h5',
#     'Sm_20_1_HAADF.h5',
#     'Sm_20_1_UCParameterization.h5'
# ]

# for model_file in model_files:
    
#     file_name = os.path.join(filedir, model_file)
#     zenodo_name = zenodo_url + model_file + "?download=1"
    
#     print(file_name)
#     if os.path.exists(file_name):
#         continue
    
#     wget.download(zenodo_name, out=file_name)
# #image files
# composition_tags = [0,7,10,13,20]    #Sm composition %


# img_filename = ['Sm_0_1_HAADF.h5',
#                 'Sm_7_0_HAADF.h5',
#                 'SM_10_0_HAADF.h5',
#                 'Sm_13_0_HAADF.h5',
#                 'Sm_20_1_HAADF.h5']

# imnum = len(img_filename)

# #paramterization files

# UCparam_filename = ['Sm_0_1_UCParameterization.h5',
#                     'Sm_7_0_UCParameterization.h5',
#                     'Sm_10_0_UCParameterization.h5',
#                     'Sm_13_0_UCParameterization.h5',
#                     'Sm_20_1_UCParameterization.h5']

# #load parameter files
# UCparam = []
# for x in UCparam_filename:
#   print('loading parameterization file: ', os.path.join(filedir, x))
#   temp = h5py.File(os.path.join(filedir, x), 'r')
#   UCparam.append(temp)

# #load images
# imgdata = []
# for x in img_filename:
#   print('loading image file: ', os.path.join(filedir, x))
#   temp = h5py.File(os.path.join(filedir, x), 'r')['MainImage']
#   imgdata.append(temp)

# print('UC parameterization:', [k for k in UCparam[0].keys()])

Physical descriptors: polarization, strain, lattice parameters

# #function maps x,y grid positions into a matrix data format
# def map2grid(inab, inVal):

#   default_val = np.nan
#   abrng = [int(np.min(inab[:,0])), int(np.max(inab[:,0])), int(np.min(inab[:,1])), int(np.max(inab[:,1]))]
#   abind = inab
#   abind[:,0] -= abrng[0]
#   abind[:,1] -= abrng[2]
#   Valgrid = np.empty((abrng[1]-abrng[0]+1,abrng[3]-abrng[2]+1))
#   Valgrid[:] = default_val
#   Valgrid[abind[:,0].astype(int),abind[:,1].astype(int)]=inVal[:]
#   return Valgrid, abrng
# SBFOdata = []     #this will be the output list of dictionaries for each dataset

# for i in np.arange(imnum):
#   temp_dict = {'Index': i}
#   temp_dict['Composition'] = composition_tags[i]
#   temp_dict['Image'] = imgdata[i]
#   temp_dict['Filename'] = img_filename[i]

#   for k in UCparam[i].keys():       #add labels for UC parameterization
#     temp_dict[k] = UCparam[i][k][()]

#   #select values mapped to ab grid
#   temp_dict['ab_a'] = map2grid(UCparam[i]['ab'][()].T, UCparam[i]['ab'][()].T[:,0])[0]       #a array
#   temp_dict['ab_b'] = map2grid(UCparam[i]['ab'][()].T, UCparam[i]['ab'][()].T[:,1])[0]       #b array
#   temp_dict['ab_x'] = map2grid(UCparam[i]['ab'][()].T, UCparam[i]['xy_COM'][()].T[:,0])[0]   #x array
#   temp_dict['ab_y'] = map2grid(UCparam[i]['ab'][()].T, UCparam[i]['xy_COM'][()].T[:,1])[0]   #y array
#   temp_dict['ab_Px'] = map2grid(UCparam[i]['ab'][()].T, UCparam[i]['Pxy'][0])[0]             #Px array
#   temp_dict['ab_Py'] = map2grid(UCparam[i]['ab'][()].T, UCparam[i]['Pxy'][1])[0]        #Py array
#   temp_dict['Vol'] = map2grid(UCparam[i]['ab'][()].T, UCparam[i]['Vol'])[0]     #Vol array

#   SBFOdata.append(temp_dict)
## Define the main area to be highlighted with a red square
# main = [1000, 3000, 400, 2400]  # [y_start, y_end, x_start, x_end]

# # Example: Resizing ab_Px and ab_Py, and plotting the selected region
# for i in np.arange(imnum):
#     img_shape = SBFOdata[i]["Image"].shape  # Target shape
#     px_shape = SBFOdata[i]["ab_Px"].shape  # Current shape of ab_Px
#     py_shape = SBFOdata[i]["ab_Py"].shape  # Current shape of ab_Py

#     # Calculate the zoom factor for resizing ab_Px and ab_Py
#     zoom_factors_px = [img_shape[0] / px_shape[0], img_shape[1] / px_shape[1]]
#     zoom_factors_py = [img_shape[0] / py_shape[0], img_shape[1] / py_shape[1]]

#     # Resize ab_Px and ab_Py to match the Image shape and ensure they are exactly the same size
#     SBFOdata[i]["ab_Px_resized"] = zoom(SBFOdata[i]["ab_Px"], zoom_factors_px, order=1)
#     SBFOdata[i]["ab_Py_resized"] = zoom(SBFOdata[i]["ab_Py"], zoom_factors_py, order=1)

#     # Ensure the resized arrays match the image shape exactly (if rounding issues occur)
#     SBFOdata[i]["ab_Px_resized"] = SBFOdata[i]["ab_Px_resized"][:img_shape[0], :img_shape[1]]
#     SBFOdata[i]["ab_Py_resized"] = SBFOdata[i]["ab_Py_resized"][:img_shape[0], :img_shape[1]]

# # Create figure with subplots for the selected data points
# fig, ax = plt.subplots(nrows=3, ncols=5, figsize=(3*5, 3*3), dpi=200)

# for j, idx in enumerate(np.arange(imnum)):
#     k = SBFOdata[idx]

#     # Image - select the region
#     selected_image = k['Image'][main[0]:main[1], main[2]:main[3]]
#     ax[0, j].imshow(selected_image, origin='upper', cmap='gray')
#     ax[0, j].set_title(f"{k['Index']}: {k['Composition']}%", fontsize=24, fontweight="bold")
#     ax[0, j].set_axis_off()
#     ax[0, j].invert_yaxis()

#     # Px - select the region
#     selected_px = k['ab_Px_resized'][main[0]:main[1], main[2]:main[3]]
#     ax[1, j].imshow(selected_px, origin='upper', cmap='jet')
#     ax[1, j].set_axis_off()
#     ax[1, j].invert_yaxis()

#     # Py - select the region
#     selected_py = k['ab_Py_resized'][main[0]:main[1], main[2]:main[3]]
#     ax[2, j].imshow(selected_py, origin='upper', cmap='jet')
#     ax[2, j].set_axis_off()
#     ax[2, j].invert_yaxis()

# plt.tight_layout()
# plt.show()
# # Let's create new lists to store the selected images, Px, and Py data
# selected_images = []
# ground_truth_px = []
# ground_truth_py = []

# # Define the main area to be highlighted
# main = [1000, 3000, 400, 2400]  # [y_start, y_end, x_start, x_end]

# # Loop over the selected indices, extract the region, and store it
# for i in np.arange(imnum):
#     k = SBFOdata[i]

#     # Select the region from the image, Px, and Py
#     selected_image = k['Image'][main[0]:main[1], main[2]:main[3]]
#     selected_px = k['ab_Px_resized'][main[0]:main[1], main[2]:main[3]]
#     selected_py = k['ab_Py_resized'][main[0]:main[1], main[2]:main[3]]

#     # Append the selected regions to the corresponding lists
#     selected_images.append(selected_image)
#     ground_truth_px.append(selected_px)
#     ground_truth_py.append(selected_py)
# # Convert lists to NumPy arrays with lower precision (float32) to reduce size
# selected_images_array = np.array(selected_images, dtype=np.float32)
# ground_truth_px_array = np.array(ground_truth_px, dtype=np.float32)
# ground_truth_py_array = np.array(ground_truth_py, dtype=np.float32)
# # I will now display a confirmation of the stored data.
# print(len(selected_images), len(ground_truth_px), len(ground_truth_py))

# images_data = "/Users/kbarakat/variational-autoencoders/notebooks/data/images_data.pkl"
# with open(images_data, "wb") as f:
#     pickle.dump((selected_images, ground_truth_px, ground_truth_py), f)
id="1AHlk5xxXiuiTtYNr8fk0YQ8Uxjbf8bfT"
if not os.path.exists("data/images_data.pkl"):
    gdown.download(id=id,fuzzy=True,output="data/")
Downloading...
From (original): https://drive.google.com/uc?id=1AHlk5xxXiuiTtYNr8fk0YQ8Uxjbf8bfT
From (redirected): https://drive.google.com/uc?id=1AHlk5xxXiuiTtYNr8fk0YQ8Uxjbf8bfT&confirm=t&uuid=194937c3-d491-493d-a2ea-cbd0f0a82e1a
To: /Users/gvarnavides/Documents/myst-sites/variational-autoencoders/notebooks/data/images_data.pkl
100%|█| 480M/480M [00:53<00:00
# ! gdown --fuzzy --id 1AHlk5xxXiuiTtYNr8fk0YQ8Uxjbf8bfT
# Load the lists from the pickle file
images_data = "data/images_data.pkl"

with open(images_data, "rb") as f:
    selected_images, ground_truth_px, ground_truth_py = pickle.load(f)

# Confirm successful loading by checking the lengths of the lists
print(len(selected_images), len(ground_truth_px), len(ground_truth_py))
5 5 5
# min-max normalization:
def norm2d(img: np.ndarray) -> np.ndarray:
    return (img - np.min(img)) / (np.max(img) - np.min(img))
def custom_extract_subimages(imgdata, coordinates, w_prime):
    # Stage 1: Extract subimages with a fixed size (64x64)
    large_window_size = (64, 64)
    half_height_large = large_window_size[0] // 2
    half_width_large = large_window_size[1] // 2
    subimages_largest = []
    coms_largest = []

    for coord in coordinates:
        cx = int(np.around(coord[0]))
        cy = int(np.around(coord[1]))
        top = max(cx - half_height_large, 0)
        bottom = min(cx + half_height_large, imgdata.shape[0])
        left = max(cy - half_width_large, 0)
        right = min(cy + half_width_large, imgdata.shape[1])

        subimage = imgdata[top:bottom, left:right]
        if subimage.shape[0] == large_window_size[0] and subimage.shape[1] == large_window_size[1]:
            subimages_largest.append(subimage)
            coms_largest.append(coord)

    # Stage 2: Use these centers to extract subimages of window size `w1`
    half_height = w_prime[0] // 2
    half_width = w_prime[1] // 2
    subimages_target = []
    coms_target = []

    for coord in coms_largest:
        cx = int(np.around(coord[0]))
        cy = int(np.around(coord[1]))
        top = max(cx - half_height, 0)
        bottom = min(cx + half_height, imgdata.shape[0])
        left = max(cy - half_width, 0)
        right = min(cy + half_width, imgdata.shape[1])

        subimage = imgdata[top:bottom, left:right]
        if subimage.shape[0] == w_prime[0] and subimage.shape[1] == w_prime[1]:
            subimages_target.append(subimage)
            coms_target.append(coord)

    return np.array(subimages_target), np.array(coms_target)
def build_descriptor(window_size, min_sigma, max_sigma, threshold, overlap):

    processed_img = img

    all_atoms = skimage.feature.blob_log(processed_img, min_sigma, max_sigma, 30, threshold, overlap)
    coordinates = all_atoms[:, : -1]
    # Extract subimages
    subimages_target, coms_target = custom_extract_subimages(processed_img, coordinates, window_size)
    # Build descriptors
    descriptors = [subimage.flatten() for subimage in subimages_target]
    descriptors = np.array(descriptors)

    return descriptors, coms_target, all_atoms, coordinates, subimages_target
# Define the Fit_GMM_param function without PCA, including covariance type
def Fit_GMM(descriptors, components, covariance_type):
    # First pass of GMM to estimate initial parameters

    # Flatten each subimage into a 1D vector
    flattened_descriptors = descriptors.reshape(descriptors.shape[0], -1)
    # Remove subimages with NaN values
    mask = ~np.isnan(flattened_descriptors).any(axis=1)
    valid_subimages = flattened_descriptors[mask]

    preliminary_gmm = GaussianMixture(n_components=components, covariance_type=covariance_type, random_state=42)
    preliminary_gmm.fit(valid_subimages)
    initial_means = preliminary_gmm.means_
    initial_weights = preliminary_gmm.weights_

    # Initialize and fit the GMM using the parameters from the preliminary GMM
    gmm = GaussianMixture(n_components=components,
                          means_init=initial_means,
                          weights_init=initial_weights,
                          covariance_type=covariance_type,
                          random_state=42)

    gmm.fit(valid_subimages)

    # Map the labels back to the original data, including NaN-handling
    labels = gmm.predict(valid_subimages)

    full_labels = np.full(valid_subimages.shape[0], -1)
    full_labels[mask] = labels

    return labels, valid_subimages
def Fit_PCA_GMM(descriptors, n_clusters, components, covariance_type):

    # Flatten each subimage into a 1D vector
    flattened_descriptors = descriptors.reshape(descriptors.shape[0], -1)

    # Remove subimages with NaN values
    mask = ~np.isnan(flattened_descriptors).any(axis=1)
    valid_subimages = flattened_descriptors[mask]

    # Apply PCA for dimensionality reduction
    pca = PCA(n_components=n_clusters)
    reduced_data = pca.fit_transform(valid_subimages)

    # Fit the preliminary GMM using valid subimages (without NaN values)
    preliminary_gmm = GaussianMixture(n_components=components, covariance_type=covariance_type, random_state=42)
    preliminary_gmm.fit(reduced_data)
    initial_means = preliminary_gmm.means_
    initial_weights = preliminary_gmm.weights_

    # Initialize the GMM with the parameters from the preliminary fit
    gmm = GaussianMixture(n_components=components,
                          means_init=initial_means,
                          weights_init=initial_weights,
                          covariance_type=covariance_type,
                          random_state=42)
    gmm.fit(reduced_data)

    # Predict labels for the valid subimages
    labels = gmm.predict(reduced_data)

    # Map the labels back to the original data, including NaN-handling
    full_labels = np.full(flattened_descriptors.shape[0], -1)  # Initialize full labels with -1
    full_labels[mask] = labels  # Only set labels for valid subimages

    return full_labels, reduced_data

Select Image of Interest

In this analysis, we focus on a selected STEM image from a larger dataset.

  1. Detecting Atomic Features in the STEM image using the blob_log function from skimage.feature identifies potential atomic positions using a Laplacian of Gaussian (LoG) approach.

  2. Extracting Subimages: using the detected coordinates, the function custom_extract_subimages is called to generate fixed-size subimages around each detected atomic feature.

  3. Flattening Subimages for Descriptor Generation: Each subimage is flattened into a one-dimensional array, creating a consistent descriptor format for further analysis or machine learning applications.

image = selected_images[0]
img = norm2d(image)

example descriptor to test

# window_size = (40,40)
# min_sigma = 1
# max_sigma = 5
# threshold = 0.025
# overlap = 0.0
# descriptors, coms_target, all_atoms, coordinates, subimages_target = build_descriptor(window_size, min_sigma, max_sigma, threshold, overlap)
# print(descriptors.shape)
# print(coms_target.shape)
# print(all_atoms.shape)
# print(coordinates.shape)
# print(subimages_target.shape)
# plt.figure(figsize=(6, 6))
# plt.imshow(image, cmap='gray')
# plt.scatter(coms_target[:, 1], coms_target[:, 0], c='r', marker='o', s = 20
#             )
# plt.axis('off')
# plt.title('Image with Subimage Centers')
# plt.xlim([300, 700])
# plt.ylim([300, 700])
# plt.show()

# # Plot a few example subimages with their centers
# fig, axes = plt.subplots(1, 4, figsize=(12, 3))
# for i, ax in enumerate(axes):
#     ax.imshow(subimages_target[i], cmap='gray')
#     ax.scatter(window_size[1] // 2, window_size[0] // 2, c='r', marker='o', s = 100)
#     ax.set_title(f'Subimage {i+1}', fontweight = "bold")
# plt.show()
# # Calculate and visualize the average descriptor
# average_descriptor = descriptors.mean(axis=0).reshape(window_size)
# # plt.figure(figsize=(6, 6))
# # plt.imshow(average_descriptor, cmap='viridis')
# # plt.colorbar()
# # plt.title("Average Descriptor of All Subimages")
# # # plt.axis('off')
# # plt.show()

GMM clustering

Here, we’re applying clustering to the descriptors extracted from subimages to group similar structural features and then visualizing the clusters to better understand the patterns in the data.

Clustering with Gaussian Mixture Model (GMM)

  1. GMM on Raw Descriptors (Fit_GMM):

    The function Fit_GMM is applied directly to the descriptors, with 5 clusters specified and a “full” covariance type. This model assumes the data can be represented by a mixture of 5 Gaussian distributions with unrestricted covariance matrices (i.e., allowing each cluster to have its unique shape).

    Outputs: labels: Cluster labels assigned to each subimage, indicating the group each subimage belongs to.

    valid_subimages: The subset of descriptors that were successfully clustered.

  2. PCA and GMM Combined (Fit_PCA_GMM):

    The Fit_PCA_GMM function applies Principal Component Analysis (PCA) to reduce the dimensionality of the descriptors from high-dimensional space down to 2 principal components (PCs).

# labels, valid_subimages  = Fit_GMM(descriptors, 5, "full")
# labels_pca, reduced_data_pca  = Fit_PCA_GMM(descriptors, 2, 5, "full")
# fig , axes = plt.subplots(1, 2 , figsize = (12, 5))

# axes[0].scatter(valid_subimages[:, 0], valid_subimages[:, 1], c=labels, s=20, cmap='jet', edgecolor='k')
# axes[0].set_title('GMM clustering' , fontsize = 16, fontweight = "bold")
# axes[0].set_xlabel('C1')
# axes[0].set_ylabel('C2')

# axes[1].scatter(reduced_data_pca[:, 0], reduced_data_pca[:, 1], c=labels_pca, s=20, cmap='jet', edgecolor='k')
# axes[1].set_title('GMM PCA clustering' , fontsize = 16, fontweight = "bold")
# axes[1].set_xlabel('PC1')
# axes[1].set_ylabel('PC2')

# plt.show()
# # Initialize lists to store the centroids and dispersions
# cluster_centroids = []
# cluster_dispersions = []

# # Calculate the overall mean descriptor
# overall_mean = descriptors.mean(axis=0)

# # Calculate the centroids and dispersions for each cluster
# for cluster_label in np.unique(labels):
#     # Get descriptors for the current cluster
#     cluster_descriptors = descriptors[labels == cluster_label]
#     # Calculate centroid and dispersion (standard deviation) for each cluster
#     centroid = cluster_descriptors.mean(axis=0) - overall_mean
#     dispersion = cluster_descriptors.std(axis=0)
#     # Append results
#     cluster_centroids.append(centroid)
#     cluster_dispersions.append(dispersion)

# # Convert to NumPy arrays for easy handling
# cluster_centroids = np.array(cluster_centroids)
# cluster_dispersions = np.array(cluster_dispersions)

# # Plot the centroids and dispersions
# fig, axes = plt.subplots(2, len(cluster_centroids), figsize=(15, 10))

# # First row: Cluster centroids (centered)
# for i, ax in enumerate(axes[0]):
#     centroid_image = cluster_centroids[i].reshape(window_size)
#     im_centroid = ax.imshow(centroid_image, cmap='coolwarm', aspect='auto')
#     ax.set_title(f'Cluster {i+1} Centroid', fontweight="bold")
#     ax.axis('off')

# # Second row: Cluster dispersions (standard deviation within each cluster)
# for i, ax in enumerate(axes[1]):
#     dispersion_image = cluster_dispersions[i].reshape(window_size)
#     im_dispersion = ax.imshow(dispersion_image, cmap='viridis', aspect='auto')
#     ax.set_title(f'Cluster {i+1} Dispersion', fontweight="bold")
#     ax.axis('off')

# # # Add colorbars for interpretation
# # fig.colorbar(im_centroid, ax=axes[0], orientation='vertical', fraction=0.02, pad=0.04)
# # fig.colorbar(im_dispersion, ax=axes[1], orientation='vertical', fraction=0.02, pad=0.04)

# plt.suptitle("Cluster Centroids and Dispersion (Standard Deviation) Across Descriptors", fontsize=16, fontweight="bold")
# plt.tight_layout()
# plt.show()
# fig, ax = plt.subplots(1, 2, figsize=(12, 5))

# # First subplot
# ax[0].scatter(coms_target[:, 1], coms_target[:, 0], c=labels, s=10, cmap='jet', marker="s")
# ax[0].set_title('GMM' , fontsize = 16, fontweight = "bold")
# ax[0].axis('off')

# ax[1].scatter(coms_target[:, 1], coms_target[:, 0], c=labels_pca, s=10, cmap='jet', marker="s")
# ax[1].set_title('PCA GMM' , fontsize = 16, fontweight = "bold")
# ax[1].axis('off')  # Just to leave it empty for now

# plt.show()
# # # interactive_clustering with all involved hyprparameters

def interactive_clustering(window_width, window_height, n_components, covariance_type):
    # Parameters
    window_size = (window_width, window_height)
    min_sigma = 1
    max_sigma = 5
    threshold = 0.025
    overlap = 0.0
    
    # Build descriptors and fit models
    descriptors, coms_target, _, _, _ = build_descriptor(window_size, min_sigma, max_sigma, threshold, overlap)
    descriptors = np.array(descriptors, dtype=np.float32)
    coms_target = np.array(coms_target, dtype=np.float32)
    
    # Use the chosen covariance_type here
    labels_gmm, subimages_gmm = Fit_GMM(descriptors, n_components, covariance_type)
    labels_pca, subimages_pca = Fit_PCA_GMM(descriptors, 2, n_components, covariance_type)
    
    # Create a 2x2 plot
    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    
    # GMM Clusters
    axes[0, 0].scatter(subimages_gmm[:, 0], subimages_gmm[:, 1], c=labels_gmm, s=20, cmap='jet', edgecolor='k')
    axes[0, 0].set_title('GMM Clusters', fontsize=16, fontweight="bold")
    axes[0, 0].set_xlabel('C1', fontsize=16, fontweight="bold")
    axes[0, 0].set_ylabel('C2', fontsize=16, fontweight="bold")
    for label in axes[0, 0].get_xticklabels():
        label.set_fontsize(12)  # Change font size
        label.set_weight("bold")
    for label in axes[0, 0].get_yticklabels():
        label.set_fontsize(12)  # Change font size
        label.set_weight("bold")

    # PCA Clusters
    axes[0, 1].scatter(subimages_pca[:, 0], subimages_pca[:, 1], c=labels_pca, s=20, cmap='jet', edgecolor='k')
    axes[0, 1].set_title('PCA GMM Clusters', fontsize=16, fontweight="bold")
    axes[0, 1].set_xlabel('PC1', fontsize=16, fontweight="bold")
    axes[0, 1].set_ylabel('PC2', fontsize=16, fontweight="bold")
    for label in axes[0, 1].get_xticklabels():
        label.set_fontsize(12)  # Change font size
        label.set_weight("bold")
    for label in axes[0, 1].get_yticklabels():
        label.set_fontsize(12)  # Change font size
        label.set_weight("bold")

    # GMM Final Maps
    axes[1, 0].scatter(coms_target[:, 1], coms_target[:, 0], c=labels_gmm, s=8, cmap='jet', marker="s")
    axes[1, 0].set_title('GMM Final Map', fontsize=16, fontweight="bold")
    axes[1, 0].axis('off')

    # PCA Final Maps
    axes[1, 1].scatter(coms_target[:, 1], coms_target[:, 0], c=labels_pca, s=8, cmap='jet', marker="s")
    axes[1, 1].set_title('PCA Final Map', fontsize=16, fontweight="bold")
    axes[1, 1].axis('off')

    # Adjust layout
    fig.tight_layout()
display(HTML("""
    <style>
        .widget-label { font-size: 16px; font-weight: bold; }
        .widget-readout { font-size: 16px; font-weight: bold; } /* This makes the value bold */
    </style>
"""))

# Define a common layout for all widgets to ensure uniformity
slider_layout = ipywidgets.Layout(width='400px')

interact(
    interactive_clustering, 
    window_width=ipywidgets.IntSlider(
        min=2, max=64, step=2, value=2, 
        description='Width', continuous_update=False, 
        layout=slider_layout,
        style={'description_width': 'initial'}
    ),
    window_height=ipywidgets.IntSlider(
        min=2, max=64, step=2, value=2, 
        description='Height', continuous_update=False, 
        layout=slider_layout,
        style={'description_width': 'initial'}
    ),
    n_components=ipywidgets.IntSlider(
        min=2, max=6, step=1, value=2, 
        description='Components', continuous_update=False, 
        layout=slider_layout,
        style={'description_width': 'initial'}
    ),
    covariance_type=ipywidgets.SelectionSlider(
        options=['full', 'tied', 'diag', 'spherical'], 
        value="full", description="Covariance Type", 
        continuous_update=False, 
        layout=slider_layout,
        style={'description_width': 'initial'}
    )
);