Coverage for biobb_pytorch / mdae / build_model.py: 88%
129 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import os
2import torch
3import importlib
4from typing import Dict, Any, Type, Optional
5from biobb_pytorch.mdae.models import __all__ as AVAILABLE_MODELS
6from biobb_pytorch.mdae.utils.model_utils import assert_valid_kwargs
7from biobb_pytorch.mdae.utils.log_utils import get_size
8from biobb_common.tools.file_utils import launchlogger
9from biobb_common.tools import file_utils as fu
10from biobb_common.generic.biobb_object import BiobbObject
13class BuildModel(BiobbObject):
14 """
15 | biobb_pytorch BuildModel
16 | Build a Molecular Dynamics AutoEncoder (MDAE) PyTorch model.
17 | Builds a PyTorch autoencoder from the given properties.
19 Args:
20 input_stats_pt_path (str): Path to the input model statistics file. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_input_model.pt>`_. Accepted formats: pt (edam:format_2333).
21 output_model_pth_path (str) (Optional): Path to save the model in .pth format. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333).
22 properties (dict - Python dictionary object containing the tool parameters, not input/output files):
23 * **model_type** (*str*) - ("AutoEncoder") Name of the model class to instantiate (must exist in biobb_pytorch.mdae.models).
24 * **n_cvs** (*int*) - (1) Dimensionality of the latent space.
25 * **encoder_layers** (*list*) - ([16]) List of integers representing the number of neurons in each encoder layer.
26 * **decoder_layers** (*list*) - ([16]) List of integers representing the number of neurons in each decoder layer.
27 * **options** (*dict*) - ({"norm_in": {"mode": "min_max"}}) Additional options (e.g. norm_in, optimizer, loss_function, device, etc.).
29 Examples:
30 This example shows how to use the BuildModel class to build a PyTorch autoencoder model::
32 from biobb_pytorch.mdae.build_model import buildModel
34 input_stats_pt_path = "input_stats.pt"
35 output_model_pth_file = "model.pth"
37 n_features = 128
38 prop = {
39 'model_type': 'AutoEncoder',
40 'n_cvs': 10,
41 'encoder_layers': [n_features, 64, 32],
42 'decoder_layers': [32, 64, n_features],
43 'options': {
44 'norm_in': {"mode": "min_max"},
45 'optimizer': {
46 'lr': 1e-4
47 }
48 }
49 }
51 buildModel(input_stats_pt_path=input_stats_pt_path,
52 output_model_pth_path=None,
53 properties=prop)
55 Info:
56 * wrapped_software:
57 * name: PyTorch
58 * version: >=1.6.0
59 * license: BSD 3-Clause
60 * ontology:
61 * name: EDAM
62 * schema: http://edamontology.org/EDAM.owl
63 """
65 def __init__(
66 self,
67 input_stats_pt_path: str,
68 output_model_pth_path: Optional[str] = None,
69 properties: dict = None,
70 **kwargs,
71 ) -> None:
73 properties = properties or {}
75 super().__init__(properties)
77 self.input_stats_pt_path = input_stats_pt_path
78 self.output_model_pth_path = output_model_pth_path
79 self.props = properties.copy()
80 self.locals_var_dict = locals().copy()
82 # Input/Output files
83 self.io_dict = {
84 "in": {
85 "input_stats_pt_path": input_stats_pt_path,
86 },
87 "out": {}
88 }
89 if output_model_pth_path:
90 self.io_dict["out"]["output_model_pth_path"] = output_model_pth_path
92 # build the per-feature arguments
93 self.options: dict = properties.get("options", {})
94 self.model_type: str = properties.get("model_type", "AutoEncoder")
95 self.n_cvs: int = properties.get("n_cvs", 1)
96 self.encoder_layers: list = properties.get("encoder_layers", [16])
97 self.decoder_layers: list = properties.get("decoder_layers", [16])
98 self.loss_function: Optional[dict] = properties.get("loss_function", None)
99 self.device = self.options['device'] if 'device' in self.options else 'cpu'
101 # load the input files
102 self.stats = torch.load(self.io_dict['in']['input_stats_pt_path'],
103 weights_only=False)
105 # Check the properties
106 self.check_properties(properties)
107 self.check_arguments()
109 self._validate_props()
110 self.model = self._build_model()
111 self.loss_fn = self._build_loss()
113 # Store hyperparameters for reproducibility
114 hparams = {
115 'model_type': properties['model_type'],
116 'n_cvs': properties['n_cvs'],
117 'encoder_layers': properties['encoder_layers'],
118 'decoder_layers': properties['decoder_layers'],
119 'loss_function': self._hparams_loss_repr(),
120 'options': {k: v for k, v in properties['options'].items() if k != 'loss_function'}
121 }
122 setattr(self.model, '_hparams', hparams)
124 # Attach loss_fn and move model to device
125 self.model.loss_fn = self.loss_fn
126 self.model.to(self.device)
128 def _validate_props(self) -> None:
129 required = ['model_type', 'n_cvs', 'encoder_layers', 'decoder_layers', 'options']
130 missing = [k for k in required if k not in self.props]
131 if missing:
132 raise KeyError(f"Missing required properties: {missing}")
134 model_type = self.props['model_type']
135 if model_type not in AVAILABLE_MODELS:
136 raise ValueError(
137 f"Unknown model_type '{model_type}'. Available: {AVAILABLE_MODELS}"
138 )
140 def _build_model(self) -> torch.nn.Module:
141 module = importlib.import_module('biobb_pytorch.mdae.models')
142 ModelClass: Type[torch.nn.Module] = getattr(module, self.props['model_type'])
144 init_args = {
145 'n_features': self.stats['shape'][1],
146 'n_cvs': self.props['n_cvs'],
147 'encoder_layers': self.props['encoder_layers'],
148 'decoder_layers': self.props['decoder_layers'],
149 'options': {k: v for k, v in self.props['options'].items() if k not in ['loss_function', 'norm_in']}
150 }
152 if 'norm_in' in self.props.get('options', {}):
154 init_args['options']['norm_in'] = {
155 'stats': self.stats,
156 'mode': self.props['options']['norm_in'].get('mode')
157 }
159 assert_valid_kwargs(ModelClass, init_args, context="model init")
161 return ModelClass(**init_args)
163 def _build_loss(self) -> torch.nn.Module:
164 loss_config = self.props['options'].get('loss_function')
165 if loss_config and 'loss_type' in loss_config and loss_config['loss_type'] == 'PhysicsLoss':
166 loss_config['stats'] = self.stats
168 if not loss_config:
169 # Use model's default
170 return getattr(self.model, 'loss_fn', None)
172 loss_type = loss_config.get('loss_type')
173 if not loss_type:
174 raise KeyError("'loss_type' must be specified in options['loss_function']")
176 loss_module = importlib.import_module('biobb_pytorch.mdae.loss')
177 LossClass = getattr(loss_module, loss_type)
179 kwargs = {k: v for k, v in loss_config.items() if k != 'loss_type'}
181 assert_valid_kwargs(LossClass, kwargs, context="loss init")
182 try:
183 return LossClass(**kwargs)
184 except Exception:
185 kwargs = {k: v for k, v in kwargs.items() if k != 'stats'}
186 return LossClass(**kwargs)
188 def _hparams_loss_repr(self) -> str:
189 loss_config = self.props['options'].get('loss_function')
190 if loss_config:
191 name = loss_config.get('loss_type', '')
192 args = [f"{k}={v}" for k, v in loss_config.items() if k not in ['loss_type', 'stats']]
193 return f"{name}({', '.join(args)})"
194 # fallback to model's representation
195 return repr(getattr(self.model, 'loss_fn', ''))
197 def save_weights(self, path: str) -> None:
198 """Save model.state_dict() to the given path."""
199 torch.save(self.model.state_dict(), path)
201 @classmethod
202 def load_weights(
203 cls,
204 props: Dict[str, Any],
205 path: str
206 ) -> 'BuildModel':
207 """Instantiate from props and load state_dict from path."""
208 inst = cls(props)
209 state = torch.load(path, map_location=inst.device)
210 inst.model.load_state_dict(state)
211 inst.model.to(inst.device)
212 return inst
214 def save_full(self) -> None:
215 """Serialize the full model object (including architecture)."""
216 torch.save(self.model, self.output_model_pth_path)
218 @staticmethod
219 def load_full(path: str) -> torch.nn.Module:
220 """Load a model serialized with save_full."""
221 return torch.load(path, weights_only=False)
223 @launchlogger
224 def launch(self) -> int:
225 """
226 Execute the :class:`BuildModel` class and its `.launch()` method.
228 Args:
229 output_model_pth_path (str): Path where the model will be saved.
230 properties (dict): Hyper‐parameters for model construction.
231 """
233 # Setup Biobb
234 if self.check_restart():
235 return 0
237 self.stage_files()
239 if self.output_model_pth_path:
240 self.save_full()
242 fu.log("## BioBB AutoEncoder Builder ##", self.out_log)
243 fu.log("", self.out_log)
244 fu.log("Hyperparameters:", self.out_log)
245 fu.log("----------------", self.out_log)
246 for key, value in self.model.__dict__.get('_hparams').items():
247 if key == 'options':
248 fu.log(f"{key}:", self.out_log)
249 for sub_key, sub_value in value.items():
250 fu.log(f" {sub_key}: {sub_value}", self.out_log)
251 else:
252 fu.log(f"{key}: {value}", self.out_log)
253 fu.log("", self.out_log)
255 fu.log("Model:", self.out_log)
256 fu.log("------", self.out_log)
258 for line in str(self.model).splitlines():
259 fu.log(line, self.out_log)
260 fu.log("", self.out_log)
262 if self.output_model_pth_path:
263 fu.log(f"Model saved in .pth format in "
264 f'{os.path.abspath(self.io_dict["out"]["output_model_pth_path"])}',
265 self.out_log,
266 )
267 fu.log(f'File size: '
268 f'{get_size(self.io_dict["out"]["output_model_pth_path"])}',
269 self.out_log,
270 )
272 # Copy files to host
273 self.copy_to_host()
275 # Remove temporal files
276 self.remove_tmp_files()
278 self.check_arguments(output_files_created=bool(self.output_model_pth_path), raise_exception=False)
280 return 0
283def buildModel(
284 properties: dict,
285 input_stats_pt_path: str,
286 output_model_pth_path: Optional[str] = None,
287 **kwargs,
288) -> int:
289 """Create the :class:`BuildModel <BuildModel.BuildModel>` class and
290 execute the :meth:`launch() <BuildModel.BuildModel.launch>` method."""
291 return BuildModel(**dict(locals())).launch()
294buildModel.__doc__ = BuildModel.__doc__
295main = BuildModel.get_main(buildModel, "Build a Molecular Dynamics AutoEncoder (MDAE) PyTorch model.")
298if __name__ == "__main__":
299 main()
301# Example usage:
303# n_features = torch.rand(100, 20)
304# n_feat = n_features.shape[1]
306# properties = {
307# 'model_type': 'VariationalAutoEncoder',
308# 'n_cvs': 10,
309# 'encoder_layers': [n_feat, 64, 32],
310# 'decoder_layers': [32, 64, n_feat],
311# 'options': {
312# 'loss_function': {
313# 'loss_type': 'ELBOLoss',
314# 'beta': 1.0,
315# 'reconstruction': 'mse',
316# 'reduction': 'sum'},
317# 'optimizer': {
318# 'lr': 0.001
319# }
320# }
321# }
323# model_builder = BuildModel(properties)
324# model_builder.save_full("test_model.pth")
325# model = model_builder.load_full("test_model.pth")
327# print()
328# print("Hyperparameters:")
329# print("----------------")
330# for key, value in model._hparams.items():
331# print(f"{key}: {value}")
332# print()
333# print("Model:")
334# print("------")
335# print(model)