Strategies for effective transfer learning for HRTEM analysis with neural networks

Nanoparticle structure

%matplotlib ipympl

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from itertools import chain, product
from functools import reduce
import colorspacious

from matplotlib.colors import LinearSegmentedColormap

from ipywidgets import (
    Select,
    SelectMultiple,
    IntSlider,
    ToggleButton,
    FloatRangeSlider,
    Label,
    Layout,
)

from IPython.display import display
from ipywidgets import HBox, VBox

plt.rcParams["legend.frameon"] = False
plt.rcParams["legend.fontsize"] = 8
plt.rcParams["legend.title_fontsize"] = 8
plt.rcParams["xtick.labelsize"] = 8
plt.rcParams["ytick.labelsize"] = 8
with open("mini_size_df.pkl", "rb") as f:
    size_df = pickle.load(f)
def reduction(series, mode):
    match mode:
        case "mean":
            return series.mean()
        case "min":
            return series.min()
        case "max":
            return series.max()
        case "median":
            return series.median()
        case _:
            return ValueError


def weight_freezing_select():
    return SelectMultiple(
        options=["none", "decoder", "encoder"],
        value=["none", "decoder", "encoder"],
        # description="Weight freezing",
        rows=3,
        disabled=False,
        layout={"width": "100px"},
    )


def reduction_select():
    return Select(
        options=["min", "mean", "median", "max"],
        value="mean",
        rows=4,
        # description="Data Reduction",
        disabled=False,
        layout={"width": "100px"},
    )
lattice_data = [
    "bcc_Fe",
    "dc_C",
    "dc_Si",
    "fcc_Ag",
    "fcc_Cu",
    "spinel_Co3O4",
    "spinel_Fe3O4",
    "w_CdSe",
    "w_ZnS",
]

lattices = set(
    reduce(lambda x, y: x + y, [[l + "_small", l + "_large"] for l in lattice_data])
)
lattices = lattices - set(("spinel_Co3O4_large", "spinel_Fe3O4_large"))
lattices = sorted(tuple(lattices))
def get_plot_data(df, lattices, reduction_mode):
    N = len(lattices)
    p_heatmap = np.zeros((N, N))
    t_heatmap = np.zeros((N, N))
    for i, starting_lattice in enumerate(lattices):
        p_df = df.query(f"pretrain_structure == '{starting_lattice}'")
        for j, transfer_lattice in enumerate(lattices):
            p_heatmap[i, j] = reduction(
                p_df[f"best_pretrain_performance_validation_{transfer_lattice}"],
                reduction_mode,
            )

            t_df = p_df.query(f"transfer_structure=='{transfer_lattice}'")
            if len(t_df) > 0:
                t_heatmap[i, j] = reduction(
                    t_df[f"best_transfer_performance_validation_{transfer_lattice}"],
                    reduction_mode,
                )
            else:
                t_heatmap[i, j] = p_heatmap[i, j]

    return p_heatmap, t_heatmap
height = 8
width = 6.4
fig_style = {
    "figsize": (width, height),  # inches
    "constrained_layout": True,
    "dpi":96
}

fig = plt.figure(**fig_style)

N_per_mat = 12

gs = fig.add_gridspec(21, 1)
subfig_structure = fig.add_subfigure(gs[:8])

structure_gs = subfig_structure.add_gridspec(1, N_per_mat * 2 + 1)
ax_p = subfig_structure.add_subplot(structure_gs[:N_per_mat])
ax_t = subfig_structure.add_subplot(structure_gs[N_per_mat:-1])
ax_cbar = subfig_structure.add_subplot(structure_gs[-1])

cmap = sns.color_palette("crest_r", as_cmap=True)
pmin = 0.0
pmax = 0.1

subfig_size = fig.add_subfigure(gs[9:])

size_gs = subfig_size.add_gridspec(2, 2 * N_per_mat + 1)
ax_ps = subfig_size.add_subplot(size_gs[0, 1 : N_per_mat - 1])
ax_ts = subfig_size.add_subplot(size_gs[0, N_per_mat - 1 : -4])
ax_pl = subfig_size.add_subplot(size_gs[1, 1 : N_per_mat - 1])
ax_tl = subfig_size.add_subplot(size_gs[1, N_per_mat - 1 : -4])

Nc = 11
palette = sns.color_palette("crest_r", Nc)
common_style = {
    "kde": True,
    "fill": True,
    "multiple": "layer",
    "stat": "percent",
    "common_bins": True,
    "element": "step",
    "cumulative": False,
    "alpha": 0.3,
    "binwidth": 5e-3,
    "kde_kws": {
        "bw_adjust": 1.0,
    },
    "palette": [
        palette[Nc // 2 + 1],
        palette[0],
    ],
}


freeze_select = weight_freezing_select()
rmode_select = reduction_select()

center_align = Layout(align_items="center")
freeze_box = VBox([Label("Weight Freezing"), freeze_select], layout=center_align)
rmode_box = VBox([Label("Reduction Mode"), rmode_select], layout=center_align)

full_box = HBox([freeze_box, rmode_box])
display(full_box)


def update_plot(*args):
    weight_freezing = freeze_select.value
    reduction_mode = rmode_select.value

    local_df = size_df.query(f"freezing_mode in {weight_freezing}")

    lattices = (
        "w_CdSe_small",
        "w_ZnS_small",
        "bcc_Fe_small",
        "fcc_Ag_small",
        "fcc_Cu_small",
        "spinel_Co3O4_small",
        "spinel_Fe3O4_small",
        "dc_C_small",
        "dc_Si_small",
    )

    p_heatmap, t_heatmap = get_plot_data(local_df, lattices, reduction_mode)

    cim = ax_p.matshow(p_heatmap - 0.3133, cmap=cmap, vmin=pmin, vmax=pmax)
    ax_t.matshow(t_heatmap - 0.3133, cmap=cmap, vmin=pmin, vmax=pmax)
    fig.colorbar(cim, cax=ax_cbar)

    ax_t.set_yticklabels([])

    ax_p.set_ylabel("Pretrain structure", fontsize=10)
    subfig_structure.supxlabel(
        "Target Structure", fontsize=10, x=N_per_mat / (2 * N_per_mat + 1)
    )
    ax_p.xaxis.set_ticks_position("bottom")
    ax_t.xaxis.set_ticks_position("bottom")

    ax_p.set_title("After pretraining", fontsize=10)
    ax_t.set_title("After transfer learning", fontsize=10)

    lattice_labels = [s.split("_")[1] for s in lattices]
    lattice_labels[5] = "Co$_3$O$_4$"
    lattice_labels[6] = "Fe$_3$O$_4$"
    ax_p.set_xticks(list(range(len(lattices))), lattice_labels, rotation=45)
    ax_p.set_yticks(list(range(len(lattices))), lattice_labels)
    ax_t.set_xticks(list(range(len(lattices))), lattice_labels, rotation=45)

    ax_cbar.set_ylabel("Loss")
    ax_cbar.set_yticks(np.arange(6) * pmax / 5)

    for ax in [ax_t, ax_p, ax_cbar]:
        _ = plt.setp(ax.spines.values(), linewidth=1.25)

    lattices = (
        "bcc_Fe_small",
        "fcc_Ag_small",
        "fcc_Cu_small",
        "w_CdSe_small",
        "w_ZnS_small",
        "dc_C_small",
        "dc_Si_small",
        "bcc_Fe_large",
        "fcc_Ag_large",
        "fcc_Cu_large",
        "w_CdSe_large",
        "w_ZnS_large",
        "dc_C_large",
        "dc_Si_large",
    )

    psmall_df = local_df.melt(
        id_vars=["Pretrain Size", "Transfer Size"],
        value_vars=[
            f"best_pretrain_performance_validation_{lattice}"
            for lattice in lattices
            if "small" in lattice
        ],
    )
    plarge_df = local_df.melt(
        id_vars=["Pretrain Size", "Transfer Size"],
        value_vars=[
            f"best_pretrain_performance_validation_{lattice}"
            for lattice in lattices
            if "large" in lattice
        ],
    )

    tsmall_df = local_df.melt(
        id_vars=["Pretrain Size", "Transfer Size"],
        value_vars=[
            f"best_transfer_performance_validation_{lattice}"
            for lattice in lattices
            if "small" in lattice
        ],
    )
    tlarge_df = local_df.melt(
        id_vars=["Pretrain Size", "Transfer Size"],
        value_vars=[
            f"best_transfer_performance_validation_{lattice}"
            for lattice in lattices
            if "large" in lattice
        ],
    )

    psmall_df["value"] = psmall_df["value"].apply(lambda x: x - 0.3133)
    plarge_df["value"] = plarge_df["value"].apply(lambda x: x - 0.3133)

    tsmall_df["value"] = tsmall_df["value"].apply(lambda x: x - 0.3133)
    tlarge_df["value"] = tlarge_df["value"].apply(lambda x: x - 0.3133)

    ax_ps.clear()
    ax_ts.clear()
    ax_pl.clear()
    ax_tl.clear()

    sns.histplot(
        ax=ax_ps,
        data=psmall_df,
        x="value",
        hue="Pretrain Size",
        **common_style,
    )
    sns.histplot(
        ax=ax_ts,
        data=tsmall_df,
        x="value",
        hue="Transfer Size",
        **common_style,
    )

    sns.histplot(
        ax=ax_pl,
        data=plarge_df,
        x="value",
        hue="Pretrain Size",
        **common_style,
    )
    sns.histplot(
        ax=ax_tl,
        data=tlarge_df,
        x="value",
        hue="Transfer Size",
        **common_style,
    )

    for ax in [ax_ps, ax_pl, ax_ts, ax_tl]:
        xlim = 0.2
        ax.set_xlim([0, xlim])
        ax.set_xticks(np.arange(5) * xlim / 4)
        ax.set_ylim([0, 10])
        _ = plt.setp(ax.spines.values(), linewidth=1.25)

        sns.move_legend(ax, 5, bbox_to_anchor=(1.025, 0.75))

    for ax in [ax_ps, ax_ts]:
        ax.set_xlabel("")
        ax.set_xticklabels([])
        ax.text(
            x=xlim / 2,
            y=9.25,
            s="Performance on small NPs",
            fontsize=8,
            horizontalalignment="center",
        )

    for ax in [ax_pl, ax_tl]:
        ax.set_xlabel("Loss")
        ax.text(
            x=xlim / 2,
            y=9.25,
            s="Performance on large NPs",
            fontsize=8,
            horizontalalignment="center",
        )

    for ax in [ax_ts, ax_tl]:
        ax.set_ylabel("")
        ax.set_yticklabels([])


update_plot()
freeze_select.observe(update_plot, "value")
rmode_select.observe(update_plot, "value")