Coverage for biobb_pytorch / mdae / featurization / featurizer.py: 54%

190 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import mdtraj as md 

2import numpy as np 

3import torch 

4from biobb_pytorch.mdae.featurization.normalization import Normalization 

5from mlcolvar.core.transform.utils import Statistics 

6 

7 

8class Featurizer: 

9 """ 

10 A class to extract geometric features (distances, angles, dihedrals) from MD trajectories using MDTraj. 

11 Supports selections by atom indices or MDTraj selection strings. 

12 Parameters: 

13 ------------- 

14 trajectory_file : str 

15 Path to the trajectory file (e.g., .dcd, .xtc). 

16 topology_file : str 

17 Path to the topology file (e.g., .pdb, .gro). 

18 input_labels_npy_path : str, optional 

19 Path to a .npy file containing labels for each frame. 

20 input_weights_npy_path : str, optional 

21 Path to a .npy file containing weights for each frame. 

22 """ 

23 

24 def __init__(self, trajectory_file, topology_file, input_labels_npy_path=None, input_weights_npy_path=None): 

25 """ 

26 Initialize with an MDTraj Trajectory object. 

27 

28 Parameters: 

29 ------------- 

30 trajectory_file : str 

31 Path to the trajectory file (e.g., .dcd, .xtc). 

32 topology_file : str 

33 Path to the topology file (e.g., .pdb, .gro). 

34 """ 

35 

36 # Load trajectory and topology 

37 trajectory = md.load(trajectory_file, 

38 top=topology_file) 

39 

40 self.trajectory = trajectory 

41 self.topology = trajectory.topology 

42 self.input_labels_npy_path = input_labels_npy_path 

43 self.input_weights_npy_path = input_weights_npy_path 

44 

45 self.complete_top = md.Trajectory(xyz=trajectory.xyz[0], topology=trajectory.topology) 

46 

47 def select_atoms(self, selection): 

48 """ 

49 Convert a selection specifier into atom indices. 

50 

51 Parameters: 

52 ------------- 

53 selection : str or list[int] or np.ndarray 

54 If str: MDTraj topology query (e.g., 'name CA', 'resid 0 to 10'). 

55 If list or array: explicit atom indices. 

56 

57 Returns: 

58 -------- 

59 np.ndarray 

60 Array of selected atom indices. 

61 """ 

62 if isinstance(selection, str): 

63 idx = self.topology.select(selection) 

64 else: 

65 idx = np.array(selection, dtype=int) 

66 return idx 

67 

68 def filter_topology(self, selection, topology): 

69 """ 

70 Filter the topology based on a selection. 

71 

72 Parameters: 

73 ------------- 

74 selection : str 

75 MDTraj topology query (e.g., 'name CA', 'resid 0 to 10'). 

76 topology : md.Topology 

77 MDTraj topology object. 

78 

79 Returns: 

80 -------- 

81 md.Topology 

82 Filtered topology. 

83 """ 

84 idx = self.select_atoms(selection) 

85 return topology.atom_slice(idx) 

86 

87 def _dicts_to_tuples(self, dict_list, expected_length): 

88 """ 

89 Internal helper: convert list of dicts mapping resid->atom to tuple of atom indices. 

90 

91 Each dict must have exactly expected_length entries. Keys are residue indices, 

92 values are atom names. Residue indices must match MDTraj topology resid numbering. 

93 """ 

94 tuples = [] 

95 for d in dict_list: 

96 if len(d) != expected_length: 

97 raise ValueError(f"Expected dict with {expected_length} entries, got {len(d)}") 

98 items = list(d.items()) 

99 idxs = [] 

100 for resid, atom_name in items: 

101 sel = f"resid {resid} and name {atom_name}" 

102 arr = self.select_atoms(sel) 

103 if len(arr) == 0: 

104 raise ValueError(f"No atom found for {sel}") 

105 # take first match 

106 idxs.append(int(arr[0])) 

107 tuples.append(tuple(idxs)) 

108 return tuples 

109 

110 def idx_distances(self, pairs): 

111 """ 

112 Convert pairs of atom indices to MDTraj topology indices. 

113 

114 Parameters: 

115 ------------- 

116 pairs : list of tuples or dicts 

117 Each tuple or dict must have 2 entries (resid, atom_name). 

118 

119 Returns: 

120 -------- 

121 np.ndarray 

122 Array of shape (n_pairs, 2) with atom indices. 

123 """ 

124 if len(pairs) > 0 and isinstance(pairs[0], dict): 

125 pairs = self._dicts_to_tuples(pairs, 2) 

126 

127 idx_pairs = [] 

128 for a, b in pairs: 

129 ia = int(self.select_atoms(a)) if isinstance(a, str) else int(a) 

130 ib = int(self.select_atoms(b)) if isinstance(b, str) else int(b) 

131 idx_pairs.append((ia, ib)) 

132 idx_pairs = np.array(idx_pairs) 

133 return idx_pairs 

134 

135 def compute_distances(self, idx_pairs, cutoff, periodic: bool = True): 

136 """ 

137 Compute inter-atomic distances for given pairs. 

138 

139 Accepts pairs as list of 2-tuples or list of dicts with 2 entries {resid: atom_name}. 

140 """ 

141 distances = md.compute_distances(self.trajectory, idx_pairs, periodic=periodic) 

142 

143 # apply cutoff and get only pairs within cutoff 

144 if cutoff is not None: 

145 keep_cols = np.any(distances < cutoff, axis=0) 

146 idx_pairs = idx_pairs[keep_cols] 

147 distances = distances[:, keep_cols] 

148 

149 return distances, idx_pairs 

150 

151 def polar2cartesian(self, a): 

152 """ 

153 Convert polar coordinates to Cartesian coordinates. 

154 

155 Parameters: 

156 ------------- 

157 a : np.ndarray 

158 Array of shape (n_frames, n_angles) representing polar coordinates. 

159 Each row corresponds to a frame, each column to an angle. 

160 

161 Returns: 

162 -------- 

163 np.ndarray 

164 Array of shape (n_frames, n_angles * 2) representing Cartesian coordinates. 

165 Each row corresponds to a frame, each column to sin and cos a. 

166 """ 

167 # Convert angles to radians 

168 a = np.deg2rad(a) 

169 # Compute sin and cos 

170 x = np.sin(a) 

171 y = np.cos(a) 

172 # Stack sin and cos values 

173 cart_angles = np.column_stack((x, y)) 

174 return cart_angles 

175 

176 def cartesian2polar(self, cart_angles): 

177 """ 

178 Convert Cartesian coordinates to polar coordinates. 

179 

180 Parameters: 

181 ------------- 

182 cart_angles : np.ndarray 

183 Array of shape (n_frames, n_angles * 2) representing Cartesian coordinates. 

184 Each row corresponds to a frame, each column to sin and cos a. 

185 

186 Returns: 

187 -------- 

188 np.ndarray 

189 Array of shape (n_frames, n_angles) representing polar coordinates. 

190 Each row corresponds to a frame, each column to an angle. 

191 """ 

192 # Compute angles from sin and cos 

193 angles = np.arctan2(cart_angles[:, 0], cart_angles[:, 1]) 

194 # Convert angles to degrees 

195 angles = np.rad2deg(angles) 

196 return angles 

197 

198 def idx_angles(self, triplets): 

199 """ 

200 Convert triplets of atom indices to MDTraj topology indices. 

201 

202 Parameters: 

203 ------------- 

204 triplets : list of tuples or dicts 

205 Each tuple or dict must have 3 entries (resid, atom_name). 

206 

207 Returns: 

208 -------- 

209 np.ndarray 

210 Array of shape (n_triplets, 3) with atom indices. 

211 """ 

212 if len(triplets) > 0 and isinstance(triplets[0], dict): 

213 triplets = self._dicts_to_tuples(triplets, 3) 

214 

215 idx_triplets = [] 

216 for i, j, k in triplets: 

217 ii = int(self.select_atoms(i)) if isinstance(i, str) else int(i) 

218 jj = int(self.select_atoms(j)) if isinstance(j, str) else int(j) 

219 kk = int(self.select_atoms(k)) if isinstance(k, str) else int(k) 

220 idx_triplets.append((ii, jj, kk)) 

221 idx_triplets = np.array(idx_triplets) 

222 return idx_triplets 

223 

224 def compute_angles(self, idx_triplets, periodic: bool = True): 

225 """ 

226 Compute angles between triplets of atoms. 

227 

228 Accepts triplets as list of 3-tuples or list of dicts with 3 entries {resid: atom_name}. 

229 """ 

230 

231 return md.compute_angles(self.trajectory, idx_triplets, periodic=periodic) 

232 

233 def idx_dihedrals(self, quadruplets): 

234 """ 

235 Convert quadruplets of atom indices to MDTraj topology indices. 

236 

237 Parameters: 

238 ------------- 

239 quadruplets : list of tuples or dicts 

240 Each tuple or dict must have 4 entries (resid, atom_name). 

241 

242 Returns: 

243 -------- 

244 np.ndarray 

245 Array of shape (n_quadruplets, 4) with atom indices. 

246 """ 

247 if len(quadruplets) > 0 and isinstance(quadruplets[0], dict): 

248 quadruplets = self._dicts_to_tuples(quadruplets, 4) 

249 

250 idx_quads = [] 

251 for i, j, k, l in quadruplets: 

252 ii = int(self.select_atoms(i)) if isinstance(i, str) else int(i) 

253 jj = int(self.select_atoms(j)) if isinstance(j, str) else int(j) 

254 kk = int(self.select_atoms(k)) if isinstance(k, str) else int(k) 

255 ll = int(self.select_atoms(l)) if isinstance(l, str) else int(l) 

256 idx_quads.append((ii, jj, kk, ll)) 

257 idx_quads = np.array(idx_quads) 

258 return idx_quads 

259 

260 def compute_dihedrals(self, idx_quads, periodic: bool = True): 

261 """ 

262 Compute dihedral angles for quadruplets of atoms. 

263 

264 Accepts quads as list of 4-tuples or list of dicts with 4 entries {resid: atom_name}. 

265 """ 

266 

267 return md.compute_dihedrals(self.trajectory, idx_quads, periodic=periodic) 

268 

269 def compute_cartesian(self, indices): 

270 """ 

271 Compute Cartesian coordinates for selected atoms. 

272 

273 Parameters: 

274 ------------- 

275 indices : list[int] 

276 List of atom indices to compute Cartesian coordinates for. 

277 

278 Returns: 

279 -------- 

280 np.ndarray 

281 Cartesian coordinates of the selected atoms. 

282 """ 

283 return self.trajectory.xyz[:, indices, :] 

284 

285 def combine_features(self, *feature_arrays): 

286 """ 

287 Concatenate multiple feature arrays along the feature axis. 

288 """ 

289 return np.concatenate(feature_arrays, axis=1) 

290 

291 def timelag(self, data: np.ndarray, lag: int): 

292 """ 

293 Split into X and Y where Y[t] = X[t+lag]. 

294 

295 Parameters 

296 ---------- 

297 data : np.ndarray, shape (n_times, n_features) 

298 lag : int 

299 

300 Returns 

301 ------- 

302 X : np.ndarray, shape (n_times-lag, n_features) 

303 Y : np.ndarray, shape (n_times-lag, n_features) 

304 """ 

305 if lag < 1 or lag >= data.shape[0]: 

306 raise ValueError("lag must be between 1 and n_times-1") 

307 

308 X = data[:-lag] 

309 Y = data[lag:] 

310 return X, Y 

311 

312 def get_n_features(self): 

313 """ 

314 Get the number of features in the combined feature array. 

315 """ 

316 return self.n_features 

317 

318 def get_n_frames(self): 

319 """ 

320 Get the number of frames in the combined feature array. 

321 """ 

322 return self.n_frames 

323 

324 def get_atom_info(self, selection): 

325 """ 

326 Get the atom information from the topology. 

327 """ 

328 idx = self.select_atoms(selection) 

329 top = self.trajectory.atom_slice(idx) 

330 atom_info = [] 

331 for i in top.topology.atoms: 

332 atom_info.append([i.name, i.residue.name, i.residue.index + 1]) 

333 return np.array(atom_info, dtype=object) 

334 

335 def set_statistics(self, combined: np.ndarray, feature_dict: dict): 

336 """ 

337 Set statistics for the combined feature array. 

338 """ 

339 stats = Statistics(torch.FloatTensor(combined)).to_dict() 

340 stats['shape'] = [self.n_frames, self.n_features] 

341 stats['selection'] = feature_dict['cartesian']['selection'] if 'cartesian' in feature_dict else None 

342 stats['topology'] = self.filter_topology(stats.get('selection', "name CA"), self.complete_top) 

343 stats["parametric"] = [torch.mean(torch.from_numpy(combined.flatten())), torch.std(torch.from_numpy(combined.flatten()))] 

344 

345 if self.idx_cartesian is not None: 

346 stats['cartesian_indices'] = self.idx_cartesian 

347 if self.idx_dist is not None: 

348 stats['distance_indices'] = self.idx_dist.tolist() 

349 if self.idx_triplets is not None: 

350 stats['angle_indices'] = self.idx_triplets.tolist() 

351 if self.idx_quads is not None: 

352 stats['dihedral_indices'] = self.idx_quads.tolist() 

353 

354 return stats 

355 

356 def compute_features(self, feature_dict: dict): 

357 """ 

358 Compute and combine multiple feature types in one call. 

359 

360 feature_dict keys: 

361 - 'distances': list of tuple or dict (2 entries) 

362 - 'angles': list of tuple or dict (3 entries) 

363 - 'dihedrals': list of tuple or dict (4 entries) 

364 

365 Returns: 

366 combined : np.ndarray shape=(n_frames, total_features) 

367 features : dict mapping feature type to its array 

368 """ 

369 self.idx_cartesian = None 

370 self.idx_dist = None 

371 self.idx_triplets = None 

372 self.idx_quads = None 

373 

374 self.features = {} 

375 arrays = [] 

376 

377 if 'distances' in feature_dict: 

378 self.idx_dist = self.idx_distances(feature_dict['distances']['pairs']) 

379 d, self.idx_dist = self.compute_distances(self.idx_dist, 

380 cutoff=feature_dict['distances']['cutoff'], 

381 periodic=feature_dict['distances']['periodic']) 

382 self.n_distances = d.shape[1] 

383 self.features['distances'] = d 

384 arrays.append(d) 

385 

386 if 'angles' in feature_dict: 

387 self.idx_triplets = self.idx_angles(feature_dict['angles']['triplets']) 

388 a = self.compute_angles(self.idx_triplets, 

389 periodic=feature_dict['angles']['periodic']) 

390 self.n_angles = a.shape[1] 

391 a = self.polar2cartesian(a) 

392 self.features['angles'] = a 

393 arrays.append(a) 

394 

395 if 'dihedrals' in feature_dict: 

396 self.idx_quads = self.idx_dihedrals(feature_dict['dihedrals']['quadruplets']) 

397 phi = self.compute_dihedrals(self.idx_quads, 

398 periodic=feature_dict['dihedrals']['periodic']) 

399 self.n_dihedrals = phi.shape[1] 

400 phi = self.polar2cartesian(phi) 

401 self.features['dihedrals'] = phi 

402 arrays.append(phi) 

403 

404 if 'cartesian' in feature_dict: 

405 self.idx_cartesian = feature_dict['cartesian']['indices'] 

406 cart = self.compute_cartesian(self.idx_cartesian) 

407 self.features['cartesian'] = cart.reshape(self.trajectory.n_frames, -1) 

408 self.n_cartesian = cart.shape[1] 

409 arrays.append(self.features['cartesian']) 

410 

411 combined = self.combine_features(*arrays) 

412 

413 self.n_features = combined.shape[1] 

414 self.n_frames = combined.shape[0] 

415 

416 stats = self.set_statistics(combined, feature_dict) 

417 

418 if self.input_labels_npy_path: 

419 labels = np.load(self.input_labels_npy_path) 

420 

421 if self.input_weights_npy_path: 

422 weights = np.load(self.input_weights_npy_path) 

423 

424 if 'norm_in' in feature_dict.get('options', {}): 

425 

426 if feature_dict['options']['norm_in']['mode'] != 'custom': 

427 

428 feature_dict['options']['norm_in']['stats'] = stats 

429 

430 norm_in = Normalization(combined.shape[1], **feature_dict['options']['norm_in']) 

431 

432 combined = norm_in(torch.FloatTensor(combined)) 

433 combined = combined.numpy() 

434 

435 # Add timelag features if specified 

436 if 'timelag' in feature_dict.get('options', {}): 

437 

438 lag = feature_dict['options']['timelag'] 

439 combined, combined_lag = self.timelag(combined, lag) 

440 

441 dataset = {"data": combined, "target": combined_lag} 

442 

443 if self.input_labels_npy_path: 

444 labels, labels_lag = self.timelag(labels, lag) 

445 dataset["labels"] = labels_lag 

446 

447 if self.input_weights_npy_path: 

448 weights, weights_lag = self.timelag(weights, lag) 

449 dataset["weights"] = weights_lag 

450 

451 return dataset, stats 

452 

453 else: 

454 dataset = {"data": combined} 

455 

456 if self.input_labels_npy_path: 

457 dataset["labels"] = labels 

458 if self.input_weights_npy_path: 

459 dataset["weights"] = weights 

460 

461 return dataset, stats 

462 

463 

464# Usage example 

465# traj = md.load("/home/pzanders/Documents/Simulations/GodMD/domini/1NE4_6NO7_b.dcd", 

466# top="/home/pzanders/Documents/Simulations/GodMD/domini/1NE4_6NO7_b.godmd.pdb") 

467 

468# featurizer = MDFeaturizer(traj) 

469 

470# atom_pairs = [(57, 78), (57, 79)] 

471# atom_triplets = [(57, 78, 79), (57, 78, 80)] 

472# atom_quadruplets = [(57, 78, 79, 80), (57, 78, 79, 81)] 

473 

474# feature_dict = { 

475# 'distances': {"pairs": atom_pairs, "cutoff": 0.5, "periodic": True}, 

476# 'angles': {"triplets": atom_triplets, "periodic": True}, 

477# 'dihedrals': {"quadruplets": atom_quadruplets, "periodic": True} 

478# } 

479# combined, details = featurizer.compute_features(feature_dict)