Coverage for biobb_pytorch / mdae / loss / autocorrelation.py: 94%
17 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1#!/usr/bin/env python
3# =============================================================================
4# MODULE DOCSTRING
5# =============================================================================
7"""
8Autocorrelation loss.
9"""
11__all__ = ["AutocorrelationLoss", "autocorrelation_loss"]
14# =============================================================================
15# GLOBAL IMPORTS
16# =============================================================================
18from typing import Optional
20import torch
22from mlcolvar.core.stats.tica import TICA
23from mlcolvar.core.loss.eigvals import reduce_eigenvalues_loss
26# =============================================================================
27# LOSS FUNCTIONS
28# =============================================================================
31class AutocorrelationLoss(torch.nn.Module):
32 """(Weighted) autocorrelation loss.
34 Computes the sum (or another reducing functions) of the eigenvalues of the
35 autocorrelation matrix. This is the same loss function used in
36 :class:`~mlcolvar.cvs.timelagged.deeptica.DeepTICA`.
38 """
40 def __init__(self, reduce_mode: str = "sum2", invert_sign: bool = True):
41 """Constructor.
43 Parameters
44 ----------
45 reduce_mode : str
46 This determines how the eigenvalues are reduced, e.g., ``sum``, ``sum2``
47 (see also :class:`~mlcolvar.core.loss.eigvals.ReduceEigenvaluesLoss`). The
48 default is ``'sum2'``.
49 invert_sign: bool, optional
50 Whether to return the negative autocorrelation in order to be minimized
51 with gradient descent methods. Default is ``True``.
52 """
53 super().__init__()
54 self.reduce_mode = reduce_mode
55 self.invert_sign = invert_sign
57 def forward(
58 self,
59 x: torch.Tensor,
60 x_lag: torch.Tensor,
61 weights: Optional[torch.Tensor] = None,
62 weights_lag: Optional[torch.Tensor] = None,
63 ) -> torch.Tensor:
64 """Estimate the autocorrelation.
66 Parameters
67 ----------
68 x : torch.Tensor
69 Shape ``(n_batches, n_features)``. The features of the sample at
70 time ``t``.
71 x_lag : torch.Tensor
72 Shape ``(n_batches, n_features)``. The features of the sample at
73 time ``t + lag``.
74 weights : torch.Tensor, optional
75 Shape ``(n_batches,)`` or ``(n_batches, 1)``. The weights associated
76 to ``x`` at time ``t``. Default is ``None``.
77 weights_lag : torch.Tensor, optional
78 Shape ``(n_batches,)`` or ``(n_batches, 1)``. The weights associated
79 to ``x`` at time ``t + lag``. Default is ``None``.
81 Returns
82 -------
83 loss : torch.Tensor
84 Loss value.
85 """
86 return autocorrelation_loss(
87 x,
88 x_lag,
89 weights=weights,
90 weights_lag=weights_lag,
91 reduce_mode=self.reduce_mode,
92 invert_sign=self.invert_sign,
93 )
96def autocorrelation_loss(
97 x: torch.Tensor,
98 x_lag: torch.Tensor,
99 weights: Optional[torch.Tensor] = None,
100 weights_lag: Optional[torch.Tensor] = None,
101 reduce_mode: str = "sum2",
102 invert_sign: bool = True,
103) -> torch.Tensor:
104 """(Weighted) autocorrelation loss.
106 Computes the sum (or another reducing functions) of the eigenvalues of the
107 autocorrelation matrix. This is the same loss function used in
108 :class:`~mlcolvar.cvs.timelagged.deeptica.DeepTICA`.
110 Parameters
111 ----------
112 x : torch.Tensor
113 Shape ``(n_batches, n_features)``. The features of the sample at
114 time ``t``.
115 x_lag : torch.Tensor
116 Shape ``(n_batches, n_features)``. The features of the sample at
117 time ``t + lag``.
118 weights : torch.Tensor, optional
119 Shape ``(n_batches,)`` or ``(n_batches, 1)``. The weights associated
120 to ``x`` at time ``t``. Default is ``None``.
121 weights_lag : torch.Tensor, optional
122 Shape ``(n_batches,)`` or ``(n_batches, 1)``. The weights associated
123 to ``x`` at time ``t + lag``. Default is ``None``.
124 reduce_mode : str
125 This determines how the eigenvalues are reduced, e.g., ``sum``, ``sum2``
126 (see also :class:`~mlcolvar.core.loss.eigvals.ReduceEigenvaluesLoss`). The
127 default is ``'sum2'``.
128 invert_sign: bool, optional
129 Whether to return the negative autocorrelation in order to be minimized
130 with gradient descent methods. Default is ``True``.
132 Returns
133 -------
134 loss: torch.Tensor
135 Loss value.
136 """
137 tica = TICA(in_features=x.shape[-1])
138 eigvals, _ = tica.compute(data=[x, x_lag], weights=[weights, weights_lag])
139 loss = reduce_eigenvalues_loss(eigvals, mode=reduce_mode, invert_sign=invert_sign)
140 return loss