Coverage for biobb_pytorch / mdae / loss / utils / torch_protein_energy.py: 9%

209 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1# Copyright (c) 2021 Venkata K. Ramaswamy, Samuel C. Musson, Chris G. Willcocks, Matteo T. Degiacomi 

2# 

3# Molearn is free software ; 

4# you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation ; 

5# either version 2 of the License, or (at your option) any later version. 

6# molearn is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY ; 

7# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

8# See the GNU General Public License for more details. 

9# You should have received a copy of the GNU General Public License along with molearn ; 

10# if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. 

11 

12import torch 

13# import torch.nn as nn 

14# import torch.nn.functional as F 

15from .torch_protein_energy_utils import get_convolutions 

16 

17 

18class TorchProteinEnergy(): 

19 def __init__(self, frame, pdb_atom_names, 

20 padded_residues=False, 

21 method=('indexed', 'convolutional', 'roll')[2], 

22 device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'), fix_h=False, alt_vdw=[], NB='repulsive'): 

23 ''' 

24 At instantiation will load amber parameters and create the necessary convolutions/indexs/rolls to calculate the energy of the molecule. Energy can be assessed with the 

25 ``TorchProteinEnergy.get_energy(x)`` method 

26 

27 :param frame: Example coordinates of the structure in a torch array. The interatomic distance will be used to determine the connectivity of the atoms. 

28 Coordinates should be of ``shape [3, N]`` where N is the number of atoms. 

29 If ``padded_residues = True`` then Coordinates should be of ``shape [R, M, 3]`` where R is the number of residues 

30 and M is the maximum number of atoms in a residue. 

31 :param pdb_atom_names: Array of ``shape [N, 2]`` containing the pdb atom names in ``pdb_atom_names[:, 0]`` and residue names in ``pdb_atom_names[:, 1]``. If 

32 ``padded_residues = True`` then should be ``shape [R, M, 2]``. 

33 :param padded_residues: If true the dataset should be formatted as ``shape [R, M, 3]`` where R is the number of residues and M is the maximum number of atoms. 

34 Padding should be ``nan``. 

35 **Note** only ``method = "indexed"`` is currently implemented currently for this. 

36 :param method: ``method = "convolutional"`` (currently experimental) method uses convolutions to calculate force (padded_residues=false only). 

37 

38 ``method = "roll"`` method uses rolling and slicing to calculate force (padded_residues = false only) 

39 

40 ``method = "indexed"`` (experimental) method is only impremented for padded_residues=True. Uses indexes to calculate forces. 

41 :param device: ``torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")`` 

42 ''' 

43 self.device = device 

44 self.method = method 

45 if padded_residues: 

46 if method == 'indexed': 

47 self._padded_indexed_init(frame, pdb_atom_names) 

48 else: 

49 if method == 'convolutional': 

50 self._convolutional_init(frame, pdb_atom_names, fix_h=fix_h, alt_vdw=alt_vdw) 

51 elif method == 'roll': 

52 self._roll_init(frame, pdb_atom_names, NB=NB, fix_h=fix_h, alt_vdw=alt_vdw) 

53 

54 def get_energy(self, x, nonbonded=False): 

55 ''' 

56 :param x: tensor of shape [B, 3, N] where B is batch_size, N is number of atoms. If padded_residues = True then tensor of shape [B, R, M, 3] where B is R is number of Residues and M is 

57 the maximum number of atoms in a residue. 

58 :param nonbonded: Boolean (default False), Whether to return a softened nonbonded energy 

59 :returns: ``float`` Bond energy (average over batch) 

60 :returns: ``float`` Angle energy (average over batch) 

61 :returns: ``float`` Torsion energy (average over batch) 

62 :returns: ``float`` (optional if nonbonded = True) non-bonded energy (average over batch) 

63 ''' 

64 if nonbonded: 

65 return (*self.get_bonded_energy(x), self._nb_loss(x)) 

66 else: 

67 return self.get_bonded_energy(x) 

68 

69 def _roll_init(self, frame, pdb_atom_names, NB='full', fix_h=False, alt_vdw=[]): 

70 (b_masks, b_equil, b_force, b_weights, 

71 a_masks, a_equil, a_force, a_weights, 

72 t_masks, t_para, t_weights, 

73 vdw_R, vdw_e, vdw_14R, vdw_14e, 

74 q1q2, q1q2_14) = get_convolutions(frame, pdb_atom_names, fix_slice_method=True, fix_h=fix_h, alt_vdw=alt_vdw) 

75 

76 self.brdiff = [] 

77 self.br_equil = [] 

78 self.br_force = [] 

79 for i, j in enumerate(b_weights): 

80 atom1 = j.index(1) 

81 atom2 = j.index(-1) 

82 d = j.index(-1) - j.index(1) 

83 padding = len(j) - 2 

84 self.brdiff.append(d) 

85 # b_equil[:,0] is just padding so can roll(-1,1) to get correct padding 

86 self.br_equil.append(torch.tensor(b_equil[i, padding - 1:]).roll(-1).to(self.device).float()) 

87 self.br_force.append(torch.tensor(b_force[i, padding - 1:]).roll(-1).to(self.device).float()) 

88 self.ardiff = [] 

89 self.arsign = [] 

90 self.arroll = [] 

91 self.ar_equil = [] 

92 self.ar_force = [] 

93 self.ar_masks = [] 

94 for i, j in enumerate(a_weights): 

95 atom1 = j[0].index(1) 

96 atom2 = j[0].index(-1) 

97 atom3 = j[1].index(1) 

98 diff1 = atom2 - atom1 

99 diff2 = atom2 - atom3 

100 padding = len(j[0]) - 3 

101 self.arroll.append([min(atom1, atom2), min(atom2, atom3)]) 

102 self.ardiff.append([abs(diff1) - 1, abs(diff2) - 1]) 

103 self.arsign.append([diff1 / abs(diff1), diff2 / abs(diff2)]) 

104 self.ar_equil.append(torch.tensor(a_equil[i, padding - 2:]).roll(-2).to(self.device).float()) 

105 self.ar_force.append(torch.tensor(a_force[i, padding - 2:]).roll(-2).to(self.device).float()) 

106 self.trdiff = [] 

107 self.trsign = [] 

108 self.trroll = [] 

109 self.tr_para = [] 

110 for i, j in enumerate(t_weights): 

111 atom1 = j[0].index(1) # i-j 0 

112 atom2 = j[0].index(-1) # i-j 2 

113 atom3 = j[1].index(-1) # j-k 3 

114 atom4 = j[2].index(1) # l-k 4 

115 diff1 = atom2 - atom1 # ij 2 

116 diff2 = atom3 - atom2 # jk 1 

117 diff3 = (atom4 - atom3) * -1 # lk 1 

118 padding = len(j[0]) - 4 

119 self.trroll.append([min(atom1, atom2), min(atom2, atom3), min(atom3, atom4)]) 

120 self.trsign.append([diff1 / abs(diff1), diff2 / abs(diff2), diff3 / abs(diff3)]) 

121 self.trdiff.append([abs(diff1) - 1, abs(diff2) - 1, abs(diff3) - 1]) 

122 self.tr_para.append(torch.tensor(t_para[i, padding - 3:]).roll(-3, 0).to(self.device).float()) 

123 

124 self.vdw_A = (vdw_e * (vdw_R**12)).to(self.device) 

125 self.vdw_B = (2 * vdw_e * (vdw_R**6)).to(self.device) 

126 self.q1q2 = q1q2.to(self.device) 

127 

128 self.get_bonded_energy = self._bonded_roll_loss 

129 if NB == 'full': 

130 self._nb_loss = self._cdist_nb_full 

131 elif NB == 'repulsive': 

132 self._nb_loss = self._cdist_nb 

133 

134 def _convolutional_init(self, frame, pdb_atom_names, NB='full', fix_h=False, alt_vdw=[]): 

135 (b_masks, b_equil, b_force, b_weights, 

136 a_masks, a_equil, a_force, a_weights, 

137 t_masks, t_para, t_weights, 

138 vdw_R, vdw_e, vdw_14R, vdw_14e, 

139 q1q2, q1q2_14) = get_convolutions(frame, pdb_atom_names, fix_slice_method=False, fix_h=fix_h, alt_vdw=alt_vdw) 

140 

141 self.b_equil = torch.tensor(b_equil).to(self.device) 

142 self.b_force = torch.tensor(b_force).to(self.device) 

143 self.b_weights = torch.tensor(b_weights).to(self.device) 

144 

145 self.a_equil = torch.tensor(a_equil).to(self.device).float() 

146 self.a_force = torch.tensor(a_force).to(self.device).float() 

147 self.a_weights = torch.tensor(a_weights).to(self.device) 

148 self.a_masks = torch.tensor(a_masks).to(self.device) 

149 

150 self.t_para = torch.tensor(t_para).to(self.device) 

151 self.t_weights = torch.tensor(t_weights).to(self.device) 

152 

153 self.vdw_A = (vdw_e * (vdw_R**12)).to(self.device) 

154 self.vdw_B = (2 * vdw_e * (vdw_R**6)).to(self.device) 

155 self.q1q2 = q1q2.to(self.device) 

156 

157 self.get_bonded_energy = self._bonded_convolutional_loss 

158 if NB == 'full': 

159 self._nb_loss = self._cdist_nb_full 

160 elif NB == 'repulsive': 

161 self._nb_loss = self._cdist_nb 

162 

163 def _padded_indexed_init(self, frame, pdb_atom_names, NB='full'): 

164 from molearn import get_conv_pad_res 

165 (bond_idxs, bond_para, 

166 angle_idxs, angle_para, angle_mask, ij_jk, 

167 torsion_idxs, torsion_para, torsion_mask, ij_jk_kl, 

168 vdw_R, vdw_e, 

169 q1q2,) = get_conv_pad_res(frame, pdb_atom_names) 

170 

171 self.bond_idxs = bond_idxs.to(self.device) 

172 self.bond_para = bond_para.to(self.device) 

173 self.angle_mask = angle_mask.to(self.device) 

174 self.ij_jk = ij_jk.to(self.device) 

175 self.angle_para = angle_para.to(self.device) 

176 self.torsion_mask = torsion_mask.to(self.device) 

177 self.ij_jk_kl = ij_jk_kl.to(self.device) 

178 self.torsion_para = torsion_para.to(self.device) 

179 self.vdw_A = (vdw_e * (vdw_R**12)).to(self.device) 

180 self.vdw_B = (2 * vdw_e * (vdw_R**6)).to(self.device) 

181 self.q1q2 = q1q2.to(self.device) 

182 self.get_bonded_energy = self._bonded_padded_residues_loss 

183 self.relevant = self.bond_idxs.unique().to(self.device) 

184 if NB == 'full': 

185 self._nb_loss = self._cdist_nb_full 

186 elif NB == 'repulsive': 

187 self._nb_loss = self._cdist_nb 

188 

189 def _bonded_convolutional_loss(self, x): 

190 bs = x.shape[0] 

191 bloss = self._conv_bond_loss(x) 

192 aloss = self._conv_angle_loss(x) 

193 tloss = self._conv_torsion_loss(x) 

194 return bloss / bs, aloss / bs, tloss / bs 

195 

196 def _bonded_roll_loss(self, x): 

197 bs = x.shape[0] 

198 bloss, aloss, tloss = self._roll_bond_angle_torsion_loss(x) 

199 return bloss / bs, aloss / bs, tloss / bs 

200 

201 def _bonded_padded_residues_loss(self, x): 

202 # x.shape [B, R, M, 3] 

203 x = x.view(x.shape[0], -1, 3)[:,] 

204 v = x[:, self.bond_idxs[:, 1]] - x[:, self.bond_idxs[:, 0]] # j-i == i->j 

205 bloss = (((v.norm(dim=2) - self.bond_para[:, 0])**2) * self.bond_para[:, 1]).sum() 

206 v1 = v[:, self.ij_jk[0]] * self.angle_mask[0].view(1, -1, 1) 

207 v2 = v[:, self.ij_jk[1]] * self.angle_mask[1].view(1, -1, 1) 

208 xyz = torch.sum(v1 * v2, dim=2) / (torch.norm(v1, dim=2) * torch.norm(v2, dim=2)) 

209 theta = torch.acos(torch.clamp(xyz, min=-0.999999, max=0.999999)) 

210 aloss = (((theta - self.angle_para[:, 0])**2) * self.angle_para[:, 1]).sum() 

211 

212 u1 = v[:, self.ij_jk_kl[0]] * self.torsion_mask[0].view(1, -1, 1) 

213 u2 = v[:, self.ij_jk_kl[1]] * self.torsion_mask[1].view(1, -1, 1) 

214 u3 = v[:, self.ij_jk_kl[2]] * self.torsion_mask[2].view(1, -1, 1) 

215 u12 = torch.cross(u1, u2) 

216 u23 = torch.cross(u2, u3) 

217 t3 = torch.atan2(u2.norm(dim=2) * ((u1 * u23).sum(dim=2)), (u12 * u23).sum(dim=2)) 

218 p = self.torsion_para 

219 tloss = ((p[:, 1] / p[:, 0]) * (1 + torch.cos((p[:, 3] * t3.unsqueeze(2)) - p[:, 2]))).sum() 

220 bs = x.shape[0] 

221 return bloss / bs, aloss / bs, tloss / bs 

222 

223 def _cdist_nb_full(self, x, cutoff=9.0, mask=False): 

224 dmat = torch.cdist(x.permute(0, 2, 1), x.permute(0, 2, 1)) 

225 dmat6 = (self._warp_domain(dmat, 1.9)**6) 

226 LJpB = self.vdw_B / dmat6 

227 LJpA = self.vdw_A / (dmat6**2) 

228 Cp = (self.q1q2 / self._warp_domain(dmat, 0.4)) 

229 return torch.nansum(LJpA - LJpB + Cp) 

230 

231 def _cdist_nb(self, x, cutoff=9.0, mask=False): 

232 dmat = torch.cdist(x.permute(0, 2, 1), x.permute(0, 2, 1)) 

233 LJp = self.vdw_A / (self._warp_domain(dmat, 1.9)**12) 

234 Cp = (self.q1q2 / self._warp_domain(dmat, 0.4)) 

235 return torch.nansum(LJp + Cp) 

236 

237 def _warp_domain(self, x, k): 

238 return torch.nn.functional.elu(x - k, 1.0) + k 

239 

240 def _conv_bond_loss(self, x): 

241 # x shape[B, 3, N] 

242 loss = torch.tensor(0.0).float().to(self.device) 

243 for i, weight in enumerate(self.b_weights): 

244 y = torch.nn.functional.conv1d(x, weight.view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight) - 2)) 

245 loss += (self.b_force[i] * ((y.norm(dim=1) - self.b_equil[i])**2)).sum() 

246 return loss 

247 

248 def _conv_angle_loss(self, x): 

249 # x shape[X, 3, N] 

250 loss = torch.tensor(0.0).float().to(self.device) 

251 for i, weight in enumerate(self.a_weights): 

252 v1 = torch.nn.functional.conv1d(x, weight[0].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[0]) - 3)) 

253 v2 = torch.nn.functional.conv1d(x, weight[1].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[1]) - 3)) 

254 xyz = torch.sum(v1 * v2, dim=1) / (torch.norm(v1, dim=1) * torch.norm(v2, dim=1)) 

255 theta = torch.acos(torch.clamp(xyz, min=-0.999999, max=0.999999)) 

256 energy = (self.a_force[i] * ((theta - self.a_equil[i])**2)).sum(dim=0)[self.a_masks[i]].sum() 

257 loss += energy 

258 return loss 

259 

260 def _conv_torsion_loss(self, x): 

261 # x shape[X, 3, N] 

262 loss = torch.tensor(0.0).float().to(self.device) 

263 for i, weight in enumerate(self.t_weights): 

264 b1 = torch.nn.functional.conv1d(x, weight[0].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[0]) - 4)) # i-j 

265 b2 = torch.nn.functional.conv1d(x, weight[1].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[1]) - 4)) # j-k 

266 b3 = torch.nn.functional.conv1d(x, weight[2].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[2]) - 4)) # l-k 

267 c32 = torch.cross(b3, b2) 

268 c12 = torch.cross(b1, b2) 

269 torsion = torch.atan2((b2 * torch.cross(c32, c12)).sum(dim=1), 

270 b2.norm(dim=1) * ((c12 * c32).sum(dim=1))) 

271 p = self.t_para[i, :, :, :].unsqueeze(0) 

272 loss += ((p[:, :, 1] / p[:, :, 0]) * (1 + torch.cos((p[:, :, 3] * torsion.unsqueeze(2)) - p[:, :, 2]))).sum() 

273 return loss 

274 

275 def _roll_bond_angle_torsion_loss(self, x): 

276 # x.shape [5,3,2145] 

277 bloss = torch.tensor(0.0).float().to(self.device) 

278 aloss = torch.tensor(0.0).float().to(self.device) 

279 tloss = torch.tensor(0.0).float().to(self.device) 

280 v = [] 

281 for i, diff in enumerate(self.brdiff): 

282 v.append(x - x.roll(-diff, 2)) 

283 bloss += (((v[-1].norm(dim=1) - self.br_equil[i])**2) * self.br_force[i]).sum() 

284 

285 for i, diff in enumerate(self.ardiff): 

286 v1 = self.arsign[i][0] * (v[diff[0]].roll(-self.arroll[i][0], 2)) 

287 v2 = self.arsign[i][1] * (v[diff[1]].roll(-self.arroll[i][1], 2)) 

288 xyz = torch.sum(v1 * v2, dim=1) / (torch.norm(v1, dim=1) * torch.norm(v2, dim=1)) 

289 theta = torch.acos(torch.clamp(xyz, min=-0.999999, max=0.999999)) 

290 energy = (self.ar_force[i] * ((theta - self.ar_equil[i])**2)) 

291 sum_e = energy.sum() 

292 aloss += (sum_e) 

293 

294 for i, diff in enumerate(self.trdiff): 

295 b1 = self.trsign[i][0] * (v[diff[0]].roll(-self.trroll[i][0], 2)) 

296 b2 = self.trsign[i][1] * (v[diff[1]].roll(-self.trroll[i][1], 2)) 

297 b3 = self.trsign[i][2] * (v[diff[2]].roll(-self.trroll[i][2], 2)) 

298 c32 = torch.cross(b3, b2) 

299 c12 = torch.cross(b1, b2) 

300 torsion = torch.atan2((b2 * torch.cross(c32, c12)).sum(dim=1), 

301 b2.norm(dim=1) * ((c12 * c32).sum(dim=1))) 

302 p = self.tr_para[i].unsqueeze(0) 

303 tloss += (((p[:, :, 1] / p[:, :, 0]) * (1 + torch.cos((p[:, :, 3] * torsion.unsqueeze(2)) - p[:, :, 2]))).sum()) 

304 return bloss, aloss, tloss