Coverage for biobb_pytorch / mdae / loss / physics_loss.py: 22%

36 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import torch 

2import numpy as np 

3from biobb_pytorch.mdae.loss.utils.torch_protein_energy import TorchProteinEnergy 

4 

5 

6class PhysicsLoss(torch.nn.Module): 

7 """ 

8 Physics loss for the FoldingNet model. 

9 """ 

10 

11 def __init__(self, stats, protein_energy=None, physics_scaling_factor=0.1): 

12 super().__init__() 

13 

14 if stats is not None: 

15 top = stats['topology'] 

16 

17 x0_coords = torch.tensor(top.xyz[0]).permute(1, 0) 

18 

19 atominfo = [] 

20 for i in top.topology.atoms: 

21 atominfo.append([i.name, i.residue.name, i.residue.index + 1]) 

22 atominfo = np.array(atominfo, dtype=object) 

23 

24 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

25 

26 self.protein_energy = TorchProteinEnergy(x0_coords, 

27 atominfo, 

28 device=device, 

29 method='roll') 

30 

31 else: 

32 self.protein_energy = protein_energy 

33 

34 self.psf = physics_scaling_factor 

35 

36 def mse_loss(self, batch, decoded): 

37 """ 

38 Mean squared error loss for the FoldingNet model. 

39 """ 

40 return ((batch - decoded) ** 2).mean() 

41 

42 def total_physics_loss(self, decoded_interpolation): 

43 ''' 

44 Called from both :func:`train_step <molearn.trainers.Torch_Physics_Trainer.train_step>` and :func:`valid_step <molearn.trainers.Torch_Physics_Trainer.valid_step>`. 

45 Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved as ``self._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss``. 

46 

47 :param torch.Tensor batch: tensor of shape [batch_size, 3, n_atoms]. Give access to the mini-batch of structures. This is used to determine ``n_atoms`` 

48 :param torch.Tensor latent: tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch. 

49 ''' 

50 bond, angle, torsion = self.protein_energy._roll_bond_angle_torsion_loss(decoded_interpolation) 

51 n = len(decoded_interpolation) 

52 bond /= n 

53 angle /= n 

54 torsion /= n 

55 _all = torch.tensor([bond, angle, torsion]) 

56 _all[_all.isinf()] = 1e35 

57 total_physics = _all.nansum() 

58 

59 return {'physics_loss': total_physics, 'bond_energy': bond, 'angle_energy': angle, 'torsion_energy': torsion} 

60 

61 def forward(self, 

62 batch, 

63 decoded, 

64 decoded_interpolation 

65 ): 

66 """ 

67 Forward pass for the FoldingNet model. 

68 """ 

69 mse_loss = self.mse_loss(batch, decoded) 

70 

71 physics_loss_dict = self.total_physics_loss(decoded_interpolation) 

72 physics_loss = physics_loss_dict['physics_loss'] 

73 

74 with torch.no_grad(): 

75 scale = self.psf * mse_loss / (physics_loss + 1e-5) 

76 

77 return mse_loss, physics_loss_dict, scale