Coverage for biobb_pytorch / mdae / featurization / featurizer.py: 54%
190 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import mdtraj as md
2import numpy as np
3import torch
4from biobb_pytorch.mdae.featurization.normalization import Normalization
5from mlcolvar.core.transform.utils import Statistics
8class Featurizer:
9 """
10 A class to extract geometric features (distances, angles, dihedrals) from MD trajectories using MDTraj.
11 Supports selections by atom indices or MDTraj selection strings.
12 Parameters:
13 -------------
14 trajectory_file : str
15 Path to the trajectory file (e.g., .dcd, .xtc).
16 topology_file : str
17 Path to the topology file (e.g., .pdb, .gro).
18 input_labels_npy_path : str, optional
19 Path to a .npy file containing labels for each frame.
20 input_weights_npy_path : str, optional
21 Path to a .npy file containing weights for each frame.
22 """
24 def __init__(self, trajectory_file, topology_file, input_labels_npy_path=None, input_weights_npy_path=None):
25 """
26 Initialize with an MDTraj Trajectory object.
28 Parameters:
29 -------------
30 trajectory_file : str
31 Path to the trajectory file (e.g., .dcd, .xtc).
32 topology_file : str
33 Path to the topology file (e.g., .pdb, .gro).
34 """
36 # Load trajectory and topology
37 trajectory = md.load(trajectory_file,
38 top=topology_file)
40 self.trajectory = trajectory
41 self.topology = trajectory.topology
42 self.input_labels_npy_path = input_labels_npy_path
43 self.input_weights_npy_path = input_weights_npy_path
45 self.complete_top = md.Trajectory(xyz=trajectory.xyz[0], topology=trajectory.topology)
47 def select_atoms(self, selection):
48 """
49 Convert a selection specifier into atom indices.
51 Parameters:
52 -------------
53 selection : str or list[int] or np.ndarray
54 If str: MDTraj topology query (e.g., 'name CA', 'resid 0 to 10').
55 If list or array: explicit atom indices.
57 Returns:
58 --------
59 np.ndarray
60 Array of selected atom indices.
61 """
62 if isinstance(selection, str):
63 idx = self.topology.select(selection)
64 else:
65 idx = np.array(selection, dtype=int)
66 return idx
68 def filter_topology(self, selection, topology):
69 """
70 Filter the topology based on a selection.
72 Parameters:
73 -------------
74 selection : str
75 MDTraj topology query (e.g., 'name CA', 'resid 0 to 10').
76 topology : md.Topology
77 MDTraj topology object.
79 Returns:
80 --------
81 md.Topology
82 Filtered topology.
83 """
84 idx = self.select_atoms(selection)
85 return topology.atom_slice(idx)
87 def _dicts_to_tuples(self, dict_list, expected_length):
88 """
89 Internal helper: convert list of dicts mapping resid->atom to tuple of atom indices.
91 Each dict must have exactly expected_length entries. Keys are residue indices,
92 values are atom names. Residue indices must match MDTraj topology resid numbering.
93 """
94 tuples = []
95 for d in dict_list:
96 if len(d) != expected_length:
97 raise ValueError(f"Expected dict with {expected_length} entries, got {len(d)}")
98 items = list(d.items())
99 idxs = []
100 for resid, atom_name in items:
101 sel = f"resid {resid} and name {atom_name}"
102 arr = self.select_atoms(sel)
103 if len(arr) == 0:
104 raise ValueError(f"No atom found for {sel}")
105 # take first match
106 idxs.append(int(arr[0]))
107 tuples.append(tuple(idxs))
108 return tuples
110 def idx_distances(self, pairs):
111 """
112 Convert pairs of atom indices to MDTraj topology indices.
114 Parameters:
115 -------------
116 pairs : list of tuples or dicts
117 Each tuple or dict must have 2 entries (resid, atom_name).
119 Returns:
120 --------
121 np.ndarray
122 Array of shape (n_pairs, 2) with atom indices.
123 """
124 if len(pairs) > 0 and isinstance(pairs[0], dict):
125 pairs = self._dicts_to_tuples(pairs, 2)
127 idx_pairs = []
128 for a, b in pairs:
129 ia = int(self.select_atoms(a)) if isinstance(a, str) else int(a)
130 ib = int(self.select_atoms(b)) if isinstance(b, str) else int(b)
131 idx_pairs.append((ia, ib))
132 idx_pairs = np.array(idx_pairs)
133 return idx_pairs
135 def compute_distances(self, idx_pairs, cutoff, periodic: bool = True):
136 """
137 Compute inter-atomic distances for given pairs.
139 Accepts pairs as list of 2-tuples or list of dicts with 2 entries {resid: atom_name}.
140 """
141 distances = md.compute_distances(self.trajectory, idx_pairs, periodic=periodic)
143 # apply cutoff and get only pairs within cutoff
144 if cutoff is not None:
145 keep_cols = np.any(distances < cutoff, axis=0)
146 idx_pairs = idx_pairs[keep_cols]
147 distances = distances[:, keep_cols]
149 return distances, idx_pairs
151 def polar2cartesian(self, a):
152 """
153 Convert polar coordinates to Cartesian coordinates.
155 Parameters:
156 -------------
157 a : np.ndarray
158 Array of shape (n_frames, n_angles) representing polar coordinates.
159 Each row corresponds to a frame, each column to an angle.
161 Returns:
162 --------
163 np.ndarray
164 Array of shape (n_frames, n_angles * 2) representing Cartesian coordinates.
165 Each row corresponds to a frame, each column to sin and cos a.
166 """
167 # Convert angles to radians
168 a = np.deg2rad(a)
169 # Compute sin and cos
170 x = np.sin(a)
171 y = np.cos(a)
172 # Stack sin and cos values
173 cart_angles = np.column_stack((x, y))
174 return cart_angles
176 def cartesian2polar(self, cart_angles):
177 """
178 Convert Cartesian coordinates to polar coordinates.
180 Parameters:
181 -------------
182 cart_angles : np.ndarray
183 Array of shape (n_frames, n_angles * 2) representing Cartesian coordinates.
184 Each row corresponds to a frame, each column to sin and cos a.
186 Returns:
187 --------
188 np.ndarray
189 Array of shape (n_frames, n_angles) representing polar coordinates.
190 Each row corresponds to a frame, each column to an angle.
191 """
192 # Compute angles from sin and cos
193 angles = np.arctan2(cart_angles[:, 0], cart_angles[:, 1])
194 # Convert angles to degrees
195 angles = np.rad2deg(angles)
196 return angles
198 def idx_angles(self, triplets):
199 """
200 Convert triplets of atom indices to MDTraj topology indices.
202 Parameters:
203 -------------
204 triplets : list of tuples or dicts
205 Each tuple or dict must have 3 entries (resid, atom_name).
207 Returns:
208 --------
209 np.ndarray
210 Array of shape (n_triplets, 3) with atom indices.
211 """
212 if len(triplets) > 0 and isinstance(triplets[0], dict):
213 triplets = self._dicts_to_tuples(triplets, 3)
215 idx_triplets = []
216 for i, j, k in triplets:
217 ii = int(self.select_atoms(i)) if isinstance(i, str) else int(i)
218 jj = int(self.select_atoms(j)) if isinstance(j, str) else int(j)
219 kk = int(self.select_atoms(k)) if isinstance(k, str) else int(k)
220 idx_triplets.append((ii, jj, kk))
221 idx_triplets = np.array(idx_triplets)
222 return idx_triplets
224 def compute_angles(self, idx_triplets, periodic: bool = True):
225 """
226 Compute angles between triplets of atoms.
228 Accepts triplets as list of 3-tuples or list of dicts with 3 entries {resid: atom_name}.
229 """
231 return md.compute_angles(self.trajectory, idx_triplets, periodic=periodic)
233 def idx_dihedrals(self, quadruplets):
234 """
235 Convert quadruplets of atom indices to MDTraj topology indices.
237 Parameters:
238 -------------
239 quadruplets : list of tuples or dicts
240 Each tuple or dict must have 4 entries (resid, atom_name).
242 Returns:
243 --------
244 np.ndarray
245 Array of shape (n_quadruplets, 4) with atom indices.
246 """
247 if len(quadruplets) > 0 and isinstance(quadruplets[0], dict):
248 quadruplets = self._dicts_to_tuples(quadruplets, 4)
250 idx_quads = []
251 for i, j, k, l in quadruplets:
252 ii = int(self.select_atoms(i)) if isinstance(i, str) else int(i)
253 jj = int(self.select_atoms(j)) if isinstance(j, str) else int(j)
254 kk = int(self.select_atoms(k)) if isinstance(k, str) else int(k)
255 ll = int(self.select_atoms(l)) if isinstance(l, str) else int(l)
256 idx_quads.append((ii, jj, kk, ll))
257 idx_quads = np.array(idx_quads)
258 return idx_quads
260 def compute_dihedrals(self, idx_quads, periodic: bool = True):
261 """
262 Compute dihedral angles for quadruplets of atoms.
264 Accepts quads as list of 4-tuples or list of dicts with 4 entries {resid: atom_name}.
265 """
267 return md.compute_dihedrals(self.trajectory, idx_quads, periodic=periodic)
269 def compute_cartesian(self, indices):
270 """
271 Compute Cartesian coordinates for selected atoms.
273 Parameters:
274 -------------
275 indices : list[int]
276 List of atom indices to compute Cartesian coordinates for.
278 Returns:
279 --------
280 np.ndarray
281 Cartesian coordinates of the selected atoms.
282 """
283 return self.trajectory.xyz[:, indices, :]
285 def combine_features(self, *feature_arrays):
286 """
287 Concatenate multiple feature arrays along the feature axis.
288 """
289 return np.concatenate(feature_arrays, axis=1)
291 def timelag(self, data: np.ndarray, lag: int):
292 """
293 Split into X and Y where Y[t] = X[t+lag].
295 Parameters
296 ----------
297 data : np.ndarray, shape (n_times, n_features)
298 lag : int
300 Returns
301 -------
302 X : np.ndarray, shape (n_times-lag, n_features)
303 Y : np.ndarray, shape (n_times-lag, n_features)
304 """
305 if lag < 1 or lag >= data.shape[0]:
306 raise ValueError("lag must be between 1 and n_times-1")
308 X = data[:-lag]
309 Y = data[lag:]
310 return X, Y
312 def get_n_features(self):
313 """
314 Get the number of features in the combined feature array.
315 """
316 return self.n_features
318 def get_n_frames(self):
319 """
320 Get the number of frames in the combined feature array.
321 """
322 return self.n_frames
324 def get_atom_info(self, selection):
325 """
326 Get the atom information from the topology.
327 """
328 idx = self.select_atoms(selection)
329 top = self.trajectory.atom_slice(idx)
330 atom_info = []
331 for i in top.topology.atoms:
332 atom_info.append([i.name, i.residue.name, i.residue.index + 1])
333 return np.array(atom_info, dtype=object)
335 def set_statistics(self, combined: np.ndarray, feature_dict: dict):
336 """
337 Set statistics for the combined feature array.
338 """
339 stats = Statistics(torch.FloatTensor(combined)).to_dict()
340 stats['shape'] = [self.n_frames, self.n_features]
341 stats['selection'] = feature_dict['cartesian']['selection'] if 'cartesian' in feature_dict else None
342 stats['topology'] = self.filter_topology(stats.get('selection', "name CA"), self.complete_top)
343 stats["parametric"] = [torch.mean(torch.from_numpy(combined.flatten())), torch.std(torch.from_numpy(combined.flatten()))]
345 if self.idx_cartesian is not None:
346 stats['cartesian_indices'] = self.idx_cartesian
347 if self.idx_dist is not None:
348 stats['distance_indices'] = self.idx_dist.tolist()
349 if self.idx_triplets is not None:
350 stats['angle_indices'] = self.idx_triplets.tolist()
351 if self.idx_quads is not None:
352 stats['dihedral_indices'] = self.idx_quads.tolist()
354 return stats
356 def compute_features(self, feature_dict: dict):
357 """
358 Compute and combine multiple feature types in one call.
360 feature_dict keys:
361 - 'distances': list of tuple or dict (2 entries)
362 - 'angles': list of tuple or dict (3 entries)
363 - 'dihedrals': list of tuple or dict (4 entries)
365 Returns:
366 combined : np.ndarray shape=(n_frames, total_features)
367 features : dict mapping feature type to its array
368 """
369 self.idx_cartesian = None
370 self.idx_dist = None
371 self.idx_triplets = None
372 self.idx_quads = None
374 self.features = {}
375 arrays = []
377 if 'distances' in feature_dict:
378 self.idx_dist = self.idx_distances(feature_dict['distances']['pairs'])
379 d, self.idx_dist = self.compute_distances(self.idx_dist,
380 cutoff=feature_dict['distances']['cutoff'],
381 periodic=feature_dict['distances']['periodic'])
382 self.n_distances = d.shape[1]
383 self.features['distances'] = d
384 arrays.append(d)
386 if 'angles' in feature_dict:
387 self.idx_triplets = self.idx_angles(feature_dict['angles']['triplets'])
388 a = self.compute_angles(self.idx_triplets,
389 periodic=feature_dict['angles']['periodic'])
390 self.n_angles = a.shape[1]
391 a = self.polar2cartesian(a)
392 self.features['angles'] = a
393 arrays.append(a)
395 if 'dihedrals' in feature_dict:
396 self.idx_quads = self.idx_dihedrals(feature_dict['dihedrals']['quadruplets'])
397 phi = self.compute_dihedrals(self.idx_quads,
398 periodic=feature_dict['dihedrals']['periodic'])
399 self.n_dihedrals = phi.shape[1]
400 phi = self.polar2cartesian(phi)
401 self.features['dihedrals'] = phi
402 arrays.append(phi)
404 if 'cartesian' in feature_dict:
405 self.idx_cartesian = feature_dict['cartesian']['indices']
406 cart = self.compute_cartesian(self.idx_cartesian)
407 self.features['cartesian'] = cart.reshape(self.trajectory.n_frames, -1)
408 self.n_cartesian = cart.shape[1]
409 arrays.append(self.features['cartesian'])
411 combined = self.combine_features(*arrays)
413 self.n_features = combined.shape[1]
414 self.n_frames = combined.shape[0]
416 stats = self.set_statistics(combined, feature_dict)
418 if self.input_labels_npy_path:
419 labels = np.load(self.input_labels_npy_path)
421 if self.input_weights_npy_path:
422 weights = np.load(self.input_weights_npy_path)
424 if 'norm_in' in feature_dict.get('options', {}):
426 if feature_dict['options']['norm_in']['mode'] != 'custom':
428 feature_dict['options']['norm_in']['stats'] = stats
430 norm_in = Normalization(combined.shape[1], **feature_dict['options']['norm_in'])
432 combined = norm_in(torch.FloatTensor(combined))
433 combined = combined.numpy()
435 # Add timelag features if specified
436 if 'timelag' in feature_dict.get('options', {}):
438 lag = feature_dict['options']['timelag']
439 combined, combined_lag = self.timelag(combined, lag)
441 dataset = {"data": combined, "target": combined_lag}
443 if self.input_labels_npy_path:
444 labels, labels_lag = self.timelag(labels, lag)
445 dataset["labels"] = labels_lag
447 if self.input_weights_npy_path:
448 weights, weights_lag = self.timelag(weights, lag)
449 dataset["weights"] = weights_lag
451 return dataset, stats
453 else:
454 dataset = {"data": combined}
456 if self.input_labels_npy_path:
457 dataset["labels"] = labels
458 if self.input_weights_npy_path:
459 dataset["weights"] = weights
461 return dataset, stats
464# Usage example
465# traj = md.load("/home/pzanders/Documents/Simulations/GodMD/domini/1NE4_6NO7_b.dcd",
466# top="/home/pzanders/Documents/Simulations/GodMD/domini/1NE4_6NO7_b.godmd.pdb")
468# featurizer = MDFeaturizer(traj)
470# atom_pairs = [(57, 78), (57, 79)]
471# atom_triplets = [(57, 78, 79), (57, 78, 80)]
472# atom_quadruplets = [(57, 78, 79, 80), (57, 78, 79, 81)]
474# feature_dict = {
475# 'distances': {"pairs": atom_pairs, "cutoff": 0.5, "periodic": True},
476# 'angles': {"triplets": atom_triplets, "periodic": True},
477# 'dihedrals': {"quadruplets": atom_quadruplets, "periodic": True}
478# }
479# combined, details = featurizer.compute_features(feature_dict)