Coverage for biobb_pytorch / mdae / build_model.py: 88%

129 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import os 

2import torch 

3import importlib 

4from typing import Dict, Any, Type, Optional 

5from biobb_pytorch.mdae.models import __all__ as AVAILABLE_MODELS 

6from biobb_pytorch.mdae.utils.model_utils import assert_valid_kwargs 

7from biobb_pytorch.mdae.utils.log_utils import get_size 

8from biobb_common.tools.file_utils import launchlogger 

9from biobb_common.tools import file_utils as fu 

10from biobb_common.generic.biobb_object import BiobbObject 

11 

12 

13class BuildModel(BiobbObject): 

14 """ 

15 | biobb_pytorch BuildModel 

16 | Build a Molecular Dynamics AutoEncoder (MDAE) PyTorch model. 

17 | Builds a PyTorch autoencoder from the given properties. 

18 

19 Args: 

20 input_stats_pt_path (str): Path to the input model statistics file. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_input_model.pt>`_. Accepted formats: pt (edam:format_2333). 

21 output_model_pth_path (str) (Optional): Path to save the model in .pth format. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333). 

22 properties (dict - Python dictionary object containing the tool parameters, not input/output files): 

23 * **model_type** (*str*) - ("AutoEncoder") Name of the model class to instantiate (must exist in biobb_pytorch.mdae.models). 

24 * **n_cvs** (*int*) - (1) Dimensionality of the latent space. 

25 * **encoder_layers** (*list*) - ([16]) List of integers representing the number of neurons in each encoder layer. 

26 * **decoder_layers** (*list*) - ([16]) List of integers representing the number of neurons in each decoder layer. 

27 * **options** (*dict*) - ({"norm_in": {"mode": "min_max"}}) Additional options (e.g. norm_in, optimizer, loss_function, device, etc.). 

28 

29 Examples: 

30 This example shows how to use the BuildModel class to build a PyTorch autoencoder model:: 

31 

32 from biobb_pytorch.mdae.build_model import buildModel 

33 

34 input_stats_pt_path = "input_stats.pt" 

35 output_model_pth_file = "model.pth" 

36 

37 n_features = 128 

38 prop = { 

39 'model_type': 'AutoEncoder', 

40 'n_cvs': 10, 

41 'encoder_layers': [n_features, 64, 32], 

42 'decoder_layers': [32, 64, n_features], 

43 'options': { 

44 'norm_in': {"mode": "min_max"}, 

45 'optimizer': { 

46 'lr': 1e-4 

47 } 

48 } 

49 } 

50 

51 buildModel(input_stats_pt_path=input_stats_pt_path, 

52 output_model_pth_path=None, 

53 properties=prop) 

54 

55 Info: 

56 * wrapped_software: 

57 * name: PyTorch 

58 * version: >=1.6.0 

59 * license: BSD 3-Clause 

60 * ontology: 

61 * name: EDAM 

62 * schema: http://edamontology.org/EDAM.owl 

63 """ 

64 

65 def __init__( 

66 self, 

67 input_stats_pt_path: str, 

68 output_model_pth_path: Optional[str] = None, 

69 properties: dict = None, 

70 **kwargs, 

71 ) -> None: 

72 

73 properties = properties or {} 

74 

75 super().__init__(properties) 

76 

77 self.input_stats_pt_path = input_stats_pt_path 

78 self.output_model_pth_path = output_model_pth_path 

79 self.props = properties.copy() 

80 self.locals_var_dict = locals().copy() 

81 

82 # Input/Output files 

83 self.io_dict = { 

84 "in": { 

85 "input_stats_pt_path": input_stats_pt_path, 

86 }, 

87 "out": {} 

88 } 

89 if output_model_pth_path: 

90 self.io_dict["out"]["output_model_pth_path"] = output_model_pth_path 

91 

92 # build the per-feature arguments 

93 self.options: dict = properties.get("options", {}) 

94 self.model_type: str = properties.get("model_type", "AutoEncoder") 

95 self.n_cvs: int = properties.get("n_cvs", 1) 

96 self.encoder_layers: list = properties.get("encoder_layers", [16]) 

97 self.decoder_layers: list = properties.get("decoder_layers", [16]) 

98 self.loss_function: Optional[dict] = properties.get("loss_function", None) 

99 self.device = self.options['device'] if 'device' in self.options else 'cpu' 

100 

101 # load the input files 

102 self.stats = torch.load(self.io_dict['in']['input_stats_pt_path'], 

103 weights_only=False) 

104 

105 # Check the properties 

106 self.check_properties(properties) 

107 self.check_arguments() 

108 

109 self._validate_props() 

110 self.model = self._build_model() 

111 self.loss_fn = self._build_loss() 

112 

113 # Store hyperparameters for reproducibility 

114 hparams = { 

115 'model_type': properties['model_type'], 

116 'n_cvs': properties['n_cvs'], 

117 'encoder_layers': properties['encoder_layers'], 

118 'decoder_layers': properties['decoder_layers'], 

119 'loss_function': self._hparams_loss_repr(), 

120 'options': {k: v for k, v in properties['options'].items() if k != 'loss_function'} 

121 } 

122 setattr(self.model, '_hparams', hparams) 

123 

124 # Attach loss_fn and move model to device 

125 self.model.loss_fn = self.loss_fn 

126 self.model.to(self.device) 

127 

128 def _validate_props(self) -> None: 

129 required = ['model_type', 'n_cvs', 'encoder_layers', 'decoder_layers', 'options'] 

130 missing = [k for k in required if k not in self.props] 

131 if missing: 

132 raise KeyError(f"Missing required properties: {missing}") 

133 

134 model_type = self.props['model_type'] 

135 if model_type not in AVAILABLE_MODELS: 

136 raise ValueError( 

137 f"Unknown model_type '{model_type}'. Available: {AVAILABLE_MODELS}" 

138 ) 

139 

140 def _build_model(self) -> torch.nn.Module: 

141 module = importlib.import_module('biobb_pytorch.mdae.models') 

142 ModelClass: Type[torch.nn.Module] = getattr(module, self.props['model_type']) 

143 

144 init_args = { 

145 'n_features': self.stats['shape'][1], 

146 'n_cvs': self.props['n_cvs'], 

147 'encoder_layers': self.props['encoder_layers'], 

148 'decoder_layers': self.props['decoder_layers'], 

149 'options': {k: v for k, v in self.props['options'].items() if k not in ['loss_function', 'norm_in']} 

150 } 

151 

152 if 'norm_in' in self.props.get('options', {}): 

153 

154 init_args['options']['norm_in'] = { 

155 'stats': self.stats, 

156 'mode': self.props['options']['norm_in'].get('mode') 

157 } 

158 

159 assert_valid_kwargs(ModelClass, init_args, context="model init") 

160 

161 return ModelClass(**init_args) 

162 

163 def _build_loss(self) -> torch.nn.Module: 

164 loss_config = self.props['options'].get('loss_function') 

165 if loss_config and 'loss_type' in loss_config and loss_config['loss_type'] == 'PhysicsLoss': 

166 loss_config['stats'] = self.stats 

167 

168 if not loss_config: 

169 # Use model's default 

170 return getattr(self.model, 'loss_fn', None) 

171 

172 loss_type = loss_config.get('loss_type') 

173 if not loss_type: 

174 raise KeyError("'loss_type' must be specified in options['loss_function']") 

175 

176 loss_module = importlib.import_module('biobb_pytorch.mdae.loss') 

177 LossClass = getattr(loss_module, loss_type) 

178 

179 kwargs = {k: v for k, v in loss_config.items() if k != 'loss_type'} 

180 

181 assert_valid_kwargs(LossClass, kwargs, context="loss init") 

182 try: 

183 return LossClass(**kwargs) 

184 except Exception: 

185 kwargs = {k: v for k, v in kwargs.items() if k != 'stats'} 

186 return LossClass(**kwargs) 

187 

188 def _hparams_loss_repr(self) -> str: 

189 loss_config = self.props['options'].get('loss_function') 

190 if loss_config: 

191 name = loss_config.get('loss_type', '') 

192 args = [f"{k}={v}" for k, v in loss_config.items() if k not in ['loss_type', 'stats']] 

193 return f"{name}({', '.join(args)})" 

194 # fallback to model's representation 

195 return repr(getattr(self.model, 'loss_fn', '')) 

196 

197 def save_weights(self, path: str) -> None: 

198 """Save model.state_dict() to the given path.""" 

199 torch.save(self.model.state_dict(), path) 

200 

201 @classmethod 

202 def load_weights( 

203 cls, 

204 props: Dict[str, Any], 

205 path: str 

206 ) -> 'BuildModel': 

207 """Instantiate from props and load state_dict from path.""" 

208 inst = cls(props) 

209 state = torch.load(path, map_location=inst.device) 

210 inst.model.load_state_dict(state) 

211 inst.model.to(inst.device) 

212 return inst 

213 

214 def save_full(self) -> None: 

215 """Serialize the full model object (including architecture).""" 

216 torch.save(self.model, self.output_model_pth_path) 

217 

218 @staticmethod 

219 def load_full(path: str) -> torch.nn.Module: 

220 """Load a model serialized with save_full.""" 

221 return torch.load(path, weights_only=False) 

222 

223 @launchlogger 

224 def launch(self) -> int: 

225 """ 

226 Execute the :class:`BuildModel` class and its `.launch()` method. 

227 

228 Args: 

229 output_model_pth_path (str): Path where the model will be saved. 

230 properties (dict): Hyper‐parameters for model construction. 

231 """ 

232 

233 # Setup Biobb 

234 if self.check_restart(): 

235 return 0 

236 

237 self.stage_files() 

238 

239 if self.output_model_pth_path: 

240 self.save_full() 

241 

242 fu.log("## BioBB AutoEncoder Builder ##", self.out_log) 

243 fu.log("", self.out_log) 

244 fu.log("Hyperparameters:", self.out_log) 

245 fu.log("----------------", self.out_log) 

246 for key, value in self.model.__dict__.get('_hparams').items(): 

247 if key == 'options': 

248 fu.log(f"{key}:", self.out_log) 

249 for sub_key, sub_value in value.items(): 

250 fu.log(f" {sub_key}: {sub_value}", self.out_log) 

251 else: 

252 fu.log(f"{key}: {value}", self.out_log) 

253 fu.log("", self.out_log) 

254 

255 fu.log("Model:", self.out_log) 

256 fu.log("------", self.out_log) 

257 

258 for line in str(self.model).splitlines(): 

259 fu.log(line, self.out_log) 

260 fu.log("", self.out_log) 

261 

262 if self.output_model_pth_path: 

263 fu.log(f"Model saved in .pth format in " 

264 f'{os.path.abspath(self.io_dict["out"]["output_model_pth_path"])}', 

265 self.out_log, 

266 ) 

267 fu.log(f'File size: ' 

268 f'{get_size(self.io_dict["out"]["output_model_pth_path"])}', 

269 self.out_log, 

270 ) 

271 

272 # Copy files to host 

273 self.copy_to_host() 

274 

275 # Remove temporal files 

276 self.remove_tmp_files() 

277 

278 self.check_arguments(output_files_created=bool(self.output_model_pth_path), raise_exception=False) 

279 

280 return 0 

281 

282 

283def buildModel( 

284 properties: dict, 

285 input_stats_pt_path: str, 

286 output_model_pth_path: Optional[str] = None, 

287 **kwargs, 

288) -> int: 

289 """Create the :class:`BuildModel <BuildModel.BuildModel>` class and 

290 execute the :meth:`launch() <BuildModel.BuildModel.launch>` method.""" 

291 return BuildModel(**dict(locals())).launch() 

292 

293 

294buildModel.__doc__ = BuildModel.__doc__ 

295main = BuildModel.get_main(buildModel, "Build a Molecular Dynamics AutoEncoder (MDAE) PyTorch model.") 

296 

297 

298if __name__ == "__main__": 

299 main() 

300 

301# Example usage: 

302 

303# n_features = torch.rand(100, 20) 

304# n_feat = n_features.shape[1] 

305 

306# properties = { 

307# 'model_type': 'VariationalAutoEncoder', 

308# 'n_cvs': 10, 

309# 'encoder_layers': [n_feat, 64, 32], 

310# 'decoder_layers': [32, 64, n_feat], 

311# 'options': { 

312# 'loss_function': { 

313# 'loss_type': 'ELBOLoss', 

314# 'beta': 1.0, 

315# 'reconstruction': 'mse', 

316# 'reduction': 'sum'}, 

317# 'optimizer': { 

318# 'lr': 0.001 

319# } 

320# } 

321# } 

322 

323# model_builder = BuildModel(properties) 

324# model_builder.save_full("test_model.pth") 

325# model = model_builder.load_full("test_model.pth") 

326 

327# print() 

328# print("Hyperparameters:") 

329# print("----------------") 

330# for key, value in model._hparams.items(): 

331# print(f"{key}: {value}") 

332# print() 

333# print("Model:") 

334# print("------") 

335# print(model)