Coverage for biobb_pytorch / mdae / feat2traj.py: 71%

76 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1from biobb_common.generic.biobb_object import BiobbObject 

2from biobb_common.tools.file_utils import launchlogger 

3import torch 

4import numpy as np 

5import mdtraj as md 

6import os 

7 

8 

9class Feat2Traj(BiobbObject): 

10 """ 

11 | biobb_pytorch Feat2Traj 

12 | Converts a .pt file (features) to a trajectory using cartesian indices and topology from the stats file. 

13 | Converts a .pt file (features) to a trajectory using cartesian indices and topology from the stats file. 

14 

15 Args: 

16 input_results_npz_path (str): Path to the input reconstructed results file (.npz), typically containing an 'xhat' array. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_input_results.npz>`_. Accepted formats: npz (edam:format_2333). 

17 input_stats_pt_path (str): Path to the input model statistics file (.pt) containing cartesian indices and optionally topology. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_input_model.pt>`_. Accepted formats: pt (edam:format_2333). 

18 input_topology_path (str) (optional): Path to the topology file (PDB) used if no suitable topology is found in the stats file. Used if no topology is found in stats. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/mdae/ref_input_topology.pdb>`_. Accepted formats: pdb (edam:format_1476). 

19 output_traj_path (str): Path to save the trajectory in xtc/pdb/dcd format. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.xtc>`_. Accepted formats: xtc (edam:format_3875), pdb (edam:format_1476), dcd (edam:format_3878). 

20 output_top_path (str) (optional): Path to save the output topology file (pdb). Used if trajectory format requires separate topology. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/mdae/output_model.pdb>`_. Accepted formats: pdb (edam:format_1476). 

21 properties (dict - Python dictionary object containing the tool parameters, not input/output files): 

22 * **restart** (*bool*) - (False) [WF property] Do not execute if output files exist. 

23 

24 Examples: 

25 This example shows how to use the Feat2Traj class to convert a .pt file (features) to a trajectory using cartesian indices and topology from the stats file:: 

26 

27 from biobb_pytorch.mdae.feat2traj import feat2traj 

28 

29 input_results_npz_path='input_results.npz' 

30 input_stats_pt_path='input_model.pt' 

31 input_topology_path='input_topology.pdb' 

32 output_traj_path='output_model.xtc' 

33 output_top_path='output_model.pdb' 

34 

35 prop={} 

36 

37 feat2traj(input_results_npz_path=input_results_npz_path, 

38 input_stats_pt_path=input_stats_pt_path, 

39 input_topology_path=input_topology_path, 

40 output_traj_path=output_traj_path, 

41 output_top_path=output_top_path, 

42 properties=prop) 

43 

44 Info: 

45 * wrapped_software: 

46 * name: PyTorch 

47 * version: >=1.6.0 

48 * license: BSD 3-Clause 

49 * ontology: 

50 * name: EDAM 

51 * schema: http://edamontology.org/EDAM.owl 

52 """ 

53 

54 def __init__( 

55 self, 

56 input_results_npz_path: str, 

57 input_stats_pt_path: str, 

58 input_topology_path: str = None, 

59 output_traj_path: str = None, 

60 output_top_path: str = None, 

61 properties: dict = None, 

62 **kwargs, 

63 ) -> None: 

64 properties = properties or {} 

65 super().__init__(properties) 

66 

67 self.input_results_npz_path = input_results_npz_path 

68 self.input_stats_pt_path = input_stats_pt_path 

69 self.input_topology_path = input_topology_path 

70 self.output_traj_path = output_traj_path 

71 self.output_top_path = output_top_path 

72 self.properties = properties.copy() 

73 self.locals_var_dict = locals().copy() 

74 self.io_dict = { 

75 "in": { 

76 "input_results_npz_path": input_results_npz_path, 

77 "input_stats_pt_path": input_stats_pt_path, 

78 "input_topology_path": input_topology_path, 

79 }, 

80 "out": { 

81 "output_traj_path": output_traj_path, 

82 "output_top_path": output_top_path, 

83 }, 

84 } 

85 self.check_properties(properties) 

86 self.check_arguments() 

87 

88 @launchlogger 

89 def launch(self) -> int: 

90 """ 

91 Execute the :class:`Feat2Traj` class and its `.launch()` method. 

92 """ 

93 # Load features 

94 features = np.load(self.input_results_npz_path) 

95 features = features['xhat'] 

96 

97 # Load stats and extract cartesian indices and topology 

98 stats = torch.load(self.input_stats_pt_path, 

99 weights_only=False) 

100 cartesian_indices = None 

101 topology = None 

102 if isinstance(stats, dict): 

103 if 'cartesian_indices' in stats: 

104 cartesian_indices = stats['cartesian_indices'] 

105 topology = stats['topology'] 

106 

107 else: 

108 raise ValueError('No cartesian indices found in stats file.') 

109 cartesian_indices = np.array(cartesian_indices) 

110 

111 n_atoms = len(cartesian_indices) 

112 n_frames = features.shape[0] 

113 coords = features.reshape((n_frames, n_atoms, 3)) 

114 

115 # Try to use topology from stats file if present 

116 top = None 

117 if topology is not None: 

118 try: 

119 # If topology is a serialized MDTraj Topology, try to load it 

120 if isinstance(topology, md.Trajectory): 

121 top = topology.topology 

122 elif isinstance(topology, str) and os.path.exists(topology): 

123 top = md.load_topology(topology) 

124 elif isinstance(topology, dict) and 'pdb_string' in topology: 

125 import io 

126 top = md.load(io.StringIO(topology['pdb_string']), format='pdb').topology 

127 except Exception as e: 

128 print(f"Warning: Could not load topology from stats file: {e}") 

129 top = None 

130 

131 # If not found, try input_topology_path 

132 if top is None and self.input_topology_path is not None and os.path.exists(self.input_topology_path): 

133 top = md.load_topology(self.input_topology_path) 

134 # Fallback: create a fake topology 

135 if top is None: 

136 top = md.Topology() 

137 chain = top.add_chain() 

138 res = top.add_residue('RES', chain) 

139 for i in range(n_atoms): 

140 top.add_atom('CA', element=md.element.carbon, residue=res) 

141 traj = md.Trajectory(xyz=coords, topology=top) 

142 

143 if self.output_traj_path: 

144 ext = os.path.splitext(self.output_traj_path)[1] 

145 if ext == '.xtc': 

146 traj.save_xtc(self.output_traj_path) 

147 traj[0].save_pdb(self.output_top_path) 

148 elif ext == '.dcd': 

149 traj.save_dcd(self.output_traj_path) 

150 traj[0].save_pdb(self.output_top_path) 

151 elif ext == '.pdb': 

152 traj.save_pdb(self.output_traj_path) 

153 else: 

154 raise ValueError(f'Unknown trajectory extension: {ext}') 

155 return 0 

156 

157 

158def feat2traj( 

159 input_results_npz_path: str, 

160 input_stats_pt_path: str, 

161 input_topology_path: str = None, 

162 output_traj_path: str = None, 

163 output_top_path: str = None, 

164 properties: dict = None, 

165 **kwargs, 

166) -> int: 

167 """Create the :class:`Feat2Traj <Feat2Traj.Feat2Traj>` class and 

168 execute the :meth:`launch() <Feat2Traj.feat2traj.launch>` method.""" 

169 return Feat2Traj(**dict(locals())).launch() 

170 

171 

172feat2traj.__doc__ = Feat2Traj.__doc__ 

173main = Feat2Traj.get_main(feat2traj, "Converts a .pt file (features) to a trajectory using cartesian indices and topology from the stats file.") 

174 

175if __name__ == "__main__": 

176 main()