Coverage for biobb_pytorch / test / unitests / test_mdae / test_plots.py: 100%
135 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1# type: ignore
2import pytest
3import numpy as np
4import tempfile
5from pathlib import Path
6import matplotlib
7import matplotlib.pyplot as plt
8from biobb_pytorch.mdae.plots import (
9 plot_loss,
10 plot_rmsd,
11 plot_rmsf,
12 plot_rmsf_difference,
13 plot_latent_space
14)
16matplotlib.use('Agg') # Use non-interactive backend for testing
19class TestPlots:
20 def teardown_method(self):
21 """Close all matplotlib figures after each test."""
22 plt.close('all')
24 def test_plot_loss_with_validation(self):
25 """Test plot_loss with training and validation losses."""
26 # Create temporary npz file with training and validation losses
27 with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as tmp:
28 tmp_path = tmp.name
29 train_loss = np.array([1.0, 0.8, 0.6, 0.5, 0.45])
30 valid_loss = np.array([1.1, 0.9, 0.65, 0.55, 0.5])
31 np.savez(tmp_path, train_loss=train_loss, valid_loss=valid_loss)
33 try:
34 # This should not raise an error
35 plot_loss(tmp_path)
36 assert True
37 finally:
38 Path(tmp_path).unlink()
40 def test_plot_loss_without_validation(self):
41 """Test plot_loss with only training loss."""
42 with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as tmp:
43 tmp_path = tmp.name
44 train_loss = np.array([1.0, 0.8, 0.6, 0.5, 0.45])
45 np.savez(tmp_path, train_loss=train_loss)
47 try:
48 plot_loss(tmp_path)
49 assert True
50 finally:
51 Path(tmp_path).unlink()
53 def test_plot_rmsd_single_file(self):
54 """Test plot_rmsd with a single XVG file."""
55 with tempfile.NamedTemporaryFile(suffix='.xvg', delete=False, mode='w') as tmp:
56 tmp_path = tmp.name
57 # Write sample XVG data
58 tmp.write("# Comment line\n")
59 tmp.write("@ xaxis label \"Time (ns)\"\n")
60 tmp.write("0.0 0.1\n")
61 tmp.write("1.0 0.15\n")
62 tmp.write("2.0 0.12\n")
64 try:
65 plot_rmsd(tmp_path)
66 assert True
67 finally:
68 Path(tmp_path).unlink()
70 def test_plot_rmsd_multiple_files(self):
71 """Test plot_rmsd with multiple XVG files."""
72 tmp_paths = []
73 try:
74 for i in range(2):
75 with tempfile.NamedTemporaryFile(suffix=f'_{i}.xvg', delete=False, mode='w') as tmp:
76 tmp_path = tmp.name
77 tmp_paths.append(tmp_path)
78 tmp.write("# Comment line\n")
79 tmp.write("0.0 0.1\n")
80 tmp.write("1.0 0.15\n")
82 plot_rmsd(tmp_paths)
83 assert True
84 finally:
85 for path in tmp_paths:
86 Path(path).unlink()
88 def test_plot_rmsd_nonexistent_file(self, capsys):
89 """Test plot_rmsd with non-existent file."""
90 plot_rmsd("/nonexistent/file.xvg")
91 captured = capsys.readouterr()
92 assert "Warning" in captured.out or "does not exist" in captured.out
94 def test_plot_rmsf_single_file(self):
95 """Test plot_rmsf with a single XVG file."""
96 with tempfile.NamedTemporaryFile(suffix='.xvg', delete=False, mode='w') as tmp:
97 tmp_path = tmp.name
98 tmp.write("# Comment line\n")
99 tmp.write("@ xaxis label \"Residue\"\n")
100 tmp.write("1 0.1\n")
101 tmp.write("2 0.15\n")
102 tmp.write("3 0.12\n")
104 try:
105 plot_rmsf(tmp_path)
106 assert True
107 finally:
108 Path(tmp_path).unlink()
110 def test_plot_rmsf_multiple_files(self):
111 """Test plot_rmsf with multiple XVG files."""
112 tmp_paths = []
113 try:
114 for i in range(2):
115 with tempfile.NamedTemporaryFile(suffix=f'_{i}.xvg', delete=False, mode='w') as tmp:
116 tmp_path = tmp.name
117 tmp_paths.append(tmp_path)
118 tmp.write("# Comment line\n")
119 tmp.write("1 0.1\n")
120 tmp.write("2 0.15\n")
122 plot_rmsf(tmp_paths)
123 assert True
124 finally:
125 for path in tmp_paths:
126 Path(path).unlink()
128 def test_plot_rmsf_difference(self):
129 """Test plot_rmsf_difference with two XVG files."""
130 tmp_paths = []
131 try:
132 for i, values in enumerate([[0.1, 0.15, 0.12], [0.12, 0.14, 0.11]]):
133 with tempfile.NamedTemporaryFile(suffix=f'_{i}.xvg', delete=False, mode='w') as tmp:
134 tmp_path = tmp.name
135 tmp_paths.append(tmp_path)
136 tmp.write("# Comment line\n")
137 for res, val in enumerate(values, 1):
138 tmp.write(f"{res} {val}\n")
140 plot_rmsf_difference(tmp_paths)
141 assert True
142 finally:
143 for path in tmp_paths:
144 Path(path).unlink()
146 def test_plot_latent_space_basic(self):
147 """Test plot_latent_space with basic data."""
148 with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as tmp:
149 tmp_path = tmp.name
150 # Create latent space data
151 z = np.random.randn(100, 5) # 100 frames, 5 dimensions
152 np.savez(tmp_path, z=z)
154 try:
155 plot_latent_space(tmp_path, projection_dim=[0, 1])
156 assert True
157 finally:
158 Path(tmp_path).unlink()
160 def test_plot_latent_space_custom_dims(self):
161 """Test plot_latent_space with custom projection dimensions."""
162 with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as tmp:
163 tmp_path = tmp.name
164 z = np.random.randn(100, 5)
165 np.savez(tmp_path, z=z)
167 try:
168 plot_latent_space(tmp_path, projection_dim=[2, 3], snapshot_freq_ps=20)
169 assert True
170 finally:
171 Path(tmp_path).unlink()
173 def test_plot_latent_space_missing_z(self):
174 """Test plot_latent_space raises KeyError when 'z' is missing."""
175 with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as tmp:
176 tmp_path = tmp.name
177 # Save data without 'z' key
178 np.savez(tmp_path, other_data=np.array([1, 2, 3]))
180 try:
181 with pytest.raises(KeyError, match="'z' not found"):
182 plot_latent_space(tmp_path, projection_dim=[0, 1])
183 finally:
184 Path(tmp_path).unlink()
186 def test_plot_latent_space_invalid_projection_dim(self):
187 """Test plot_latent_space raises ValueError for invalid projection_dim."""
188 with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as tmp:
189 tmp_path = tmp.name
190 z = np.random.randn(100, 5)
191 np.savez(tmp_path, z=z)
193 try:
194 with pytest.raises(ValueError, match="projection_dim must have length 2"):
195 plot_latent_space(tmp_path, projection_dim=[0, 1, 2])
196 finally:
197 Path(tmp_path).unlink()