Coverage for biobb_pytorch / mdae / loss / physics_loss.py: 22%
36 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import torch
2import numpy as np
3from biobb_pytorch.mdae.loss.utils.torch_protein_energy import TorchProteinEnergy
6class PhysicsLoss(torch.nn.Module):
7 """
8 Physics loss for the FoldingNet model.
9 """
11 def __init__(self, stats, protein_energy=None, physics_scaling_factor=0.1):
12 super().__init__()
14 if stats is not None:
15 top = stats['topology']
17 x0_coords = torch.tensor(top.xyz[0]).permute(1, 0)
19 atominfo = []
20 for i in top.topology.atoms:
21 atominfo.append([i.name, i.residue.name, i.residue.index + 1])
22 atominfo = np.array(atominfo, dtype=object)
24 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26 self.protein_energy = TorchProteinEnergy(x0_coords,
27 atominfo,
28 device=device,
29 method='roll')
31 else:
32 self.protein_energy = protein_energy
34 self.psf = physics_scaling_factor
36 def mse_loss(self, batch, decoded):
37 """
38 Mean squared error loss for the FoldingNet model.
39 """
40 return ((batch - decoded) ** 2).mean()
42 def total_physics_loss(self, decoded_interpolation):
43 '''
44 Called from both :func:`train_step <molearn.trainers.Torch_Physics_Trainer.train_step>` and :func:`valid_step <molearn.trainers.Torch_Physics_Trainer.valid_step>`.
45 Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved as ``self._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss``.
47 :param torch.Tensor batch: tensor of shape [batch_size, 3, n_atoms]. Give access to the mini-batch of structures. This is used to determine ``n_atoms``
48 :param torch.Tensor latent: tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch.
49 '''
50 bond, angle, torsion = self.protein_energy._roll_bond_angle_torsion_loss(decoded_interpolation)
51 n = len(decoded_interpolation)
52 bond /= n
53 angle /= n
54 torsion /= n
55 _all = torch.tensor([bond, angle, torsion])
56 _all[_all.isinf()] = 1e35
57 total_physics = _all.nansum()
59 return {'physics_loss': total_physics, 'bond_energy': bond, 'angle_energy': angle, 'torsion_energy': torsion}
61 def forward(self,
62 batch,
63 decoded,
64 decoded_interpolation
65 ):
66 """
67 Forward pass for the FoldingNet model.
68 """
69 mse_loss = self.mse_loss(batch, decoded)
71 physics_loss_dict = self.total_physics_loss(decoded_interpolation)
72 physics_loss = physics_loss_dict['physics_loss']
74 with torch.no_grad():
75 scale = self.psf * mse_loss / (physics_loss + 1e-5)
77 return mse_loss, physics_loss_dict, scale