Coverage for biobb_pytorch/mdae/plots.py: 0%
82 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-21 09:06 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-21 09:06 +0000
1import matplotlib.gridspec as gridspec # type: ignore
2import matplotlib.pyplot as plt # type: ignore
3import numpy as np
4from matplotlib.markers import MarkerStyle # type: ignore
7def plot_loss(output_train_data_npz_path: str) -> None:
8 """
9 Plot the training and validation losses from the given npz file.
11 Args:
12 output_train_data_npz_path (str): The path to the npz file containing the training and validation losses.
13 """
14 npz_file = np.load(output_train_data_npz_path)
15 train_loss = npz_file["train_losses"]
16 val_loss = npz_file["validation_losses"]
17 min_train_loss_idx = np.argmin(train_loss)
18 min_val_loss_idx = np.argmin(val_loss)
19 plt.plot(
20 range(len(train_loss)),
21 train_loss,
22 label=f"Training (min.: {min_train_loss_idx})",
23 color="blue",
24 )
25 plt.plot(
26 range(len(val_loss)),
27 val_loss,
28 label=f"Validation (min.: {min_val_loss_idx})",
29 color="orange",
30 )
31 plt.scatter(
32 min_train_loss_idx,
33 train_loss[min_train_loss_idx],
34 color="blue",
35 marker=MarkerStyle("v"),
36 s=50,
37 )
38 plt.scatter(
39 min_val_loss_idx,
40 val_loss[min_val_loss_idx],
41 color="orange",
42 marker=MarkerStyle("v"),
43 s=50,
44 )
45 plt.legend()
46 plt.ylabel("Total Loss")
47 plt.xlabel("Epochs")
48 plt.title("Training/Validation")
49 plt.show()
52def _numpy_rmsd(reference, trajectory):
53 return np.sqrt(np.mean(np.sum((reference - trajectory) ** 2, axis=2), axis=1))
56def plot_rmsd(traj_file_npy_path, output_reconstructed_traj_npy_path) -> None:
57 perf_data = np.load(traj_file_npy_path)
58 output = np.load(output_reconstructed_traj_npy_path)
59 rmsd_trajectory = _numpy_rmsd(perf_data[0], perf_data) * 10 # Convert to Å
60 rmsd_output = _numpy_rmsd(perf_data[0], output) * 10 # Convert to Å
61 frames = np.arange(len(rmsd_trajectory))
62 fig, ax = plt.subplots(figsize=(20, 6))
63 ax.plot(frames, rmsd_trajectory, color="blue", linewidth=1, label="Original")
64 ax.plot(frames, rmsd_output, color="red", linewidth=1, label="Reconstruction")
65 # Labels, title, and legend
66 ax.set_xlabel("# Frame")
67 ax.set_ylabel("RMSD (Å)")
68 plt.title("RMSD Plot")
69 plt.legend()
70 plt.show()
73def plot_latent_space(latent_space_npy_path: str) -> None:
74 z = np.load(latent_space_npy_path)
75 gs = gridspec.GridSpec(4, 4)
76 fig = plt.figure(figsize=(15, 10))
77 ax_main = plt.subplot(gs[1:4, :3])
78 ax_xDist = plt.subplot(gs[0, :3], sharex=ax_main)
79 ax_yDist = plt.subplot(gs[1:4, 3], sharey=ax_main)
80 sc = ax_main.scatter(
81 z[::1, 0], z[::1, 1], c=np.arange(len(z)), alpha=1, cmap="jet", s=2
82 )
83 # Position and size of colorbar based on ax_yDist
84 pos = ax_yDist.get_position()
85 cbar_ax = fig.add_axes((pos.x1 + 0.01, pos.y0, 0.02, pos.height))
86 cbar = plt.colorbar(sc, cax=cbar_ax)
87 cbar.set_label("Frames")
88 # X-axis marginal distribution
89 ax_xDist.hist(z[::1, 0], bins=100, color="blue", alpha=0.7)
90 # Y-axis marginal distribution
91 ax_yDist.hist(
92 z[::1, 1], bins=100, color="blue", alpha=0.7, orientation="horizontal"
93 )
94 ax_main.set_xlabel("z0", labelpad=20)
95 ax_main.set_ylabel("z1", labelpad=20)
96 plt.show()
99def _numpy_rmsf_by_atom(trajectory):
100 return np.sqrt(
101 np.mean(np.sum((trajectory - np.mean(trajectory, axis=0)) ** 2, axis=2), axis=0)
102 )
105def plot_rmsf(original_traj_npy_file, mutated_reconstructed_traj_npy_file):
106 original_traj = np.load(original_traj_npy_file)
107 mutated_reconstructed_traj = np.load(mutated_reconstructed_traj_npy_file)
108 rmsf_trajectory = _numpy_rmsf_by_atom(original_traj) * 10 # Convert to Å
109 rmsf_output = _numpy_rmsf_by_atom(mutated_reconstructed_traj) * 10 # Convert to Å
110 fig, ax = plt.subplots(figsize=(20, 6))
111 indices = np.arange(len(rmsf_trajectory))
112 ax.plot(indices, rmsf_trajectory, color="blue", linewidth=1, label="Original")
113 ax.plot(indices, rmsf_output, color="red", linewidth=1, label="Reconstruction")
114 ax.set_xlabel("# Atom")
115 ax.set_ylabel("RMSD (Å) Average structure as reference")
116 plt.title("RMSF Plot")
117 plt.legend()
118 plt.show()
121def plot_rmsf_difference(original_traj_npy_file, mutated_reconstructed_traj_npy_file):
122 original_traj = np.load(original_traj_npy_file)
123 mutated_reconstructed_traj = np.load(mutated_reconstructed_traj_npy_file)
124 rmsf_trajectory = _numpy_rmsf_by_atom(original_traj) * 10 # Convert to Å
125 rmsf_output = _numpy_rmsf_by_atom(mutated_reconstructed_traj) * 10 # Convert to Å
126 fig, ax = plt.subplots(figsize=(20, 6))
127 indices = np.arange(len(rmsf_trajectory))
128 # Plot RMSF for diference between input and output
129 ax.plot(
130 indices,
131 (rmsf_trajectory - rmsf_output),
132 color="orange",
133 linewidth=1,
134 label="DIO",
135 )
136 ax.axhline(y=1, color="grey", linestyle="--", linewidth=1)
137 ax.set_xlabel("# Atom")
138 ax.set_ylabel("RMSD (Å)")
139 plt.title("RMSF Plot")
140 plt.legend()
141 plt.show()