Contents
Nanoparticle structure
%matplotlib ipympl
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from itertools import chain, product
from functools import reduce
import colorspacious
from matplotlib.colors import LinearSegmentedColormap
from ipywidgets import (
Select,
SelectMultiple,
IntSlider,
ToggleButton,
FloatRangeSlider,
Label,
Layout,
)
from IPython.display import display
from ipywidgets import HBox, VBox
plt.rcParams["legend.frameon"] = False
plt.rcParams["legend.fontsize"] = 8
plt.rcParams["legend.title_fontsize"] = 8
plt.rcParams["xtick.labelsize"] = 8
plt.rcParams["ytick.labelsize"] = 8with open("mini_size_df.pkl", "rb") as f:
size_df = pickle.load(f)def reduction(series, mode):
match mode:
case "mean":
return series.mean()
case "min":
return series.min()
case "max":
return series.max()
case "median":
return series.median()
case _:
return ValueError
def weight_freezing_select():
return SelectMultiple(
options=["none", "decoder", "encoder"],
value=["none", "decoder", "encoder"],
# description="Weight freezing",
rows=3,
disabled=False,
layout={"width": "100px"},
)
def reduction_select():
return Select(
options=["min", "mean", "median", "max"],
value="mean",
rows=4,
# description="Data Reduction",
disabled=False,
layout={"width": "100px"},
)lattice_data = [
"bcc_Fe",
"dc_C",
"dc_Si",
"fcc_Ag",
"fcc_Cu",
"spinel_Co3O4",
"spinel_Fe3O4",
"w_CdSe",
"w_ZnS",
]
lattices = set(
reduce(lambda x, y: x + y, [[l + "_small", l + "_large"] for l in lattice_data])
)
lattices = lattices - set(("spinel_Co3O4_large", "spinel_Fe3O4_large"))
lattices = sorted(tuple(lattices))def get_plot_data(df, lattices, reduction_mode):
N = len(lattices)
p_heatmap = np.zeros((N, N))
t_heatmap = np.zeros((N, N))
for i, starting_lattice in enumerate(lattices):
p_df = df.query(f"pretrain_structure == '{starting_lattice}'")
for j, transfer_lattice in enumerate(lattices):
p_heatmap[i, j] = reduction(
p_df[f"best_pretrain_performance_validation_{transfer_lattice}"],
reduction_mode,
)
t_df = p_df.query(f"transfer_structure=='{transfer_lattice}'")
if len(t_df) > 0:
t_heatmap[i, j] = reduction(
t_df[f"best_transfer_performance_validation_{transfer_lattice}"],
reduction_mode,
)
else:
t_heatmap[i, j] = p_heatmap[i, j]
return p_heatmap, t_heatmapheight = 8
width = 6.4
fig_style = {
"figsize": (width, height), # inches
"constrained_layout": True,
"dpi":96
}
fig = plt.figure(**fig_style)
N_per_mat = 12
gs = fig.add_gridspec(21, 1)
subfig_structure = fig.add_subfigure(gs[:8])
structure_gs = subfig_structure.add_gridspec(1, N_per_mat * 2 + 1)
ax_p = subfig_structure.add_subplot(structure_gs[:N_per_mat])
ax_t = subfig_structure.add_subplot(structure_gs[N_per_mat:-1])
ax_cbar = subfig_structure.add_subplot(structure_gs[-1])
cmap = sns.color_palette("crest_r", as_cmap=True)
pmin = 0.0
pmax = 0.1
subfig_size = fig.add_subfigure(gs[9:])
size_gs = subfig_size.add_gridspec(2, 2 * N_per_mat + 1)
ax_ps = subfig_size.add_subplot(size_gs[0, 1 : N_per_mat - 1])
ax_ts = subfig_size.add_subplot(size_gs[0, N_per_mat - 1 : -4])
ax_pl = subfig_size.add_subplot(size_gs[1, 1 : N_per_mat - 1])
ax_tl = subfig_size.add_subplot(size_gs[1, N_per_mat - 1 : -4])
Nc = 11
palette = sns.color_palette("crest_r", Nc)
common_style = {
"kde": True,
"fill": True,
"multiple": "layer",
"stat": "percent",
"common_bins": True,
"element": "step",
"cumulative": False,
"alpha": 0.3,
"binwidth": 5e-3,
"kde_kws": {
"bw_adjust": 1.0,
},
"palette": [
palette[Nc // 2 + 1],
palette[0],
],
}
freeze_select = weight_freezing_select()
rmode_select = reduction_select()
center_align = Layout(align_items="center")
freeze_box = VBox([Label("Weight Freezing"), freeze_select], layout=center_align)
rmode_box = VBox([Label("Reduction Mode"), rmode_select], layout=center_align)
full_box = HBox([freeze_box, rmode_box])
display(full_box)
def update_plot(*args):
weight_freezing = freeze_select.value
reduction_mode = rmode_select.value
local_df = size_df.query(f"freezing_mode in {weight_freezing}")
lattices = (
"w_CdSe_small",
"w_ZnS_small",
"bcc_Fe_small",
"fcc_Ag_small",
"fcc_Cu_small",
"spinel_Co3O4_small",
"spinel_Fe3O4_small",
"dc_C_small",
"dc_Si_small",
)
p_heatmap, t_heatmap = get_plot_data(local_df, lattices, reduction_mode)
cim = ax_p.matshow(p_heatmap - 0.3133, cmap=cmap, vmin=pmin, vmax=pmax)
ax_t.matshow(t_heatmap - 0.3133, cmap=cmap, vmin=pmin, vmax=pmax)
fig.colorbar(cim, cax=ax_cbar)
ax_t.set_yticklabels([])
ax_p.set_ylabel("Pretrain structure", fontsize=10)
subfig_structure.supxlabel(
"Target Structure", fontsize=10, x=N_per_mat / (2 * N_per_mat + 1)
)
ax_p.xaxis.set_ticks_position("bottom")
ax_t.xaxis.set_ticks_position("bottom")
ax_p.set_title("After pretraining", fontsize=10)
ax_t.set_title("After transfer learning", fontsize=10)
lattice_labels = [s.split("_")[1] for s in lattices]
lattice_labels[5] = "Co$_3$O$_4$"
lattice_labels[6] = "Fe$_3$O$_4$"
ax_p.set_xticks(list(range(len(lattices))), lattice_labels, rotation=45)
ax_p.set_yticks(list(range(len(lattices))), lattice_labels)
ax_t.set_xticks(list(range(len(lattices))), lattice_labels, rotation=45)
ax_cbar.set_ylabel("Loss")
ax_cbar.set_yticks(np.arange(6) * pmax / 5)
for ax in [ax_t, ax_p, ax_cbar]:
_ = plt.setp(ax.spines.values(), linewidth=1.25)
lattices = (
"bcc_Fe_small",
"fcc_Ag_small",
"fcc_Cu_small",
"w_CdSe_small",
"w_ZnS_small",
"dc_C_small",
"dc_Si_small",
"bcc_Fe_large",
"fcc_Ag_large",
"fcc_Cu_large",
"w_CdSe_large",
"w_ZnS_large",
"dc_C_large",
"dc_Si_large",
)
psmall_df = local_df.melt(
id_vars=["Pretrain Size", "Transfer Size"],
value_vars=[
f"best_pretrain_performance_validation_{lattice}"
for lattice in lattices
if "small" in lattice
],
)
plarge_df = local_df.melt(
id_vars=["Pretrain Size", "Transfer Size"],
value_vars=[
f"best_pretrain_performance_validation_{lattice}"
for lattice in lattices
if "large" in lattice
],
)
tsmall_df = local_df.melt(
id_vars=["Pretrain Size", "Transfer Size"],
value_vars=[
f"best_transfer_performance_validation_{lattice}"
for lattice in lattices
if "small" in lattice
],
)
tlarge_df = local_df.melt(
id_vars=["Pretrain Size", "Transfer Size"],
value_vars=[
f"best_transfer_performance_validation_{lattice}"
for lattice in lattices
if "large" in lattice
],
)
psmall_df["value"] = psmall_df["value"].apply(lambda x: x - 0.3133)
plarge_df["value"] = plarge_df["value"].apply(lambda x: x - 0.3133)
tsmall_df["value"] = tsmall_df["value"].apply(lambda x: x - 0.3133)
tlarge_df["value"] = tlarge_df["value"].apply(lambda x: x - 0.3133)
ax_ps.clear()
ax_ts.clear()
ax_pl.clear()
ax_tl.clear()
sns.histplot(
ax=ax_ps,
data=psmall_df,
x="value",
hue="Pretrain Size",
**common_style,
)
sns.histplot(
ax=ax_ts,
data=tsmall_df,
x="value",
hue="Transfer Size",
**common_style,
)
sns.histplot(
ax=ax_pl,
data=plarge_df,
x="value",
hue="Pretrain Size",
**common_style,
)
sns.histplot(
ax=ax_tl,
data=tlarge_df,
x="value",
hue="Transfer Size",
**common_style,
)
for ax in [ax_ps, ax_pl, ax_ts, ax_tl]:
xlim = 0.2
ax.set_xlim([0, xlim])
ax.set_xticks(np.arange(5) * xlim / 4)
ax.set_ylim([0, 10])
_ = plt.setp(ax.spines.values(), linewidth=1.25)
sns.move_legend(ax, 5, bbox_to_anchor=(1.025, 0.75))
for ax in [ax_ps, ax_ts]:
ax.set_xlabel("")
ax.set_xticklabels([])
ax.text(
x=xlim / 2,
y=9.25,
s="Performance on small NPs",
fontsize=8,
horizontalalignment="center",
)
for ax in [ax_pl, ax_tl]:
ax.set_xlabel("Loss")
ax.text(
x=xlim / 2,
y=9.25,
s="Performance on large NPs",
fontsize=8,
horizontalalignment="center",
)
for ax in [ax_ts, ax_tl]:
ax.set_ylabel("")
ax.set_yticklabels([])
update_plot()
freeze_select.observe(update_plot, "value")
rmode_select.observe(update_plot, "value")