Coverage for biobb_pytorch / mdae / make_plumed.py: 69%
217 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
2import torch
3from typing import Dict, Any, Optional, List
4import os
5from biobb_pytorch.mdae.utils.log_utils import get_size
6from biobb_common.tools.file_utils import launchlogger
7from biobb_common.tools import file_utils as fu
8from biobb_common.generic.biobb_object import BiobbObject
11class GeneratePlumed(BiobbObject):
12 """
13 | biobb_plumed GeneratePlumed
14 | Generate PLUMED input for biased dynamics using an MDAE model.
15 | Generates a PLUMED input file, features.dat, and converts the model to .ptc format.
17 Args:
18 input_model_pth_path (str): Path to the trained PyTorch model (.pth) to be converted to TorchScript and used in PLUMED. File type: input. Accepted formats: pth (edam:format_2333).
19 input_stats_pt_path (str) (Optional): Path to statistics file (.pt) produced during featurization, used to derive the PLUMED features.dat content. File type: input. Accepted formats: pt (edam:format_2333).
20 input_reference_pdb_path (str) (Optional): Path to reference PDB used for FIT_TO_TEMPLATE actions when Cartesian features are present. File type: input. Accepted formats: pdb (edam:format_1476).
21 input_ndx_path (str) (Optional): Path to GROMACS index (NDX) file used to define groups when required by PLUMED. File type: input. Accepted formats: ndx (edam:format_2033).
22 output_plumed_dat_path (str): Path to the output PLUMED input file. File type: output. Accepted formats: dat (edam:format_2330).
23 output_features_dat_path (str): Path to the output features.dat file describing the CVs to PLUMED. File type: output. Accepted formats: dat (edam:format_2330).
24 output_model_ptc_path (str): Path to the output TorchScript model file (.ptc) for PLUMED's PYTORCH_MODEL action. File type: output. Accepted formats: ptc (edam:format_2333).
25 properties (dict - Python dictionary object containing the tool parameters, not input/output files):
26 * **include_energy** (*bool*) - (True) Whether to include ENERGY in PLUMED.
27 * **bias** (*list*) - ([]) List of biasing actions (e.g. METAD) to be added to the PLUMED file.
28 * **prints** (*dict*) - ({"ARG": "*", "STRIDE": 1, "FILE": "COLVAR"}) PRINT command parameters (e.g. ARG, STRIDE, FILE).
29 * **group** (*dict*) - (None) GROUP definition options (label, NDX group or atom selection parameters).
30 * **wholemolecules** (*dict*) - (None) WHOLEMOLECULES options when using Cartesian coordinates.
31 * **fit_to_template** (*dict*) - (None) FIT_TO_TEMPLATE options (e.g. STRIDE, TYPE, etc.).
32 * **pytorch_model** (*dict*) - (None) PYTORCH_MODEL options (label, PACE and other parameters).
34 Examples:
35 This example shows how to use the GeneratePlumed class to generate a PLUMED input file for biased dynamics using an MDAE model::
37 from biobb_plumed.generate_plumed import generatePlumed
39 prop = {
40 "additional_actions": [
41 {
42 "name": "ENERGY",
43 "label": "ene"
44 },
45 {
46 "name": "RMSD",
47 "label": "rmsd",
48 "params": {
49 "TYPE": "OPTIMAL"
50 }
51 }
52 ],
53 "group": {
54 "label": "c_alphas",
55 "NDX_GROUP": "chA_&_C-alpha"
56 },
57 "wholemolecules": {
58 "ENTITY0": "c_alphas"
59 },
60 "fit_to_template": {
61 "STRIDE": 1,
62 "TYPE": "OPTIMAL"
63 },
64 "pytorch_model": {
65 "label": "cv",
66 "PACE": 1
67 },
68 "bias": [
69 {
70 "name": "METAD",
71 "label": "bias",
72 "params": {
73 "ARG": "cv.1",
74 "PACE": 500,
75 "HEIGHT": 1.2,
76 "SIGMA": 0.35,
77 "FILE": "HILLS",
78 "BIASFACTOR": 8
79 }
80 }
81 ],
82 "prints": {
83 "ARG": "cv.*,bias.*",
84 "STRIDE": 1,
85 "FILE": "COLVAR"
86 }
87 }
89 generatePlumed(
90 input_model_pth_path="model.pth",
91 input_stats_pt_path="stats.pt",
92 output_plumed_dat_path="plumed.dat",
93 output_features_dat_path="features.dat",
94 output_model_ptc_path="model.ptc",
95 properties=prop
96 )
98 Info:
99 * wrapped_software:
100 * name: PLUMED with PyTorch
101 * version: >=2.0
102 * license: LGPL 3.0
103 * ontology:
104 * name: EDAM
105 * schema: http://edamontology.org/EDAM.owl
106 """
108 def __init__(
109 self,
110 input_model_pth_path: str,
111 input_stats_pt_path: Optional[str] = None,
112 input_reference_pdb_path: Optional[str] = None,
113 input_ndx_path: Optional[str] = None,
114 output_plumed_dat_path: str = 'plumed.dat',
115 output_features_dat_path: str = 'features.dat',
116 output_model_ptc_path: str = 'model.ptc',
117 properties: Optional[Dict[str, Any]] = None,
118 **kwargs,
119 ) -> None:
120 properties = properties or {}
122 super().__init__(properties)
123 self.locals_var_dict = locals().copy()
125 # Input/Output files
126 self.io_dict = {
127 "in": {"input_model_pth_path": input_model_pth_path},
128 "out": {
129 "output_plumed_dat_path": output_plumed_dat_path,
130 "output_features_dat_path": output_features_dat_path,
131 "output_model_ptc_path": output_model_ptc_path
132 }
133 }
134 if input_stats_pt_path:
135 self.io_dict["in"]["input_stats_pt_path"] = input_stats_pt_path
136 if input_reference_pdb_path:
137 self.io_dict["in"]["input_reference_pdb_path"] = input_reference_pdb_path
138 if input_ndx_path:
139 self.io_dict["in"]["input_ndx_path"] = input_ndx_path
141 # Properties
142 self.model_pth = input_model_pth_path
143 self.stats_pt = input_stats_pt_path
144 self.ref_pdb = input_reference_pdb_path
145 self.ndx = input_ndx_path
146 self.properties = properties
148 self.additional_actions = self.properties.get('additional_actions', [])
149 self.group = self.properties.get('group', None)
150 self.wholemolecules = self.properties.get('wholemolecules', None)
151 self.fit_to_template = self.properties.get('fit_to_template', None)
152 self.pytorch_model = self.properties.get('pytorch_model', None)
153 self.bias = self.properties.get('bias', [])
154 self.prints = self.properties.get('prints', {'ARG': '*', 'STRIDE': 1, 'FILE': 'COLVAR'})
156 # Check the properties
157 self.check_properties(properties)
158 self.check_arguments()
160 self.stats = self._load_stats()
161 self.n_features = self.stats.get('shape', [None, None])[1]
163 def _load_stats(self) -> Optional[Dict[str, Any]]:
164 """Load stats.pt if provided."""
165 if self.stats_pt:
166 return torch.load(self.stats_pt,
167 weights_only=False)
168 return None
170 def _generate_features(self) -> str:
171 """
172 Generate features.dat and return the ARG string for PYTORCH_MODEL.
174 Returns:
175 str: Comma-separated ARG string.
176 """
177 if self.stats_pt:
178 # Non-Cartesian or mixed mode
179 return self._generate_features_from_stats(self.stats, self.io_dict['out']['output_features_dat_path'])
180 else:
181 raise ValueError('Input_stats_pt_path is required.')
183 def _generate_features_from_stats(self, stats: Dict[str, Any], features_path: str) -> str:
184 """
185 Generate features.dat from stats.pt for distances, angles, dihedrals, and/or cartesians.
187 Args:
188 stats (Dict[str, Any]): Loaded stats dictionary.
189 features_path (str): Path to write features.dat.
191 Returns:
192 str: Comma-separated ARG string.
193 """
194 feat_lines = []
195 arg_list = []
196 dist_count = 1
197 ang_count = 1
198 tor_count = 1
200 # Adjust indices to 1-based for PLUMED
201 def adjust_indices(indices: List[int]) -> List[int]:
202 return [idx + 1 for idx in indices]
204 if 'cartesian_indices' in stats:
205 pos_atoms = adjust_indices(stats['cartesian_indices'])
206 fu.log(f"Found {len(pos_atoms)} Cartesian features.", self.out_log)
207 for atom in pos_atoms:
208 feat_lines.append(f'p{atom}: POSITION ATOM={atom}')
209 arg_list.extend([f'p{atom}.x', f'p{atom}.y', f'p{atom}.z'])
211 if 'distance_indices' in stats:
212 fu.log(f"Found {len(stats['distance_indices'])} Distance features.", self.out_log)
213 for pair in stats['distance_indices']:
214 a, b = adjust_indices(pair)
215 label = f'd{dist_count}'
216 feat_lines.append(f'{label}: DISTANCE ATOMS={a},{b}')
217 arg_list.append(label)
218 dist_count += 1
220 if 'angle_indices' in stats:
221 fu.log(f"Found {len(stats['angle_indices'])} Angle features.", self.out_log)
222 for triple in stats['angle_indices']:
223 a, b, c = adjust_indices(triple)
224 label = f'a{ang_count}'
225 feat_lines.append(f'{label}: ANGLE ATOMS={a},{b},{c}')
226 arg_list.append(label)
227 ang_count += 1
229 if 'dihedral_indices' in stats:
230 fu.log(f"Found {len(stats['dihedral_indices'])} Dihedral features.", self.out_log)
231 for quad in stats['dihedral_indices']:
232 a, b, c, d = adjust_indices(quad)
233 label = f't{tor_count}'
234 feat_lines.append(f'{label}: TORSION ATOMS={a},{b},{c},{d}')
235 arg_list.append(label)
236 tor_count += 1
238 with open(features_path, 'w') as f:
239 for line in feat_lines:
240 f.write(line + '\n')
242 return feat_lines, arg_list
244 def _convert_model_to_ptc(self) -> None:
245 """Convert the PyTorch model to TorchScript format (.ptc)."""
246 model = torch.load(self.model_pth, weights_only=False)
248 # Add this: Convert numpy.int64 attributes to Python int for JIT compatibility
249 def convert_attributes_to_int(m):
250 if hasattr(m, 'in_features'):
251 m.in_features = int(m.in_features)
252 if hasattr(m, 'out_features'):
253 m.out_features = int(m.out_features)
254 for child in m.children():
255 convert_attributes_to_int(child)
257 convert_attributes_to_int(model)
259 self._enable_jit_scripting(model)
260 output_path = self.io_dict['out']['output_model_ptc_path']
261 try:
262 scripted_model = torch.jit.script(model)
263 torch.jit.save(scripted_model, output_path)
264 fu.log(f'Successfully scripted and saved model to {output_path}', self.out_log)
265 except Exception as e:
266 fu.log(f'jit.script failed: {e}. Attempting jit.trace instead.', self.out_log)
267 # Set to eval mode for tracing (required for BatchNorm with batch size 1)
268 model.eval()
269 example_input = torch.randn(1, self.n_features) # Batch size 1, flat input
270 traced_model = torch.jit.trace(model, example_input)
271 torch.jit.save(traced_model, output_path)
272 fu.log(f'Successfully traced and saved model to {output_path}', self.out_log)
274 def _enable_jit_scripting(self, module: torch.nn.Module) -> None:
275 """Set _jit_is_scripting flag to True for the module and submodules to enable scripting."""
276 if hasattr(module, '_jit_is_scripting'):
277 module._jit_is_scripting = True
278 for subm in module.modules():
279 if hasattr(subm, '_jit_is_scripting'):
280 subm._jit_is_scripting = True
282 def _build_plumed_lines(self) -> List[str]:
283 """Build the list of lines for the PLUMED file."""
284 lines = []
285 lines.append(f'INCLUDE FILE={os.path.abspath(self.io_dict["out"]["output_features_dat_path"])}')
287 # Additional actions (e.g., ENERGY, other metrics)
288 for action in self.additional_actions:
289 label = action.get('label', '')
290 if label:
291 label += ': '
292 name = action['name']
293 params_str = ' '.join(f'{k}={v}' for k, v in action.get('params', {}).items())
294 lines.append(f'{label}{name} {params_str}')
296 # GROUP
297 group_label = 'C-alpha'
298 if self.group:
299 g = self.group
300 group_label = g.get('label', 'C-alpha')
301 params = ' '.join(f'{k}={v}' for k, v in g.items() if k not in ['label', 'name'])
302 lines.append(f"{group_label}: GROUP {params}")
303 fu.log(f'Using GROUP: {group_label}', self.out_log)
304 fu.log(' Parameters:', self.out_log)
305 for k, v in g.items():
306 if k not in ['label', 'name']:
307 fu.log(f' > {k.upper()}: {v}', self.out_log)
309 # WHOLEMOLECULES
310 uses_positions = True if 'cartesian_indices' in self.stats else False
311 if uses_positions:
312 if self.wholemolecules:
313 w = self.wholemolecules
314 params = ' '.join(f'{k}={v}' for k, v in w.items())
315 lines.append(f'WHOLEMOLECULES {params}')
316 fu.log(f'Using WHOLEMOLECULES with parameters: {params}', self.out_log)
317 else:
318 fu.log('WARNING: Using Cartesian coordinates but no WHOLEMOLECULES parameters provided; add WHOLEMOLECULES in properties.', self.out_log)
319 else:
320 if self.wholemolecules:
321 fu.log('NOTE: Reference PDB provided but no POSITION features detected; skipping WHOLEMOLECULES.', self.out_log)
323 # FIT_TO_TEMPLATE
324 if uses_positions:
325 if self.fit_to_template:
326 f = self.fit_to_template
327 params = ' '.join(f'{k}={v}' for k, v in f.items())
328 lines.append(f'FIT_TO_TEMPLATE REFERENCE={os.path.abspath(self.ref_pdb)} {params}')
329 fu.log('Using FIT_TO_TEMPLATE', self.out_log)
330 fu.log(f' Reference PDB: {os.path.abspath(self.ref_pdb)}', self.out_log)
331 fu.log(' Parameters:', self.out_log)
332 for k, v in f.items():
333 fu.log(f' > {k.upper()}: {v}', self.out_log)
334 else:
335 fu.log('WARNING: Using Cartesian coordinates but no FIT_TO_TEMPLATE parameters provided; add FIT_TO_TEMPLATE in properties.', self.out_log)
336 else:
337 if self.fit_to_template:
338 fu.log('NOTE: Reference PDB provided but no POSITION features detected; skipping FIT_TO_TEMPLATE.', self.out_log)
340 # PYTORCH_MODEL
341 pyt_label = 'cv'
342 pyt_params = {'FILE': os.path.abspath(self.io_dict['out']['output_model_ptc_path']), 'ARG': self.arg}
343 if self.pytorch_model:
344 p = self.pytorch_model
345 pyt_label = p.get('label', 'cv')
346 pyt_params.update({k: v for k, v in p.items() if k not in ['label']})
347 params_str = ' '.join(f'{k}={v}' for k, v in pyt_params.items())
348 params_non_args = {f'{k}: {v}' for k, v in pyt_params.items() if k != 'ARG'}
349 lines.append(f'{pyt_label}: PYTORCH_MODEL {params_str}')
350 fu.log(f'Using PYTORCH_MODEL: {pyt_label}', self.out_log)
351 fu.log(f' Model ptc file: {os.path.abspath(self.io_dict["out"]["output_model_ptc_path"])}', self.out_log)
352 for param in params_non_args:
353 if not param.startswith('FILE'):
354 fu.log(' Parameters:', self.out_log)
355 fu.log(f' > {param}', self.out_log)
357 # Bias actions
358 for command in self.bias:
359 label = command.get('label', '')
360 if label:
361 label += ': '
362 name = command['name']
363 params_str = ' '.join(f'{k}={v}' for k, v in command.get('params', {}).items())
364 lines.append(f'{label}{name} {params_str}')
365 fu.log('Using Bias:', self.out_log)
366 fu.log(f' Command: {name}', self.out_log)
367 fu.log(' Parameters:', self.out_log)
368 for param in command.get('params', {}).items():
369 fu.log(f' > {param[0]}: {param[1]}', self.out_log)
370 # PRINT
371 prints_str = ' '.join(f'{k}={v}' for k, v in self.prints.items())
372 lines.append(f'PRINT {prints_str}')
374 return lines
376 @launchlogger
377 def launch(self) -> int:
378 """Execute the generation of PLUMED files."""
380 # Setup Biobb
381 if self.check_restart():
382 return 0
384 self.stage_files()
386 # Perform model conversion and feature generation after staging files
387 self._convert_model_to_ptc()
388 features_lines, arg_list = self._generate_features()
389 self.arg = ','.join(arg_list)
390 plumed_lines = self._build_plumed_lines()
392 has_cartesian = True if 'cartesian_indices' in self.stats else False
393 if self.ndx is None:
394 if has_cartesian:
395 fu.log('WARNING: When employing Cartesian coordinates as collective variables (CVs) for biasing in PLUMED, '
396 'an NDX index file is required to properly define atom groups for fitting and alignment purposes, '
397 'make sure to provide a NDX file.', self.out_log)
399 fu.log(f'Generated features.dat at {os.path.abspath(self.io_dict["out"]["output_features_dat_path"])}', self.out_log)
400 fu.log(f'File size: {get_size(self.io_dict["out"]["output_features_dat_path"])}', self.out_log)
402 with open(self.io_dict['out']['output_plumed_dat_path'], 'w') as f:
403 f.write('\n'.join(plumed_lines) + '\n')
404 fu.log(f'Generated PLUMED file at {os.path.abspath(self.io_dict["out"]["output_plumed_dat_path"])}', self.out_log)
405 fu.log(f'File size: {get_size(self.io_dict["out"]["output_plumed_dat_path"])}', self.out_log)
407 # Copy files to host
408 self.copy_to_host()
410 # Remove temporal files
411 self.remove_tmp_files()
413 self.check_arguments(output_files_created=True, raise_exception=False)
415 return 0
418def generatePlumed(
419 input_model_pth_path: str,
420 input_stats_pt_path: Optional[str] = None,
421 input_reference_pdb_path: Optional[str] = None,
422 input_ndx_path: Optional[str] = None,
423 output_plumed_dat_path: str = 'plumed.dat',
424 output_features_dat_path: str = 'features.dat',
425 output_model_ptc_path: str = 'model.ptc',
426 properties: Optional[Dict[str, Any]] = None,
427 **kwargs,
428) -> int:
429 """Create the :class:`GeneratePlumed <generatePlumed.GeneratePlumed>` class and
430 execute the :meth:`launch() <generatePlumed.GeneratePlumed.launch>` method."""
431 return GeneratePlumed(**dict(locals())).launch()
434generatePlumed.__doc__ = GeneratePlumed.__doc__
435main = GeneratePlumed.get_main(generatePlumed, "Generate PLUMED input for biased dynamics using an MDAE model.")
437if __name__ == "__main__":
438 main()