Coverage for biobb_pytorch / mdae / plots.py: 93%

105 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import os 

2from typing import Union, List 

3import matplotlib.pyplot as plt 

4import numpy as np 

5from matplotlib.markers import MarkerStyle 

6 

7 

8def plot_loss(output_train_data_npz_path: str) -> None: 

9 """ 

10 Plot the training and validation losses from the given npz file. 

11 

12 Args: 

13 output_train_data_npz_path (str): The path to the npz file containing the training and validation losses. 

14 """ 

15 npz_file = np.load(output_train_data_npz_path) 

16 train_loss = npz_file["train_loss"] 

17 val_loss = npz_file.get("valid_loss", None) 

18 min_train_loss_idx = np.argmin(train_loss) 

19 min_val_loss_idx = np.argmin(val_loss) if val_loss is not None else None 

20 

21 plt.plot( 

22 range(len(train_loss)), 

23 train_loss, 

24 label=f"Training (min.: {min_train_loss_idx})", 

25 color="blue", 

26 ) 

27 if val_loss is not None: 

28 plt.plot( 

29 range(len(val_loss)), 

30 val_loss, 

31 label=f"Validation (min.: {min_val_loss_idx})", 

32 color="orange", 

33 ) 

34 plt.scatter( 

35 min_train_loss_idx, 

36 train_loss[min_train_loss_idx], 

37 color="blue", 

38 marker=MarkerStyle("v"), 

39 s=50, 

40 ) 

41 if val_loss is not None and min_val_loss_idx is not None: 

42 plt.scatter( 

43 min_val_loss_idx, 

44 val_loss[min_val_loss_idx], 

45 color="orange", 

46 marker=MarkerStyle("v"), 

47 s=50, 

48 ) 

49 plt.legend() 

50 plt.ylabel("Total Loss") 

51 plt.xlabel("Epochs") 

52 plt.title("Training/Validation") 

53 plt.show() 

54 

55 

56def plot_rmsd(input_xvg_path: Union[str, List[str]]) -> None: 

57 """ 

58 Plots RMSD from one or more XVG files. 

59 

60 Parameters: 

61 input_xvg_path (str or list of str): Path to a single XVG file or list of paths to multiple XVG files. 

62 

63 The function parses each XVG file, extracts residue numbers and RMSD values, 

64 and plots them on a single figure for comparison. 

65 """ 

66 if isinstance(input_xvg_path, str): 

67 input_xvg_path = [input_xvg_path] 

68 

69 plt.figure(figsize=(15, 6)) 

70 

71 for file_path in input_xvg_path: 

72 if not os.path.exists(file_path): 

73 print(f"Warning: File {file_path} does not exist. Skipping.") 

74 continue 

75 

76 # Load data from XVG, skipping comment lines 

77 data = np.loadtxt(file_path, comments=['#', '@']) 

78 

79 # Assume column 0: time, column 1: RMSD 

80 time = data[:, 0] 

81 rmsd = data[:, 1] 

82 

83 # Get label from filename 

84 label = os.path.basename(file_path).replace('.xvg', '') 

85 

86 plt.plot(time, rmsd, label=label) 

87 

88 plt.xlabel('time (ns)') 

89 plt.ylabel('RMSD (nm)') 

90 plt.legend() 

91 plt.grid(True) 

92 plt.show() 

93 

94 

95def plot_rmsf(input_xvg_path: Union[str, List[str]]) -> None: 

96 """ 

97 Plots RMSF from one or more XVG files. 

98 

99 Parameters: 

100 input_xvg_path (str or list of str): Path to a single XVG file or list of paths to multiple XVG files. 

101 

102 The function parses each XVG file, extracts residue numbers and RMSF values, 

103 and plots them on a single figure for comparison. 

104 """ 

105 if isinstance(input_xvg_path, str): 

106 input_xvg_path = [input_xvg_path] 

107 

108 plt.figure(figsize=(15, 6)) 

109 

110 for file_path in input_xvg_path: 

111 if not os.path.exists(file_path): 

112 print(f"Warning: File {file_path} does not exist. Skipping.") 

113 continue 

114 

115 # Load data from XVG, skipping comment lines 

116 data = np.loadtxt(file_path, comments=['#', '@']) 

117 

118 # Assume column 0: residue, column 1: RMSF 

119 residues = data[:, 0] 

120 rmsf = data[:, 1] 

121 

122 # Get label from filename 

123 label = os.path.basename(file_path).replace('.xvg', '') 

124 

125 plt.plot(residues, rmsf, label=label) 

126 

127 plt.xlabel('Residue Number') 

128 plt.ylabel('RMSF (nm)') 

129 plt.legend() 

130 plt.grid(True) 

131 plt.show() 

132 

133 

134def plot_rmsf_difference(input_xvg_path: Union[str, List[str]]) -> None: 

135 """ 

136 Plots RMSF from one or more XVG files. 

137 

138 Parameters: 

139 input_xvg_path (str or list of str): Path to a single XVG file or list of paths to multiple XVG files. 

140 

141 The function parses each XVG file, extracts residue numbers and RMSF values, 

142 and plots them on a single figure for comparison. 

143 """ 

144 if isinstance(input_xvg_path, str): 

145 input_xvg_path = [input_xvg_path] 

146 

147 rmsfs = [] 

148 for file_path in input_xvg_path: 

149 if not os.path.exists(file_path): 

150 print(f"Warning: File {file_path} does not exist. Skipping.") 

151 continue 

152 

153 # Load data from XVG, skipping comment lines 

154 data = np.loadtxt(file_path, comments=['#', '@']) 

155 

156 # Assume column 0: residue, column 1: RMSF 

157 residues = data[:, 0] 

158 rmsf = data[:, 1] 

159 

160 # Get label from filename 

161 label = f"DIO: {os.path.basename(file_path).replace('xvg', '')} vs {os.path.basename(file_path).replace('xvg', '')}" 

162 

163 rmsfs.append(rmsf) 

164 

165 diff_rmsf = abs(rmsfs[0] - rmsfs[1]) 

166 

167 plt.figure(figsize=(10, 6)) 

168 plt.plot(residues, diff_rmsf, label=label) 

169 plt.xlabel('Residue Number') 

170 plt.ylabel('RMSF (nm)') 

171 plt.title('RMSF Difference') 

172 plt.legend() 

173 plt.grid(True) 

174 plt.show() 

175 

176 

177def plot_latent_space(results_npz_path: str, 

178 projection_dim: list, 

179 snapshot_freq_ps=10) -> None: 

180 

181 results = np.load(results_npz_path, allow_pickle=True) 

182 

183 if 'z' not in results: 

184 raise KeyError(f"'z' not found in {results_npz_path}") 

185 

186 z = results['z'] 

187 

188 if projection_dim is None: 

189 projection_dim = [0, 1] 

190 

191 if len(projection_dim) != 2: 

192 raise ValueError(f"projection_dim must have length 2, got {projection_dim}") 

193 

194 dim1, dim2 = projection_dim 

195 n_frames = z.shape[0] 

196 n_ticks = int(n_frames / 10) 

197 timestep_ns = 1 / snapshot_freq_ps 

198 

199 plt.figure(figsize=(10, 6)) 

200 plt.scatter(z[:, dim1], z[:, dim2], c=np.arange(n_frames) * timestep_ns, s=10, alpha=1.0) 

201 plt.xlabel(f'latent_dim {dim1}') 

202 plt.ylabel(f'latent_dim {dim2}') 

203 

204 ticks = np.arange(n_frames)[::n_ticks] * timestep_ns 

205 if 0 not in ticks: 

206 ticks = np.insert(ticks, 0, 0) 

207 if (n_frames - 1) * timestep_ns not in ticks: 

208 ticks = np.append(ticks, (n_frames - 1) * timestep_ns) 

209 

210 plt.colorbar(ticks=ticks, label='Time (ns)') 

211 plt.title('Latent Space Visualization') 

212 plt.show()