Coverage for biobb_pytorch / mdae / featurization / plumed_feat.py: 0%

88 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import mdtraj as md 

2import numpy as np 

3from collections import defaultdict 

4import itertools 

5 

6 

7class FeaturesGenerator: 

8 """ 

9 Class to generate a PLUMED features.dat file from an MDTraj trajectory (single frame reference structure) 

10 and a selection dictionary specifying the types of features to include (cartesian, distances, angles, dihedrals). 

11 

12 The reference structure is used to compute positions for pair selection in distances (if cutoff is provided). 

13 

14 Usage: 

15 generator = FeaturesGenerator(ref_structure='reference.pdb', selection_dict=your_dict) 

16 generator.generate(output_file='features.dat') 

17 """ 

18 

19 def __init__(self, ref_structure: str, selection_dict: dict): 

20 """ 

21 Initialize with path to reference structure (e.g., PDB file) and selection dictionary. 

22 

23 Args: 

24 ref_structure: Path to the reference PDB or other MDTraj-loadable file (provides topology and coordinates). 

25 selection_dict: Dictionary specifying features, e.g., 

26 { 

27 'cartesian': {'selection': 'backbone'}, 

28 'distances': {'selection': 'name CA', 'cutoff': 0.4, 'periodic': True, 'bonded': False}, 

29 'angles': {'selection': 'backbone', 'periodic': True, 'bonded': True}, 

30 'dihedrals': {'selection': 'backbone', 'periodic': True, 'bonded': True} 

31 } 

32 """ 

33 self.traj = md.load(ref_structure) 

34 self.top = self.traj.top 

35 self.selection_dict = selection_dict 

36 

37 def generate(self, output_file: str = 'features.dat'): 

38 """ 

39 Generate the features.dat file with PLUMED commands for the specified features. 

40 

41 Args: 

42 output_file: Path to output features.dat file. 

43 """ 

44 lines = [] 

45 

46 # Cartesian positions 

47 if 'cartesian' in self.selection_dict: 

48 sel = self.selection_dict['cartesian']['selection'] 

49 atoms = self.top.select(sel) 

50 for i, atom_idx in enumerate(sorted(atoms)): 

51 label = f"p{i + 1}" 

52 lines.append(f"{label}: POSITION ATOM={atom_idx + 1}") 

53 

54 # Distances 

55 if 'distances' in self.selection_dict: 

56 sel_dict = self.selection_dict['distances'] 

57 sel = sel_dict['selection'] 

58 cutoff = sel_dict.get('cutoff', None) # If no cutoff, perhaps all pairs, but assume required 

59 if cutoff is None: 

60 raise ValueError("Cutoff must be provided for distances.") 

61 periodic = sel_dict.get('periodic', True) 

62 include_bonded = sel_dict.get('bonded', True) # But in example False 

63 

64 atoms = self.top.select(sel) 

65 if len(atoms) == 0: 

66 raise ValueError(f"No atoms selected for distances with '{sel}'") 

67 

68 # Compute pairwise distances to find pairs within cutoff 

69 pairs_list = list(itertools.combinations(atoms, 2)) 

70 if len(pairs_list) > 0: 

71 dists = md.compute_distances(self.traj[0:1], pairs_list, periodic=periodic)[0] 

72 pairs = set() 

73 for ii in range(len(dists)): 

74 if dists[ii] < cutoff: 

75 a1, a2 = pairs_list[ii] 

76 pairs.add((min(a1, a2), max(a1, a2))) 

77 else: 

78 pairs = set() 

79 

80 # Exclude bonded if not include_bonded 

81 if not include_bonded: 

82 bond_set = {frozenset({b.atom1.index, b.atom2.index}) for b in self.top.bonds} 

83 pairs = {p for p in pairs if frozenset(p) not in bond_set} 

84 

85 pbc_str = " NOPBC" if not periodic else "" 

86 for i, (a1, a2) in enumerate(sorted(pairs)): 

87 label = f"d{i + 1}" 

88 lines.append(f"{label}: DISTANCE ATOMS={a1 + 1},{a2 + 1}{pbc_str}") 

89 

90 # Angles 

91 if 'angles' in self.selection_dict: 

92 sel_dict = self.selection_dict['angles'] 

93 sel = sel_dict['selection'] 

94 bonded = sel_dict.get('bonded', True) 

95 if not bonded: 

96 raise NotImplementedError("Non-bonded angles not supported; too many combinations.") 

97 

98 sel_atoms = set(self.top.select(sel)) 

99 

100 # Build adjacency list for selected atoms 

101 adj = defaultdict(list) 

102 for bond in self.top.bonds: 

103 a1, a2 = bond.atom1.index, bond.atom2.index 

104 if a1 in sel_atoms and a2 in sel_atoms: 

105 adj[a1].append(a2) 

106 adj[a2].append(a1) 

107 

108 # Find triplets 

109 triplets = set() 

110 for j in sorted(sel_atoms): 

111 neigh = sorted(adj[j]) 

112 for idx1 in range(len(neigh)): 

113 for idx2 in range(idx1 + 1, len(neigh)): 

114 i, k = neigh[idx1], neigh[idx2] 

115 ordered = sorted([i, j, k]) 

116 triplets.add(tuple(ordered)) 

117 

118 for i, triplet in enumerate(sorted(triplets)): 

119 a1, a2, a3 = triplet 

120 label = f"a{i + 1}" 

121 lines.append(f"{label}: ANGLE ATOMS={a1 + 1},{a2 + 1},{a3 + 1}") 

122 

123 # Dihedrals 

124 if 'dihedrals' in self.selection_dict: 

125 sel_dict = self.selection_dict['dihedrals'] 

126 sel = sel_dict['selection'] 

127 bonded = sel_dict.get('bonded', True) 

128 if not bonded: 

129 raise NotImplementedError("Non-bonded dihedrals not supported; too many combinations.") 

130 

131 # Use MDTraj's built-in for backbone dihedrals (phi, psi, omega) 

132 # Assumes selection is 'backbone' for protein; extend if needed 

133 phi_indices = md.compute_phi(self.traj)[0] 

134 psi_indices = md.compute_psi(self.traj)[0] 

135 omega_indices = md.compute_omega(self.traj)[0] 

136 

137 all_indices = np.vstack((phi_indices, psi_indices, omega_indices)) 

138 unique_dihedrals = set(tuple(row) for row in all_indices if all(idx in self.top.select(sel) for idx in row)) 

139 

140 for i, dihedral in enumerate(sorted(unique_dihedrals)): 

141 a1, a2, a3, a4 = dihedral 

142 label = f"t{i + 1}" 

143 lines.append(f"{label}: TORSION ATOMS={a1 + 1},{a2 + 1},{a3 + 1},{a4 + 1}") 

144 

145 # Write to file 

146 with open(output_file, 'w') as f: 

147 for line in lines: 

148 f.write(line + '\n') 

149 

150 print(f"Features file generated at {output_file}")