Coverage for biobb_pytorch / test / unitests / test_mdae / test_build_model.py: 98%

45 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1# type: ignore 

2import pytest 

3from biobb_common.tools import test_fixtures as fx 

4from biobb_pytorch.mdae.build_model import buildModel, BuildModel 

5from biobb_pytorch.mdae.utils.model_utils import assert_valid_kwargs 

6import torch 

7import tempfile 

8from pathlib import Path 

9 

10 

11class TestBuildModel: 

12 def setup_class(self): 

13 fx.test_setup(self, 'buildModel') 

14 

15 def teardown_class(self): 

16 fx.test_teardown(self) 

17 

18 def test_build_model(self): 

19 buildModel(properties=self.properties, **self.paths) 

20 assert fx.not_empty(self.paths['output_model_pth_path']) 

21 

22 # Load and verify model structure 

23 # The model is saved directly as an object, not in a dictionary 

24 model = torch.load(self.paths['output_model_pth_path'], weights_only=False) 

25 # Verify it's a PyTorch model (has state_dict method) 

26 assert hasattr(model, 'state_dict'), "Model file should contain a PyTorch model" 

27 assert hasattr(model, '_hparams'), "Model file should contain _hparams attribute" 

28 

29 def test_build_model_with_custom_loss(self): 

30 """Test building model with custom loss function.""" 

31 props = self.properties.copy() 

32 props['options'] = props.get('options', {}).copy() 

33 props['options']['loss_function'] = { 

34 'loss_type': 'MSELoss' 

35 } 

36 

37 with tempfile.NamedTemporaryFile(suffix='.pth', delete=False) as tmp: 

38 tmp_path = tmp.name 

39 

40 try: 

41 buildModel(properties=props, 

42 input_stats_pt_path=self.paths['input_stats_pt_path'], 

43 output_model_pth_path=tmp_path) 

44 assert Path(tmp_path).exists() 

45 finally: 

46 if Path(tmp_path).exists(): 

47 Path(tmp_path).unlink() 

48 

49 def test_build_model_no_output(self): 

50 """Test building model without saving.""" 

51 instance = BuildModel( 

52 input_stats_pt_path=self.paths['input_stats_pt_path'], 

53 output_model_pth_path=None, 

54 properties=self.properties 

55 ) 

56 assert instance.model is not None 

57 assert hasattr(instance.model, 'forward') 

58 

59 def test_assert_valid_kwargs(self): 

60 """Test assert_valid_kwargs utility function.""" 

61 class DummyClass: 

62 def __init__(self, a, b, c=None): 

63 pass 

64 

65 # Valid kwargs should not raise 

66 assert_valid_kwargs(DummyClass, {'a': 1, 'b': 2}, context="test") 

67 assert_valid_kwargs(DummyClass, {'a': 1, 'b': 2, 'c': 3}, context="test") 

68 

69 # Invalid kwargs should raise 

70 with pytest.raises(AssertionError): 

71 assert_valid_kwargs(DummyClass, {'a': 1, 'b': 2, 'invalid': 3}, context="test") 

72 

73 def test_load_full(self): 

74 """Test load_full static method.""" 

75 buildModel(properties=self.properties, **self.paths) 

76 loaded_model = BuildModel.load_full(self.paths['output_model_pth_path']) 

77 assert hasattr(loaded_model, 'state_dict')