Coverage for biobb_pytorch / mdae / featurization / topology_selector.py: 47%
105 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import mdtraj as md
2import itertools
5class MDTopologySelector:
6 """
7 A class to load an MDTraj topology and extract atom pairs (bonds or distances), triplets (angles or arbitrary triples),
8 and quads (torsions or arbitrary quadruplets) for a given atom selection.
9 """
11 def __init__(self, topology):
12 """
13 Parameters
14 ----------
15 topology : str | md.Trajectory | md.Topology
16 Path to a structure file (e.g., .pdb, .gro), an MDTraj Trajectory, or an MDTraj Topology.
17 """
19 if isinstance(topology, md.Trajectory):
20 self.topology = topology.topology
21 elif isinstance(topology, md.Topology):
22 self.topology = topology
23 elif isinstance(topology, str):
24 traj = md.load(topology)
25 self.topology = traj.topology
26 else:
27 raise ValueError("`topology` must be a file path, md.Trajectory, or md.Topology instance.")
29 # Precompute bond list as tuples of atom indices
30 self.bonds = [(b.atom1.index, b.atom2.index) for b in self.topology.bonds]
32 def select(self, selection):
33 """
34 Select atom indices matching an MDTraj selection string.
36 Parameters
37 ----------
38 selection : str
39 MDTraj selection syntax, e.g., "backbone", "name CA", etc.
41 Returns
42 -------
43 numpy.ndarray of int
44 Array of atom indices.
45 """
46 return self.topology.select(selection)
48 def get_atom_pairs(self, selection, bonded=True):
49 """
50 Get atom pairs for a selection.
52 Parameters
53 ----------
54 selection : str
55 MDTraj selection syntax.
56 bonded : bool, default=True
57 If True, return only bonded pairs. If False, return all unique pairs (nonbonded).
59 Returns
60 -------
61 List of tuple(int, int)
62 Each tuple is (i, j) of atom indices.
63 """
64 sel = list(self.select(selection))
65 if bonded:
66 sel_set = set(sel)
67 atom_pairs = [(i, j) for (i, j) in self.bonds if i in sel_set and j in sel_set]
68 else:
69 atom_pairs = list(itertools.combinations(sel, 2))
71 self.n_distances = len(atom_pairs)
73 return atom_pairs
75 def get_triplets(self, selection, bonded=True):
76 """
77 Get atom triplets for a selection.
79 Parameters
80 ----------
81 selection : str
82 MDTraj selection syntax.
83 bonded : bool, default=True
84 If True, return triplets that form angles (i-j-k where i-j and j-k are bonds).
85 If False, return all unique triplets.
87 Returns
88 -------
89 List of tuple(int, int, int)
90 Each tuple is (i, j, k).
91 """
92 sel = list(self.select(selection))
93 if not bonded:
94 return list(itertools.combinations(sel, 3))
96 sel_set = set(sel)
97 # build adjacency dict
98 nbrs = {a: set() for a in sel_set}
99 for i, j in self.bonds:
100 if i in sel_set and j in sel_set:
101 nbrs[i].add(j)
102 nbrs[j].add(i)
104 triplets = set()
105 for j in sel_set:
106 for i in nbrs[j]:
107 for k in nbrs[j]:
108 if i != k:
109 triplets.add((i, j, k))
111 self.n_angles = len(list(triplets))
113 return list(triplets)
115 def get_quads(self, selection, bonded=True):
116 """
117 Get atom quads for a selection.
119 Parameters
120 ----------
121 selection : str
122 MDTraj selection syntax.
123 bonded : bool, default=True
124 If True, return quads that form torsions (i-j-k-l where i-j, j-k, k-l are bonds).
125 If False, return all unique quadruplets.
127 Returns
128 -------
129 List of tuple(int, int, int, int)
130 Each tuple is (i, j, k, l).
131 """
132 sel = list(self.select(selection))
133 if not bonded:
134 return list(itertools.combinations(sel, 4))
136 sel_set = set(sel)
137 nbrs = {a: set() for a in sel_set}
138 for i, j in self.bonds:
139 if i in sel_set and j in sel_set:
140 nbrs[i].add(j)
141 nbrs[j].add(i)
143 quads = set()
144 for j in sel_set:
145 for k in nbrs[j]:
146 for i in nbrs[j]:
147 if i == k:
148 continue
149 for neighbor_l in nbrs[k]:
150 if neighbor_l == j:
151 continue
152 quads.add((i, j, k, neighbor_l))
154 self.n_dihedrals = len(quads)
156 return list(quads)
158 def get_atom_indices(self, selection):
159 """
160 Get atom indices for a selection.
162 Parameters
163 ----------
164 selection : str
165 MDTraj selection syntax.
167 Returns
168 -------
169 List of int
170 Atom indices.
171 """
173 atom_indices = list(self.select(selection))
174 self.n_atoms = len(atom_indices)
176 return atom_indices
178 def topology_indexing(self, config):
179 """
180 Get the topology indexing for a given configuration.
182 Parameters
183 ----------
184 config : dict
185 Configuration dictionary containing selections for cartesian, distances, angles, and dihedrals.
187 Returns
188 -------
189 Dict of indices
190 Topology indices.
191 """
193 self.config = config
195 self.topology_idx = {}
196 if 'cartesian' in self.config:
197 sel = self.config['cartesian']['selection']
198 idx = self.get_atom_indices(sel)
200 fit_sel = self.config['cartesian'].get('fit_selection', None)
201 if fit_sel is not None:
202 fit_idx = self.get_atom_indices(fit_sel)
203 self.topology_idx['cartesian']['fit_selection'] = fit_idx
204 self.topology_idx['cartesian'] = {'selection': sel, 'indices': idx, 'fit_selection': fit_idx}
205 else:
206 self.topology_idx['cartesian'] = {'selection': sel, 'indices': idx}
208 if 'distances' in self.config:
209 sel = self.config['distances']['selection']
210 bonded = self.config['distances'].get('bonded', False)
211 pairs = self.get_atom_pairs(sel, bonded=bonded)
212 # pull other args
213 cutoff = self.config['distances'].get('cutoff', None)
214 periodic = self.config['distances'].get('periodic', False)
215 args = {'selection': sel,
216 'pairs': pairs}
217 if cutoff is not None:
218 args['cutoff'] = cutoff
219 args['periodic'] = periodic
220 self.topology_idx['distances'] = args
222 if 'angles' in self.config:
223 sel = self.config['angles']['selection']
224 bonded = self.config['angles'].get('bonded', True)
225 triplets = self.get_triplets(sel, bonded=bonded)
226 periodic = self.config['angles'].get('periodic', False)
227 self.topology_idx['angles'] = {
228 'selection': sel,
229 'triplets': triplets,
230 'periodic': periodic
231 }
233 if 'dihedrals' in self.config:
234 sel = self.config['dihedrals']['selection']
235 bonded = self.config['dihedrals'].get('bonded', True)
236 quads = self.get_quads(sel, bonded=bonded)
237 periodic = self.config['dihedrals'].get('periodic', False)
238 self.topology_idx['dihedrals'] = {
239 'selection': sel,
240 'quadruplets': quads,
241 'periodic': periodic
242 }
244 # collect any options
245 if 'options' in self.config:
246 self.topology_idx['options'] = self.config['options']
248 return self.topology_idx