Coverage for biobb_pytorch / mdae / loss / fisher.py: 93%

27 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1#!/usr/bin/env python 

2 

3# ============================================================================= 

4# MODULE DOCSTRING 

5# ============================================================================= 

6 

7""" 

8Fisher discriminant loss for (Deep) Linear Discriminant Analysis. 

9""" 

10 

11__all__ = ["FisherDiscriminantLoss", "fisher_discriminant_loss"] 

12 

13 

14# ============================================================================= 

15# GLOBAL IMPORTS 

16# ============================================================================= 

17 

18from typing import Optional 

19 

20import torch 

21 

22from mlcolvar.core.stats import LDA 

23from mlcolvar.core.loss import reduce_eigenvalues_loss 

24 

25 

26# ============================================================================= 

27# LOSS FUNCTIONS 

28# ============================================================================= 

29 

30 

31class FisherDiscriminantLoss(torch.nn.Module): 

32 """Fisher's discriminant ratio. 

33 

34 Computes the sum (or another reducing functions) of the eigenvalues of the 

35 ratio between the Fisher's scatter matrices. This is the same loss function 

36 used in :class:`~mlcolvar.cvs.supervised.deeplda.DeepLDA`. 

37 """ 

38 

39 def __init__( 

40 self, 

41 n_states: int, 

42 lda_mode: str = "standard", 

43 reduce_mode: str = "sum", 

44 lorentzian_reg: Optional[float] = None, 

45 invert_sign: bool = True, 

46 ): 

47 """Constructor. 

48 

49 Parameters 

50 ---------- 

51 n_states : int 

52 The number of states. Labels are in the range ``[0, n_states-1]``. 

53 lda_mode : str 

54 Either ``'standard'`` or ``'harmonic'``. This determines how the scatter 

55 matrices are computed (see also :class:`~mlcolvar.core.stats.lda.LDA`). The 

56 default is ``'standard'``. 

57 reduce_mode : str 

58 This determines how the eigenvalues are reduced, e.g., ``sum``, ``sum2`` 

59 (see also :class:`~mlcolvar.core.loss.eigvals.ReduceEigenvaluesLoss`). The 

60 default is ``'sum'``. 

61 lorentzian_reg: float, optional 

62 The magnitude of the regularization for Lorentzian regularization. 

63 If not provided, this is automatically set. 

64 invert_sign: bool, optional 

65 Whether to return the negative Fisher's discriminant ratio in order to be 

66 minimized with gradient descent methods. Default is ``True``. 

67 """ 

68 super().__init__() 

69 self.n_states = n_states 

70 self.lda_mode = lda_mode 

71 self.reduce_mode = reduce_mode 

72 self.lorentzian_reg = lorentzian_reg 

73 self.invert_sign = invert_sign 

74 

75 def forward( 

76 self, 

77 x: torch.Tensor, 

78 labels: torch.Tensor, 

79 ) -> torch.Tensor: 

80 """Compute the value of the loss function. 

81 

82 Parameters 

83 ---------- 

84 x : torch.Tensor 

85 Shape ``(n_batches, n_features)``. Input features. 

86 labels : torch.Tensor 

87 Shape ``(n_batches,)``. Classes labels. 

88 

89 Returns 

90 ------- 

91 loss : torch.Tensor 

92 Loss value. 

93 """ 

94 return fisher_discriminant_loss( 

95 x, 

96 labels, 

97 n_states=self.n_states, 

98 lda_mode=self.lda_mode, 

99 reduce_mode=self.reduce_mode, 

100 lorentzian_reg=self.lorentzian_reg, 

101 invert_sign=self.invert_sign, 

102 ) 

103 

104 

105def fisher_discriminant_loss( 

106 x: torch.Tensor, 

107 labels: torch.Tensor, 

108 n_states: int, 

109 lda_mode: str = "standard", 

110 reduce_mode: str = "sum", 

111 sw_reg: Optional[float] = 0.05, 

112 lorentzian_reg: Optional[float] = None, 

113 invert_sign: bool = True, 

114) -> torch.Tensor: 

115 """Fisher's discriminant ratio. 

116 

117 Computes the sum (or another reducing functions) of the eigenvalues of the 

118 ratio between the Fisher's scatter matrices with a Lorentzian regularization. 

119 This is the same loss function used in :class:`~mlcolvar.cvs.supervised.deeplda.DeepLDA`. 

120 

121 Parameters 

122 ---------- 

123 x : torch.Tensor 

124 Shape ``(n_batches, n_features)``. Input features. 

125 labels : torch.Tensor 

126 Shape ``(n_batches,)``. Classes labels. 

127 n_states : int 

128 The number of states. Labels are in the range ``[0, n_states-1]``. 

129 lda_mode : str, optional 

130 Either ``'standard'`` or ``'harmonic'``. This determines how the scatter 

131 matrices are computed (see also :class:`~mlcolvar.core.stats.lda.LDA`). The 

132 default is ``'standard'``. 

133 reduce_mode : str, optional 

134 This determines how the eigenvalues are reduced, e.g., ``sum``, ``sum2`` 

135 (see also :class:`~mlcolvar.core.loss.eigvals.ReduceEigenvaluesLoss`). The 

136 default is ``'sum'``. 

137 sw_reg: float, optional 

138 The magnitude of the regularization for the within-scatter matrix, by default 

139 equal to 0.05. 

140 lorentzian_reg: float, optional 

141 The magnitude of the regularization for Lorentzian regularization. If not 

142 provided, this is automatically set according to sw_reg. 

143 invert_sign: bool, optional 

144 Whether to return the negative Fisher's discriminant ratio in order to be 

145 minimized with gradient descent methods. Default is ``True``. 

146 

147 Returns 

148 ------- 

149 loss: torch.Tensor 

150 Loss value. 

151 """ 

152 # define lda object 

153 lda = LDA(in_features=x.shape[-1], n_states=n_states, mode=lda_mode) 

154 

155 # regularize s_w 

156 lda.sw_reg = sw_reg 

157 

158 # compute LDA eigvals 

159 eigvals, _ = lda.compute(x, labels) 

160 loss = reduce_eigenvalues_loss(eigvals, mode=reduce_mode, invert_sign=invert_sign) 

161 

162 # Add lorentzian regularization. The heuristic is the same used by DeepLDA. 

163 # TODO: ENCAPSULATE THIS IN A UTILITY FUNCTION USED BY BOTH THIS AND DEEPLDA? 

164 if lorentzian_reg is None: 

165 if sw_reg == 0 or sw_reg is None: 

166 raise ValueError( 

167 f"Unable to calculate `lorentzian_reg` from `sw_reg` ({sw_reg}), please specify the value." 

168 ) 

169 lorentzian_reg = 2.0 / sw_reg 

170 reg_loss = x.pow(2).sum().div(x.size(0)) 

171 reg_loss = -lorentzian_reg / (1 + (reg_loss - 1).pow(2)) 

172 

173 return loss + reg_loss