Coverage for biobb_pytorch/mdae/plots.py: 0%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-21 09:06 +0000

1import matplotlib.gridspec as gridspec # type: ignore 

2import matplotlib.pyplot as plt # type: ignore 

3import numpy as np 

4from matplotlib.markers import MarkerStyle # type: ignore 

5 

6 

7def plot_loss(output_train_data_npz_path: str) -> None: 

8 """ 

9 Plot the training and validation losses from the given npz file. 

10 

11 Args: 

12 output_train_data_npz_path (str): The path to the npz file containing the training and validation losses. 

13 """ 

14 npz_file = np.load(output_train_data_npz_path) 

15 train_loss = npz_file["train_losses"] 

16 val_loss = npz_file["validation_losses"] 

17 min_train_loss_idx = np.argmin(train_loss) 

18 min_val_loss_idx = np.argmin(val_loss) 

19 plt.plot( 

20 range(len(train_loss)), 

21 train_loss, 

22 label=f"Training (min.: {min_train_loss_idx})", 

23 color="blue", 

24 ) 

25 plt.plot( 

26 range(len(val_loss)), 

27 val_loss, 

28 label=f"Validation (min.: {min_val_loss_idx})", 

29 color="orange", 

30 ) 

31 plt.scatter( 

32 min_train_loss_idx, 

33 train_loss[min_train_loss_idx], 

34 color="blue", 

35 marker=MarkerStyle("v"), 

36 s=50, 

37 ) 

38 plt.scatter( 

39 min_val_loss_idx, 

40 val_loss[min_val_loss_idx], 

41 color="orange", 

42 marker=MarkerStyle("v"), 

43 s=50, 

44 ) 

45 plt.legend() 

46 plt.ylabel("Total Loss") 

47 plt.xlabel("Epochs") 

48 plt.title("Training/Validation") 

49 plt.show() 

50 

51 

52def _numpy_rmsd(reference, trajectory): 

53 return np.sqrt(np.mean(np.sum((reference - trajectory) ** 2, axis=2), axis=1)) 

54 

55 

56def plot_rmsd(traj_file_npy_path, output_reconstructed_traj_npy_path) -> None: 

57 perf_data = np.load(traj_file_npy_path) 

58 output = np.load(output_reconstructed_traj_npy_path) 

59 rmsd_trajectory = _numpy_rmsd(perf_data[0], perf_data) * 10 # Convert to Å 

60 rmsd_output = _numpy_rmsd(perf_data[0], output) * 10 # Convert to Å 

61 frames = np.arange(len(rmsd_trajectory)) 

62 fig, ax = plt.subplots(figsize=(20, 6)) 

63 ax.plot(frames, rmsd_trajectory, color="blue", linewidth=1, label="Original") 

64 ax.plot(frames, rmsd_output, color="red", linewidth=1, label="Reconstruction") 

65 # Labels, title, and legend 

66 ax.set_xlabel("# Frame") 

67 ax.set_ylabel("RMSD (Å)") 

68 plt.title("RMSD Plot") 

69 plt.legend() 

70 plt.show() 

71 

72 

73def plot_latent_space(latent_space_npy_path: str) -> None: 

74 z = np.load(latent_space_npy_path) 

75 gs = gridspec.GridSpec(4, 4) 

76 fig = plt.figure(figsize=(15, 10)) 

77 ax_main = plt.subplot(gs[1:4, :3]) 

78 ax_xDist = plt.subplot(gs[0, :3], sharex=ax_main) 

79 ax_yDist = plt.subplot(gs[1:4, 3], sharey=ax_main) 

80 sc = ax_main.scatter( 

81 z[::1, 0], z[::1, 1], c=np.arange(len(z)), alpha=1, cmap="jet", s=2 

82 ) 

83 # Position and size of colorbar based on ax_yDist 

84 pos = ax_yDist.get_position() 

85 cbar_ax = fig.add_axes((pos.x1 + 0.01, pos.y0, 0.02, pos.height)) 

86 cbar = plt.colorbar(sc, cax=cbar_ax) 

87 cbar.set_label("Frames") 

88 # X-axis marginal distribution 

89 ax_xDist.hist(z[::1, 0], bins=100, color="blue", alpha=0.7) 

90 # Y-axis marginal distribution 

91 ax_yDist.hist( 

92 z[::1, 1], bins=100, color="blue", alpha=0.7, orientation="horizontal" 

93 ) 

94 ax_main.set_xlabel("z0", labelpad=20) 

95 ax_main.set_ylabel("z1", labelpad=20) 

96 plt.show() 

97 

98 

99def _numpy_rmsf_by_atom(trajectory): 

100 return np.sqrt( 

101 np.mean(np.sum((trajectory - np.mean(trajectory, axis=0)) ** 2, axis=2), axis=0) 

102 ) 

103 

104 

105def plot_rmsf(original_traj_npy_file, mutated_reconstructed_traj_npy_file): 

106 original_traj = np.load(original_traj_npy_file) 

107 mutated_reconstructed_traj = np.load(mutated_reconstructed_traj_npy_file) 

108 rmsf_trajectory = _numpy_rmsf_by_atom(original_traj) * 10 # Convert to Å 

109 rmsf_output = _numpy_rmsf_by_atom(mutated_reconstructed_traj) * 10 # Convert to Å 

110 fig, ax = plt.subplots(figsize=(20, 6)) 

111 indices = np.arange(len(rmsf_trajectory)) 

112 ax.plot(indices, rmsf_trajectory, color="blue", linewidth=1, label="Original") 

113 ax.plot(indices, rmsf_output, color="red", linewidth=1, label="Reconstruction") 

114 ax.set_xlabel("# Atom") 

115 ax.set_ylabel("RMSD (Å) Average structure as reference") 

116 plt.title("RMSF Plot") 

117 plt.legend() 

118 plt.show() 

119 

120 

121def plot_rmsf_difference(original_traj_npy_file, mutated_reconstructed_traj_npy_file): 

122 original_traj = np.load(original_traj_npy_file) 

123 mutated_reconstructed_traj = np.load(mutated_reconstructed_traj_npy_file) 

124 rmsf_trajectory = _numpy_rmsf_by_atom(original_traj) * 10 # Convert to Å 

125 rmsf_output = _numpy_rmsf_by_atom(mutated_reconstructed_traj) * 10 # Convert to Å 

126 fig, ax = plt.subplots(figsize=(20, 6)) 

127 indices = np.arange(len(rmsf_trajectory)) 

128 # Plot RMSF for diference between input and output 

129 ax.plot( 

130 indices, 

131 (rmsf_trajectory - rmsf_output), 

132 color="orange", 

133 linewidth=1, 

134 label="DIO", 

135 ) 

136 ax.axhline(y=1, color="grey", linestyle="--", linewidth=1) 

137 ax.set_xlabel("# Atom") 

138 ax.set_ylabel("RMSD (Å)") 

139 plt.title("RMSF Plot") 

140 plt.legend() 

141 plt.show()