Coverage for biobb_pytorch / mdae / featurization / plumed_feat.py: 0%
88 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import mdtraj as md
2import numpy as np
3from collections import defaultdict
4import itertools
7class FeaturesGenerator:
8 """
9 Class to generate a PLUMED features.dat file from an MDTraj trajectory (single frame reference structure)
10 and a selection dictionary specifying the types of features to include (cartesian, distances, angles, dihedrals).
12 The reference structure is used to compute positions for pair selection in distances (if cutoff is provided).
14 Usage:
15 generator = FeaturesGenerator(ref_structure='reference.pdb', selection_dict=your_dict)
16 generator.generate(output_file='features.dat')
17 """
19 def __init__(self, ref_structure: str, selection_dict: dict):
20 """
21 Initialize with path to reference structure (e.g., PDB file) and selection dictionary.
23 Args:
24 ref_structure: Path to the reference PDB or other MDTraj-loadable file (provides topology and coordinates).
25 selection_dict: Dictionary specifying features, e.g.,
26 {
27 'cartesian': {'selection': 'backbone'},
28 'distances': {'selection': 'name CA', 'cutoff': 0.4, 'periodic': True, 'bonded': False},
29 'angles': {'selection': 'backbone', 'periodic': True, 'bonded': True},
30 'dihedrals': {'selection': 'backbone', 'periodic': True, 'bonded': True}
31 }
32 """
33 self.traj = md.load(ref_structure)
34 self.top = self.traj.top
35 self.selection_dict = selection_dict
37 def generate(self, output_file: str = 'features.dat'):
38 """
39 Generate the features.dat file with PLUMED commands for the specified features.
41 Args:
42 output_file: Path to output features.dat file.
43 """
44 lines = []
46 # Cartesian positions
47 if 'cartesian' in self.selection_dict:
48 sel = self.selection_dict['cartesian']['selection']
49 atoms = self.top.select(sel)
50 for i, atom_idx in enumerate(sorted(atoms)):
51 label = f"p{i + 1}"
52 lines.append(f"{label}: POSITION ATOM={atom_idx + 1}")
54 # Distances
55 if 'distances' in self.selection_dict:
56 sel_dict = self.selection_dict['distances']
57 sel = sel_dict['selection']
58 cutoff = sel_dict.get('cutoff', None) # If no cutoff, perhaps all pairs, but assume required
59 if cutoff is None:
60 raise ValueError("Cutoff must be provided for distances.")
61 periodic = sel_dict.get('periodic', True)
62 include_bonded = sel_dict.get('bonded', True) # But in example False
64 atoms = self.top.select(sel)
65 if len(atoms) == 0:
66 raise ValueError(f"No atoms selected for distances with '{sel}'")
68 # Compute pairwise distances to find pairs within cutoff
69 pairs_list = list(itertools.combinations(atoms, 2))
70 if len(pairs_list) > 0:
71 dists = md.compute_distances(self.traj[0:1], pairs_list, periodic=periodic)[0]
72 pairs = set()
73 for ii in range(len(dists)):
74 if dists[ii] < cutoff:
75 a1, a2 = pairs_list[ii]
76 pairs.add((min(a1, a2), max(a1, a2)))
77 else:
78 pairs = set()
80 # Exclude bonded if not include_bonded
81 if not include_bonded:
82 bond_set = {frozenset({b.atom1.index, b.atom2.index}) for b in self.top.bonds}
83 pairs = {p for p in pairs if frozenset(p) not in bond_set}
85 pbc_str = " NOPBC" if not periodic else ""
86 for i, (a1, a2) in enumerate(sorted(pairs)):
87 label = f"d{i + 1}"
88 lines.append(f"{label}: DISTANCE ATOMS={a1 + 1},{a2 + 1}{pbc_str}")
90 # Angles
91 if 'angles' in self.selection_dict:
92 sel_dict = self.selection_dict['angles']
93 sel = sel_dict['selection']
94 bonded = sel_dict.get('bonded', True)
95 if not bonded:
96 raise NotImplementedError("Non-bonded angles not supported; too many combinations.")
98 sel_atoms = set(self.top.select(sel))
100 # Build adjacency list for selected atoms
101 adj = defaultdict(list)
102 for bond in self.top.bonds:
103 a1, a2 = bond.atom1.index, bond.atom2.index
104 if a1 in sel_atoms and a2 in sel_atoms:
105 adj[a1].append(a2)
106 adj[a2].append(a1)
108 # Find triplets
109 triplets = set()
110 for j in sorted(sel_atoms):
111 neigh = sorted(adj[j])
112 for idx1 in range(len(neigh)):
113 for idx2 in range(idx1 + 1, len(neigh)):
114 i, k = neigh[idx1], neigh[idx2]
115 ordered = sorted([i, j, k])
116 triplets.add(tuple(ordered))
118 for i, triplet in enumerate(sorted(triplets)):
119 a1, a2, a3 = triplet
120 label = f"a{i + 1}"
121 lines.append(f"{label}: ANGLE ATOMS={a1 + 1},{a2 + 1},{a3 + 1}")
123 # Dihedrals
124 if 'dihedrals' in self.selection_dict:
125 sel_dict = self.selection_dict['dihedrals']
126 sel = sel_dict['selection']
127 bonded = sel_dict.get('bonded', True)
128 if not bonded:
129 raise NotImplementedError("Non-bonded dihedrals not supported; too many combinations.")
131 # Use MDTraj's built-in for backbone dihedrals (phi, psi, omega)
132 # Assumes selection is 'backbone' for protein; extend if needed
133 phi_indices = md.compute_phi(self.traj)[0]
134 psi_indices = md.compute_psi(self.traj)[0]
135 omega_indices = md.compute_omega(self.traj)[0]
137 all_indices = np.vstack((phi_indices, psi_indices, omega_indices))
138 unique_dihedrals = set(tuple(row) for row in all_indices if all(idx in self.top.select(sel) for idx in row))
140 for i, dihedral in enumerate(sorted(unique_dihedrals)):
141 a1, a2, a3, a4 = dihedral
142 label = f"t{i + 1}"
143 lines.append(f"{label}: TORSION ATOMS={a1 + 1},{a2 + 1},{a3 + 1},{a4 + 1}")
145 # Write to file
146 with open(output_file, 'w') as f:
147 for line in lines:
148 f.write(line + '\n')
150 print(f"Features file generated at {output_file}")