⬅ biobb_pytorch/mdae/loss/elbo.py source

1 #!/usr/bin/env python
2  
3 # =============================================================================
4 # MODULE DOCSTRING
5 # =============================================================================
6  
7 """
8 Evidence Lower BOund (ELBO) loss functions used to train variational Autoencoders.
9 """
10  
11 __all__ = ["ELBOGaussiansLoss", "elbo_gaussians_loss", "ELBOLoss", "ELBOGaussianMixtureLoss"]
12  
13  
14 # =============================================================================
15 # GLOBAL IMPORTS
16 # =============================================================================
17  
18 from typing import Optional
19 import torch
20 import math
21 from torch import nn
22 from torch.nn import functional as F
23 from mlcolvar.core.loss.mse import mse_loss
24  
25  
26 # =============================================================================
27 # LOSS FUNCTIONS
28 # =============================================================================
29  
30  
31 class ELBOGaussiansLoss(torch.nn.Module):
32 """ELBO loss function assuming the latent and reconstruction distributions are Gaussian.
33  
34 The ELBO uses the MSE as the reconstruction loss (i.e., assumes that the
35 decoder outputs the mean of a Gaussian distribution with variance 1), and
36 the KL divergence between two normal distributions ``N(mean, var)`` and
37 ``N(0, 1)``, where ``mean`` and ``var`` are the output of the encoder.
38 """
39  
40 def forward(
41 self,
42 target: torch.Tensor,
43 output: torch.Tensor,
44 mean: torch.Tensor,
45 log_variance: torch.Tensor,
46 weights: Optional[torch.Tensor] = None,
47 ) -> torch.Tensor:
48 """Compute the value of the loss function.
49  
50 Parameters
51 ----------
52 target : torch.Tensor
53 Shape ``(n_batches, in_features)``. Data points (e.g. input of encoder
54 or time-lagged features).
55 output : torch.Tensor
56 Shape ``(n_batches, in_features)``. Output of the decoder.
57 mean : torch.Tensor
58 Shape ``(n_batches, latent_features)``. The means of the Gaussian
59 distributions associated to the inputs.
60 log_variance : torch.Tensor
61 Shape ``(n_batches, latent_features)``. The logarithm of the variances
62 of the Gaussian distributions associated to the inputs.
63 weights : torch.Tensor, optional
64 Shape ``(n_batches,)`` or ``(n_batches,1)``. If given, the average over
65 batches is weighted. The default (``None``) is unweighted.
66  
67 Returns
68 -------
69 loss: torch.Tensor
70 The value of the loss function.
71 """
72 return elbo_gaussians_loss(target, output, mean, log_variance, weights)
73  
74  
75 def elbo_gaussians_loss(
76 target: torch.Tensor,
77 output: torch.Tensor,
78 mean: torch.Tensor,
79 log_variance: torch.Tensor,
80 weights: Optional[torch.Tensor] = None,
81 ) -> torch.Tensor:
82 """ELBO loss function assuming the latent and reconstruction distributions are Gaussian.
83  
84 The ELBO uses the MSE as the reconstruction loss (i.e., assumes that the
85 decoder outputs the mean of a Gaussian distribution with variance 1), and
86 the KL divergence between two normal distributions ``N(mean, var)`` and
87 ``N(0, 1)``, where ``mean`` and ``var`` are the output of the encoder.
88  
89 Parameters
90 ----------
91 target : torch.Tensor
92 Shape ``(n_batches, in_features)``. Data points (e.g. input of encoder
93 or time-lagged features).
94 output : torch.Tensor
95 Shape ``(n_batches, in_features)``. Output of the decoder.
96 mean : torch.Tensor
97 Shape ``(n_batches, latent_features)``. The means of the Gaussian
98 distributions associated to the inputs.
99 log_variance : torch.Tensor
100 Shape ``(n_batches, latent_features)``. The logarithm of the variances
101 of the Gaussian distributions associated to the inputs.
102 weights : torch.Tensor, optional
103 Shape ``(n_batches,)`` or ``(n_batches,1)``. If given, the average over
104 batches is weighted. The default (``None``) is unweighted.
105  
106 Returns
107 -------
108 loss: torch.Tensor
109 The value of the loss function.
110 """
111 # KL divergence between N(mean, variance) and N(0, 1).
112 # See https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
113 kl = -0.5 * (log_variance - log_variance.exp() - mean**2 + 1).sum(dim=1)
114  
115 # Weighted mean over batches.
116 if weights is None:
117 kl = kl.mean()
118 else:
119 weights = weights.squeeze()
120 if weights.shape != kl.shape:
121 raise ValueError(
122 f"weights should be a tensor of shape (n_batches,) or (n_batches,1), not {weights.shape}."
123 )
124 kl = (kl * weights).sum()
125  
126 # Reconstruction loss.
127 reconstruction = mse_loss(output, target, weights=weights)
128  
129 return reconstruction + kl
130  
131  
132 class ELBOLoss(nn.Module):
133 """
134 Variational Autoencoder ELBO loss function.
135  
136 Implements the evidence lower bound (ELBO) objective:
137 L = reconstruction_loss + beta * KL_divergence
138  
139 Reconstruction loss options:
140 - Mean-squared error (MSE) -> assumes Gaussian decoder with unit variance
141 - Binary cross-entropy (BCE) -> assumes Bernoulli decoder
142  
143 KL divergence is computed analytically between the approximate posterior
144 q(z|x) = N(mu, diag(var)) and the prior p(z) = N(0, I):
145 KL(q||p) = -0.5 * sum(1 + log(var) - mu^2 - var)
146  
147 Parameters
148 ----------
149 beta : float, default=1.0
150 Scaling factor for the KL divergence term (beta-VAE).
151 loss_type : {'mse', 'bce'}, default='mse'
152 Type of reconstruction loss:
153 - 'mse': use mean squared error
154 - 'bce': use binary cross-entropy
155 reduction : {'sum', 'mean', 'none'}, default='sum'
156 How to reduce the reconstruction loss over elements:
157 - 'sum': sum over all elements
158 - 'mean': average over all elements
159 - 'none': no reduction (returns per-element loss)
160 """
161  
162 def __init__(
163 self,
164 beta: float = 1.0,
165 reconstruction: str = 'mse',
166 reduction: str = 'sum'
167 ):
168 super().__init__()
169 if reconstruction not in {'mse', 'bce'}:
170 raise ValueError(f"Unsupported reconstruction '{reconstruction}', choose 'mse' or 'bce'.")
171 if reduction not in {'sum', 'mean', 'none'}:
172 raise ValueError(f"Unsupported reduction '{reduction}', choose 'sum', 'mean', or 'none'.")
173  
174 self.beta = beta
175 self.reconstruction = reconstruction
176 self.reduction = reduction
177  
178 def forward(
179 self,
180 x: torch.Tensor,
181 recon_x: torch.Tensor,
182 mu: torch.Tensor,
183 log_var: torch.Tensor
184 ) -> torch.Tensor:
185 """
186 Compute the combined ELBO loss.
187  
188 Parameters
189 ----------
190 x : Tensor
191 Original input tensor (shape: [batch_size, ...]).
192 recon_x : Tensor
193 Reconstructed output tensor (same shape as x).
194 mu : Tensor
195 Mean of the approximate posterior q(z|x) (shape: [batch_size, latent_dim]).
196 log_var : Tensor
197 Log-variance of q(z|x) (same shape as mu).
198  
199 Returns
200 -------
201 loss : Tensor
202 Scalar loss (if reduction!='none') or tensor of per-element losses.
203 """
204 # Reconstruction loss
205 if self.reconstruction == 'bce':
206 # For binary data, use BCE
207 recon_loss = F.binary_cross_entropy(
208 recon_x, x, reduction=self.reduction
209 )
210 else:
211 # For continuous data, use MSE
212 recon_loss = F.mse_loss(
213 recon_x, x, reduction=self.reduction
214 )
215  
216 # Analytic KL divergence between N(mu, var) and N(0, I)
217 # var = exp(log_var)
218 var = torch.exp(log_var)
219 kl_div = -0.5 * torch.sum(
220 1 + log_var - mu.pow(2) - var,
221 dim=1 # sum over latent dimension for each sample
222 )
223  
224 # Combine terms: sum or mean over batch
225 if self.reduction == 'mean':
226 kl_div = kl_div.mean()
227 elif self.reduction == 'sum':
228 kl_div = kl_div.sum()
229 # else 'none': keep per-sample KL vector
230  
231 # Scale KL and add reconstruction
232 return recon_loss + self.beta * kl_div
233  
234  
235 class ELBOGaussianMixtureLoss(nn.Module):
236 """
237 Gaussian Mixture VAE loss.
238  
239 Combines:
240 1) Entropy regularization: -∑_i q(y=i|x) log q(y=i|x)
241 2) Reconstruction + KL:
242 - E_{q(y|x)} [ log p(x|z,y) ]
243 + E_{q(y|x)} [ KL( q(z|x,y) ‖ p(z|y) ) ]
244 """
245  
246 def __init__(self, k: int, r_nent: float = 1.0):
247 """
248 Args:
249 k Number of mixture components.
250 r_nent Weight on the entropy term.
251 """
252 super().__init__()
253 self.k = k
254 self.r_nent = r_nent
255  
256 @staticmethod
257 def log_normal(x: torch.Tensor,
258 mu: torch.Tensor,
259 var: torch.Tensor,
260 eps: float = 1e-10) -> torch.Tensor:
261 """
262 Compute log N(x; mu, var) summed over the last dim:
263 -½ ∑ [ log(2π) + (x−μ)^2 / var + log var ]
264 """
265 const = math.log(2 * math.pi)
266 return -0.5 * torch.sum(
267 const + (x - mu).pow(2) / (var + eps) + var.log(),
268 dim=-1
269 )
270  
271 def forward(self,
272 x: torch.Tensor,
273 qy_logit: torch.Tensor,
274 xm_list: list[torch.Tensor],
275 xv_list: list[torch.Tensor],
276 z_list: list[torch.Tensor],
277 zm_list: list[torch.Tensor],
278 zv_list: list[torch.Tensor],
279 zm_prior_list: list[torch.Tensor],
280 zv_prior_list: list[torch.Tensor]
281 ) -> torch.Tensor:
282 """
283 Args:
284 x [batch, n_features] Input data
285 qy_logit [batch, k] Cluster logits
286 xm_list, xv_list length-k lists of [batch, n_features]
287 z_list, zm_list, zv_list length-k lists of [batch, n_cvs]
288 zm_prior_list, zv_prior_list length-k lists of [batch, n_cvs]
289 Returns:
290 scalar loss = mean_batch( r_nent*nent + ∑_i qy_i * [rec_i + KL_i] )
291 """
292 # 1) cluster posteriors
293 qy = F.softmax(qy_logit, dim=1) # [batch, k]
294  
295 # 2) entropy regularization (cross-entropy of qy wrt itself)
296 # nent = -E[ log q(y|x) ]
297 nent = -torch.sum(qy * F.log_softmax(qy_logit, dim=1), dim=1).mean()
298  
299 # 3) per-component reconstruction + KL
300 comp_losses = []
301 for i in range(self.k):
302 # reconstruction: - log p(x | z_i)
303 rec_i = -self.log_normal(x, xm_list[i], xv_list[i])
304 # KL divergence: KL( q(z|x,y=i) ‖ p(z|y=i) )
305 kl_i = (
  • W504 Line break after binary operator
306 self.log_normal(z_list[i], zm_list[i], zv_list[i]) -
307 self.log_normal(z_list[i], zm_prior_list[i], zv_prior_list[i])
308 )
309 comp_losses.append(rec_i + kl_i) # shape [batch]
310  
311 # 4) weight each comp by qy[:,i] and sum
312 weighted = [qy[:, i] * comp_losses[i] for i in range(self.k)]
313 total = self.r_nent * nent + sum(weighted) # shape [batch]
314  
315 return total, nent