Coverage for biobb_pytorch / mdae / train_model.py: 83%

120 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import torch 

2import os 

3from typing import Optional 

4from biobb_common.tools.file_utils import launchlogger 

5from biobb_common.tools import file_utils as fu 

6from biobb_pytorch.mdae.utils.log_utils import get_size 

7from biobb_common.generic.biobb_object import BiobbObject 

8import lightning.pytorch.callbacks as _cbs 

9import lightning.pytorch.loggers as _loggers 

10import lightning.pytorch.profilers as _profiler 

11from mlcolvar.utils.trainer import MetricsCallback 

12import lightning 

13from mlcolvar.data import DictModule 

14from mlcolvar.data import DictDataset 

15import numpy as np 

16 

17 

18class TrainModel(BiobbObject): 

19 """ 

20 | biobb_pytorch TrainModel 

21 | Trains a PyTorch autoencoder using the given properties. 

22 | Trains a PyTorch autoencoder using the given properties. 

23 

24 Args: 

25 input_model_pth_path (str): Path to the input model file. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333). 

26 input_dataset_pt_path (str): Path to the input dataset file (.pt) produced by the MD feature pipeline. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pt>`_. Accepted formats: pt (edam:format_2333). 

27 output_model_pth_path (str) (Optional): Path to save the trained model (.pth). If omitted, the trained model is only available in memory. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333). 

28 output_metrics_npz_path (str) (Optional): Path save training metrics in compressed NumPy format (.npz). File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.npz>`_. Accepted formats: npz (edam:format_2333). 

29 properties (dict - Python dictionary object containing the tool parameters, not input/output files): 

30 * **Trainer** (*dict*) - ({}) PyTorch Lightning Trainer options (e.g. max_epochs, callbacks, logger, profiler, accelerator, devices, etc.). 

31 * **Dataset** (*dict*) - ({}) mlcolvar DictDataset / DictModule options (e.g. batch_size, split proportions and shuffling flags). 

32 

33 Examples: 

34 This example shows how to use the TrainModel class to train a PyTorch autoencoder model:: 

35 

36 from biobb_pytorch.mdae.train_model import trainModel 

37 

38 input_model_pth_path='input_model.pth' 

39 input_dataset_pt_path='input_dataset.pt' 

40 output_model_pth_path='output_model.pth' 

41 output_metrics_npz_path='output_metrics.npz' 

42 

43 prop={ 

44 'Trainer': { 

45 'max_epochs': 10, 

46 'callbacks': { 

47 'metrics': ['EarlyStopping'] 

48 } 

49 } 

50 }, 

51 'Dataset': { 

52 'batch_size': 32, 

53 'split': { 

54 'train_prop': 0.8, 

55 'val_prop': 0.2 

56 } 

57 } 

58 } 

59 

60 trainModel(input_model_pth_path=input_model_pth_path, 

61 input_dataset_pt_path=input_dataset_pt_path, 

62 output_model_pth_path=None, 

63 output_metrics_npz_path=None, 

64 properties=prop) 

65 

66 Info: 

67 * wrapped_software: 

68 * name: PyTorch 

69 * version: >=1.6.0 

70 * license: BSD 3-Clause 

71 * ontology: 

72 * name: EDAM 

73 * schema: http://edamontology.org/EDAM.owl 

74 """ 

75 

76 def __init__( 

77 self, 

78 input_model_pth_path: str, 

79 input_dataset_pt_path: str, 

80 output_model_pth_path: Optional[str] = None, 

81 output_metrics_npz_path: Optional[str] = None, 

82 properties: dict = None, 

83 **kwargs, 

84 ) -> None: 

85 

86 properties = properties or {} 

87 

88 super().__init__(properties) 

89 

90 self.input_model_pth_path = input_model_pth_path 

91 self.input_dataset_pt_path = input_dataset_pt_path 

92 self.output_model_pth_path = output_model_pth_path 

93 self.output_metrics_npz_path = output_metrics_npz_path 

94 self.properties = properties.copy() 

95 self.locals_var_dict = locals().copy() 

96 

97 # Input/Output files 

98 self.io_dict = { 

99 "in": { 

100 "input_model_pth_path": input_model_pth_path, 

101 "input_dataset_pt_path": input_dataset_pt_path, 

102 }, 

103 "out": {}, 

104 } 

105 if output_model_pth_path: 

106 self.io_dict["out"]["output_model_pth_path"] = output_model_pth_path 

107 if output_metrics_npz_path: 

108 self.io_dict["out"]["output_metrics_npz_path"] = output_metrics_npz_path 

109 

110 self.Trainer = self.properties.get('Trainer', {}) 

111 self.Dataset = self.properties.get('Dataset', {}) 

112 

113 # Check the properties 

114 self.check_properties(properties) 

115 self.check_arguments() 

116 

117 def get_callbacks(self): 

118 self.colvars_metrics = MetricsCallback() 

119 cbs_list = [self.colvars_metrics] 

120 

121 callbacks_prop = self.properties.get('Trainer', {}).get('callbacks', {}) 

122 if not callbacks_prop: 

123 return cbs_list 

124 else: 

125 for k, v in self.properties['Trainer']['callbacks'].items(): 

126 callback_params = self.properties['Trainer']['callbacks'][k] 

127 CallbackClass = getattr(_cbs, k, None) 

128 if CallbackClass: 

129 callback = CallbackClass(**callback_params) 

130 cbs_list.append(callback) 

131 return cbs_list 

132 

133 def get_logger(self): 

134 logger_prop = self.properties.get('Trainer', {}).get('logger', False) 

135 if not logger_prop: 

136 return None 

137 

138 logger_type, logger_params = next(iter(logger_prop.items())) 

139 LoggerClass = getattr(_loggers, logger_type, None) 

140 if LoggerClass is None: 

141 raise KeyError(f"No Logger named {logger_type} in lightning.pytorch.loggers") 

142 

143 return LoggerClass(**logger_params) 

144 

145 def get_profiler(self): 

146 profiler_prop = self.properties.get('Trainer', {}).get('profiler') 

147 if not profiler_prop: 

148 return None 

149 

150 profiler_type, profiler_params = next(iter(profiler_prop.items())) 

151 ProfilerClass = getattr(_profiler, profiler_type, None) 

152 if ProfilerClass is None: 

153 raise KeyError(f"No Profiler named {profiler_type} in lightning.pytorch.profilers") 

154 

155 return ProfilerClass(**profiler_params) 

156 

157 def get_trainer(self): 

158 train_params = {k: v for k, v in self.properties['Trainer'].items() 

159 if k not in ['callbacks', 'logger', 'profiler']} 

160 train_params['callbacks'] = self.get_callbacks() 

161 train_params['logger'] = self.get_logger() 

162 train_params['profiler'] = self.get_profiler() 

163 return lightning.Trainer(**train_params) 

164 

165 def load_model(self): 

166 return torch.load(self.io_dict["in"]["input_model_pth_path"], 

167 weights_only=False) 

168 

169 def load_dataset(self): 

170 dataset = torch.load(self.io_dict["in"]["input_dataset_pt_path"], 

171 weights_only=False) 

172 return DictDataset(dataset) 

173 

174 def create_datamodule(self, dataset): 

175 ds_cfg = self.properties['Dataset'] 

176 

177 lengths = [ds_cfg['split'].get('train_prop', 0.8), 

178 ds_cfg['split'].get('val_prop', 0.2)] 

179 if ds_cfg['split'].get('test_prop', 0) > 0: 

180 lengths.append(ds_cfg['split'].get('test_prop', 0)) 

181 

182 return DictModule( 

183 dataset, 

184 batch_size=ds_cfg.get('batch_size', 16), 

185 lengths=lengths, 

186 shuffle=ds_cfg['split'].get('shuffle', True), 

187 random_split=ds_cfg['split'].get('random_split', True) 

188 ) 

189 

190 def fit_model(self, trainer, model, datamodule): 

191 """Fit the model to the data, capturing logs and keeping tqdm clean.""" 

192 trainer.fit(model, datamodule) 

193 

194 def save_full(self, model) -> None: 

195 """Serialize the full model object (including architecture).""" 

196 torch.save(model, self.io_dict["out"]["output_model_pth_path"]) 

197 

198 @launchlogger 

199 def launch(self) -> int: 

200 """ 

201 Execute the :class:`TrainModel` class and its `.launch()` method. 

202 """ 

203 

204 fu.log('## BioBB Model Trainer ##', self.out_log) 

205 

206 # Setup Biobb 

207 if self.check_restart(): 

208 return 0 

209 

210 self.stage_files() 

211 

212 # Start Pipeline 

213 

214 # load the model 

215 fu.log(f'Load model from {os.path.abspath(self.io_dict["in"]["input_model_pth_path"])}', self.out_log) 

216 self.model = self.load_model() 

217 

218 # load the dataset 

219 fu.log(f'Load dataset from {os.path.abspath(self.io_dict["in"]["input_dataset_pt_path"])}', self.out_log) 

220 self.dataset = self.load_dataset() 

221 

222 # create the datamodule 

223 fu.log('Start training...', self.out_log) 

224 self.datamodule = self.create_datamodule(self.dataset) 

225 

226 # get the trainer 

227 self.trainer = self.get_trainer() 

228 

229 # fit the model 

230 self.fit_model(self.trainer, self.model, self.datamodule) 

231 

232 # Set the metrics 

233 self.metrics = self.colvars_metrics.metrics 

234 

235 # Save the metrics if path provided 

236 if self.output_metrics_npz_path: 

237 np.savez_compressed(self.io_dict["out"]["output_metrics_npz_path"], **self.metrics) 

238 fu.log(f'Training Metrics saved to {os.path.abspath(self.io_dict["out"]["output_metrics_npz_path"])}', self.out_log) 

239 fu.log(f'File size: {get_size(self.io_dict["out"]["output_metrics_npz_path"])}', self.out_log) 

240 

241 # save the model if path provided 

242 if self.output_model_pth_path: 

243 self.save_full(self.model) 

244 fu.log(f'Trained Model saved to {os.path.abspath(self.io_dict["out"]["output_model_pth_path"])}', self.out_log) 

245 fu.log(f'File size: {get_size(self.io_dict["out"]["output_model_pth_path"])}', self.out_log) 

246 

247 # Copy files to host 

248 self.copy_to_host() 

249 

250 # Remove temporal files 

251 self.remove_tmp_files() 

252 

253 output_created = bool(self.output_model_pth_path or self.output_metrics_npz_path) 

254 self.check_arguments(output_files_created=output_created, raise_exception=False) 

255 

256 return 0 

257 

258 

259def trainModel( 

260 properties: dict, 

261 input_model_pth_path: str, 

262 input_dataset_pt_path: str, 

263 output_model_pth_path: Optional[str] = None, 

264 output_metrics_npz_path: Optional[str] = None, 

265 **kwargs, 

266) -> int: 

267 """Create the :class:`TrainModel <TrainModel.TrainModel>` class and 

268 execute the :meth:`launch() <TrainModel.TrainModel.launch>` method.""" 

269 return TrainModel(**dict(locals())).launch() 

270 

271 

272trainModel.__doc__ = TrainModel.__doc__ 

273main = TrainModel.get_main(trainModel, "Trains a PyTorch autoencoder using the given properties.") 

274 

275if __name__ == "__main__": 

276 main()