Coverage for biobb_pytorch / mdae / loss / autocorrelation.py: 94%

17 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1#!/usr/bin/env python 

2 

3# ============================================================================= 

4# MODULE DOCSTRING 

5# ============================================================================= 

6 

7""" 

8Autocorrelation loss. 

9""" 

10 

11__all__ = ["AutocorrelationLoss", "autocorrelation_loss"] 

12 

13 

14# ============================================================================= 

15# GLOBAL IMPORTS 

16# ============================================================================= 

17 

18from typing import Optional 

19 

20import torch 

21 

22from mlcolvar.core.stats.tica import TICA 

23from mlcolvar.core.loss.eigvals import reduce_eigenvalues_loss 

24 

25 

26# ============================================================================= 

27# LOSS FUNCTIONS 

28# ============================================================================= 

29 

30 

31class AutocorrelationLoss(torch.nn.Module): 

32 """(Weighted) autocorrelation loss. 

33 

34 Computes the sum (or another reducing functions) of the eigenvalues of the 

35 autocorrelation matrix. This is the same loss function used in 

36 :class:`~mlcolvar.cvs.timelagged.deeptica.DeepTICA`. 

37 

38 """ 

39 

40 def __init__(self, reduce_mode: str = "sum2", invert_sign: bool = True): 

41 """Constructor. 

42 

43 Parameters 

44 ---------- 

45 reduce_mode : str 

46 This determines how the eigenvalues are reduced, e.g., ``sum``, ``sum2`` 

47 (see also :class:`~mlcolvar.core.loss.eigvals.ReduceEigenvaluesLoss`). The 

48 default is ``'sum2'``. 

49 invert_sign: bool, optional 

50 Whether to return the negative autocorrelation in order to be minimized 

51 with gradient descent methods. Default is ``True``. 

52 """ 

53 super().__init__() 

54 self.reduce_mode = reduce_mode 

55 self.invert_sign = invert_sign 

56 

57 def forward( 

58 self, 

59 x: torch.Tensor, 

60 x_lag: torch.Tensor, 

61 weights: Optional[torch.Tensor] = None, 

62 weights_lag: Optional[torch.Tensor] = None, 

63 ) -> torch.Tensor: 

64 """Estimate the autocorrelation. 

65 

66 Parameters 

67 ---------- 

68 x : torch.Tensor 

69 Shape ``(n_batches, n_features)``. The features of the sample at 

70 time ``t``. 

71 x_lag : torch.Tensor 

72 Shape ``(n_batches, n_features)``. The features of the sample at 

73 time ``t + lag``. 

74 weights : torch.Tensor, optional 

75 Shape ``(n_batches,)`` or ``(n_batches, 1)``. The weights associated 

76 to ``x`` at time ``t``. Default is ``None``. 

77 weights_lag : torch.Tensor, optional 

78 Shape ``(n_batches,)`` or ``(n_batches, 1)``. The weights associated 

79 to ``x`` at time ``t + lag``. Default is ``None``. 

80 

81 Returns 

82 ------- 

83 loss : torch.Tensor 

84 Loss value. 

85 """ 

86 return autocorrelation_loss( 

87 x, 

88 x_lag, 

89 weights=weights, 

90 weights_lag=weights_lag, 

91 reduce_mode=self.reduce_mode, 

92 invert_sign=self.invert_sign, 

93 ) 

94 

95 

96def autocorrelation_loss( 

97 x: torch.Tensor, 

98 x_lag: torch.Tensor, 

99 weights: Optional[torch.Tensor] = None, 

100 weights_lag: Optional[torch.Tensor] = None, 

101 reduce_mode: str = "sum2", 

102 invert_sign: bool = True, 

103) -> torch.Tensor: 

104 """(Weighted) autocorrelation loss. 

105 

106 Computes the sum (or another reducing functions) of the eigenvalues of the 

107 autocorrelation matrix. This is the same loss function used in 

108 :class:`~mlcolvar.cvs.timelagged.deeptica.DeepTICA`. 

109 

110 Parameters 

111 ---------- 

112 x : torch.Tensor 

113 Shape ``(n_batches, n_features)``. The features of the sample at 

114 time ``t``. 

115 x_lag : torch.Tensor 

116 Shape ``(n_batches, n_features)``. The features of the sample at 

117 time ``t + lag``. 

118 weights : torch.Tensor, optional 

119 Shape ``(n_batches,)`` or ``(n_batches, 1)``. The weights associated 

120 to ``x`` at time ``t``. Default is ``None``. 

121 weights_lag : torch.Tensor, optional 

122 Shape ``(n_batches,)`` or ``(n_batches, 1)``. The weights associated 

123 to ``x`` at time ``t + lag``. Default is ``None``. 

124 reduce_mode : str 

125 This determines how the eigenvalues are reduced, e.g., ``sum``, ``sum2`` 

126 (see also :class:`~mlcolvar.core.loss.eigvals.ReduceEigenvaluesLoss`). The 

127 default is ``'sum2'``. 

128 invert_sign: bool, optional 

129 Whether to return the negative autocorrelation in order to be minimized 

130 with gradient descent methods. Default is ``True``. 

131 

132 Returns 

133 ------- 

134 loss: torch.Tensor 

135 Loss value. 

136 """ 

137 tica = TICA(in_features=x.shape[-1]) 

138 eigvals, _ = tica.compute(data=[x, x_lag], weights=[weights, weights_lag]) 

139 loss = reduce_eigenvalues_loss(eigvals, mode=reduce_mode, invert_sign=invert_sign) 

140 return loss