1 #!/usr/bin/env python
2
3 import torch
4 import torch.nn as nn
5
6
7 class InformationBottleneckLoss(nn.Module):
8 """
9 Loss = reconstruction_error + beta * KL[q(z|x) || p(z)]
10
11 Where p(z) is modeled as a mixture over representative_z (means/logvars),
12 weighted by representative_weights(idle_input).
13 """
14
15 def __init__(
16 self,
17 beta: float = 1.0,
18 eps: float = 1e-8,
19 ):
20 super().__init__()
21 self.beta = beta
22 self.eps = eps
23 self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
25 def log_p(
26 self,
27 z: torch.Tensor, # [B] or [B,1]
28 rep_mean: torch.Tensor, # [R]
29 rep_logvar: torch.Tensor, # [R]
30 w: torch.Tensor, # [R] or [R,1]
31 sum_up: bool = True
32 ) -> torch.Tensor:
33 """
34 Compute log p(z) under the mixture prior.
35
36 Args:
37 z (Tensor[B] or [B,1]): latent samples
38 rep_mean (Tensor[R]) : mixture means
39 rep_logvar(Tensor[R]) : mixture log‐vars
40 w (Tensor[R] or [R,1]) : mixture weights
41 sum_up (bool) : if True, returns [B] log‐density;
42 if False, returns [B,R] component‐wise log‐probs
43 Returns:
44 Tensor[B] if sum_up else Tensor[B,R]
45 """
46
47 z_expand = z.unsqueeze(1)
48 mu = rep_mean.unsqueeze(0)
49 lv = rep_logvar.unsqueeze(0)
50
-
W504
Line break after binary operator
51 representative_log_q = -0.5 * torch.sum(lv + torch.pow(z_expand - mu, 2) /
52 torch.exp(lv), dim=2)
53
54 if sum_up:
55 log_p = torch.sum(torch.log(torch.exp(representative_log_q) @ w + self.eps), dim=1)
56 else:
57 log_p = torch.log(torch.exp(representative_log_q) * w.T + self.eps)
58
59 return log_p
60
61 def forward(
62 self,
63 data_targets: torch.Tensor, # [B, C_out]
64 outputs: torch.Tensor, # [B, C_out], log‐probs
65 z_sample: torch.Tensor, # [B]
66 z_mean: torch.Tensor, # [B]
67 z_logvar: torch.Tensor, # [B]
68 rep_mean: torch.Tensor, # [R]
69 rep_logvar: torch.Tensor, # [R]
70 w: torch.Tensor, # [R] or [R,1]
71 data_weights: torch.Tensor = None,
72 sum_up: bool = True,
73 ):
74 """
75 Computes:
76 rec_err = E_q[−log p(x|z)]
77 kld = E_q[ log q(z|x) − log p(z) ]
78 Returns:
79 loss, rec_err (scalar), kl_term (scalar)
80 """
81
82 # --- RECONSTRUCTION ---
83 # cross‐entropy per sample: [B]
84 ce = torch.sum(-data_targets * outputs, dim=1)
85 rec_err = torch.mean(ce * data_weights) if data_weights is not None else ce.mean()
86
87 # --- KL TERM ---
88 # log q(z|x): -½ ∑[logvar + (z−mean)² / exp(logvar)]
89 log_q = -0.5 * (z_logvar + (z_sample - z_mean).pow(2).div(z_logvar.exp()))
90 # log p(z): mixture prior
91 log_p = self.log_p(z_sample, rep_mean, rep_logvar, w, sum_up=sum_up)
92
93 # per‐sample KL
94 kld = log_q - log_p
95 kl_term = (kld * data_weights).mean() if data_weights is not None else kld.mean()
96
97 loss = rec_err + self.beta * kl_term
98 return loss, rec_err, kl_term
99 # class InformationBottleneckLoss(nn.Module):
100 # """
101 # Loss = reconstruction_error + beta * KL[q(z|x) || p(z)]
102 # Where p(z) is a mixture over representative_z (means/logvars),
103 # weighted by representative_weights(idle_input).
104 # """
105 # def __init__(self, beta: float = 1.0, eps: float = 1e-8):
106 # super().__init__()
107 # self.beta = beta
108 # self.eps = eps
109 # self.device = "cuda" if torch.cuda.is_available() else "cpu"
110
111 # def log_p(
112 # self,
113 # z: torch.Tensor, # [B, n_cvs]
114 # rep_mean: torch.Tensor, # [k, n_cvs]
115 # rep_logvar: torch.Tensor, # [k, n_cvs]
116 # w: torch.Tensor, # [k, 1]
117 # ) -> torch.Tensor:
118 # """
119 # Compute log p(z) under the mixture prior.
120
121 # Args:
122 # z: (Tensor[B, n_cvs]) latent samples
123 # rep_mean: (Tensor[k, n_cvs]) mixture means
124 # rep_logvar: (Tensor[k, n_cvs]) mixture log-variances
125 # w: (Tensor[k, 1]) mixture weights
126
127 # Returns:
128 # Tensor[B]
129 # """
130 # batch_size, n_cvs = z.shape
131 # k = rep_mean.shape[0]
132
133 # # Expand dimensions for broadcasting
134 # z_expand = z.unsqueeze(1) # [B, 1, n_cvs]
135 # mu = rep_mean.unsqueeze(0) # [1, k, n_cvs]
136 # lv = rep_logvar.unsqueeze(0) # [1, k, n_cvs]
137
138 # var = torch.exp(lv)
139
140 # # Log-probability per dimension per component
141 # log_prob_per_dim = -0.5 * (math.log(2 * math.pi) + lv + ((z_expand - mu) ** 2) / var) # [B, k, n_cvs]
142
143 # # Sum over dimensions to get log_prob per component
144 # log_prob_comp = log_prob_per_dim.sum(dim=2) # [B, k]
145
146 # # Add log-weights
147 # log_w = torch.log(w + self.eps).squeeze(-1) # [k]
148 # log_prob_comp += log_w.unsqueeze(0) # [B, k]
149
150 # # Marginalize over components via logsumexp
151 # log_p = torch.logsumexp(log_prob_comp, dim=1) # [B]
152
153 # return log_p
154
155 # def forward(
156 # self,
157 # data_targets: torch.Tensor, # [B, C_out]
158 # outputs: torch.Tensor, # [B, C_out], log-probs
159 # z_sample: torch.Tensor, # [B, n_cvs]
160 # z_mean: torch.Tensor, # [B, n_cvs]
161 # z_logvar: torch.Tensor, # [B, n_cvs]
162 # rep_mean: torch.Tensor, # [k, n_cvs]
163 # rep_logvar: torch.Tensor, # [k, n_cvs]
164 # w: torch.Tensor, # [k, 1]
165 # data_weights: torch.Tensor = None, # [B] or None
166 # ):
167 # """
168 # Computes:
169 # rec_err = E_q[-log p(x|z)]
170 # kld = E_q[ log q(z|x) - log p(z) ]
171 # Returns:
172 # loss, rec_err (scalar), kl_term (scalar)
173 # """
174 # # Reconstruction error: cross-entropy per sample
175 # ce = torch.sum(-data_targets * outputs, dim=1) # [B]
176 # rec_err = (ce * data_weights).mean() if data_weights is not None else ce.mean()
177
178 # # KL term: log q(z|x) - log p(z)
179 # # log q(z|x): full multivariate diagonal Gaussian log-prob
180 # log_q_per_dim = -0.5 * (math.log(2 * math.pi) + z_logvar + ((z_sample - z_mean) ** 2) / z_logvar.exp()) # [B, n_cvs]
181 # log_q = log_q_per_dim.sum(dim=1) # [B]
182
183 # # log p(z): mixture prior
184 # log_p = self.log_p(z_sample, rep_mean, rep_logvar, w) # [B]
185
186 # # Per-sample KL
187 # kld = log_q - log_p # [B]
188 # kl_term = (kld * data_weights).mean() if data_weights is not None else kld.mean()
189
190 # loss = rec_err + self.beta * kl_term
191 # return loss, rec_err, kl_term