Coverage for biobb_pytorch / mdae / train_model.py: 83%
120 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import torch
2import os
3from typing import Optional
4from biobb_common.tools.file_utils import launchlogger
5from biobb_common.tools import file_utils as fu
6from biobb_pytorch.mdae.utils.log_utils import get_size
7from biobb_common.generic.biobb_object import BiobbObject
8import lightning.pytorch.callbacks as _cbs
9import lightning.pytorch.loggers as _loggers
10import lightning.pytorch.profilers as _profiler
11from mlcolvar.utils.trainer import MetricsCallback
12import lightning
13from mlcolvar.data import DictModule
14from mlcolvar.data import DictDataset
15import numpy as np
18class TrainModel(BiobbObject):
19 """
20 | biobb_pytorch TrainModel
21 | Trains a PyTorch autoencoder using the given properties.
22 | Trains a PyTorch autoencoder using the given properties.
24 Args:
25 input_model_pth_path (str): Path to the input model file. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333).
26 input_dataset_pt_path (str): Path to the input dataset file (.pt) produced by the MD feature pipeline. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pt>`_. Accepted formats: pt (edam:format_2333).
27 output_model_pth_path (str) (Optional): Path to save the trained model (.pth). If omitted, the trained model is only available in memory. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333).
28 output_metrics_npz_path (str) (Optional): Path save training metrics in compressed NumPy format (.npz). File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.npz>`_. Accepted formats: npz (edam:format_2333).
29 properties (dict - Python dictionary object containing the tool parameters, not input/output files):
30 * **Trainer** (*dict*) - ({}) PyTorch Lightning Trainer options (e.g. max_epochs, callbacks, logger, profiler, accelerator, devices, etc.).
31 * **Dataset** (*dict*) - ({}) mlcolvar DictDataset / DictModule options (e.g. batch_size, split proportions and shuffling flags).
33 Examples:
34 This example shows how to use the TrainModel class to train a PyTorch autoencoder model::
36 from biobb_pytorch.mdae.train_model import trainModel
38 input_model_pth_path='input_model.pth'
39 input_dataset_pt_path='input_dataset.pt'
40 output_model_pth_path='output_model.pth'
41 output_metrics_npz_path='output_metrics.npz'
43 prop={
44 'Trainer': {
45 'max_epochs': 10,
46 'callbacks': {
47 'metrics': ['EarlyStopping']
48 }
49 }
50 },
51 'Dataset': {
52 'batch_size': 32,
53 'split': {
54 'train_prop': 0.8,
55 'val_prop': 0.2
56 }
57 }
58 }
60 trainModel(input_model_pth_path=input_model_pth_path,
61 input_dataset_pt_path=input_dataset_pt_path,
62 output_model_pth_path=None,
63 output_metrics_npz_path=None,
64 properties=prop)
66 Info:
67 * wrapped_software:
68 * name: PyTorch
69 * version: >=1.6.0
70 * license: BSD 3-Clause
71 * ontology:
72 * name: EDAM
73 * schema: http://edamontology.org/EDAM.owl
74 """
76 def __init__(
77 self,
78 input_model_pth_path: str,
79 input_dataset_pt_path: str,
80 output_model_pth_path: Optional[str] = None,
81 output_metrics_npz_path: Optional[str] = None,
82 properties: dict = None,
83 **kwargs,
84 ) -> None:
86 properties = properties or {}
88 super().__init__(properties)
90 self.input_model_pth_path = input_model_pth_path
91 self.input_dataset_pt_path = input_dataset_pt_path
92 self.output_model_pth_path = output_model_pth_path
93 self.output_metrics_npz_path = output_metrics_npz_path
94 self.properties = properties.copy()
95 self.locals_var_dict = locals().copy()
97 # Input/Output files
98 self.io_dict = {
99 "in": {
100 "input_model_pth_path": input_model_pth_path,
101 "input_dataset_pt_path": input_dataset_pt_path,
102 },
103 "out": {},
104 }
105 if output_model_pth_path:
106 self.io_dict["out"]["output_model_pth_path"] = output_model_pth_path
107 if output_metrics_npz_path:
108 self.io_dict["out"]["output_metrics_npz_path"] = output_metrics_npz_path
110 self.Trainer = self.properties.get('Trainer', {})
111 self.Dataset = self.properties.get('Dataset', {})
113 # Check the properties
114 self.check_properties(properties)
115 self.check_arguments()
117 def get_callbacks(self):
118 self.colvars_metrics = MetricsCallback()
119 cbs_list = [self.colvars_metrics]
121 callbacks_prop = self.properties.get('Trainer', {}).get('callbacks', {})
122 if not callbacks_prop:
123 return cbs_list
124 else:
125 for k, v in self.properties['Trainer']['callbacks'].items():
126 callback_params = self.properties['Trainer']['callbacks'][k]
127 CallbackClass = getattr(_cbs, k, None)
128 if CallbackClass:
129 callback = CallbackClass(**callback_params)
130 cbs_list.append(callback)
131 return cbs_list
133 def get_logger(self):
134 logger_prop = self.properties.get('Trainer', {}).get('logger', False)
135 if not logger_prop:
136 return None
138 logger_type, logger_params = next(iter(logger_prop.items()))
139 LoggerClass = getattr(_loggers, logger_type, None)
140 if LoggerClass is None:
141 raise KeyError(f"No Logger named {logger_type} in lightning.pytorch.loggers")
143 return LoggerClass(**logger_params)
145 def get_profiler(self):
146 profiler_prop = self.properties.get('Trainer', {}).get('profiler')
147 if not profiler_prop:
148 return None
150 profiler_type, profiler_params = next(iter(profiler_prop.items()))
151 ProfilerClass = getattr(_profiler, profiler_type, None)
152 if ProfilerClass is None:
153 raise KeyError(f"No Profiler named {profiler_type} in lightning.pytorch.profilers")
155 return ProfilerClass(**profiler_params)
157 def get_trainer(self):
158 train_params = {k: v for k, v in self.properties['Trainer'].items()
159 if k not in ['callbacks', 'logger', 'profiler']}
160 train_params['callbacks'] = self.get_callbacks()
161 train_params['logger'] = self.get_logger()
162 train_params['profiler'] = self.get_profiler()
163 return lightning.Trainer(**train_params)
165 def load_model(self):
166 return torch.load(self.io_dict["in"]["input_model_pth_path"],
167 weights_only=False)
169 def load_dataset(self):
170 dataset = torch.load(self.io_dict["in"]["input_dataset_pt_path"],
171 weights_only=False)
172 return DictDataset(dataset)
174 def create_datamodule(self, dataset):
175 ds_cfg = self.properties['Dataset']
177 lengths = [ds_cfg['split'].get('train_prop', 0.8),
178 ds_cfg['split'].get('val_prop', 0.2)]
179 if ds_cfg['split'].get('test_prop', 0) > 0:
180 lengths.append(ds_cfg['split'].get('test_prop', 0))
182 return DictModule(
183 dataset,
184 batch_size=ds_cfg.get('batch_size', 16),
185 lengths=lengths,
186 shuffle=ds_cfg['split'].get('shuffle', True),
187 random_split=ds_cfg['split'].get('random_split', True)
188 )
190 def fit_model(self, trainer, model, datamodule):
191 """Fit the model to the data, capturing logs and keeping tqdm clean."""
192 trainer.fit(model, datamodule)
194 def save_full(self, model) -> None:
195 """Serialize the full model object (including architecture)."""
196 torch.save(model, self.io_dict["out"]["output_model_pth_path"])
198 @launchlogger
199 def launch(self) -> int:
200 """
201 Execute the :class:`TrainModel` class and its `.launch()` method.
202 """
204 fu.log('## BioBB Model Trainer ##', self.out_log)
206 # Setup Biobb
207 if self.check_restart():
208 return 0
210 self.stage_files()
212 # Start Pipeline
214 # load the model
215 fu.log(f'Load model from {os.path.abspath(self.io_dict["in"]["input_model_pth_path"])}', self.out_log)
216 self.model = self.load_model()
218 # load the dataset
219 fu.log(f'Load dataset from {os.path.abspath(self.io_dict["in"]["input_dataset_pt_path"])}', self.out_log)
220 self.dataset = self.load_dataset()
222 # create the datamodule
223 fu.log('Start training...', self.out_log)
224 self.datamodule = self.create_datamodule(self.dataset)
226 # get the trainer
227 self.trainer = self.get_trainer()
229 # fit the model
230 self.fit_model(self.trainer, self.model, self.datamodule)
232 # Set the metrics
233 self.metrics = self.colvars_metrics.metrics
235 # Save the metrics if path provided
236 if self.output_metrics_npz_path:
237 np.savez_compressed(self.io_dict["out"]["output_metrics_npz_path"], **self.metrics)
238 fu.log(f'Training Metrics saved to {os.path.abspath(self.io_dict["out"]["output_metrics_npz_path"])}', self.out_log)
239 fu.log(f'File size: {get_size(self.io_dict["out"]["output_metrics_npz_path"])}', self.out_log)
241 # save the model if path provided
242 if self.output_model_pth_path:
243 self.save_full(self.model)
244 fu.log(f'Trained Model saved to {os.path.abspath(self.io_dict["out"]["output_model_pth_path"])}', self.out_log)
245 fu.log(f'File size: {get_size(self.io_dict["out"]["output_model_pth_path"])}', self.out_log)
247 # Copy files to host
248 self.copy_to_host()
250 # Remove temporal files
251 self.remove_tmp_files()
253 output_created = bool(self.output_model_pth_path or self.output_metrics_npz_path)
254 self.check_arguments(output_files_created=output_created, raise_exception=False)
256 return 0
259def trainModel(
260 properties: dict,
261 input_model_pth_path: str,
262 input_dataset_pt_path: str,
263 output_model_pth_path: Optional[str] = None,
264 output_metrics_npz_path: Optional[str] = None,
265 **kwargs,
266) -> int:
267 """Create the :class:`TrainModel <TrainModel.TrainModel>` class and
268 execute the :meth:`launch() <TrainModel.TrainModel.launch>` method."""
269 return TrainModel(**dict(locals())).launch()
272trainModel.__doc__ = TrainModel.__doc__
273main = TrainModel.get_main(trainModel, "Trains a PyTorch autoencoder using the given properties.")
275if __name__ == "__main__":
276 main()