Coverage for biobb_pytorch / mdae / loss / elbo.py: 49%

61 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1#!/usr/bin/env python 

2 

3# ============================================================================= 

4# MODULE DOCSTRING 

5# ============================================================================= 

6 

7""" 

8Evidence Lower BOund (ELBO) loss functions used to train variational Autoencoders. 

9""" 

10 

11__all__ = ["ELBOGaussiansLoss", "elbo_gaussians_loss", "ELBOLoss", "ELBOGaussianMixtureLoss"] 

12 

13 

14# ============================================================================= 

15# GLOBAL IMPORTS 

16# ============================================================================= 

17 

18from typing import Optional 

19import torch 

20import math 

21from torch import nn 

22from torch.nn import functional as F 

23from mlcolvar.core.loss.mse import mse_loss 

24 

25 

26# ============================================================================= 

27# LOSS FUNCTIONS 

28# ============================================================================= 

29 

30 

31class ELBOGaussiansLoss(torch.nn.Module): 

32 """ELBO loss function assuming the latent and reconstruction distributions are Gaussian. 

33 

34 The ELBO uses the MSE as the reconstruction loss (i.e., assumes that the 

35 decoder outputs the mean of a Gaussian distribution with variance 1), and 

36 the KL divergence between two normal distributions ``N(mean, var)`` and 

37 ``N(0, 1)``, where ``mean`` and ``var`` are the output of the encoder. 

38 """ 

39 

40 def forward( 

41 self, 

42 target: torch.Tensor, 

43 output: torch.Tensor, 

44 mean: torch.Tensor, 

45 log_variance: torch.Tensor, 

46 weights: Optional[torch.Tensor] = None, 

47 ) -> torch.Tensor: 

48 """Compute the value of the loss function. 

49 

50 Parameters 

51 ---------- 

52 target : torch.Tensor 

53 Shape ``(n_batches, in_features)``. Data points (e.g. input of encoder 

54 or time-lagged features). 

55 output : torch.Tensor 

56 Shape ``(n_batches, in_features)``. Output of the decoder. 

57 mean : torch.Tensor 

58 Shape ``(n_batches, latent_features)``. The means of the Gaussian 

59 distributions associated to the inputs. 

60 log_variance : torch.Tensor 

61 Shape ``(n_batches, latent_features)``. The logarithm of the variances 

62 of the Gaussian distributions associated to the inputs. 

63 weights : torch.Tensor, optional 

64 Shape ``(n_batches,)`` or ``(n_batches,1)``. If given, the average over 

65 batches is weighted. The default (``None``) is unweighted. 

66 

67 Returns 

68 ------- 

69 loss: torch.Tensor 

70 The value of the loss function. 

71 """ 

72 return elbo_gaussians_loss(target, output, mean, log_variance, weights) 

73 

74 

75def elbo_gaussians_loss( 

76 target: torch.Tensor, 

77 output: torch.Tensor, 

78 mean: torch.Tensor, 

79 log_variance: torch.Tensor, 

80 weights: Optional[torch.Tensor] = None, 

81) -> torch.Tensor: 

82 """ELBO loss function assuming the latent and reconstruction distributions are Gaussian. 

83 

84 The ELBO uses the MSE as the reconstruction loss (i.e., assumes that the 

85 decoder outputs the mean of a Gaussian distribution with variance 1), and 

86 the KL divergence between two normal distributions ``N(mean, var)`` and 

87 ``N(0, 1)``, where ``mean`` and ``var`` are the output of the encoder. 

88 

89 Parameters 

90 ---------- 

91 target : torch.Tensor 

92 Shape ``(n_batches, in_features)``. Data points (e.g. input of encoder 

93 or time-lagged features). 

94 output : torch.Tensor 

95 Shape ``(n_batches, in_features)``. Output of the decoder. 

96 mean : torch.Tensor 

97 Shape ``(n_batches, latent_features)``. The means of the Gaussian 

98 distributions associated to the inputs. 

99 log_variance : torch.Tensor 

100 Shape ``(n_batches, latent_features)``. The logarithm of the variances 

101 of the Gaussian distributions associated to the inputs. 

102 weights : torch.Tensor, optional 

103 Shape ``(n_batches,)`` or ``(n_batches,1)``. If given, the average over 

104 batches is weighted. The default (``None``) is unweighted. 

105 

106 Returns 

107 ------- 

108 loss: torch.Tensor 

109 The value of the loss function. 

110 """ 

111 # KL divergence between N(mean, variance) and N(0, 1). 

112 # See https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians 

113 kl = -0.5 * (log_variance - log_variance.exp() - mean**2 + 1).sum(dim=1) 

114 

115 # Weighted mean over batches. 

116 if weights is None: 

117 kl = kl.mean() 

118 else: 

119 weights = weights.squeeze() 

120 if weights.shape != kl.shape: 

121 raise ValueError( 

122 f"weights should be a tensor of shape (n_batches,) or (n_batches,1), not {weights.shape}." 

123 ) 

124 kl = (kl * weights).sum() 

125 

126 # Reconstruction loss. 

127 reconstruction = mse_loss(output, target, weights=weights) 

128 

129 return reconstruction + kl 

130 

131 

132class ELBOLoss(nn.Module): 

133 """ 

134 Variational Autoencoder ELBO loss function. 

135 

136 Implements the evidence lower bound (ELBO) objective: 

137 L = reconstruction_loss + beta * KL_divergence 

138 

139 Reconstruction loss options: 

140 - Mean-squared error (MSE) -> assumes Gaussian decoder with unit variance 

141 - Binary cross-entropy (BCE) -> assumes Bernoulli decoder 

142 

143 KL divergence is computed analytically between the approximate posterior 

144 q(z|x) = N(mu, diag(var)) and the prior p(z) = N(0, I): 

145 KL(q||p) = -0.5 * sum(1 + log(var) - mu^2 - var) 

146 

147 Parameters 

148 ---------- 

149 beta : float, default=1.0 

150 Scaling factor for the KL divergence term (beta-VAE). 

151 loss_type : {'mse', 'bce'}, default='mse' 

152 Type of reconstruction loss: 

153 - 'mse': use mean squared error 

154 - 'bce': use binary cross-entropy 

155 reduction : {'sum', 'mean', 'none'}, default='sum' 

156 How to reduce the reconstruction loss over elements: 

157 - 'sum': sum over all elements 

158 - 'mean': average over all elements 

159 - 'none': no reduction (returns per-element loss) 

160 """ 

161 

162 def __init__( 

163 self, 

164 beta: float = 1.0, 

165 reconstruction: str = 'mse', 

166 reduction: str = 'sum' 

167 ): 

168 super().__init__() 

169 if reconstruction not in {'mse', 'bce'}: 

170 raise ValueError(f"Unsupported reconstruction '{reconstruction}', choose 'mse' or 'bce'.") 

171 if reduction not in {'sum', 'mean', 'none'}: 

172 raise ValueError(f"Unsupported reduction '{reduction}', choose 'sum', 'mean', or 'none'.") 

173 

174 self.beta = beta 

175 self.reconstruction = reconstruction 

176 self.reduction = reduction 

177 

178 def forward( 

179 self, 

180 x: torch.Tensor, 

181 recon_x: torch.Tensor, 

182 mu: torch.Tensor, 

183 log_var: torch.Tensor 

184 ) -> torch.Tensor: 

185 """ 

186 Compute the combined ELBO loss. 

187 

188 Parameters 

189 ---------- 

190 x : Tensor 

191 Original input tensor (shape: [batch_size, ...]). 

192 recon_x : Tensor 

193 Reconstructed output tensor (same shape as x). 

194 mu : Tensor 

195 Mean of the approximate posterior q(z|x) (shape: [batch_size, latent_dim]). 

196 log_var : Tensor 

197 Log-variance of q(z|x) (same shape as mu). 

198 

199 Returns 

200 ------- 

201 loss : Tensor 

202 Scalar loss (if reduction!='none') or tensor of per-element losses. 

203 """ 

204 # Reconstruction loss 

205 if self.reconstruction == 'bce': 

206 # For binary data, use BCE 

207 recon_loss = F.binary_cross_entropy( 

208 recon_x, x, reduction=self.reduction 

209 ) 

210 else: 

211 # For continuous data, use MSE 

212 recon_loss = F.mse_loss( 

213 recon_x, x, reduction=self.reduction 

214 ) 

215 

216 # Analytic KL divergence between N(mu, var) and N(0, I) 

217 # var = exp(log_var) 

218 var = torch.exp(log_var) 

219 kl_div = -0.5 * torch.sum( 

220 1 + log_var - mu.pow(2) - var, 

221 dim=1 # sum over latent dimension for each sample 

222 ) 

223 

224 # Combine terms: sum or mean over batch 

225 if self.reduction == 'mean': 

226 kl_div = kl_div.mean() 

227 elif self.reduction == 'sum': 

228 kl_div = kl_div.sum() 

229 # else 'none': keep per-sample KL vector 

230 

231 # Scale KL and add reconstruction 

232 return recon_loss + self.beta * kl_div 

233 

234 

235class ELBOGaussianMixtureLoss(nn.Module): 

236 """ 

237 Gaussian Mixture VAE loss. 

238 

239 Combines: 

240 1) Entropy regularization: -∑_i q(y=i|x) log q(y=i|x) 

241 2) Reconstruction + KL: 

242 - E_{q(y|x)} [ log p(x|z,y) ] 

243 + E_{q(y|x)} [ KL( q(z|x,y) ‖ p(z|y) ) ] 

244 """ 

245 

246 def __init__(self, k: int, r_nent: float = 1.0): 

247 """ 

248 Args: 

249 k Number of mixture components. 

250 r_nent Weight on the entropy term. 

251 """ 

252 super().__init__() 

253 self.k = k 

254 self.r_nent = r_nent 

255 

256 @staticmethod 

257 def log_normal(x: torch.Tensor, 

258 mu: torch.Tensor, 

259 var: torch.Tensor, 

260 eps: float = 1e-10) -> torch.Tensor: 

261 """ 

262 Compute log N(x; mu, var) summed over the last dim: 

263 -½ ∑ [ log(2π) + (x−μ)^2 / var + log var ] 

264 """ 

265 const = math.log(2 * math.pi) 

266 return -0.5 * torch.sum( 

267 const + (x - mu).pow(2) / (var + eps) + var.log(), 

268 dim=-1 

269 ) 

270 

271 def forward(self, 

272 x: torch.Tensor, 

273 qy_logit: torch.Tensor, 

274 xm_list: list[torch.Tensor], 

275 xv_list: list[torch.Tensor], 

276 z_list: list[torch.Tensor], 

277 zm_list: list[torch.Tensor], 

278 zv_list: list[torch.Tensor], 

279 zm_prior_list: list[torch.Tensor], 

280 zv_prior_list: list[torch.Tensor] 

281 ) -> torch.Tensor: 

282 """ 

283 Args: 

284 x [batch, n_features] Input data 

285 qy_logit [batch, k] Cluster logits 

286 xm_list, xv_list length-k lists of [batch, n_features] 

287 z_list, zm_list, zv_list length-k lists of [batch, n_cvs] 

288 zm_prior_list, zv_prior_list length-k lists of [batch, n_cvs] 

289 Returns: 

290 scalar loss = mean_batch( r_nent*nent + ∑_i qy_i * [rec_i + KL_i] ) 

291 """ 

292 # 1) cluster posteriors 

293 qy = F.softmax(qy_logit, dim=1) # [batch, k] 

294 

295 # 2) entropy regularization (cross-entropy of qy wrt itself) 

296 # nent = -E[ log q(y|x) ] 

297 nent = -torch.sum(qy * F.log_softmax(qy_logit, dim=1), dim=1).mean() 

298 

299 # 3) per-component reconstruction + KL 

300 comp_losses = [] 

301 for i in range(self.k): 

302 # reconstruction: - log p(x | z_i) 

303 rec_i = -self.log_normal(x, xm_list[i], xv_list[i]) 

304 # KL divergence: KL( q(z|x,y=i) ‖ p(z|y=i) ) 

305 kl_i = ( 

306 self.log_normal(z_list[i], zm_list[i], zv_list[i]) - 

307 self.log_normal(z_list[i], zm_prior_list[i], zv_prior_list[i]) 

308 ) 

309 comp_losses.append(rec_i + kl_i) # shape [batch] 

310 

311 # 4) weight each comp by qy[:,i] and sum 

312 weighted = [qy[:, i] * comp_losses[i] for i in range(self.k)] 

313 total = self.r_nent * nent + sum(weighted) # shape [batch] 

314 

315 return total, nent