Coverage for biobb_pytorch / mdae / make_plumed.py: 69%

217 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1 

2import torch 

3from typing import Dict, Any, Optional, List 

4import os 

5from biobb_pytorch.mdae.utils.log_utils import get_size 

6from biobb_common.tools.file_utils import launchlogger 

7from biobb_common.tools import file_utils as fu 

8from biobb_common.generic.biobb_object import BiobbObject 

9 

10 

11class GeneratePlumed(BiobbObject): 

12 """ 

13 | biobb_plumed GeneratePlumed 

14 | Generate PLUMED input for biased dynamics using an MDAE model. 

15 | Generates a PLUMED input file, features.dat, and converts the model to .ptc format. 

16 

17 Args: 

18 input_model_pth_path (str): Path to the trained PyTorch model (.pth) to be converted to TorchScript and used in PLUMED. File type: input. Accepted formats: pth (edam:format_2333). 

19 input_stats_pt_path (str) (Optional): Path to statistics file (.pt) produced during featurization, used to derive the PLUMED features.dat content. File type: input. Accepted formats: pt (edam:format_2333). 

20 input_reference_pdb_path (str) (Optional): Path to reference PDB used for FIT_TO_TEMPLATE actions when Cartesian features are present. File type: input. Accepted formats: pdb (edam:format_1476). 

21 input_ndx_path (str) (Optional): Path to GROMACS index (NDX) file used to define groups when required by PLUMED. File type: input. Accepted formats: ndx (edam:format_2033). 

22 output_plumed_dat_path (str): Path to the output PLUMED input file. File type: output. Accepted formats: dat (edam:format_2330). 

23 output_features_dat_path (str): Path to the output features.dat file describing the CVs to PLUMED. File type: output. Accepted formats: dat (edam:format_2330). 

24 output_model_ptc_path (str): Path to the output TorchScript model file (.ptc) for PLUMED's PYTORCH_MODEL action. File type: output. Accepted formats: ptc (edam:format_2333). 

25 properties (dict - Python dictionary object containing the tool parameters, not input/output files): 

26 * **include_energy** (*bool*) - (True) Whether to include ENERGY in PLUMED. 

27 * **bias** (*list*) - ([]) List of biasing actions (e.g. METAD) to be added to the PLUMED file. 

28 * **prints** (*dict*) - ({"ARG": "*", "STRIDE": 1, "FILE": "COLVAR"}) PRINT command parameters (e.g. ARG, STRIDE, FILE). 

29 * **group** (*dict*) - (None) GROUP definition options (label, NDX group or atom selection parameters). 

30 * **wholemolecules** (*dict*) - (None) WHOLEMOLECULES options when using Cartesian coordinates. 

31 * **fit_to_template** (*dict*) - (None) FIT_TO_TEMPLATE options (e.g. STRIDE, TYPE, etc.). 

32 * **pytorch_model** (*dict*) - (None) PYTORCH_MODEL options (label, PACE and other parameters). 

33 

34 Examples: 

35 This example shows how to use the GeneratePlumed class to generate a PLUMED input file for biased dynamics using an MDAE model:: 

36 

37 from biobb_plumed.generate_plumed import generatePlumed 

38 

39 prop = { 

40 "additional_actions": [ 

41 { 

42 "name": "ENERGY", 

43 "label": "ene" 

44 }, 

45 { 

46 "name": "RMSD", 

47 "label": "rmsd", 

48 "params": { 

49 "TYPE": "OPTIMAL" 

50 } 

51 } 

52 ], 

53 "group": { 

54 "label": "c_alphas", 

55 "NDX_GROUP": "chA_&_C-alpha" 

56 }, 

57 "wholemolecules": { 

58 "ENTITY0": "c_alphas" 

59 }, 

60 "fit_to_template": { 

61 "STRIDE": 1, 

62 "TYPE": "OPTIMAL" 

63 }, 

64 "pytorch_model": { 

65 "label": "cv", 

66 "PACE": 1 

67 }, 

68 "bias": [ 

69 { 

70 "name": "METAD", 

71 "label": "bias", 

72 "params": { 

73 "ARG": "cv.1", 

74 "PACE": 500, 

75 "HEIGHT": 1.2, 

76 "SIGMA": 0.35, 

77 "FILE": "HILLS", 

78 "BIASFACTOR": 8 

79 } 

80 } 

81 ], 

82 "prints": { 

83 "ARG": "cv.*,bias.*", 

84 "STRIDE": 1, 

85 "FILE": "COLVAR" 

86 } 

87 } 

88 

89 generatePlumed( 

90 input_model_pth_path="model.pth", 

91 input_stats_pt_path="stats.pt", 

92 output_plumed_dat_path="plumed.dat", 

93 output_features_dat_path="features.dat", 

94 output_model_ptc_path="model.ptc", 

95 properties=prop 

96 ) 

97 

98 Info: 

99 * wrapped_software: 

100 * name: PLUMED with PyTorch 

101 * version: >=2.0 

102 * license: LGPL 3.0 

103 * ontology: 

104 * name: EDAM 

105 * schema: http://edamontology.org/EDAM.owl 

106 """ 

107 

108 def __init__( 

109 self, 

110 input_model_pth_path: str, 

111 input_stats_pt_path: Optional[str] = None, 

112 input_reference_pdb_path: Optional[str] = None, 

113 input_ndx_path: Optional[str] = None, 

114 output_plumed_dat_path: str = 'plumed.dat', 

115 output_features_dat_path: str = 'features.dat', 

116 output_model_ptc_path: str = 'model.ptc', 

117 properties: Optional[Dict[str, Any]] = None, 

118 **kwargs, 

119 ) -> None: 

120 properties = properties or {} 

121 

122 super().__init__(properties) 

123 self.locals_var_dict = locals().copy() 

124 

125 # Input/Output files 

126 self.io_dict = { 

127 "in": {"input_model_pth_path": input_model_pth_path}, 

128 "out": { 

129 "output_plumed_dat_path": output_plumed_dat_path, 

130 "output_features_dat_path": output_features_dat_path, 

131 "output_model_ptc_path": output_model_ptc_path 

132 } 

133 } 

134 if input_stats_pt_path: 

135 self.io_dict["in"]["input_stats_pt_path"] = input_stats_pt_path 

136 if input_reference_pdb_path: 

137 self.io_dict["in"]["input_reference_pdb_path"] = input_reference_pdb_path 

138 if input_ndx_path: 

139 self.io_dict["in"]["input_ndx_path"] = input_ndx_path 

140 

141 # Properties 

142 self.model_pth = input_model_pth_path 

143 self.stats_pt = input_stats_pt_path 

144 self.ref_pdb = input_reference_pdb_path 

145 self.ndx = input_ndx_path 

146 self.properties = properties 

147 

148 self.additional_actions = self.properties.get('additional_actions', []) 

149 self.group = self.properties.get('group', None) 

150 self.wholemolecules = self.properties.get('wholemolecules', None) 

151 self.fit_to_template = self.properties.get('fit_to_template', None) 

152 self.pytorch_model = self.properties.get('pytorch_model', None) 

153 self.bias = self.properties.get('bias', []) 

154 self.prints = self.properties.get('prints', {'ARG': '*', 'STRIDE': 1, 'FILE': 'COLVAR'}) 

155 

156 # Check the properties 

157 self.check_properties(properties) 

158 self.check_arguments() 

159 

160 self.stats = self._load_stats() 

161 self.n_features = self.stats.get('shape', [None, None])[1] 

162 

163 def _load_stats(self) -> Optional[Dict[str, Any]]: 

164 """Load stats.pt if provided.""" 

165 if self.stats_pt: 

166 return torch.load(self.stats_pt, 

167 weights_only=False) 

168 return None 

169 

170 def _generate_features(self) -> str: 

171 """ 

172 Generate features.dat and return the ARG string for PYTORCH_MODEL. 

173 

174 Returns: 

175 str: Comma-separated ARG string. 

176 """ 

177 if self.stats_pt: 

178 # Non-Cartesian or mixed mode 

179 return self._generate_features_from_stats(self.stats, self.io_dict['out']['output_features_dat_path']) 

180 else: 

181 raise ValueError('Input_stats_pt_path is required.') 

182 

183 def _generate_features_from_stats(self, stats: Dict[str, Any], features_path: str) -> str: 

184 """ 

185 Generate features.dat from stats.pt for distances, angles, dihedrals, and/or cartesians. 

186 

187 Args: 

188 stats (Dict[str, Any]): Loaded stats dictionary. 

189 features_path (str): Path to write features.dat. 

190 

191 Returns: 

192 str: Comma-separated ARG string. 

193 """ 

194 feat_lines = [] 

195 arg_list = [] 

196 dist_count = 1 

197 ang_count = 1 

198 tor_count = 1 

199 

200 # Adjust indices to 1-based for PLUMED 

201 def adjust_indices(indices: List[int]) -> List[int]: 

202 return [idx + 1 for idx in indices] 

203 

204 if 'cartesian_indices' in stats: 

205 pos_atoms = adjust_indices(stats['cartesian_indices']) 

206 fu.log(f"Found {len(pos_atoms)} Cartesian features.", self.out_log) 

207 for atom in pos_atoms: 

208 feat_lines.append(f'p{atom}: POSITION ATOM={atom}') 

209 arg_list.extend([f'p{atom}.x', f'p{atom}.y', f'p{atom}.z']) 

210 

211 if 'distance_indices' in stats: 

212 fu.log(f"Found {len(stats['distance_indices'])} Distance features.", self.out_log) 

213 for pair in stats['distance_indices']: 

214 a, b = adjust_indices(pair) 

215 label = f'd{dist_count}' 

216 feat_lines.append(f'{label}: DISTANCE ATOMS={a},{b}') 

217 arg_list.append(label) 

218 dist_count += 1 

219 

220 if 'angle_indices' in stats: 

221 fu.log(f"Found {len(stats['angle_indices'])} Angle features.", self.out_log) 

222 for triple in stats['angle_indices']: 

223 a, b, c = adjust_indices(triple) 

224 label = f'a{ang_count}' 

225 feat_lines.append(f'{label}: ANGLE ATOMS={a},{b},{c}') 

226 arg_list.append(label) 

227 ang_count += 1 

228 

229 if 'dihedral_indices' in stats: 

230 fu.log(f"Found {len(stats['dihedral_indices'])} Dihedral features.", self.out_log) 

231 for quad in stats['dihedral_indices']: 

232 a, b, c, d = adjust_indices(quad) 

233 label = f't{tor_count}' 

234 feat_lines.append(f'{label}: TORSION ATOMS={a},{b},{c},{d}') 

235 arg_list.append(label) 

236 tor_count += 1 

237 

238 with open(features_path, 'w') as f: 

239 for line in feat_lines: 

240 f.write(line + '\n') 

241 

242 return feat_lines, arg_list 

243 

244 def _convert_model_to_ptc(self) -> None: 

245 """Convert the PyTorch model to TorchScript format (.ptc).""" 

246 model = torch.load(self.model_pth, weights_only=False) 

247 

248 # Add this: Convert numpy.int64 attributes to Python int for JIT compatibility 

249 def convert_attributes_to_int(m): 

250 if hasattr(m, 'in_features'): 

251 m.in_features = int(m.in_features) 

252 if hasattr(m, 'out_features'): 

253 m.out_features = int(m.out_features) 

254 for child in m.children(): 

255 convert_attributes_to_int(child) 

256 

257 convert_attributes_to_int(model) 

258 

259 self._enable_jit_scripting(model) 

260 output_path = self.io_dict['out']['output_model_ptc_path'] 

261 try: 

262 scripted_model = torch.jit.script(model) 

263 torch.jit.save(scripted_model, output_path) 

264 fu.log(f'Successfully scripted and saved model to {output_path}', self.out_log) 

265 except Exception as e: 

266 fu.log(f'jit.script failed: {e}. Attempting jit.trace instead.', self.out_log) 

267 # Set to eval mode for tracing (required for BatchNorm with batch size 1) 

268 model.eval() 

269 example_input = torch.randn(1, self.n_features) # Batch size 1, flat input 

270 traced_model = torch.jit.trace(model, example_input) 

271 torch.jit.save(traced_model, output_path) 

272 fu.log(f'Successfully traced and saved model to {output_path}', self.out_log) 

273 

274 def _enable_jit_scripting(self, module: torch.nn.Module) -> None: 

275 """Set _jit_is_scripting flag to True for the module and submodules to enable scripting.""" 

276 if hasattr(module, '_jit_is_scripting'): 

277 module._jit_is_scripting = True 

278 for subm in module.modules(): 

279 if hasattr(subm, '_jit_is_scripting'): 

280 subm._jit_is_scripting = True 

281 

282 def _build_plumed_lines(self) -> List[str]: 

283 """Build the list of lines for the PLUMED file.""" 

284 lines = [] 

285 lines.append(f'INCLUDE FILE={os.path.abspath(self.io_dict["out"]["output_features_dat_path"])}') 

286 

287 # Additional actions (e.g., ENERGY, other metrics) 

288 for action in self.additional_actions: 

289 label = action.get('label', '') 

290 if label: 

291 label += ': ' 

292 name = action['name'] 

293 params_str = ' '.join(f'{k}={v}' for k, v in action.get('params', {}).items()) 

294 lines.append(f'{label}{name} {params_str}') 

295 

296 # GROUP 

297 group_label = 'C-alpha' 

298 if self.group: 

299 g = self.group 

300 group_label = g.get('label', 'C-alpha') 

301 params = ' '.join(f'{k}={v}' for k, v in g.items() if k not in ['label', 'name']) 

302 lines.append(f"{group_label}: GROUP {params}") 

303 fu.log(f'Using GROUP: {group_label}', self.out_log) 

304 fu.log(' Parameters:', self.out_log) 

305 for k, v in g.items(): 

306 if k not in ['label', 'name']: 

307 fu.log(f' > {k.upper()}: {v}', self.out_log) 

308 

309 # WHOLEMOLECULES 

310 uses_positions = True if 'cartesian_indices' in self.stats else False 

311 if uses_positions: 

312 if self.wholemolecules: 

313 w = self.wholemolecules 

314 params = ' '.join(f'{k}={v}' for k, v in w.items()) 

315 lines.append(f'WHOLEMOLECULES {params}') 

316 fu.log(f'Using WHOLEMOLECULES with parameters: {params}', self.out_log) 

317 else: 

318 fu.log('WARNING: Using Cartesian coordinates but no WHOLEMOLECULES parameters provided; add WHOLEMOLECULES in properties.', self.out_log) 

319 else: 

320 if self.wholemolecules: 

321 fu.log('NOTE: Reference PDB provided but no POSITION features detected; skipping WHOLEMOLECULES.', self.out_log) 

322 

323 # FIT_TO_TEMPLATE 

324 if uses_positions: 

325 if self.fit_to_template: 

326 f = self.fit_to_template 

327 params = ' '.join(f'{k}={v}' for k, v in f.items()) 

328 lines.append(f'FIT_TO_TEMPLATE REFERENCE={os.path.abspath(self.ref_pdb)} {params}') 

329 fu.log('Using FIT_TO_TEMPLATE', self.out_log) 

330 fu.log(f' Reference PDB: {os.path.abspath(self.ref_pdb)}', self.out_log) 

331 fu.log(' Parameters:', self.out_log) 

332 for k, v in f.items(): 

333 fu.log(f' > {k.upper()}: {v}', self.out_log) 

334 else: 

335 fu.log('WARNING: Using Cartesian coordinates but no FIT_TO_TEMPLATE parameters provided; add FIT_TO_TEMPLATE in properties.', self.out_log) 

336 else: 

337 if self.fit_to_template: 

338 fu.log('NOTE: Reference PDB provided but no POSITION features detected; skipping FIT_TO_TEMPLATE.', self.out_log) 

339 

340 # PYTORCH_MODEL 

341 pyt_label = 'cv' 

342 pyt_params = {'FILE': os.path.abspath(self.io_dict['out']['output_model_ptc_path']), 'ARG': self.arg} 

343 if self.pytorch_model: 

344 p = self.pytorch_model 

345 pyt_label = p.get('label', 'cv') 

346 pyt_params.update({k: v for k, v in p.items() if k not in ['label']}) 

347 params_str = ' '.join(f'{k}={v}' for k, v in pyt_params.items()) 

348 params_non_args = {f'{k}: {v}' for k, v in pyt_params.items() if k != 'ARG'} 

349 lines.append(f'{pyt_label}: PYTORCH_MODEL {params_str}') 

350 fu.log(f'Using PYTORCH_MODEL: {pyt_label}', self.out_log) 

351 fu.log(f' Model ptc file: {os.path.abspath(self.io_dict["out"]["output_model_ptc_path"])}', self.out_log) 

352 for param in params_non_args: 

353 if not param.startswith('FILE'): 

354 fu.log(' Parameters:', self.out_log) 

355 fu.log(f' > {param}', self.out_log) 

356 

357 # Bias actions 

358 for command in self.bias: 

359 label = command.get('label', '') 

360 if label: 

361 label += ': ' 

362 name = command['name'] 

363 params_str = ' '.join(f'{k}={v}' for k, v in command.get('params', {}).items()) 

364 lines.append(f'{label}{name} {params_str}') 

365 fu.log('Using Bias:', self.out_log) 

366 fu.log(f' Command: {name}', self.out_log) 

367 fu.log(' Parameters:', self.out_log) 

368 for param in command.get('params', {}).items(): 

369 fu.log(f' > {param[0]}: {param[1]}', self.out_log) 

370 # PRINT 

371 prints_str = ' '.join(f'{k}={v}' for k, v in self.prints.items()) 

372 lines.append(f'PRINT {prints_str}') 

373 

374 return lines 

375 

376 @launchlogger 

377 def launch(self) -> int: 

378 """Execute the generation of PLUMED files.""" 

379 

380 # Setup Biobb 

381 if self.check_restart(): 

382 return 0 

383 

384 self.stage_files() 

385 

386 # Perform model conversion and feature generation after staging files 

387 self._convert_model_to_ptc() 

388 features_lines, arg_list = self._generate_features() 

389 self.arg = ','.join(arg_list) 

390 plumed_lines = self._build_plumed_lines() 

391 

392 has_cartesian = True if 'cartesian_indices' in self.stats else False 

393 if self.ndx is None: 

394 if has_cartesian: 

395 fu.log('WARNING: When employing Cartesian coordinates as collective variables (CVs) for biasing in PLUMED, ' 

396 'an NDX index file is required to properly define atom groups for fitting and alignment purposes, ' 

397 'make sure to provide a NDX file.', self.out_log) 

398 

399 fu.log(f'Generated features.dat at {os.path.abspath(self.io_dict["out"]["output_features_dat_path"])}', self.out_log) 

400 fu.log(f'File size: {get_size(self.io_dict["out"]["output_features_dat_path"])}', self.out_log) 

401 

402 with open(self.io_dict['out']['output_plumed_dat_path'], 'w') as f: 

403 f.write('\n'.join(plumed_lines) + '\n') 

404 fu.log(f'Generated PLUMED file at {os.path.abspath(self.io_dict["out"]["output_plumed_dat_path"])}', self.out_log) 

405 fu.log(f'File size: {get_size(self.io_dict["out"]["output_plumed_dat_path"])}', self.out_log) 

406 

407 # Copy files to host 

408 self.copy_to_host() 

409 

410 # Remove temporal files 

411 self.remove_tmp_files() 

412 

413 self.check_arguments(output_files_created=True, raise_exception=False) 

414 

415 return 0 

416 

417 

418def generatePlumed( 

419 input_model_pth_path: str, 

420 input_stats_pt_path: Optional[str] = None, 

421 input_reference_pdb_path: Optional[str] = None, 

422 input_ndx_path: Optional[str] = None, 

423 output_plumed_dat_path: str = 'plumed.dat', 

424 output_features_dat_path: str = 'features.dat', 

425 output_model_ptc_path: str = 'model.ptc', 

426 properties: Optional[Dict[str, Any]] = None, 

427 **kwargs, 

428) -> int: 

429 """Create the :class:`GeneratePlumed <generatePlumed.GeneratePlumed>` class and 

430 execute the :meth:`launch() <generatePlumed.GeneratePlumed.launch>` method.""" 

431 return GeneratePlumed(**dict(locals())).launch() 

432 

433 

434generatePlumed.__doc__ = GeneratePlumed.__doc__ 

435main = GeneratePlumed.get_main(generatePlumed, "Generate PLUMED input for biased dynamics using an MDAE model.") 

436 

437if __name__ == "__main__": 

438 main()