Coverage for biobb_pytorch / test / unitests / test_mdae / test_all_models.py: 97%

123 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1# type: ignore 

2""" 

3Comprehensive test suite for all model types in biobb_pytorch.mdae.models 

4 

5This test file focuses on models that can be instantiated with standard 

6list-based encoder/decoder layer configurations (AutoEncoder, VAE). 

7 

8Models with specialized configurations (GMVAE, SPIB, CNNAutoEncoder) require 

9dictionaries for encoder/decoder layers and have their own specific tests. 

10""" 

11import pytest 

12import torch 

13import tempfile 

14from pathlib import Path 

15from biobb_common.tools import test_fixtures as fx 

16from biobb_pytorch.mdae.build_model import buildModel, BuildModel 

17 

18 

19class TestAllModels: 

20 """Test suite for model architectures with standard configurations.""" 

21 

22 def setup_class(self): 

23 """Setup test fixtures using buildModel configuration.""" 

24 fx.test_setup(self, 'buildModel') 

25 

26 def teardown_class(self): 

27 """Cleanup after tests.""" 

28 fx.test_teardown(self) 

29 

30 @pytest.mark.parametrize( 

31 "model_type,extra_props,expected_attrs", [ 

32 ('AutoEncoder', {}, [ 

33 'encoder', 'decoder', 'norm_in']), ('VariationalAutoEncoder', {}, [ 

34 'encoder', 'decoder', 'mean_nn', 'log_var_nn']), ]) 

35 def test_build_all_model_types( 

36 self, 

37 model_type, 

38 extra_props, 

39 expected_attrs): 

40 """ 

41 Test building all supported model types. 

42 

43 Args: 

44 model_type: Name of the model class to instantiate 

45 extra_props: Additional properties specific to the model type 

46 expected_attrs: Expected attributes that should exist in the model 

47 """ 

48 props = self.properties.copy() 

49 props['model_type'] = model_type 

50 props.update(extra_props) 

51 

52 with tempfile.NamedTemporaryFile(suffix='.pth', delete=False) as tmp: 

53 tmp_path = tmp.name 

54 

55 try: 

56 # Build the model 

57 buildModel( 

58 properties=props, 

59 input_stats_pt_path=self.paths['input_stats_pt_path'], 

60 output_model_pth_path=tmp_path 

61 ) 

62 

63 # Verify file was created 

64 assert Path(tmp_path).exists( 

65 ), f"Model file should exist for {model_type}" 

66 

67 # Load and verify model 

68 model = torch.load(tmp_path, weights_only=False) 

69 

70 # Verify model class name 

71 assert model.__class__.__name__ == model_type, \ 

72 f"Model class should be {model_type}, got {model.__class__.__name__}" 

73 

74 # Verify basic PyTorch model properties 

75 assert hasattr(model, 'state_dict'), \ 

76 f"{model_type} should have state_dict method" 

77 assert hasattr(model, 'forward'), \ 

78 f"{model_type} should have forward method" 

79 assert hasattr(model, '_hparams'), \ 

80 f"{model_type} should have _hparams attribute" 

81 

82 # Verify model-specific attributes 

83 for attr in expected_attrs: 

84 assert hasattr(model, attr), \ 

85 f"{model_type} should have {attr} attribute" 

86 

87 # Verify model can be put in eval mode 

88 model.eval() 

89 

90 finally: 

91 if Path(tmp_path).exists(): 

92 Path(tmp_path).unlink() 

93 

94 @pytest.mark.parametrize("model_type", [ 

95 'AutoEncoder', 

96 'VariationalAutoEncoder', 

97 ]) 

98 def test_model_forward_pass(self, model_type): 

99 """ 

100 Test that each model type can perform a forward pass. 

101 

102 Args: 

103 model_type: Name of the model class to test 

104 """ 

105 props = self.properties.copy() 

106 props['model_type'] = model_type 

107 

108 # Build model without saving 

109 instance = BuildModel( 

110 input_stats_pt_path=self.paths['input_stats_pt_path'], 

111 output_model_pth_path=None, 

112 properties=props 

113 ) 

114 

115 model = instance.model 

116 assert model is not None, f"{model_type} should be instantiated" 

117 

118 # Load stats to get input dimensions 

119 stats = torch.load( 

120 self.paths['input_stats_pt_path'], 

121 weights_only=False) 

122 n_features = stats['shape'][1] 

123 

124 # Create dummy input 

125 batch_size = 4 

126 dummy_input = torch.randn(batch_size, n_features) 

127 

128 # Set model to eval mode 

129 model.eval() 

130 

131 # Perform forward pass 

132 with torch.no_grad(): 

133 try: 

134 output = model(dummy_input) 

135 

136 # Verify output - AutoEncoder and VAE both return dict 

137 if isinstance(output, dict): 

138 assert 'z' in output, \ 

139 f"{model_type} output dict should contain latent representation 'z'" 

140 assert output['z'].shape[0] == batch_size, \ 

141 f"{model_type} latent batch size should match input" 

142 else: 

143 # Some models might return tensor directly 

144 assert output.shape[0] == batch_size, \ 

145 f"{model_type} output batch size should match input" 

146 

147 except Exception as e: 

148 pytest.fail(f"{model_type} forward pass failed: {str(e)}") 

149 

150 @pytest.mark.parametrize( 

151 "model_type,custom_layers", [ 

152 ('AutoEncoder', { 

153 'encoder_layers': [ 

154 32, 16, 8], 'decoder_layers': [ 

155 8, 16, 32]}), ('VariationalAutoEncoder', { 

156 'encoder_layers': [ 

157 32, 16], 'decoder_layers': [ 

158 16, 32]}), ]) 

159 def test_custom_layer_configurations(self, model_type, custom_layers): 

160 """ 

161 Test building models with custom layer configurations. 

162 

163 Args: 

164 model_type: Name of the model class to test 

165 custom_layers: Custom encoder and decoder layer configurations 

166 """ 

167 props = self.properties.copy() 

168 props['model_type'] = model_type 

169 props.update(custom_layers) 

170 

171 with tempfile.NamedTemporaryFile(suffix='.pth', delete=False) as tmp: 

172 tmp_path = tmp.name 

173 

174 try: 

175 buildModel( 

176 properties=props, 

177 input_stats_pt_path=self.paths['input_stats_pt_path'], 

178 output_model_pth_path=tmp_path 

179 ) 

180 

181 assert Path(tmp_path).exists(), \ 

182 f"Model with custom layers should be created for {model_type}" 

183 

184 model = torch.load(tmp_path, weights_only=False) 

185 assert model.__class__.__name__ == model_type 

186 

187 finally: 

188 if Path(tmp_path).exists(): 

189 Path(tmp_path).unlink() 

190 

191 @pytest.mark.parametrize("model_type", [ 

192 'AutoEncoder', 

193 'VariationalAutoEncoder', 

194 ]) 

195 def test_model_with_custom_loss(self, model_type): 

196 """ 

197 Test building models with custom loss functions. 

198 

199 Args: 

200 model_type: Name of the model class to test 

201 """ 

202 props = self.properties.copy() 

203 props['model_type'] = model_type 

204 props['options'] = props.get('options', {}).copy() 

205 props['options']['loss_function'] = { 

206 'loss_type': 'MSELoss' 

207 } 

208 

209 with tempfile.NamedTemporaryFile(suffix='.pth', delete=False) as tmp: 

210 tmp_path = tmp.name 

211 

212 try: 

213 buildModel( 

214 properties=props, 

215 input_stats_pt_path=self.paths['input_stats_pt_path'], 

216 output_model_pth_path=tmp_path 

217 ) 

218 

219 assert Path(tmp_path).exists(), \ 

220 f"Model with custom loss should be created for {model_type}" 

221 

222 model = torch.load(tmp_path, weights_only=False) 

223 assert hasattr(model, '_hparams'), \ 

224 f"{model_type} should contain hparams with loss configuration" 

225 

226 finally: 

227 if Path(tmp_path).exists(): 

228 Path(tmp_path).unlink() 

229 

230 def test_model_state_dict_compatibility(self): 

231 """Test that all models produce compatible state dicts.""" 

232 model_types = [ 

233 'AutoEncoder', 

234 'VariationalAutoEncoder', 

235 ] 

236 

237 for model_type in model_types: 

238 props = self.properties.copy() 

239 props['model_type'] = model_type 

240 

241 instance = BuildModel( 

242 input_stats_pt_path=self.paths['input_stats_pt_path'], 

243 output_model_pth_path=None, 

244 properties=props 

245 ) 

246 

247 model = instance.model 

248 state_dict = model.state_dict() 

249 

250 # Verify state dict contains parameters 

251 assert len(state_dict) > 0, \ 

252 f"{model_type} state_dict should contain parameters" 

253 

254 # Verify all parameters are tensors 

255 for key, value in state_dict.items(): 

256 assert isinstance(value, torch.Tensor), \ 

257 f"All state_dict values should be tensors in {model_type}" 

258 

259 @pytest.mark.parametrize("n_cvs", [1, 2, 5, 10]) 

260 def test_different_latent_dimensions(self, n_cvs): 

261 """ 

262 Test AutoEncoder with different latent space dimensions. 

263 

264 Args: 

265 n_cvs: Number of collective variables (latent dimensions) 

266 """ 

267 props = self.properties.copy() 

268 props['model_type'] = 'AutoEncoder' 

269 props['n_cvs'] = n_cvs 

270 

271 with tempfile.NamedTemporaryFile(suffix='.pth', delete=False) as tmp: 

272 tmp_path = tmp.name 

273 

274 try: 

275 buildModel( 

276 properties=props, 

277 input_stats_pt_path=self.paths['input_stats_pt_path'], 

278 output_model_pth_path=tmp_path 

279 ) 

280 

281 assert Path(tmp_path).exists(), \ 

282 f"Model should be created with n_cvs={n_cvs}" 

283 

284 model = torch.load(tmp_path, weights_only=False) 

285 assert model._hparams.get('n_cvs') == n_cvs, \ 

286 f"Model should have n_cvs={n_cvs} in hparams" 

287 

288 finally: 

289 if Path(tmp_path).exists(): 

290 Path(tmp_path).unlink() 

291 

292 def test_vae_encode_decode(self): 

293 """Test that VariationalAutoEncoder produces proper encode/decode output.""" 

294 props = self.properties.copy() 

295 props['model_type'] = 'VariationalAutoEncoder' 

296 

297 instance = BuildModel( 

298 input_stats_pt_path=self.paths['input_stats_pt_path'], 

299 output_model_pth_path=None, 

300 properties=props 

301 ) 

302 

303 model = instance.model 

304 stats = torch.load( 

305 self.paths['input_stats_pt_path'], 

306 weights_only=False) 

307 n_features = stats['shape'][1] 

308 n_cvs = self.properties.get('n_cvs', 2) 

309 

310 # Create dummy input 

311 batch_size = 4 

312 dummy_input = torch.randn(batch_size, n_features) 

313 

314 model.eval() 

315 with torch.no_grad(): 

316 # Test encode_decode method which is specific to VAE 

317 z, mean, log_var, x_hat = model.encode_decode(dummy_input) 

318 

319 # Verify outputs 

320 assert z.shape == (batch_size, n_cvs), \ 

321 f"Latent z should have shape ({batch_size}, {n_cvs})" 

322 assert mean.shape == (batch_size, n_cvs), \ 

323 f"Mean should have shape ({batch_size}, {n_cvs})" 

324 assert log_var.shape == (batch_size, n_cvs), \ 

325 f"Log variance should have shape ({batch_size}, {n_cvs})" 

326 assert x_hat.shape == (batch_size, n_features), \ 

327 f"Reconstruction should have shape ({batch_size}, {n_features})"