Coverage for biobb_pytorch / mdae / plots.py: 93%
105 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import os
2from typing import Union, List
3import matplotlib.pyplot as plt
4import numpy as np
5from matplotlib.markers import MarkerStyle
8def plot_loss(output_train_data_npz_path: str) -> None:
9 """
10 Plot the training and validation losses from the given npz file.
12 Args:
13 output_train_data_npz_path (str): The path to the npz file containing the training and validation losses.
14 """
15 npz_file = np.load(output_train_data_npz_path)
16 train_loss = npz_file["train_loss"]
17 val_loss = npz_file.get("valid_loss", None)
18 min_train_loss_idx = np.argmin(train_loss)
19 min_val_loss_idx = np.argmin(val_loss) if val_loss is not None else None
21 plt.plot(
22 range(len(train_loss)),
23 train_loss,
24 label=f"Training (min.: {min_train_loss_idx})",
25 color="blue",
26 )
27 if val_loss is not None:
28 plt.plot(
29 range(len(val_loss)),
30 val_loss,
31 label=f"Validation (min.: {min_val_loss_idx})",
32 color="orange",
33 )
34 plt.scatter(
35 min_train_loss_idx,
36 train_loss[min_train_loss_idx],
37 color="blue",
38 marker=MarkerStyle("v"),
39 s=50,
40 )
41 if val_loss is not None and min_val_loss_idx is not None:
42 plt.scatter(
43 min_val_loss_idx,
44 val_loss[min_val_loss_idx],
45 color="orange",
46 marker=MarkerStyle("v"),
47 s=50,
48 )
49 plt.legend()
50 plt.ylabel("Total Loss")
51 plt.xlabel("Epochs")
52 plt.title("Training/Validation")
53 plt.show()
56def plot_rmsd(input_xvg_path: Union[str, List[str]]) -> None:
57 """
58 Plots RMSD from one or more XVG files.
60 Parameters:
61 input_xvg_path (str or list of str): Path to a single XVG file or list of paths to multiple XVG files.
63 The function parses each XVG file, extracts residue numbers and RMSD values,
64 and plots them on a single figure for comparison.
65 """
66 if isinstance(input_xvg_path, str):
67 input_xvg_path = [input_xvg_path]
69 plt.figure(figsize=(15, 6))
71 for file_path in input_xvg_path:
72 if not os.path.exists(file_path):
73 print(f"Warning: File {file_path} does not exist. Skipping.")
74 continue
76 # Load data from XVG, skipping comment lines
77 data = np.loadtxt(file_path, comments=['#', '@'])
79 # Assume column 0: time, column 1: RMSD
80 time = data[:, 0]
81 rmsd = data[:, 1]
83 # Get label from filename
84 label = os.path.basename(file_path).replace('.xvg', '')
86 plt.plot(time, rmsd, label=label)
88 plt.xlabel('time (ns)')
89 plt.ylabel('RMSD (nm)')
90 plt.legend()
91 plt.grid(True)
92 plt.show()
95def plot_rmsf(input_xvg_path: Union[str, List[str]]) -> None:
96 """
97 Plots RMSF from one or more XVG files.
99 Parameters:
100 input_xvg_path (str or list of str): Path to a single XVG file or list of paths to multiple XVG files.
102 The function parses each XVG file, extracts residue numbers and RMSF values,
103 and plots them on a single figure for comparison.
104 """
105 if isinstance(input_xvg_path, str):
106 input_xvg_path = [input_xvg_path]
108 plt.figure(figsize=(15, 6))
110 for file_path in input_xvg_path:
111 if not os.path.exists(file_path):
112 print(f"Warning: File {file_path} does not exist. Skipping.")
113 continue
115 # Load data from XVG, skipping comment lines
116 data = np.loadtxt(file_path, comments=['#', '@'])
118 # Assume column 0: residue, column 1: RMSF
119 residues = data[:, 0]
120 rmsf = data[:, 1]
122 # Get label from filename
123 label = os.path.basename(file_path).replace('.xvg', '')
125 plt.plot(residues, rmsf, label=label)
127 plt.xlabel('Residue Number')
128 plt.ylabel('RMSF (nm)')
129 plt.legend()
130 plt.grid(True)
131 plt.show()
134def plot_rmsf_difference(input_xvg_path: Union[str, List[str]]) -> None:
135 """
136 Plots RMSF from one or more XVG files.
138 Parameters:
139 input_xvg_path (str or list of str): Path to a single XVG file or list of paths to multiple XVG files.
141 The function parses each XVG file, extracts residue numbers and RMSF values,
142 and plots them on a single figure for comparison.
143 """
144 if isinstance(input_xvg_path, str):
145 input_xvg_path = [input_xvg_path]
147 rmsfs = []
148 for file_path in input_xvg_path:
149 if not os.path.exists(file_path):
150 print(f"Warning: File {file_path} does not exist. Skipping.")
151 continue
153 # Load data from XVG, skipping comment lines
154 data = np.loadtxt(file_path, comments=['#', '@'])
156 # Assume column 0: residue, column 1: RMSF
157 residues = data[:, 0]
158 rmsf = data[:, 1]
160 # Get label from filename
161 label = f"DIO: {os.path.basename(file_path).replace('xvg', '')} vs {os.path.basename(file_path).replace('xvg', '')}"
163 rmsfs.append(rmsf)
165 diff_rmsf = abs(rmsfs[0] - rmsfs[1])
167 plt.figure(figsize=(10, 6))
168 plt.plot(residues, diff_rmsf, label=label)
169 plt.xlabel('Residue Number')
170 plt.ylabel('RMSF (nm)')
171 plt.title('RMSF Difference')
172 plt.legend()
173 plt.grid(True)
174 plt.show()
177def plot_latent_space(results_npz_path: str,
178 projection_dim: list,
179 snapshot_freq_ps=10) -> None:
181 results = np.load(results_npz_path, allow_pickle=True)
183 if 'z' not in results:
184 raise KeyError(f"'z' not found in {results_npz_path}")
186 z = results['z']
188 if projection_dim is None:
189 projection_dim = [0, 1]
191 if len(projection_dim) != 2:
192 raise ValueError(f"projection_dim must have length 2, got {projection_dim}")
194 dim1, dim2 = projection_dim
195 n_frames = z.shape[0]
196 n_ticks = int(n_frames / 10)
197 timestep_ns = 1 / snapshot_freq_ps
199 plt.figure(figsize=(10, 6))
200 plt.scatter(z[:, dim1], z[:, dim2], c=np.arange(n_frames) * timestep_ns, s=10, alpha=1.0)
201 plt.xlabel(f'latent_dim {dim1}')
202 plt.ylabel(f'latent_dim {dim2}')
204 ticks = np.arange(n_frames)[::n_ticks] * timestep_ns
205 if 0 not in ticks:
206 ticks = np.insert(ticks, 0, 0)
207 if (n_frames - 1) * timestep_ns not in ticks:
208 ticks = np.append(ticks, (n_frames - 1) * timestep_ns)
210 plt.colorbar(ticks=ticks, label='Time (ns)')
211 plt.title('Latent Space Visualization')
212 plt.show()