Coverage for biobb_pytorch / mdae / loss / eigvals.py: 61%

36 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1#!/usr/bin/env python 

2 

3# ============================================================================= 

4# MODULE DOCSTRING 

5# ============================================================================= 

6 

7""" 

8Reduce eigenvalues loss. 

9""" 

10 

11__all__ = ["ReduceEigenvaluesLoss", "reduce_eigenvalues_loss"] 

12 

13 

14# ============================================================================= 

15# GLOBAL IMPORTS 

16# ============================================================================= 

17 

18import torch 

19 

20 

21# ============================================================================= 

22# LOSS FUNCTIONS 

23# ============================================================================= 

24 

25 

26class ReduceEigenvaluesLoss(torch.nn.Module): 

27 """Calculate a monotonic function f(x) of the eigenvalues, by default the sum. 

28 

29 By default it returns -f(x) to be used as loss function to maximize 

30 eigenvalues in gradient descent schemes. 

31 

32 The following reduce functions are implemented: 

33 - sum : sum_i (lambda_i) 

34 - sum2 : sum_i (lambda_i)**2 

35 - gap : (lambda_1-lambda_2) 

36 - its : sum_i (1/log(lambda_i)) 

37 - single : (lambda_i) 

38 - single2 : (lambda_i)**2 

39 

40 """ 

41 

42 def __init__( 

43 self, 

44 mode: str = "sum", 

45 n_eig: int = 0, 

46 invert_sign: bool = True, 

47 ): 

48 """Constructor. 

49 

50 Parameters 

51 ---------- 

52 mode : str, optional 

53 Function of the eigenvalues to optimize (see notes). Default is ``'sum'``. 

54 n_eig: int, optional 

55 Number of eigenvalues to include in the loss (default: 0 --> all). 

56 In case of ``'single'`` and ``'single2'`` is used to specify which 

57 eigenvalue to use. 

58 invert_sign: bool, optional 

59 Whether to return the opposite of the function (in order to be minimized 

60 with GD methods). Default is ``True``. 

61 """ 

62 super().__init__() 

63 self.mode = mode 

64 self.n_eig = n_eig 

65 self.invert_sign = invert_sign 

66 

67 def forward(self, evals: torch.Tensor) -> torch.Tensor: 

68 """Compute the loss. 

69 

70 Parameters 

71 ---------- 

72 evals : torch.Tensor 

73 Shape ``(n_batches, n_eigenvalues)``. Eigenvalues. 

74 

75 Returns 

76 ------- 

77 loss : torch.Tensor 

78 """ 

79 return reduce_eigenvalues_loss(evals, self.mode, self.n_eig, self.invert_sign) 

80 

81 

82def reduce_eigenvalues_loss( 

83 evals: torch.Tensor, 

84 mode: str = "sum", 

85 n_eig: int = 0, 

86 invert_sign: bool = True, 

87) -> torch.Tensor: 

88 """Calculate a monotonic function f(x) of the eigenvalues, by default the sum. 

89 

90 By default it returns -f(x) to be used as loss function to maximize 

91 eigenvalues in gradient descent schemes. 

92 

93 Parameters 

94 ---------- 

95 evals : torch.Tensor 

96 Shape ``(n_batches, n_eigenvalues)``. Eigenvalues. 

97 mode : str, optional 

98 Function of the eigenvalues to optimize (see notes). Default is ``'sum'``. 

99 n_eig: int, optional 

100 Number of eigenvalues to include in the loss (default: 0 --> all). 

101 In case of ``'single'`` and ``'single2'`` is used to specify which 

102 eigenvalue to use. 

103 invert_sign: bool, optional 

104 Whether to return the opposite of the function (in order to be minimized 

105 with GD methods). Default is ``True``. 

106 

107 Notes 

108 ----- 

109 The following functions are implemented: 

110 - sum : sum_i (lambda_i) 

111 - sum2 : sum_i (lambda_i)**2 

112 - gap : (lambda_1-lambda_2) 

113 - its : sum_i (1/log(lambda_i)) 

114 - single : (lambda_i) 

115 - single2 : (lambda_i)**2 

116 

117 Returns 

118 ------- 

119 loss : torch.Tensor (scalar) 

120 Loss value. 

121 """ 

122 

123 # check if n_eig is given and 

124 if (n_eig > 0) & (len(evals) < n_eig): 

125 raise ValueError("n_eig must be lower than the number of eigenvalues.") 

126 elif n_eig == 0: 

127 if (mode == "single") | (mode == "single2"): 

128 raise ValueError("n_eig must be specified when using single or single2.") 

129 else: 

130 n_eig = len(evals) 

131 

132 loss = None 

133 

134 if mode == "sum": 

135 loss = torch.sum(evals[:n_eig]) 

136 elif mode == "sum2": 

137 g_lambda = torch.pow(evals, 2) 

138 loss = torch.sum(g_lambda[:n_eig]) 

139 elif mode == "gap": 

140 loss = evals[0] - evals[1] 

141 elif mode == "its": 

142 g_lambda = 1 / torch.log(evals) 

143 loss = torch.sum(g_lambda[:n_eig]) 

144 elif mode == "single": 

145 loss = evals[n_eig - 1] 

146 elif mode == "single2": 

147 loss = torch.pow(evals[n_eig - 1], 2) 

148 else: 

149 raise ValueError( 

150 f"unknown mode : {mode}. options: 'sum','sum2','gap','single','its'." 

151 ) 

152 

153 if invert_sign: 

154 loss *= -1 

155 

156 return loss