Coverage for biobb_pytorch / test / unitests / test_mdae / test_build_model.py: 98%
45 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1# type: ignore
2import pytest
3from biobb_common.tools import test_fixtures as fx
4from biobb_pytorch.mdae.build_model import buildModel, BuildModel
5from biobb_pytorch.mdae.utils.model_utils import assert_valid_kwargs
6import torch
7import tempfile
8from pathlib import Path
11class TestBuildModel:
12 def setup_class(self):
13 fx.test_setup(self, 'buildModel')
15 def teardown_class(self):
16 fx.test_teardown(self)
18 def test_build_model(self):
19 buildModel(properties=self.properties, **self.paths)
20 assert fx.not_empty(self.paths['output_model_pth_path'])
22 # Load and verify model structure
23 # The model is saved directly as an object, not in a dictionary
24 model = torch.load(self.paths['output_model_pth_path'], weights_only=False)
25 # Verify it's a PyTorch model (has state_dict method)
26 assert hasattr(model, 'state_dict'), "Model file should contain a PyTorch model"
27 assert hasattr(model, '_hparams'), "Model file should contain _hparams attribute"
29 def test_build_model_with_custom_loss(self):
30 """Test building model with custom loss function."""
31 props = self.properties.copy()
32 props['options'] = props.get('options', {}).copy()
33 props['options']['loss_function'] = {
34 'loss_type': 'MSELoss'
35 }
37 with tempfile.NamedTemporaryFile(suffix='.pth', delete=False) as tmp:
38 tmp_path = tmp.name
40 try:
41 buildModel(properties=props,
42 input_stats_pt_path=self.paths['input_stats_pt_path'],
43 output_model_pth_path=tmp_path)
44 assert Path(tmp_path).exists()
45 finally:
46 if Path(tmp_path).exists():
47 Path(tmp_path).unlink()
49 def test_build_model_no_output(self):
50 """Test building model without saving."""
51 instance = BuildModel(
52 input_stats_pt_path=self.paths['input_stats_pt_path'],
53 output_model_pth_path=None,
54 properties=self.properties
55 )
56 assert instance.model is not None
57 assert hasattr(instance.model, 'forward')
59 def test_assert_valid_kwargs(self):
60 """Test assert_valid_kwargs utility function."""
61 class DummyClass:
62 def __init__(self, a, b, c=None):
63 pass
65 # Valid kwargs should not raise
66 assert_valid_kwargs(DummyClass, {'a': 1, 'b': 2}, context="test")
67 assert_valid_kwargs(DummyClass, {'a': 1, 'b': 2, 'c': 3}, context="test")
69 # Invalid kwargs should raise
70 with pytest.raises(AssertionError):
71 assert_valid_kwargs(DummyClass, {'a': 1, 'b': 2, 'invalid': 3}, context="test")
73 def test_load_full(self):
74 """Test load_full static method."""
75 buildModel(properties=self.properties, **self.paths)
76 loaded_model = BuildModel.load_full(self.paths['output_model_pth_path'])
77 assert hasattr(loaded_model, 'state_dict')