Coverage for biobb_pytorch / mdae / featurization / topology_selector.py: 47%

105 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import mdtraj as md 

2import itertools 

3 

4 

5class MDTopologySelector: 

6 """ 

7 A class to load an MDTraj topology and extract atom pairs (bonds or distances), triplets (angles or arbitrary triples), 

8 and quads (torsions or arbitrary quadruplets) for a given atom selection. 

9 """ 

10 

11 def __init__(self, topology): 

12 """ 

13 Parameters 

14 ---------- 

15 topology : str | md.Trajectory | md.Topology 

16 Path to a structure file (e.g., .pdb, .gro), an MDTraj Trajectory, or an MDTraj Topology. 

17 """ 

18 

19 if isinstance(topology, md.Trajectory): 

20 self.topology = topology.topology 

21 elif isinstance(topology, md.Topology): 

22 self.topology = topology 

23 elif isinstance(topology, str): 

24 traj = md.load(topology) 

25 self.topology = traj.topology 

26 else: 

27 raise ValueError("`topology` must be a file path, md.Trajectory, or md.Topology instance.") 

28 

29 # Precompute bond list as tuples of atom indices 

30 self.bonds = [(b.atom1.index, b.atom2.index) for b in self.topology.bonds] 

31 

32 def select(self, selection): 

33 """ 

34 Select atom indices matching an MDTraj selection string. 

35 

36 Parameters 

37 ---------- 

38 selection : str 

39 MDTraj selection syntax, e.g., "backbone", "name CA", etc. 

40 

41 Returns 

42 ------- 

43 numpy.ndarray of int 

44 Array of atom indices. 

45 """ 

46 return self.topology.select(selection) 

47 

48 def get_atom_pairs(self, selection, bonded=True): 

49 """ 

50 Get atom pairs for a selection. 

51 

52 Parameters 

53 ---------- 

54 selection : str 

55 MDTraj selection syntax. 

56 bonded : bool, default=True 

57 If True, return only bonded pairs. If False, return all unique pairs (nonbonded). 

58 

59 Returns 

60 ------- 

61 List of tuple(int, int) 

62 Each tuple is (i, j) of atom indices. 

63 """ 

64 sel = list(self.select(selection)) 

65 if bonded: 

66 sel_set = set(sel) 

67 atom_pairs = [(i, j) for (i, j) in self.bonds if i in sel_set and j in sel_set] 

68 else: 

69 atom_pairs = list(itertools.combinations(sel, 2)) 

70 

71 self.n_distances = len(atom_pairs) 

72 

73 return atom_pairs 

74 

75 def get_triplets(self, selection, bonded=True): 

76 """ 

77 Get atom triplets for a selection. 

78 

79 Parameters 

80 ---------- 

81 selection : str 

82 MDTraj selection syntax. 

83 bonded : bool, default=True 

84 If True, return triplets that form angles (i-j-k where i-j and j-k are bonds). 

85 If False, return all unique triplets. 

86 

87 Returns 

88 ------- 

89 List of tuple(int, int, int) 

90 Each tuple is (i, j, k). 

91 """ 

92 sel = list(self.select(selection)) 

93 if not bonded: 

94 return list(itertools.combinations(sel, 3)) 

95 

96 sel_set = set(sel) 

97 # build adjacency dict 

98 nbrs = {a: set() for a in sel_set} 

99 for i, j in self.bonds: 

100 if i in sel_set and j in sel_set: 

101 nbrs[i].add(j) 

102 nbrs[j].add(i) 

103 

104 triplets = set() 

105 for j in sel_set: 

106 for i in nbrs[j]: 

107 for k in nbrs[j]: 

108 if i != k: 

109 triplets.add((i, j, k)) 

110 

111 self.n_angles = len(list(triplets)) 

112 

113 return list(triplets) 

114 

115 def get_quads(self, selection, bonded=True): 

116 """ 

117 Get atom quads for a selection. 

118 

119 Parameters 

120 ---------- 

121 selection : str 

122 MDTraj selection syntax. 

123 bonded : bool, default=True 

124 If True, return quads that form torsions (i-j-k-l where i-j, j-k, k-l are bonds). 

125 If False, return all unique quadruplets. 

126 

127 Returns 

128 ------- 

129 List of tuple(int, int, int, int) 

130 Each tuple is (i, j, k, l). 

131 """ 

132 sel = list(self.select(selection)) 

133 if not bonded: 

134 return list(itertools.combinations(sel, 4)) 

135 

136 sel_set = set(sel) 

137 nbrs = {a: set() for a in sel_set} 

138 for i, j in self.bonds: 

139 if i in sel_set and j in sel_set: 

140 nbrs[i].add(j) 

141 nbrs[j].add(i) 

142 

143 quads = set() 

144 for j in sel_set: 

145 for k in nbrs[j]: 

146 for i in nbrs[j]: 

147 if i == k: 

148 continue 

149 for neighbor_l in nbrs[k]: 

150 if neighbor_l == j: 

151 continue 

152 quads.add((i, j, k, neighbor_l)) 

153 

154 self.n_dihedrals = len(quads) 

155 

156 return list(quads) 

157 

158 def get_atom_indices(self, selection): 

159 """ 

160 Get atom indices for a selection. 

161 

162 Parameters 

163 ---------- 

164 selection : str 

165 MDTraj selection syntax. 

166 

167 Returns 

168 ------- 

169 List of int 

170 Atom indices. 

171 """ 

172 

173 atom_indices = list(self.select(selection)) 

174 self.n_atoms = len(atom_indices) 

175 

176 return atom_indices 

177 

178 def topology_indexing(self, config): 

179 """ 

180 Get the topology indexing for a given configuration. 

181 

182 Parameters 

183 ---------- 

184 config : dict 

185 Configuration dictionary containing selections for cartesian, distances, angles, and dihedrals. 

186 

187 Returns 

188 ------- 

189 Dict of indices 

190 Topology indices. 

191 """ 

192 

193 self.config = config 

194 

195 self.topology_idx = {} 

196 if 'cartesian' in self.config: 

197 sel = self.config['cartesian']['selection'] 

198 idx = self.get_atom_indices(sel) 

199 

200 fit_sel = self.config['cartesian'].get('fit_selection', None) 

201 if fit_sel is not None: 

202 fit_idx = self.get_atom_indices(fit_sel) 

203 self.topology_idx['cartesian']['fit_selection'] = fit_idx 

204 self.topology_idx['cartesian'] = {'selection': sel, 'indices': idx, 'fit_selection': fit_idx} 

205 else: 

206 self.topology_idx['cartesian'] = {'selection': sel, 'indices': idx} 

207 

208 if 'distances' in self.config: 

209 sel = self.config['distances']['selection'] 

210 bonded = self.config['distances'].get('bonded', False) 

211 pairs = self.get_atom_pairs(sel, bonded=bonded) 

212 # pull other args 

213 cutoff = self.config['distances'].get('cutoff', None) 

214 periodic = self.config['distances'].get('periodic', False) 

215 args = {'selection': sel, 

216 'pairs': pairs} 

217 if cutoff is not None: 

218 args['cutoff'] = cutoff 

219 args['periodic'] = periodic 

220 self.topology_idx['distances'] = args 

221 

222 if 'angles' in self.config: 

223 sel = self.config['angles']['selection'] 

224 bonded = self.config['angles'].get('bonded', True) 

225 triplets = self.get_triplets(sel, bonded=bonded) 

226 periodic = self.config['angles'].get('periodic', False) 

227 self.topology_idx['angles'] = { 

228 'selection': sel, 

229 'triplets': triplets, 

230 'periodic': periodic 

231 } 

232 

233 if 'dihedrals' in self.config: 

234 sel = self.config['dihedrals']['selection'] 

235 bonded = self.config['dihedrals'].get('bonded', True) 

236 quads = self.get_quads(sel, bonded=bonded) 

237 periodic = self.config['dihedrals'].get('periodic', False) 

238 self.topology_idx['dihedrals'] = { 

239 'selection': sel, 

240 'quadruplets': quads, 

241 'periodic': periodic 

242 } 

243 

244 # collect any options 

245 if 'options' in self.config: 

246 self.topology_idx['options'] = self.config['options'] 

247 

248 return self.topology_idx