Coverage for biobb_pytorch / mdae / loss / fisher.py: 93%
27 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1#!/usr/bin/env python
3# =============================================================================
4# MODULE DOCSTRING
5# =============================================================================
7"""
8Fisher discriminant loss for (Deep) Linear Discriminant Analysis.
9"""
11__all__ = ["FisherDiscriminantLoss", "fisher_discriminant_loss"]
14# =============================================================================
15# GLOBAL IMPORTS
16# =============================================================================
18from typing import Optional
20import torch
22from mlcolvar.core.stats import LDA
23from mlcolvar.core.loss import reduce_eigenvalues_loss
26# =============================================================================
27# LOSS FUNCTIONS
28# =============================================================================
31class FisherDiscriminantLoss(torch.nn.Module):
32 """Fisher's discriminant ratio.
34 Computes the sum (or another reducing functions) of the eigenvalues of the
35 ratio between the Fisher's scatter matrices. This is the same loss function
36 used in :class:`~mlcolvar.cvs.supervised.deeplda.DeepLDA`.
37 """
39 def __init__(
40 self,
41 n_states: int,
42 lda_mode: str = "standard",
43 reduce_mode: str = "sum",
44 lorentzian_reg: Optional[float] = None,
45 invert_sign: bool = True,
46 ):
47 """Constructor.
49 Parameters
50 ----------
51 n_states : int
52 The number of states. Labels are in the range ``[0, n_states-1]``.
53 lda_mode : str
54 Either ``'standard'`` or ``'harmonic'``. This determines how the scatter
55 matrices are computed (see also :class:`~mlcolvar.core.stats.lda.LDA`). The
56 default is ``'standard'``.
57 reduce_mode : str
58 This determines how the eigenvalues are reduced, e.g., ``sum``, ``sum2``
59 (see also :class:`~mlcolvar.core.loss.eigvals.ReduceEigenvaluesLoss`). The
60 default is ``'sum'``.
61 lorentzian_reg: float, optional
62 The magnitude of the regularization for Lorentzian regularization.
63 If not provided, this is automatically set.
64 invert_sign: bool, optional
65 Whether to return the negative Fisher's discriminant ratio in order to be
66 minimized with gradient descent methods. Default is ``True``.
67 """
68 super().__init__()
69 self.n_states = n_states
70 self.lda_mode = lda_mode
71 self.reduce_mode = reduce_mode
72 self.lorentzian_reg = lorentzian_reg
73 self.invert_sign = invert_sign
75 def forward(
76 self,
77 x: torch.Tensor,
78 labels: torch.Tensor,
79 ) -> torch.Tensor:
80 """Compute the value of the loss function.
82 Parameters
83 ----------
84 x : torch.Tensor
85 Shape ``(n_batches, n_features)``. Input features.
86 labels : torch.Tensor
87 Shape ``(n_batches,)``. Classes labels.
89 Returns
90 -------
91 loss : torch.Tensor
92 Loss value.
93 """
94 return fisher_discriminant_loss(
95 x,
96 labels,
97 n_states=self.n_states,
98 lda_mode=self.lda_mode,
99 reduce_mode=self.reduce_mode,
100 lorentzian_reg=self.lorentzian_reg,
101 invert_sign=self.invert_sign,
102 )
105def fisher_discriminant_loss(
106 x: torch.Tensor,
107 labels: torch.Tensor,
108 n_states: int,
109 lda_mode: str = "standard",
110 reduce_mode: str = "sum",
111 sw_reg: Optional[float] = 0.05,
112 lorentzian_reg: Optional[float] = None,
113 invert_sign: bool = True,
114) -> torch.Tensor:
115 """Fisher's discriminant ratio.
117 Computes the sum (or another reducing functions) of the eigenvalues of the
118 ratio between the Fisher's scatter matrices with a Lorentzian regularization.
119 This is the same loss function used in :class:`~mlcolvar.cvs.supervised.deeplda.DeepLDA`.
121 Parameters
122 ----------
123 x : torch.Tensor
124 Shape ``(n_batches, n_features)``. Input features.
125 labels : torch.Tensor
126 Shape ``(n_batches,)``. Classes labels.
127 n_states : int
128 The number of states. Labels are in the range ``[0, n_states-1]``.
129 lda_mode : str, optional
130 Either ``'standard'`` or ``'harmonic'``. This determines how the scatter
131 matrices are computed (see also :class:`~mlcolvar.core.stats.lda.LDA`). The
132 default is ``'standard'``.
133 reduce_mode : str, optional
134 This determines how the eigenvalues are reduced, e.g., ``sum``, ``sum2``
135 (see also :class:`~mlcolvar.core.loss.eigvals.ReduceEigenvaluesLoss`). The
136 default is ``'sum'``.
137 sw_reg: float, optional
138 The magnitude of the regularization for the within-scatter matrix, by default
139 equal to 0.05.
140 lorentzian_reg: float, optional
141 The magnitude of the regularization for Lorentzian regularization. If not
142 provided, this is automatically set according to sw_reg.
143 invert_sign: bool, optional
144 Whether to return the negative Fisher's discriminant ratio in order to be
145 minimized with gradient descent methods. Default is ``True``.
147 Returns
148 -------
149 loss: torch.Tensor
150 Loss value.
151 """
152 # define lda object
153 lda = LDA(in_features=x.shape[-1], n_states=n_states, mode=lda_mode)
155 # regularize s_w
156 lda.sw_reg = sw_reg
158 # compute LDA eigvals
159 eigvals, _ = lda.compute(x, labels)
160 loss = reduce_eigenvalues_loss(eigvals, mode=reduce_mode, invert_sign=invert_sign)
162 # Add lorentzian regularization. The heuristic is the same used by DeepLDA.
163 # TODO: ENCAPSULATE THIS IN A UTILITY FUNCTION USED BY BOTH THIS AND DEEPLDA?
164 if lorentzian_reg is None:
165 if sw_reg == 0 or sw_reg is None:
166 raise ValueError(
167 f"Unable to calculate `lorentzian_reg` from `sw_reg` ({sw_reg}), please specify the value."
168 )
169 lorentzian_reg = 2.0 / sw_reg
170 reg_loss = x.pow(2).sum().div(x.size(0))
171 reg_loss = -lorentzian_reg / (1 + (reg_loss - 1).pow(2))
173 return loss + reg_loss