Coverage for biobb_pytorch / mdae / loss / ib_loss.py: 96%
26 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1#!/usr/bin/env python
3import torch
4import torch.nn as nn
7class InformationBottleneckLoss(nn.Module):
8 """
9 Loss = reconstruction_error + beta * KL[q(z|x) || p(z)]
11 Where p(z) is modeled as a mixture over representative_z (means/logvars),
12 weighted by representative_weights(idle_input).
13 """
15 def __init__(
16 self,
17 beta: float = 1.0,
18 eps: float = 1e-8,
19 ):
20 super().__init__()
21 self.beta = beta
22 self.eps = eps
23 self.device = "cuda" if torch.cuda.is_available() else "cpu"
25 def log_p(
26 self,
27 z: torch.Tensor, # [B] or [B,1]
28 rep_mean: torch.Tensor, # [R]
29 rep_logvar: torch.Tensor, # [R]
30 w: torch.Tensor, # [R] or [R,1]
31 sum_up: bool = True
32 ) -> torch.Tensor:
33 """
34 Compute log p(z) under the mixture prior.
36 Args:
37 z (Tensor[B] or [B,1]): latent samples
38 rep_mean (Tensor[R]) : mixture means
39 rep_logvar(Tensor[R]) : mixture log‐vars
40 w (Tensor[R] or [R,1]) : mixture weights
41 sum_up (bool) : if True, returns [B] log‐density;
42 if False, returns [B,R] component‐wise log‐probs
43 Returns:
44 Tensor[B] if sum_up else Tensor[B,R]
45 """
47 z_expand = z.unsqueeze(1)
48 mu = rep_mean.unsqueeze(0)
49 lv = rep_logvar.unsqueeze(0)
51 representative_log_q = -0.5 * torch.sum(lv + torch.pow(z_expand - mu, 2) /
52 torch.exp(lv), dim=2)
54 if sum_up:
55 log_p = torch.sum(torch.log(torch.exp(representative_log_q) @ w + self.eps), dim=1)
56 else:
57 log_p = torch.log(torch.exp(representative_log_q) * w.T + self.eps)
59 return log_p
61 def forward(
62 self,
63 data_targets: torch.Tensor, # [B, C_out]
64 outputs: torch.Tensor, # [B, C_out], log‐probs
65 z_sample: torch.Tensor, # [B]
66 z_mean: torch.Tensor, # [B]
67 z_logvar: torch.Tensor, # [B]
68 rep_mean: torch.Tensor, # [R]
69 rep_logvar: torch.Tensor, # [R]
70 w: torch.Tensor, # [R] or [R,1]
71 data_weights: torch.Tensor = None,
72 sum_up: bool = True,
73 ):
74 """
75 Computes:
76 rec_err = E_q[−log p(x|z)]
77 kld = E_q[ log q(z|x) − log p(z) ]
78 Returns:
79 loss, rec_err (scalar), kl_term (scalar)
80 """
82 # --- RECONSTRUCTION ---
83 # cross‐entropy per sample: [B]
84 ce = torch.sum(-data_targets * outputs, dim=1)
85 rec_err = torch.mean(ce * data_weights) if data_weights is not None else ce.mean()
87 # --- KL TERM ---
88 # log q(z|x): -½ ∑[logvar + (z−mean)² / exp(logvar)]
89 log_q = -0.5 * (z_logvar + (z_sample - z_mean).pow(2).div(z_logvar.exp()))
90 # log p(z): mixture prior
91 log_p = self.log_p(z_sample, rep_mean, rep_logvar, w, sum_up=sum_up)
93 # per‐sample KL
94 kld = log_q - log_p
95 kl_term = (kld * data_weights).mean() if data_weights is not None else kld.mean()
97 loss = rec_err + self.beta * kl_term
98 return loss, rec_err, kl_term
99# class InformationBottleneckLoss(nn.Module):
100# """
101# Loss = reconstruction_error + beta * KL[q(z|x) || p(z)]
102# Where p(z) is a mixture over representative_z (means/logvars),
103# weighted by representative_weights(idle_input).
104# """
105# def __init__(self, beta: float = 1.0, eps: float = 1e-8):
106# super().__init__()
107# self.beta = beta
108# self.eps = eps
109# self.device = "cuda" if torch.cuda.is_available() else "cpu"
111# def log_p(
112# self,
113# z: torch.Tensor, # [B, n_cvs]
114# rep_mean: torch.Tensor, # [k, n_cvs]
115# rep_logvar: torch.Tensor, # [k, n_cvs]
116# w: torch.Tensor, # [k, 1]
117# ) -> torch.Tensor:
118# """
119# Compute log p(z) under the mixture prior.
121# Args:
122# z: (Tensor[B, n_cvs]) latent samples
123# rep_mean: (Tensor[k, n_cvs]) mixture means
124# rep_logvar: (Tensor[k, n_cvs]) mixture log-variances
125# w: (Tensor[k, 1]) mixture weights
127# Returns:
128# Tensor[B]
129# """
130# batch_size, n_cvs = z.shape
131# k = rep_mean.shape[0]
133# # Expand dimensions for broadcasting
134# z_expand = z.unsqueeze(1) # [B, 1, n_cvs]
135# mu = rep_mean.unsqueeze(0) # [1, k, n_cvs]
136# lv = rep_logvar.unsqueeze(0) # [1, k, n_cvs]
138# var = torch.exp(lv)
140# # Log-probability per dimension per component
141# log_prob_per_dim = -0.5 * (math.log(2 * math.pi) + lv + ((z_expand - mu) ** 2) / var) # [B, k, n_cvs]
143# # Sum over dimensions to get log_prob per component
144# log_prob_comp = log_prob_per_dim.sum(dim=2) # [B, k]
146# # Add log-weights
147# log_w = torch.log(w + self.eps).squeeze(-1) # [k]
148# log_prob_comp += log_w.unsqueeze(0) # [B, k]
150# # Marginalize over components via logsumexp
151# log_p = torch.logsumexp(log_prob_comp, dim=1) # [B]
153# return log_p
155# def forward(
156# self,
157# data_targets: torch.Tensor, # [B, C_out]
158# outputs: torch.Tensor, # [B, C_out], log-probs
159# z_sample: torch.Tensor, # [B, n_cvs]
160# z_mean: torch.Tensor, # [B, n_cvs]
161# z_logvar: torch.Tensor, # [B, n_cvs]
162# rep_mean: torch.Tensor, # [k, n_cvs]
163# rep_logvar: torch.Tensor, # [k, n_cvs]
164# w: torch.Tensor, # [k, 1]
165# data_weights: torch.Tensor = None, # [B] or None
166# ):
167# """
168# Computes:
169# rec_err = E_q[-log p(x|z)]
170# kld = E_q[ log q(z|x) - log p(z) ]
171# Returns:
172# loss, rec_err (scalar), kl_term (scalar)
173# """
174# # Reconstruction error: cross-entropy per sample
175# ce = torch.sum(-data_targets * outputs, dim=1) # [B]
176# rec_err = (ce * data_weights).mean() if data_weights is not None else ce.mean()
178# # KL term: log q(z|x) - log p(z)
179# # log q(z|x): full multivariate diagonal Gaussian log-prob
180# log_q_per_dim = -0.5 * (math.log(2 * math.pi) + z_logvar + ((z_sample - z_mean) ** 2) / z_logvar.exp()) # [B, n_cvs]
181# log_q = log_q_per_dim.sum(dim=1) # [B]
183# # log p(z): mixture prior
184# log_p = self.log_p(z_sample, rep_mean, rep_logvar, w) # [B]
186# # Per-sample KL
187# kld = log_q - log_p # [B]
188# kl_term = (kld * data_weights).mean() if data_weights is not None else kld.mean()
190# loss = rec_err + self.beta * kl_term
191# return loss, rec_err, kl_term