Coverage for biobb_pytorch / test / unitests / test_mdae / test_plots.py: 100%

135 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1# type: ignore 

2import pytest 

3import numpy as np 

4import tempfile 

5from pathlib import Path 

6import matplotlib 

7import matplotlib.pyplot as plt 

8from biobb_pytorch.mdae.plots import ( 

9 plot_loss, 

10 plot_rmsd, 

11 plot_rmsf, 

12 plot_rmsf_difference, 

13 plot_latent_space 

14) 

15 

16matplotlib.use('Agg') # Use non-interactive backend for testing 

17 

18 

19class TestPlots: 

20 def teardown_method(self): 

21 """Close all matplotlib figures after each test.""" 

22 plt.close('all') 

23 

24 def test_plot_loss_with_validation(self): 

25 """Test plot_loss with training and validation losses.""" 

26 # Create temporary npz file with training and validation losses 

27 with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as tmp: 

28 tmp_path = tmp.name 

29 train_loss = np.array([1.0, 0.8, 0.6, 0.5, 0.45]) 

30 valid_loss = np.array([1.1, 0.9, 0.65, 0.55, 0.5]) 

31 np.savez(tmp_path, train_loss=train_loss, valid_loss=valid_loss) 

32 

33 try: 

34 # This should not raise an error 

35 plot_loss(tmp_path) 

36 assert True 

37 finally: 

38 Path(tmp_path).unlink() 

39 

40 def test_plot_loss_without_validation(self): 

41 """Test plot_loss with only training loss.""" 

42 with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as tmp: 

43 tmp_path = tmp.name 

44 train_loss = np.array([1.0, 0.8, 0.6, 0.5, 0.45]) 

45 np.savez(tmp_path, train_loss=train_loss) 

46 

47 try: 

48 plot_loss(tmp_path) 

49 assert True 

50 finally: 

51 Path(tmp_path).unlink() 

52 

53 def test_plot_rmsd_single_file(self): 

54 """Test plot_rmsd with a single XVG file.""" 

55 with tempfile.NamedTemporaryFile(suffix='.xvg', delete=False, mode='w') as tmp: 

56 tmp_path = tmp.name 

57 # Write sample XVG data 

58 tmp.write("# Comment line\n") 

59 tmp.write("@ xaxis label \"Time (ns)\"\n") 

60 tmp.write("0.0 0.1\n") 

61 tmp.write("1.0 0.15\n") 

62 tmp.write("2.0 0.12\n") 

63 

64 try: 

65 plot_rmsd(tmp_path) 

66 assert True 

67 finally: 

68 Path(tmp_path).unlink() 

69 

70 def test_plot_rmsd_multiple_files(self): 

71 """Test plot_rmsd with multiple XVG files.""" 

72 tmp_paths = [] 

73 try: 

74 for i in range(2): 

75 with tempfile.NamedTemporaryFile(suffix=f'_{i}.xvg', delete=False, mode='w') as tmp: 

76 tmp_path = tmp.name 

77 tmp_paths.append(tmp_path) 

78 tmp.write("# Comment line\n") 

79 tmp.write("0.0 0.1\n") 

80 tmp.write("1.0 0.15\n") 

81 

82 plot_rmsd(tmp_paths) 

83 assert True 

84 finally: 

85 for path in tmp_paths: 

86 Path(path).unlink() 

87 

88 def test_plot_rmsd_nonexistent_file(self, capsys): 

89 """Test plot_rmsd with non-existent file.""" 

90 plot_rmsd("/nonexistent/file.xvg") 

91 captured = capsys.readouterr() 

92 assert "Warning" in captured.out or "does not exist" in captured.out 

93 

94 def test_plot_rmsf_single_file(self): 

95 """Test plot_rmsf with a single XVG file.""" 

96 with tempfile.NamedTemporaryFile(suffix='.xvg', delete=False, mode='w') as tmp: 

97 tmp_path = tmp.name 

98 tmp.write("# Comment line\n") 

99 tmp.write("@ xaxis label \"Residue\"\n") 

100 tmp.write("1 0.1\n") 

101 tmp.write("2 0.15\n") 

102 tmp.write("3 0.12\n") 

103 

104 try: 

105 plot_rmsf(tmp_path) 

106 assert True 

107 finally: 

108 Path(tmp_path).unlink() 

109 

110 def test_plot_rmsf_multiple_files(self): 

111 """Test plot_rmsf with multiple XVG files.""" 

112 tmp_paths = [] 

113 try: 

114 for i in range(2): 

115 with tempfile.NamedTemporaryFile(suffix=f'_{i}.xvg', delete=False, mode='w') as tmp: 

116 tmp_path = tmp.name 

117 tmp_paths.append(tmp_path) 

118 tmp.write("# Comment line\n") 

119 tmp.write("1 0.1\n") 

120 tmp.write("2 0.15\n") 

121 

122 plot_rmsf(tmp_paths) 

123 assert True 

124 finally: 

125 for path in tmp_paths: 

126 Path(path).unlink() 

127 

128 def test_plot_rmsf_difference(self): 

129 """Test plot_rmsf_difference with two XVG files.""" 

130 tmp_paths = [] 

131 try: 

132 for i, values in enumerate([[0.1, 0.15, 0.12], [0.12, 0.14, 0.11]]): 

133 with tempfile.NamedTemporaryFile(suffix=f'_{i}.xvg', delete=False, mode='w') as tmp: 

134 tmp_path = tmp.name 

135 tmp_paths.append(tmp_path) 

136 tmp.write("# Comment line\n") 

137 for res, val in enumerate(values, 1): 

138 tmp.write(f"{res} {val}\n") 

139 

140 plot_rmsf_difference(tmp_paths) 

141 assert True 

142 finally: 

143 for path in tmp_paths: 

144 Path(path).unlink() 

145 

146 def test_plot_latent_space_basic(self): 

147 """Test plot_latent_space with basic data.""" 

148 with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as tmp: 

149 tmp_path = tmp.name 

150 # Create latent space data 

151 z = np.random.randn(100, 5) # 100 frames, 5 dimensions 

152 np.savez(tmp_path, z=z) 

153 

154 try: 

155 plot_latent_space(tmp_path, projection_dim=[0, 1]) 

156 assert True 

157 finally: 

158 Path(tmp_path).unlink() 

159 

160 def test_plot_latent_space_custom_dims(self): 

161 """Test plot_latent_space with custom projection dimensions.""" 

162 with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as tmp: 

163 tmp_path = tmp.name 

164 z = np.random.randn(100, 5) 

165 np.savez(tmp_path, z=z) 

166 

167 try: 

168 plot_latent_space(tmp_path, projection_dim=[2, 3], snapshot_freq_ps=20) 

169 assert True 

170 finally: 

171 Path(tmp_path).unlink() 

172 

173 def test_plot_latent_space_missing_z(self): 

174 """Test plot_latent_space raises KeyError when 'z' is missing.""" 

175 with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as tmp: 

176 tmp_path = tmp.name 

177 # Save data without 'z' key 

178 np.savez(tmp_path, other_data=np.array([1, 2, 3])) 

179 

180 try: 

181 with pytest.raises(KeyError, match="'z' not found"): 

182 plot_latent_space(tmp_path, projection_dim=[0, 1]) 

183 finally: 

184 Path(tmp_path).unlink() 

185 

186 def test_plot_latent_space_invalid_projection_dim(self): 

187 """Test plot_latent_space raises ValueError for invalid projection_dim.""" 

188 with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as tmp: 

189 tmp_path = tmp.name 

190 z = np.random.randn(100, 5) 

191 np.savez(tmp_path, z=z) 

192 

193 try: 

194 with pytest.raises(ValueError, match="projection_dim must have length 2"): 

195 plot_latent_space(tmp_path, projection_dim=[0, 1, 2]) 

196 finally: 

197 Path(tmp_path).unlink()