Coverage for biobb_pytorch / mdae / feat2traj.py: 71%
76 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1from biobb_common.generic.biobb_object import BiobbObject
2from biobb_common.tools.file_utils import launchlogger
3import torch
4import numpy as np
5import mdtraj as md
6import os
9class Feat2Traj(BiobbObject):
10 """
11 | biobb_pytorch Feat2Traj
12 | Converts a .pt file (features) to a trajectory using cartesian indices and topology from the stats file.
13 | Converts a .pt file (features) to a trajectory using cartesian indices and topology from the stats file.
15 Args:
16 input_results_npz_path (str): Path to the input reconstructed results file (.npz), typically containing an 'xhat' array. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_input_results.npz>`_. Accepted formats: npz (edam:format_2333).
17 input_stats_pt_path (str): Path to the input model statistics file (.pt) containing cartesian indices and optionally topology. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_input_model.pt>`_. Accepted formats: pt (edam:format_2333).
18 input_topology_path (str) (optional): Path to the topology file (PDB) used if no suitable topology is found in the stats file. Used if no topology is found in stats. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/mdae/ref_input_topology.pdb>`_. Accepted formats: pdb (edam:format_1476).
19 output_traj_path (str): Path to save the trajectory in xtc/pdb/dcd format. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.xtc>`_. Accepted formats: xtc (edam:format_3875), pdb (edam:format_1476), dcd (edam:format_3878).
20 output_top_path (str) (optional): Path to save the output topology file (pdb). Used if trajectory format requires separate topology. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/mdae/output_model.pdb>`_. Accepted formats: pdb (edam:format_1476).
21 properties (dict - Python dictionary object containing the tool parameters, not input/output files):
22 * **restart** (*bool*) - (False) [WF property] Do not execute if output files exist.
24 Examples:
25 This example shows how to use the Feat2Traj class to convert a .pt file (features) to a trajectory using cartesian indices and topology from the stats file::
27 from biobb_pytorch.mdae.feat2traj import feat2traj
29 input_results_npz_path='input_results.npz'
30 input_stats_pt_path='input_model.pt'
31 input_topology_path='input_topology.pdb'
32 output_traj_path='output_model.xtc'
33 output_top_path='output_model.pdb'
35 prop={}
37 feat2traj(input_results_npz_path=input_results_npz_path,
38 input_stats_pt_path=input_stats_pt_path,
39 input_topology_path=input_topology_path,
40 output_traj_path=output_traj_path,
41 output_top_path=output_top_path,
42 properties=prop)
44 Info:
45 * wrapped_software:
46 * name: PyTorch
47 * version: >=1.6.0
48 * license: BSD 3-Clause
49 * ontology:
50 * name: EDAM
51 * schema: http://edamontology.org/EDAM.owl
52 """
54 def __init__(
55 self,
56 input_results_npz_path: str,
57 input_stats_pt_path: str,
58 input_topology_path: str = None,
59 output_traj_path: str = None,
60 output_top_path: str = None,
61 properties: dict = None,
62 **kwargs,
63 ) -> None:
64 properties = properties or {}
65 super().__init__(properties)
67 self.input_results_npz_path = input_results_npz_path
68 self.input_stats_pt_path = input_stats_pt_path
69 self.input_topology_path = input_topology_path
70 self.output_traj_path = output_traj_path
71 self.output_top_path = output_top_path
72 self.properties = properties.copy()
73 self.locals_var_dict = locals().copy()
74 self.io_dict = {
75 "in": {
76 "input_results_npz_path": input_results_npz_path,
77 "input_stats_pt_path": input_stats_pt_path,
78 "input_topology_path": input_topology_path,
79 },
80 "out": {
81 "output_traj_path": output_traj_path,
82 "output_top_path": output_top_path,
83 },
84 }
85 self.check_properties(properties)
86 self.check_arguments()
88 @launchlogger
89 def launch(self) -> int:
90 """
91 Execute the :class:`Feat2Traj` class and its `.launch()` method.
92 """
93 # Load features
94 features = np.load(self.input_results_npz_path)
95 features = features['xhat']
97 # Load stats and extract cartesian indices and topology
98 stats = torch.load(self.input_stats_pt_path,
99 weights_only=False)
100 cartesian_indices = None
101 topology = None
102 if isinstance(stats, dict):
103 if 'cartesian_indices' in stats:
104 cartesian_indices = stats['cartesian_indices']
105 topology = stats['topology']
107 else:
108 raise ValueError('No cartesian indices found in stats file.')
109 cartesian_indices = np.array(cartesian_indices)
111 n_atoms = len(cartesian_indices)
112 n_frames = features.shape[0]
113 coords = features.reshape((n_frames, n_atoms, 3))
115 # Try to use topology from stats file if present
116 top = None
117 if topology is not None:
118 try:
119 # If topology is a serialized MDTraj Topology, try to load it
120 if isinstance(topology, md.Trajectory):
121 top = topology.topology
122 elif isinstance(topology, str) and os.path.exists(topology):
123 top = md.load_topology(topology)
124 elif isinstance(topology, dict) and 'pdb_string' in topology:
125 import io
126 top = md.load(io.StringIO(topology['pdb_string']), format='pdb').topology
127 except Exception as e:
128 print(f"Warning: Could not load topology from stats file: {e}")
129 top = None
131 # If not found, try input_topology_path
132 if top is None and self.input_topology_path is not None and os.path.exists(self.input_topology_path):
133 top = md.load_topology(self.input_topology_path)
134 # Fallback: create a fake topology
135 if top is None:
136 top = md.Topology()
137 chain = top.add_chain()
138 res = top.add_residue('RES', chain)
139 for i in range(n_atoms):
140 top.add_atom('CA', element=md.element.carbon, residue=res)
141 traj = md.Trajectory(xyz=coords, topology=top)
143 if self.output_traj_path:
144 ext = os.path.splitext(self.output_traj_path)[1]
145 if ext == '.xtc':
146 traj.save_xtc(self.output_traj_path)
147 traj[0].save_pdb(self.output_top_path)
148 elif ext == '.dcd':
149 traj.save_dcd(self.output_traj_path)
150 traj[0].save_pdb(self.output_top_path)
151 elif ext == '.pdb':
152 traj.save_pdb(self.output_traj_path)
153 else:
154 raise ValueError(f'Unknown trajectory extension: {ext}')
155 return 0
158def feat2traj(
159 input_results_npz_path: str,
160 input_stats_pt_path: str,
161 input_topology_path: str = None,
162 output_traj_path: str = None,
163 output_top_path: str = None,
164 properties: dict = None,
165 **kwargs,
166) -> int:
167 """Create the :class:`Feat2Traj <Feat2Traj.Feat2Traj>` class and
168 execute the :meth:`launch() <Feat2Traj.feat2traj.launch>` method."""
169 return Feat2Traj(**dict(locals())).launch()
172feat2traj.__doc__ = Feat2Traj.__doc__
173main = Feat2Traj.get_main(feat2traj, "Converts a .pt file (features) to a trajectory using cartesian indices and topology from the stats file.")
175if __name__ == "__main__":
176 main()