⬅ biobb_pytorch/mdae/make_plumed.py source

1 import torch
2 from typing import Dict, Any, Optional, List
3 import os
4 from biobb_pytorch.mdae.utils.log_utils import get_size
5 from biobb_common.tools.file_utils import launchlogger
6 from biobb_common.tools import file_utils as fu
7 from biobb_common.generic.biobb_object import BiobbObject
8  
9  
10 class GeneratePlumed(BiobbObject):
11 """
12 | biobb_plumed GeneratePlumed
13 | Generate PLUMED input for biased dynamics using an MDAE model.
14 | Generates a PLUMED input file, features.dat, and converts the model to .ptc format.
15  
16 Args:
17 input_model_pth_path (str): Path to the trained PyTorch model (.pth) to be converted to TorchScript and used in PLUMED. File type: input. Accepted formats: pth (edam:format_2333).
18 input_stats_pt_path (str) (Optional): Path to statistics file (.pt) produced during featurization, used to derive the PLUMED features.dat content. File type: input. Accepted formats: pt (edam:format_2333).
19 input_reference_pdb_path (str) (Optional): Path to reference PDB used for FIT_TO_TEMPLATE actions when Cartesian features are present. File type: input. Accepted formats: pdb (edam:format_1476).
20 input_ndx_path (str) (Optional): Path to GROMACS index (NDX) file used to define groups when required by PLUMED. File type: input. Accepted formats: ndx (edam:format_2033).
21 output_plumed_dat_path (str): Path to the output PLUMED input file. File type: output. Accepted formats: dat (edam:format_2330).
22 output_features_dat_path (str): Path to the output features.dat file describing the CVs to PLUMED. File type: output. Accepted formats: dat (edam:format_2330).
23 output_model_ptc_path (str): Path to the output TorchScript model file (.ptc) for PLUMED's PYTORCH_MODEL action. File type: output. Accepted formats: ptc (edam:format_2333).
24 properties (dict - Python dictionary object containing the tool parameters, not input/output files):
25 * **include_energy** (*bool*) - (True) Whether to include ENERGY in PLUMED.
26 * **bias** (*list*) - ([]) List of biasing actions (e.g. METAD) to be added to the PLUMED file.
27 * **prints** (*dict*) - ({"ARG": "*", "STRIDE": 1, "FILE": "COLVAR"}) PRINT command parameters (e.g. ARG, STRIDE, FILE).
28 * **group** (*dict*) - (None) GROUP definition options (label, NDX group or atom selection parameters).
29 * **wholemolecules** (*dict*) - (None) WHOLEMOLECULES options when using Cartesian coordinates.
30 * **fit_to_template** (*dict*) - (None) FIT_TO_TEMPLATE options (e.g. STRIDE, TYPE, etc.).
31 * **pytorch_model** (*dict*) - (None) PYTORCH_MODEL options (label, PACE and other parameters).
32  
33 Examples:
34 This example shows how to use the GeneratePlumed class to generate a PLUMED input file for biased dynamics using an MDAE model::
35  
36 from biobb_plumed.generate_plumed import generatePlumed
37  
38 prop = {
39 "additional_actions": [
40 {
41 "name": "ENERGY",
42 "label": "ene"
43 },
44 {
45 "name": "RMSD",
46 "label": "rmsd",
47 "params": {
48 "TYPE": "OPTIMAL"
49 }
50 }
51 ],
52 "group": {
53 "label": "c_alphas",
54 "NDX_GROUP": "chA_&_C-alpha"
55 },
56 "wholemolecules": {
57 "ENTITY0": "c_alphas"
58 },
59 "fit_to_template": {
60 "STRIDE": 1,
61 "TYPE": "OPTIMAL"
62 },
63 "pytorch_model": {
64 "label": "cv",
65 "PACE": 1
66 },
67 "bias": [
68 {
69 "name": "METAD",
70 "label": "bias",
71 "params": {
72 "ARG": "cv.1",
73 "PACE": 500,
74 "HEIGHT": 1.2,
75 "SIGMA": 0.35,
76 "FILE": "HILLS",
77 "BIASFACTOR": 8
78 }
79 }
80 ],
81 "prints": {
82 "ARG": "cv.*,bias.*",
83 "STRIDE": 1,
84 "FILE": "COLVAR"
85 }
86 }
87  
88 generatePlumed(
89 input_model_pth_path="model.pth",
90 input_stats_pt_path="stats.pt",
91 output_plumed_dat_path="plumed.dat",
92 output_features_dat_path="features.dat",
93 output_model_ptc_path="model.ptc",
94 properties=prop
95 )
96  
97 Info:
98 * wrapped_software:
99 * name: PLUMED with PyTorch
100 * version: >=2.0
101 * license: LGPL 3.0
102 * ontology:
103 * name: EDAM
104 * schema: http://edamontology.org/EDAM.owl
105 """
106  
107 def __init__(
108 self,
109 input_model_pth_path: str,
110 input_stats_pt_path: Optional[str] = None,
111 input_reference_pdb_path: Optional[str] = None,
112 input_ndx_path: Optional[str] = None,
113 output_plumed_dat_path: str = 'plumed.dat',
114 output_features_dat_path: str = 'features.dat',
115 output_model_ptc_path: str = 'model.ptc',
116 properties: Optional[Dict[str, Any]] = None,
117 **kwargs,
118 ) -> None:
119 properties = properties or {}
120  
121 super().__init__(properties)
122 self.locals_var_dict = locals().copy()
123  
124 # Input/Output files
125 self.io_dict = {
126 "in": {"input_model_pth_path": input_model_pth_path},
127 "out": {
128 "output_plumed_dat_path": output_plumed_dat_path,
129 "output_features_dat_path": output_features_dat_path,
130 "output_model_ptc_path": output_model_ptc_path
131 }
132 }
133 if input_stats_pt_path:
134 self.io_dict["in"]["input_stats_pt_path"] = input_stats_pt_path
135 if input_reference_pdb_path:
136 self.io_dict["in"]["input_reference_pdb_path"] = input_reference_pdb_path
137 if input_ndx_path:
138 self.io_dict["in"]["input_ndx_path"] = input_ndx_path
139  
140 # Properties
141 self.model_pth = input_model_pth_path
142 self.stats_pt = input_stats_pt_path
143 self.ref_pdb = input_reference_pdb_path
144 self.ndx = input_ndx_path
145 self.properties = properties
146  
147 self.additional_actions = self.properties.get('additional_actions', [])
148 self.group = self.properties.get('group', None)
149 self.wholemolecules = self.properties.get('wholemolecules', None)
150 self.fit_to_template = self.properties.get('fit_to_template', None)
151 self.pytorch_model = self.properties.get('pytorch_model', None)
152 self.bias = self.properties.get('bias', [])
153 self.prints = self.properties.get('prints', {'ARG': '*', 'STRIDE': 1, 'FILE': 'COLVAR'})
154  
155 # Check the properties
156 self.check_properties(properties)
157 self.check_arguments()
158  
159 self.stats = self._load_stats()
160 self.n_features = self.stats.get('shape', [None, None])[1]
161  
162 def _load_stats(self) -> Optional[Dict[str, Any]]:
163 """Load stats.pt if provided."""
164 if self.stats_pt:
165 return torch.load(self.stats_pt,
166 weights_only=False)
167 return None
168  
169 def _generate_features(self) -> str:
170 """
171 Generate features.dat and return the ARG string for PYTORCH_MODEL.
172  
173 Returns:
174 str: Comma-separated ARG string.
175 """
176 if self.stats_pt:
177 # Non-Cartesian or mixed mode
178 return self._generate_features_from_stats(self.stats, self.io_dict['out']['output_features_dat_path'])
179 else:
180 raise ValueError('Input_stats_pt_path is required.')
181  
182 def _generate_features_from_stats(self, stats: Dict[str, Any], features_path: str) -> str:
183 """
184 Generate features.dat from stats.pt for distances, angles, dihedrals, and/or cartesians.
185  
186 Args:
187 stats (Dict[str, Any]): Loaded stats dictionary.
188 features_path (str): Path to write features.dat.
189  
190 Returns:
191 str: Comma-separated ARG string.
192 """
193 feat_lines = []
194 arg_list = []
195 dist_count = 1
196 ang_count = 1
197 tor_count = 1
198  
199 # Adjust indices to 1-based for PLUMED
200 def adjust_indices(indices: List[int]) -> List[int]:
201 return [idx + 1 for idx in indices]
202  
203 if 'cartesian_indices' in stats:
204 pos_atoms = adjust_indices(stats['cartesian_indices'])
205 fu.log(f"Found {len(pos_atoms)} Cartesian features.", self.out_log)
206 for atom in pos_atoms:
207 feat_lines.append(f'p{atom}: POSITION ATOM={atom}')
208 arg_list.extend([f'p{atom}.x', f'p{atom}.y', f'p{atom}.z'])
209  
210 if 'distance_indices' in stats:
211 fu.log(f"Found {len(stats['distance_indices'])} Distance features.", self.out_log)
212 for pair in stats['distance_indices']:
213 a, b = adjust_indices(pair)
214 label = f'd{dist_count}'
215 feat_lines.append(f'{label}: DISTANCE ATOMS={a},{b}')
216 arg_list.append(label)
217 dist_count += 1
218  
219 if 'angle_indices' in stats:
220 fu.log(f"Found {len(stats['angle_indices'])} Angle features.", self.out_log)
221 for triple in stats['angle_indices']:
222 a, b, c = adjust_indices(triple)
223 label = f'a{ang_count}'
224 feat_lines.append(f'{label}: ANGLE ATOMS={a},{b},{c}')
225 arg_list.append(label)
226 ang_count += 1
227  
228 if 'dihedral_indices' in stats:
229 fu.log(f"Found {len(stats['dihedral_indices'])} Dihedral features.", self.out_log)
230 for quad in stats['dihedral_indices']:
231 a, b, c, d = adjust_indices(quad)
232 label = f't{tor_count}'
233 feat_lines.append(f'{label}: TORSION ATOMS={a},{b},{c},{d}')
234 arg_list.append(label)
235 tor_count += 1
236  
237 with open(features_path, 'w') as f:
238 for line in feat_lines:
239 f.write(line + '\n')
240  
241 return feat_lines, arg_list
242  
243 def _convert_model_to_ptc(self) -> None:
244 """Convert the PyTorch model to TorchScript format (.ptc)."""
245 model = torch.load(self.model_pth, weights_only=False)
246  
247 # Add this: Convert numpy.int64 attributes to Python int for JIT compatibility
248 def convert_attributes_to_int(m):
249 if hasattr(m, 'in_features'):
250 m.in_features = int(m.in_features)
251 if hasattr(m, 'out_features'):
252 m.out_features = int(m.out_features)
253 for child in m.children():
254 convert_attributes_to_int(child)
255  
256 convert_attributes_to_int(model)
257  
258 self._enable_jit_scripting(model)
259 output_path = self.io_dict['out']['output_model_ptc_path']
260 try:
261 scripted_model = torch.jit.script(model)
262 torch.jit.save(scripted_model, output_path)
263 fu.log(f'Successfully scripted and saved model to {output_path}', self.out_log)
264 except Exception as e:
  • F841 Local variable 'e' is assigned to but never used
265 fu.log(f'jit.script failed: Attempting jit.trace instead.', self.out_log)
  • F541 F-string is missing placeholders
266 # Set to eval mode for tracing (required for BatchNorm with batch size 1)
267 model.eval()
268 example_input = torch.randn(1, self.n_features) # Batch size 1, flat input
269 traced_model = torch.jit.trace(model, example_input)
270 torch.jit.save(traced_model, output_path)
271 fu.log(f'Successfully traced and saved model to {output_path}', self.out_log)
272  
273 def _enable_jit_scripting(self, module: torch.nn.Module) -> None:
274 """Set _jit_is_scripting flag to True for the module and submodules to enable scripting."""
275 if hasattr(module, '_jit_is_scripting'):
276 module._jit_is_scripting = True
277 for subm in module.modules():
278 if hasattr(subm, '_jit_is_scripting'):
279 subm._jit_is_scripting = True
280  
281 def _build_plumed_lines(self) -> List[str]:
282 """Build the list of lines for the PLUMED file."""
283 lines = []
284 lines.append(f'INCLUDE FILE={os.path.abspath(self.io_dict["out"]["output_features_dat_path"])}')
285  
286 # Additional actions (e.g., ENERGY, other metrics)
287 for action in self.additional_actions:
288 label = action.get('label', '')
289 if label:
290 label += ': '
291 name = action['name']
292 params_str = ' '.join(f'{k}={v}' for k, v in action.get('params', {}).items())
293 lines.append(f'{label}{name} {params_str}')
294  
295 # GROUP
296 group_label = 'C-alpha'
297 if self.group:
298 g = self.group
299 group_label = g.get('label', 'C-alpha')
300 params = ' '.join(f'{k}={v}' for k, v in g.items() if k not in ['label', 'name'])
301 lines.append(f"{group_label}: GROUP {params}")
302 fu.log(f'Using GROUP: {group_label}', self.out_log)
303 fu.log(' Parameters:', self.out_log)
304 for k, v in g.items():
305 if k not in ['label', 'name']:
306 fu.log(f' > {k.upper()}: {v}', self.out_log)
307  
308 # WHOLEMOLECULES
309 uses_positions = True if 'cartesian_indices' in self.stats else False
310 if uses_positions:
311 if self.wholemolecules:
312 w = self.wholemolecules
313 params = ' '.join(f'{k}={v}' for k, v in w.items())
314 lines.append(f'WHOLEMOLECULES {params}')
315 fu.log(f'Using WHOLEMOLECULES with parameters: {params}', self.out_log)
316 else:
317 fu.log('WARNING: Using Cartesian coordinates but no WHOLEMOLECULES parameters provided; add WHOLEMOLECULES in properties.', self.out_log)
318 else:
319 if self.wholemolecules:
320 fu.log('NOTE: Reference PDB provided but no POSITION features detected; skipping WHOLEMOLECULES.', self.out_log)
321  
322 # FIT_TO_TEMPLATE
323 if uses_positions:
324 if self.fit_to_template:
325 f = self.fit_to_template
326 params = ' '.join(f'{k}={v}' for k, v in f.items())
327 lines.append(f'FIT_TO_TEMPLATE REFERENCE={os.path.abspath(self.ref_pdb)} {params}')
328 fu.log('Using FIT_TO_TEMPLATE', self.out_log)
329 fu.log(f' Reference PDB: {os.path.abspath(self.ref_pdb)}', self.out_log)
330 fu.log(' Parameters:', self.out_log)
331 for k, v in f.items():
332 fu.log(f' > {k.upper()}: {v}', self.out_log)
333 else:
334 fu.log('WARNING: Using Cartesian coordinates but no FIT_TO_TEMPLATE parameters provided; add FIT_TO_TEMPLATE in properties.', self.out_log)
335 else:
336 if self.fit_to_template:
337 fu.log('NOTE: Reference PDB provided but no POSITION features detected; skipping FIT_TO_TEMPLATE.', self.out_log)
338  
339 # PYTORCH_MODEL
340 pyt_label = 'cv'
341 pyt_params = {'FILE': os.path.abspath(self.io_dict['out']['output_model_ptc_path']), 'ARG': self.arg}
342 if self.pytorch_model:
343 p = self.pytorch_model
344 pyt_label = p.get('label', 'cv')
345 pyt_params.update({k: v for k, v in p.items() if k not in ['label']})
346 params_str = ' '.join(f'{k}={v}' for k, v in pyt_params.items())
347 params_non_args = {f'{k}: {v}' for k, v in pyt_params.items() if k != 'ARG'}
348 lines.append(f'{pyt_label}: PYTORCH_MODEL {params_str}')
349 fu.log(f'Using PYTORCH_MODEL: {pyt_label}', self.out_log)
350 fu.log(f' Model ptc file: {os.path.abspath(self.io_dict["out"]["output_model_ptc_path"])}', self.out_log)
351 for param in params_non_args:
352 if not param.startswith('FILE'):
353 fu.log(' Parameters:', self.out_log)
354 fu.log(f' > {param}', self.out_log)
355  
356 # Bias actions
357 for command in self.bias:
358 label = command.get('label', '')
359 if label:
360 label += ': '
361 name = command['name']
362 params_str = ' '.join(f'{k}={v}' for k, v in command.get('params', {}).items())
363 lines.append(f'{label}{name} {params_str}')
364 fu.log('Using Bias:', self.out_log)
365 fu.log(f' Command: {name}', self.out_log)
366 fu.log(' Parameters:', self.out_log)
367 for param in command.get('params', {}).items():
368 fu.log(f' > {param[0]}: {param[1]}', self.out_log)
369 # PRINT
370 prints_str = ' '.join(f'{k}={v}' for k, v in self.prints.items())
371 lines.append(f'PRINT {prints_str}')
372  
373 return lines
374  
375 @launchlogger
376 def launch(self) -> int:
377 """Execute the generation of PLUMED files."""
378  
379 # Setup Biobb
380 if self.check_restart():
381 return 0
382  
383 self.stage_files()
384  
385 # Perform model conversion and feature generation after staging files
386 self._convert_model_to_ptc()
387 features_lines, arg_list = self._generate_features()
388 self.arg = ','.join(arg_list)
389 plumed_lines = self._build_plumed_lines()
390  
391 has_cartesian = True if 'cartesian_indices' in self.stats else False
392 if self.ndx is None:
393 if has_cartesian:
394 fu.log('WARNING: When employing Cartesian coordinates as collective variables (CVs) for biasing in PLUMED, '
395 'an NDX index file is required to properly define atom groups for fitting and alignment purposes, '
396 'make sure to provide a NDX file.', self.out_log)
397  
398 fu.log(f'Generated features.dat at {os.path.abspath(self.io_dict["out"]["output_features_dat_path"])}', self.out_log)
399 fu.log(f'File size: {get_size(self.io_dict["out"]["output_features_dat_path"])}', self.out_log)
400  
401 with open(self.io_dict['out']['output_plumed_dat_path'], 'w') as f:
402 f.write('\n'.join(plumed_lines) + '\n')
403 fu.log(f'Generated PLUMED file at {os.path.abspath(self.io_dict["out"]["output_plumed_dat_path"])}', self.out_log)
404 fu.log(f'File size: {get_size(self.io_dict["out"]["output_plumed_dat_path"])}', self.out_log)
405  
406 # Copy files to host
407 self.copy_to_host()
408  
409 # Remove temporal files
410 self.remove_tmp_files()
411  
412 self.check_arguments(output_files_created=True, raise_exception=False)
413  
414 return 0
415  
416  
417 def generatePlumed(
418 input_model_pth_path: str,
419 input_stats_pt_path: Optional[str] = None,
420 input_reference_pdb_path: Optional[str] = None,
421 input_ndx_path: Optional[str] = None,
422 output_plumed_dat_path: str = 'plumed.dat',
423 output_features_dat_path: str = 'features.dat',
424 output_model_ptc_path: str = 'model.ptc',
425 properties: Optional[Dict[str, Any]] = None,
426 **kwargs,
427 ) -> int:
428 """Create the :class:`GeneratePlumed <generatePlumed.GeneratePlumed>` class and
429 execute the :meth:`launch() <generatePlumed.GeneratePlumed.launch>` method."""
430 return GeneratePlumed(**dict(locals())).launch()
431  
432  
433 generatePlumed.__doc__ = GeneratePlumed.__doc__
434 main = GeneratePlumed.get_main(generatePlumed, "Generate PLUMED input for biased dynamics using an MDAE model.")
435  
436 if __name__ == "__main__":
437 main()