Coverage for biobb_pytorch / mdae / loss / utils / torch_protein_energy.py: 9%
209 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1# Copyright (c) 2021 Venkata K. Ramaswamy, Samuel C. Musson, Chris G. Willcocks, Matteo T. Degiacomi
2#
3# Molearn is free software ;
4# you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation ;
5# either version 2 of the License, or (at your option) any later version.
6# molearn is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY ;
7# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
8# See the GNU General Public License for more details.
9# You should have received a copy of the GNU General Public License along with molearn ;
10# if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
12import torch
13# import torch.nn as nn
14# import torch.nn.functional as F
15from .torch_protein_energy_utils import get_convolutions
18class TorchProteinEnergy():
19 def __init__(self, frame, pdb_atom_names,
20 padded_residues=False,
21 method=('indexed', 'convolutional', 'roll')[2],
22 device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'), fix_h=False, alt_vdw=[], NB='repulsive'):
23 '''
24 At instantiation will load amber parameters and create the necessary convolutions/indexs/rolls to calculate the energy of the molecule. Energy can be assessed with the
25 ``TorchProteinEnergy.get_energy(x)`` method
27 :param frame: Example coordinates of the structure in a torch array. The interatomic distance will be used to determine the connectivity of the atoms.
28 Coordinates should be of ``shape [3, N]`` where N is the number of atoms.
29 If ``padded_residues = True`` then Coordinates should be of ``shape [R, M, 3]`` where R is the number of residues
30 and M is the maximum number of atoms in a residue.
31 :param pdb_atom_names: Array of ``shape [N, 2]`` containing the pdb atom names in ``pdb_atom_names[:, 0]`` and residue names in ``pdb_atom_names[:, 1]``. If
32 ``padded_residues = True`` then should be ``shape [R, M, 2]``.
33 :param padded_residues: If true the dataset should be formatted as ``shape [R, M, 3]`` where R is the number of residues and M is the maximum number of atoms.
34 Padding should be ``nan``.
35 **Note** only ``method = "indexed"`` is currently implemented currently for this.
36 :param method: ``method = "convolutional"`` (currently experimental) method uses convolutions to calculate force (padded_residues=false only).
38 ``method = "roll"`` method uses rolling and slicing to calculate force (padded_residues = false only)
40 ``method = "indexed"`` (experimental) method is only impremented for padded_residues=True. Uses indexes to calculate forces.
41 :param device: ``torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")``
42 '''
43 self.device = device
44 self.method = method
45 if padded_residues:
46 if method == 'indexed':
47 self._padded_indexed_init(frame, pdb_atom_names)
48 else:
49 if method == 'convolutional':
50 self._convolutional_init(frame, pdb_atom_names, fix_h=fix_h, alt_vdw=alt_vdw)
51 elif method == 'roll':
52 self._roll_init(frame, pdb_atom_names, NB=NB, fix_h=fix_h, alt_vdw=alt_vdw)
54 def get_energy(self, x, nonbonded=False):
55 '''
56 :param x: tensor of shape [B, 3, N] where B is batch_size, N is number of atoms. If padded_residues = True then tensor of shape [B, R, M, 3] where B is R is number of Residues and M is
57 the maximum number of atoms in a residue.
58 :param nonbonded: Boolean (default False), Whether to return a softened nonbonded energy
59 :returns: ``float`` Bond energy (average over batch)
60 :returns: ``float`` Angle energy (average over batch)
61 :returns: ``float`` Torsion energy (average over batch)
62 :returns: ``float`` (optional if nonbonded = True) non-bonded energy (average over batch)
63 '''
64 if nonbonded:
65 return (*self.get_bonded_energy(x), self._nb_loss(x))
66 else:
67 return self.get_bonded_energy(x)
69 def _roll_init(self, frame, pdb_atom_names, NB='full', fix_h=False, alt_vdw=[]):
70 (b_masks, b_equil, b_force, b_weights,
71 a_masks, a_equil, a_force, a_weights,
72 t_masks, t_para, t_weights,
73 vdw_R, vdw_e, vdw_14R, vdw_14e,
74 q1q2, q1q2_14) = get_convolutions(frame, pdb_atom_names, fix_slice_method=True, fix_h=fix_h, alt_vdw=alt_vdw)
76 self.brdiff = []
77 self.br_equil = []
78 self.br_force = []
79 for i, j in enumerate(b_weights):
80 atom1 = j.index(1)
81 atom2 = j.index(-1)
82 d = j.index(-1) - j.index(1)
83 padding = len(j) - 2
84 self.brdiff.append(d)
85 # b_equil[:,0] is just padding so can roll(-1,1) to get correct padding
86 self.br_equil.append(torch.tensor(b_equil[i, padding - 1:]).roll(-1).to(self.device).float())
87 self.br_force.append(torch.tensor(b_force[i, padding - 1:]).roll(-1).to(self.device).float())
88 self.ardiff = []
89 self.arsign = []
90 self.arroll = []
91 self.ar_equil = []
92 self.ar_force = []
93 self.ar_masks = []
94 for i, j in enumerate(a_weights):
95 atom1 = j[0].index(1)
96 atom2 = j[0].index(-1)
97 atom3 = j[1].index(1)
98 diff1 = atom2 - atom1
99 diff2 = atom2 - atom3
100 padding = len(j[0]) - 3
101 self.arroll.append([min(atom1, atom2), min(atom2, atom3)])
102 self.ardiff.append([abs(diff1) - 1, abs(diff2) - 1])
103 self.arsign.append([diff1 / abs(diff1), diff2 / abs(diff2)])
104 self.ar_equil.append(torch.tensor(a_equil[i, padding - 2:]).roll(-2).to(self.device).float())
105 self.ar_force.append(torch.tensor(a_force[i, padding - 2:]).roll(-2).to(self.device).float())
106 self.trdiff = []
107 self.trsign = []
108 self.trroll = []
109 self.tr_para = []
110 for i, j in enumerate(t_weights):
111 atom1 = j[0].index(1) # i-j 0
112 atom2 = j[0].index(-1) # i-j 2
113 atom3 = j[1].index(-1) # j-k 3
114 atom4 = j[2].index(1) # l-k 4
115 diff1 = atom2 - atom1 # ij 2
116 diff2 = atom3 - atom2 # jk 1
117 diff3 = (atom4 - atom3) * -1 # lk 1
118 padding = len(j[0]) - 4
119 self.trroll.append([min(atom1, atom2), min(atom2, atom3), min(atom3, atom4)])
120 self.trsign.append([diff1 / abs(diff1), diff2 / abs(diff2), diff3 / abs(diff3)])
121 self.trdiff.append([abs(diff1) - 1, abs(diff2) - 1, abs(diff3) - 1])
122 self.tr_para.append(torch.tensor(t_para[i, padding - 3:]).roll(-3, 0).to(self.device).float())
124 self.vdw_A = (vdw_e * (vdw_R**12)).to(self.device)
125 self.vdw_B = (2 * vdw_e * (vdw_R**6)).to(self.device)
126 self.q1q2 = q1q2.to(self.device)
128 self.get_bonded_energy = self._bonded_roll_loss
129 if NB == 'full':
130 self._nb_loss = self._cdist_nb_full
131 elif NB == 'repulsive':
132 self._nb_loss = self._cdist_nb
134 def _convolutional_init(self, frame, pdb_atom_names, NB='full', fix_h=False, alt_vdw=[]):
135 (b_masks, b_equil, b_force, b_weights,
136 a_masks, a_equil, a_force, a_weights,
137 t_masks, t_para, t_weights,
138 vdw_R, vdw_e, vdw_14R, vdw_14e,
139 q1q2, q1q2_14) = get_convolutions(frame, pdb_atom_names, fix_slice_method=False, fix_h=fix_h, alt_vdw=alt_vdw)
141 self.b_equil = torch.tensor(b_equil).to(self.device)
142 self.b_force = torch.tensor(b_force).to(self.device)
143 self.b_weights = torch.tensor(b_weights).to(self.device)
145 self.a_equil = torch.tensor(a_equil).to(self.device).float()
146 self.a_force = torch.tensor(a_force).to(self.device).float()
147 self.a_weights = torch.tensor(a_weights).to(self.device)
148 self.a_masks = torch.tensor(a_masks).to(self.device)
150 self.t_para = torch.tensor(t_para).to(self.device)
151 self.t_weights = torch.tensor(t_weights).to(self.device)
153 self.vdw_A = (vdw_e * (vdw_R**12)).to(self.device)
154 self.vdw_B = (2 * vdw_e * (vdw_R**6)).to(self.device)
155 self.q1q2 = q1q2.to(self.device)
157 self.get_bonded_energy = self._bonded_convolutional_loss
158 if NB == 'full':
159 self._nb_loss = self._cdist_nb_full
160 elif NB == 'repulsive':
161 self._nb_loss = self._cdist_nb
163 def _padded_indexed_init(self, frame, pdb_atom_names, NB='full'):
164 from molearn import get_conv_pad_res
165 (bond_idxs, bond_para,
166 angle_idxs, angle_para, angle_mask, ij_jk,
167 torsion_idxs, torsion_para, torsion_mask, ij_jk_kl,
168 vdw_R, vdw_e,
169 q1q2,) = get_conv_pad_res(frame, pdb_atom_names)
171 self.bond_idxs = bond_idxs.to(self.device)
172 self.bond_para = bond_para.to(self.device)
173 self.angle_mask = angle_mask.to(self.device)
174 self.ij_jk = ij_jk.to(self.device)
175 self.angle_para = angle_para.to(self.device)
176 self.torsion_mask = torsion_mask.to(self.device)
177 self.ij_jk_kl = ij_jk_kl.to(self.device)
178 self.torsion_para = torsion_para.to(self.device)
179 self.vdw_A = (vdw_e * (vdw_R**12)).to(self.device)
180 self.vdw_B = (2 * vdw_e * (vdw_R**6)).to(self.device)
181 self.q1q2 = q1q2.to(self.device)
182 self.get_bonded_energy = self._bonded_padded_residues_loss
183 self.relevant = self.bond_idxs.unique().to(self.device)
184 if NB == 'full':
185 self._nb_loss = self._cdist_nb_full
186 elif NB == 'repulsive':
187 self._nb_loss = self._cdist_nb
189 def _bonded_convolutional_loss(self, x):
190 bs = x.shape[0]
191 bloss = self._conv_bond_loss(x)
192 aloss = self._conv_angle_loss(x)
193 tloss = self._conv_torsion_loss(x)
194 return bloss / bs, aloss / bs, tloss / bs
196 def _bonded_roll_loss(self, x):
197 bs = x.shape[0]
198 bloss, aloss, tloss = self._roll_bond_angle_torsion_loss(x)
199 return bloss / bs, aloss / bs, tloss / bs
201 def _bonded_padded_residues_loss(self, x):
202 # x.shape [B, R, M, 3]
203 x = x.view(x.shape[0], -1, 3)[:,]
204 v = x[:, self.bond_idxs[:, 1]] - x[:, self.bond_idxs[:, 0]] # j-i == i->j
205 bloss = (((v.norm(dim=2) - self.bond_para[:, 0])**2) * self.bond_para[:, 1]).sum()
206 v1 = v[:, self.ij_jk[0]] * self.angle_mask[0].view(1, -1, 1)
207 v2 = v[:, self.ij_jk[1]] * self.angle_mask[1].view(1, -1, 1)
208 xyz = torch.sum(v1 * v2, dim=2) / (torch.norm(v1, dim=2) * torch.norm(v2, dim=2))
209 theta = torch.acos(torch.clamp(xyz, min=-0.999999, max=0.999999))
210 aloss = (((theta - self.angle_para[:, 0])**2) * self.angle_para[:, 1]).sum()
212 u1 = v[:, self.ij_jk_kl[0]] * self.torsion_mask[0].view(1, -1, 1)
213 u2 = v[:, self.ij_jk_kl[1]] * self.torsion_mask[1].view(1, -1, 1)
214 u3 = v[:, self.ij_jk_kl[2]] * self.torsion_mask[2].view(1, -1, 1)
215 u12 = torch.cross(u1, u2)
216 u23 = torch.cross(u2, u3)
217 t3 = torch.atan2(u2.norm(dim=2) * ((u1 * u23).sum(dim=2)), (u12 * u23).sum(dim=2))
218 p = self.torsion_para
219 tloss = ((p[:, 1] / p[:, 0]) * (1 + torch.cos((p[:, 3] * t3.unsqueeze(2)) - p[:, 2]))).sum()
220 bs = x.shape[0]
221 return bloss / bs, aloss / bs, tloss / bs
223 def _cdist_nb_full(self, x, cutoff=9.0, mask=False):
224 dmat = torch.cdist(x.permute(0, 2, 1), x.permute(0, 2, 1))
225 dmat6 = (self._warp_domain(dmat, 1.9)**6)
226 LJpB = self.vdw_B / dmat6
227 LJpA = self.vdw_A / (dmat6**2)
228 Cp = (self.q1q2 / self._warp_domain(dmat, 0.4))
229 return torch.nansum(LJpA - LJpB + Cp)
231 def _cdist_nb(self, x, cutoff=9.0, mask=False):
232 dmat = torch.cdist(x.permute(0, 2, 1), x.permute(0, 2, 1))
233 LJp = self.vdw_A / (self._warp_domain(dmat, 1.9)**12)
234 Cp = (self.q1q2 / self._warp_domain(dmat, 0.4))
235 return torch.nansum(LJp + Cp)
237 def _warp_domain(self, x, k):
238 return torch.nn.functional.elu(x - k, 1.0) + k
240 def _conv_bond_loss(self, x):
241 # x shape[B, 3, N]
242 loss = torch.tensor(0.0).float().to(self.device)
243 for i, weight in enumerate(self.b_weights):
244 y = torch.nn.functional.conv1d(x, weight.view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight) - 2))
245 loss += (self.b_force[i] * ((y.norm(dim=1) - self.b_equil[i])**2)).sum()
246 return loss
248 def _conv_angle_loss(self, x):
249 # x shape[X, 3, N]
250 loss = torch.tensor(0.0).float().to(self.device)
251 for i, weight in enumerate(self.a_weights):
252 v1 = torch.nn.functional.conv1d(x, weight[0].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[0]) - 3))
253 v2 = torch.nn.functional.conv1d(x, weight[1].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[1]) - 3))
254 xyz = torch.sum(v1 * v2, dim=1) / (torch.norm(v1, dim=1) * torch.norm(v2, dim=1))
255 theta = torch.acos(torch.clamp(xyz, min=-0.999999, max=0.999999))
256 energy = (self.a_force[i] * ((theta - self.a_equil[i])**2)).sum(dim=0)[self.a_masks[i]].sum()
257 loss += energy
258 return loss
260 def _conv_torsion_loss(self, x):
261 # x shape[X, 3, N]
262 loss = torch.tensor(0.0).float().to(self.device)
263 for i, weight in enumerate(self.t_weights):
264 b1 = torch.nn.functional.conv1d(x, weight[0].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[0]) - 4)) # i-j
265 b2 = torch.nn.functional.conv1d(x, weight[1].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[1]) - 4)) # j-k
266 b3 = torch.nn.functional.conv1d(x, weight[2].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[2]) - 4)) # l-k
267 c32 = torch.cross(b3, b2)
268 c12 = torch.cross(b1, b2)
269 torsion = torch.atan2((b2 * torch.cross(c32, c12)).sum(dim=1),
270 b2.norm(dim=1) * ((c12 * c32).sum(dim=1)))
271 p = self.t_para[i, :, :, :].unsqueeze(0)
272 loss += ((p[:, :, 1] / p[:, :, 0]) * (1 + torch.cos((p[:, :, 3] * torsion.unsqueeze(2)) - p[:, :, 2]))).sum()
273 return loss
275 def _roll_bond_angle_torsion_loss(self, x):
276 # x.shape [5,3,2145]
277 bloss = torch.tensor(0.0).float().to(self.device)
278 aloss = torch.tensor(0.0).float().to(self.device)
279 tloss = torch.tensor(0.0).float().to(self.device)
280 v = []
281 for i, diff in enumerate(self.brdiff):
282 v.append(x - x.roll(-diff, 2))
283 bloss += (((v[-1].norm(dim=1) - self.br_equil[i])**2) * self.br_force[i]).sum()
285 for i, diff in enumerate(self.ardiff):
286 v1 = self.arsign[i][0] * (v[diff[0]].roll(-self.arroll[i][0], 2))
287 v2 = self.arsign[i][1] * (v[diff[1]].roll(-self.arroll[i][1], 2))
288 xyz = torch.sum(v1 * v2, dim=1) / (torch.norm(v1, dim=1) * torch.norm(v2, dim=1))
289 theta = torch.acos(torch.clamp(xyz, min=-0.999999, max=0.999999))
290 energy = (self.ar_force[i] * ((theta - self.ar_equil[i])**2))
291 sum_e = energy.sum()
292 aloss += (sum_e)
294 for i, diff in enumerate(self.trdiff):
295 b1 = self.trsign[i][0] * (v[diff[0]].roll(-self.trroll[i][0], 2))
296 b2 = self.trsign[i][1] * (v[diff[1]].roll(-self.trroll[i][1], 2))
297 b3 = self.trsign[i][2] * (v[diff[2]].roll(-self.trroll[i][2], 2))
298 c32 = torch.cross(b3, b2)
299 c12 = torch.cross(b1, b2)
300 torsion = torch.atan2((b2 * torch.cross(c32, c12)).sum(dim=1),
301 b2.norm(dim=1) * ((c12 * c32).sum(dim=1)))
302 p = self.tr_para[i].unsqueeze(0)
303 tloss += (((p[:, :, 1] / p[:, :, 0]) * (1 + torch.cos((p[:, :, 3] * torsion.unsqueeze(2)) - p[:, :, 2]))).sum())
304 return bloss, aloss, tloss