Coverage for biobb_pytorch / test / unitests / test_mdae / test_all_models.py: 97%
123 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1# type: ignore
2"""
3Comprehensive test suite for all model types in biobb_pytorch.mdae.models
5This test file focuses on models that can be instantiated with standard
6list-based encoder/decoder layer configurations (AutoEncoder, VAE).
8Models with specialized configurations (GMVAE, SPIB, CNNAutoEncoder) require
9dictionaries for encoder/decoder layers and have their own specific tests.
10"""
11import pytest
12import torch
13import tempfile
14from pathlib import Path
15from biobb_common.tools import test_fixtures as fx
16from biobb_pytorch.mdae.build_model import buildModel, BuildModel
19class TestAllModels:
20 """Test suite for model architectures with standard configurations."""
22 def setup_class(self):
23 """Setup test fixtures using buildModel configuration."""
24 fx.test_setup(self, 'buildModel')
26 def teardown_class(self):
27 """Cleanup after tests."""
28 fx.test_teardown(self)
30 @pytest.mark.parametrize(
31 "model_type,extra_props,expected_attrs", [
32 ('AutoEncoder', {}, [
33 'encoder', 'decoder', 'norm_in']), ('VariationalAutoEncoder', {}, [
34 'encoder', 'decoder', 'mean_nn', 'log_var_nn']), ])
35 def test_build_all_model_types(
36 self,
37 model_type,
38 extra_props,
39 expected_attrs):
40 """
41 Test building all supported model types.
43 Args:
44 model_type: Name of the model class to instantiate
45 extra_props: Additional properties specific to the model type
46 expected_attrs: Expected attributes that should exist in the model
47 """
48 props = self.properties.copy()
49 props['model_type'] = model_type
50 props.update(extra_props)
52 with tempfile.NamedTemporaryFile(suffix='.pth', delete=False) as tmp:
53 tmp_path = tmp.name
55 try:
56 # Build the model
57 buildModel(
58 properties=props,
59 input_stats_pt_path=self.paths['input_stats_pt_path'],
60 output_model_pth_path=tmp_path
61 )
63 # Verify file was created
64 assert Path(tmp_path).exists(
65 ), f"Model file should exist for {model_type}"
67 # Load and verify model
68 model = torch.load(tmp_path, weights_only=False)
70 # Verify model class name
71 assert model.__class__.__name__ == model_type, \
72 f"Model class should be {model_type}, got {model.__class__.__name__}"
74 # Verify basic PyTorch model properties
75 assert hasattr(model, 'state_dict'), \
76 f"{model_type} should have state_dict method"
77 assert hasattr(model, 'forward'), \
78 f"{model_type} should have forward method"
79 assert hasattr(model, '_hparams'), \
80 f"{model_type} should have _hparams attribute"
82 # Verify model-specific attributes
83 for attr in expected_attrs:
84 assert hasattr(model, attr), \
85 f"{model_type} should have {attr} attribute"
87 # Verify model can be put in eval mode
88 model.eval()
90 finally:
91 if Path(tmp_path).exists():
92 Path(tmp_path).unlink()
94 @pytest.mark.parametrize("model_type", [
95 'AutoEncoder',
96 'VariationalAutoEncoder',
97 ])
98 def test_model_forward_pass(self, model_type):
99 """
100 Test that each model type can perform a forward pass.
102 Args:
103 model_type: Name of the model class to test
104 """
105 props = self.properties.copy()
106 props['model_type'] = model_type
108 # Build model without saving
109 instance = BuildModel(
110 input_stats_pt_path=self.paths['input_stats_pt_path'],
111 output_model_pth_path=None,
112 properties=props
113 )
115 model = instance.model
116 assert model is not None, f"{model_type} should be instantiated"
118 # Load stats to get input dimensions
119 stats = torch.load(
120 self.paths['input_stats_pt_path'],
121 weights_only=False)
122 n_features = stats['shape'][1]
124 # Create dummy input
125 batch_size = 4
126 dummy_input = torch.randn(batch_size, n_features)
128 # Set model to eval mode
129 model.eval()
131 # Perform forward pass
132 with torch.no_grad():
133 try:
134 output = model(dummy_input)
136 # Verify output - AutoEncoder and VAE both return dict
137 if isinstance(output, dict):
138 assert 'z' in output, \
139 f"{model_type} output dict should contain latent representation 'z'"
140 assert output['z'].shape[0] == batch_size, \
141 f"{model_type} latent batch size should match input"
142 else:
143 # Some models might return tensor directly
144 assert output.shape[0] == batch_size, \
145 f"{model_type} output batch size should match input"
147 except Exception as e:
148 pytest.fail(f"{model_type} forward pass failed: {str(e)}")
150 @pytest.mark.parametrize(
151 "model_type,custom_layers", [
152 ('AutoEncoder', {
153 'encoder_layers': [
154 32, 16, 8], 'decoder_layers': [
155 8, 16, 32]}), ('VariationalAutoEncoder', {
156 'encoder_layers': [
157 32, 16], 'decoder_layers': [
158 16, 32]}), ])
159 def test_custom_layer_configurations(self, model_type, custom_layers):
160 """
161 Test building models with custom layer configurations.
163 Args:
164 model_type: Name of the model class to test
165 custom_layers: Custom encoder and decoder layer configurations
166 """
167 props = self.properties.copy()
168 props['model_type'] = model_type
169 props.update(custom_layers)
171 with tempfile.NamedTemporaryFile(suffix='.pth', delete=False) as tmp:
172 tmp_path = tmp.name
174 try:
175 buildModel(
176 properties=props,
177 input_stats_pt_path=self.paths['input_stats_pt_path'],
178 output_model_pth_path=tmp_path
179 )
181 assert Path(tmp_path).exists(), \
182 f"Model with custom layers should be created for {model_type}"
184 model = torch.load(tmp_path, weights_only=False)
185 assert model.__class__.__name__ == model_type
187 finally:
188 if Path(tmp_path).exists():
189 Path(tmp_path).unlink()
191 @pytest.mark.parametrize("model_type", [
192 'AutoEncoder',
193 'VariationalAutoEncoder',
194 ])
195 def test_model_with_custom_loss(self, model_type):
196 """
197 Test building models with custom loss functions.
199 Args:
200 model_type: Name of the model class to test
201 """
202 props = self.properties.copy()
203 props['model_type'] = model_type
204 props['options'] = props.get('options', {}).copy()
205 props['options']['loss_function'] = {
206 'loss_type': 'MSELoss'
207 }
209 with tempfile.NamedTemporaryFile(suffix='.pth', delete=False) as tmp:
210 tmp_path = tmp.name
212 try:
213 buildModel(
214 properties=props,
215 input_stats_pt_path=self.paths['input_stats_pt_path'],
216 output_model_pth_path=tmp_path
217 )
219 assert Path(tmp_path).exists(), \
220 f"Model with custom loss should be created for {model_type}"
222 model = torch.load(tmp_path, weights_only=False)
223 assert hasattr(model, '_hparams'), \
224 f"{model_type} should contain hparams with loss configuration"
226 finally:
227 if Path(tmp_path).exists():
228 Path(tmp_path).unlink()
230 def test_model_state_dict_compatibility(self):
231 """Test that all models produce compatible state dicts."""
232 model_types = [
233 'AutoEncoder',
234 'VariationalAutoEncoder',
235 ]
237 for model_type in model_types:
238 props = self.properties.copy()
239 props['model_type'] = model_type
241 instance = BuildModel(
242 input_stats_pt_path=self.paths['input_stats_pt_path'],
243 output_model_pth_path=None,
244 properties=props
245 )
247 model = instance.model
248 state_dict = model.state_dict()
250 # Verify state dict contains parameters
251 assert len(state_dict) > 0, \
252 f"{model_type} state_dict should contain parameters"
254 # Verify all parameters are tensors
255 for key, value in state_dict.items():
256 assert isinstance(value, torch.Tensor), \
257 f"All state_dict values should be tensors in {model_type}"
259 @pytest.mark.parametrize("n_cvs", [1, 2, 5, 10])
260 def test_different_latent_dimensions(self, n_cvs):
261 """
262 Test AutoEncoder with different latent space dimensions.
264 Args:
265 n_cvs: Number of collective variables (latent dimensions)
266 """
267 props = self.properties.copy()
268 props['model_type'] = 'AutoEncoder'
269 props['n_cvs'] = n_cvs
271 with tempfile.NamedTemporaryFile(suffix='.pth', delete=False) as tmp:
272 tmp_path = tmp.name
274 try:
275 buildModel(
276 properties=props,
277 input_stats_pt_path=self.paths['input_stats_pt_path'],
278 output_model_pth_path=tmp_path
279 )
281 assert Path(tmp_path).exists(), \
282 f"Model should be created with n_cvs={n_cvs}"
284 model = torch.load(tmp_path, weights_only=False)
285 assert model._hparams.get('n_cvs') == n_cvs, \
286 f"Model should have n_cvs={n_cvs} in hparams"
288 finally:
289 if Path(tmp_path).exists():
290 Path(tmp_path).unlink()
292 def test_vae_encode_decode(self):
293 """Test that VariationalAutoEncoder produces proper encode/decode output."""
294 props = self.properties.copy()
295 props['model_type'] = 'VariationalAutoEncoder'
297 instance = BuildModel(
298 input_stats_pt_path=self.paths['input_stats_pt_path'],
299 output_model_pth_path=None,
300 properties=props
301 )
303 model = instance.model
304 stats = torch.load(
305 self.paths['input_stats_pt_path'],
306 weights_only=False)
307 n_features = stats['shape'][1]
308 n_cvs = self.properties.get('n_cvs', 2)
310 # Create dummy input
311 batch_size = 4
312 dummy_input = torch.randn(batch_size, n_features)
314 model.eval()
315 with torch.no_grad():
316 # Test encode_decode method which is specific to VAE
317 z, mean, log_var, x_hat = model.encode_decode(dummy_input)
319 # Verify outputs
320 assert z.shape == (batch_size, n_cvs), \
321 f"Latent z should have shape ({batch_size}, {n_cvs})"
322 assert mean.shape == (batch_size, n_cvs), \
323 f"Mean should have shape ({batch_size}, {n_cvs})"
324 assert log_var.shape == (batch_size, n_cvs), \
325 f"Log variance should have shape ({batch_size}, {n_cvs})"
326 assert x_hat.shape == (batch_size, n_features), \
327 f"Reconstruction should have shape ({batch_size}, {n_features})"