Coverage for biobb_pytorch / test / unitests / test_mdae / test_train_model.py: 87%

23 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1# type: ignore 

2from biobb_common.tools import test_fixtures as fx 

3from biobb_pytorch.mdae.train_model import trainModel 

4import torch 

5import numpy as np 

6 

7 

8class TestTrainModel: 

9 def setup_class(self): 

10 fx.test_setup(self, 'trainModel') 

11 

12 def teardown_class(self): 

13 fx.test_teardown(self) 

14 

15 def test_trainModel(self): 

16 trainModel(properties=self.properties, **self.paths) 

17 

18 if 'output_model_pth_path' in self.paths: 

19 assert fx.not_empty(self.paths['output_model_pth_path']) 

20 # The model is saved directly as an object, not in a dictionary 

21 model = torch.load(self.paths['output_model_pth_path'], weights_only=False) 

22 assert hasattr(model, 'state_dict'), "Model file should contain a PyTorch model" 

23 

24 if 'output_metrics_npz_path' in self.paths: 

25 assert fx.not_empty(self.paths['output_metrics_npz_path']) 

26 metrics = np.load(self.paths['output_metrics_npz_path'], allow_pickle=True) 

27 assert 'train_loss' in metrics or 'loss' in metrics, "Metrics should contain loss information" 

28 

29 if 'ref_output_metrics_npz_path' in self.paths: 

30 ref_metrics = np.load(self.paths['ref_output_metrics_npz_path'], allow_pickle=True) 

31 # Compare final loss values 

32 if 'train_loss' in metrics and 'train_loss' in ref_metrics: 

33 assert isinstance(metrics['train_loss'], (np.ndarray, float)), "Train loss should be numeric"