Coverage for biobb_pytorch / mdae / loss / tda_loss.py: 18%

45 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1#!/usr/bin/env python 

2 

3# ============================================================================= 

4# MODULE DOCSTRING 

5# ============================================================================= 

6 

7""" 

8Target Discriminant Analysis Loss Function. 

9""" 

10 

11__all__ = ["TDALoss", "tda_loss"] 

12 

13 

14# ============================================================================= 

15# GLOBAL IMPORTS 

16# ============================================================================= 

17 

18from typing import Union 

19from warnings import warn 

20 

21import torch 

22 

23 

24# ============================================================================= 

25# LOSS FUNCTIONS 

26# ============================================================================= 

27 

28 

29class TDALoss(torch.nn.Module): 

30 """Compute a loss function as the distance from a simple Gaussian target distribution.""" 

31 

32 def __init__( 

33 self, 

34 n_states: int, 

35 target_centers: Union[list, torch.Tensor], 

36 target_sigmas: Union[list, torch.Tensor], 

37 alpha: float = 1, 

38 beta: float = 100, 

39 ): 

40 """Constructor. 

41 

42 Parameters 

43 ---------- 

44 n_states : int 

45 Number of states. The integer labels are expected to be in between 0 

46 and ``n_states-1``. 

47 target_centers : list or torch.Tensor 

48 Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets. 

49 target_sigmas : list or torch.Tensor 

50 Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets. 

51 alpha : float, optional 

52 Centers_loss component prefactor, by default 1. 

53 beta : float, optional 

54 Sigmas loss compontent prefactor, by default 100. 

55 """ 

56 super().__init__() 

57 self.n_states = n_states 

58 if not isinstance(target_centers, torch.Tensor): 

59 target_centers = torch.Tensor(target_centers) 

60 if not isinstance(target_sigmas, torch.Tensor): 

61 target_sigmas = torch.Tensor(target_sigmas) 

62 self.register_buffer("target_centers", target_centers) 

63 self.register_buffer("target_sigmas", target_sigmas) 

64 self.alpha = alpha 

65 self.beta = beta 

66 

67 def forward( 

68 self, H: torch.Tensor, labels: torch.Tensor, return_loss_terms: bool = False 

69 ) -> torch.Tensor: 

70 """Compute the value of the loss function. 

71 

72 Parameters 

73 ---------- 

74 H : torch.Tensor 

75 Shape ``(n_batches, n_features)``. Output of the NN. 

76 labels : torch.Tensor 

77 Shape ``(n_batches,)``. Labels of the dataset. 

78 return_loss_terms : bool, optional 

79 If ``True``, the loss terms associated to the center and standard 

80 deviations of the target Gaussians are returned as well. Default 

81 is ``False``. 

82 

83 Returns 

84 ------- 

85 loss : torch.Tensor 

86 Loss value. 

87 loss_centers : torch.Tensor, optional 

88 Only returned if ``return_loss_terms is True``. The value of the 

89 loss term associated to the centers of the target Gaussians. 

90 loss_sigmas : torch.Tensor, optional 

91 Only returned if ``return_loss_terms is True``. The value of the 

92 loss term associated to the standard deviations of the target Gaussians. 

93 """ 

94 return tda_loss( 

95 H, 

96 labels, 

97 self.n_states, 

98 self.target_centers, 

99 self.target_sigmas, 

100 self.alpha, 

101 self.beta, 

102 return_loss_terms, 

103 ) 

104 

105 

106def tda_loss( 

107 H: torch.Tensor, 

108 labels: torch.Tensor, 

109 n_states: int, 

110 target_centers: Union[list, torch.Tensor], 

111 target_sigmas: Union[list, torch.Tensor], 

112 alpha: float = 1, 

113 beta: float = 100, 

114 return_loss_terms: bool = False, 

115) -> torch.Tensor: 

116 """ 

117 Compute a loss function as the distance from a simple Gaussian target distribution. 

118 

119 Parameters 

120 ---------- 

121 H : torch.Tensor 

122 Shape ``(n_batches, n_cvs)``. Output of the NN. 

123 labels : torch.Tensor 

124 Shape ``(n_batches,)``. Labels of the dataset. 

125 n_states : int 

126 The integer labels are expected to be in between 0 and ``n_states-1``. 

127 target_centers : list or torch.Tensor 

128 Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets. 

129 target_sigmas : list or torch.Tensor 

130 Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets. 

131 alpha : float, optional 

132 Centers_loss component prefactor, by default 1. 

133 beta : float, optional 

134 Sigmas loss compontent prefactor, by default 100. 

135 return_loss_terms : bool, optional 

136 If ``True``, the loss terms associated to the center and standard deviations 

137 of the target Gaussians are returned as well. Default is ``False``. 

138 

139 Returns 

140 ------- 

141 loss : torch.Tensor 

142 Loss value. 

143 loss_centers : torch.Tensor, optional 

144 Only returned if ``return_loss_terms is True``. The value of the loss 

145 term associated to the centers of the target Gaussians. 

146 loss_sigmas : torch.Tensor, optional 

147 Only returned if ``return_loss_terms is True``. The value of the loss 

148 term associated to the standard deviations of the target Gaussians. 

149 """ 

150 if not isinstance(target_centers, torch.Tensor): 

151 target_centers = torch.Tensor(target_centers) 

152 if not isinstance(target_sigmas, torch.Tensor): 

153 target_sigmas = torch.Tensor(target_sigmas) 

154 

155 device = H.device 

156 target_centers = target_centers.to(device) 

157 target_sigmas = target_sigmas.to(device) 

158 loss_centers = torch.zeros_like(target_centers, device=device) 

159 loss_sigmas = torch.zeros_like(target_sigmas, device=device) 

160 

161 for i in range(n_states): 

162 # check which elements belong to class i 

163 if not (labels == i).any(): 

164 raise ValueError( 

165 f"State {i} was not represented in this batch! Either use bigger batch_size or a more equilibrated dataset composition!" 

166 ) 

167 else: 

168 H_red = H[torch.nonzero(labels == i, as_tuple=True)] 

169 

170 # compute mean and standard deviation over the class i 

171 mu = torch.mean(H_red, 0) 

172 if len(torch.nonzero(labels == i)) == 1: 

173 warn( 

174 f"There is only one sample for state {i} in this batch! Std is set to 0, this may affect the training! Either use bigger batch_size or a more equilibrated dataset composition!" 

175 ) 

176 sigma = 0 

177 else: 

178 sigma = torch.std(H_red, 0) 

179 

180 # compute loss function contributes for class i 

181 loss_centers[i] = alpha * (mu - target_centers[i]).pow(2) 

182 loss_sigmas[i] = beta * (sigma - target_sigmas[i]).pow(2) 

183 

184 # get total model loss 

185 loss_centers = torch.sum(loss_centers) 

186 loss_sigmas = torch.sum(loss_sigmas) 

187 loss = loss_centers + loss_sigmas 

188 

189 if return_loss_terms: 

190 return loss, loss_centers, loss_sigmas 

191 return loss