Coverage for biobb_pytorch / mdae / loss / ib_loss.py: 96%

26 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1#!/usr/bin/env python 

2 

3import torch 

4import torch.nn as nn 

5 

6 

7class InformationBottleneckLoss(nn.Module): 

8 """ 

9 Loss = reconstruction_error + beta * KL[q(z|x) || p(z)] 

10 

11 Where p(z) is modeled as a mixture over representative_z (means/logvars), 

12 weighted by representative_weights(idle_input). 

13 """ 

14 

15 def __init__( 

16 self, 

17 beta: float = 1.0, 

18 eps: float = 1e-8, 

19 ): 

20 super().__init__() 

21 self.beta = beta 

22 self.eps = eps 

23 self.device = "cuda" if torch.cuda.is_available() else "cpu" 

24 

25 def log_p( 

26 self, 

27 z: torch.Tensor, # [B] or [B,1] 

28 rep_mean: torch.Tensor, # [R] 

29 rep_logvar: torch.Tensor, # [R] 

30 w: torch.Tensor, # [R] or [R,1] 

31 sum_up: bool = True 

32 ) -> torch.Tensor: 

33 """ 

34 Compute log p(z) under the mixture prior. 

35 

36 Args: 

37 z (Tensor[B] or [B,1]): latent samples 

38 rep_mean (Tensor[R]) : mixture means 

39 rep_logvar(Tensor[R]) : mixture log‐vars 

40 w (Tensor[R] or [R,1]) : mixture weights 

41 sum_up (bool) : if True, returns [B] log‐density; 

42 if False, returns [B,R] component‐wise log‐probs 

43 Returns: 

44 Tensor[B] if sum_up else Tensor[B,R] 

45 """ 

46 

47 z_expand = z.unsqueeze(1) 

48 mu = rep_mean.unsqueeze(0) 

49 lv = rep_logvar.unsqueeze(0) 

50 

51 representative_log_q = -0.5 * torch.sum(lv + torch.pow(z_expand - mu, 2) / 

52 torch.exp(lv), dim=2) 

53 

54 if sum_up: 

55 log_p = torch.sum(torch.log(torch.exp(representative_log_q) @ w + self.eps), dim=1) 

56 else: 

57 log_p = torch.log(torch.exp(representative_log_q) * w.T + self.eps) 

58 

59 return log_p 

60 

61 def forward( 

62 self, 

63 data_targets: torch.Tensor, # [B, C_out] 

64 outputs: torch.Tensor, # [B, C_out], log‐probs 

65 z_sample: torch.Tensor, # [B] 

66 z_mean: torch.Tensor, # [B] 

67 z_logvar: torch.Tensor, # [B] 

68 rep_mean: torch.Tensor, # [R] 

69 rep_logvar: torch.Tensor, # [R] 

70 w: torch.Tensor, # [R] or [R,1] 

71 data_weights: torch.Tensor = None, 

72 sum_up: bool = True, 

73 ): 

74 """ 

75 Computes: 

76 rec_err = E_q[−log p(x|z)] 

77 kld = E_q[ log q(z|x) − log p(z) ] 

78 Returns: 

79 loss, rec_err (scalar), kl_term (scalar) 

80 """ 

81 

82 # --- RECONSTRUCTION --- 

83 # cross‐entropy per sample: [B] 

84 ce = torch.sum(-data_targets * outputs, dim=1) 

85 rec_err = torch.mean(ce * data_weights) if data_weights is not None else ce.mean() 

86 

87 # --- KL TERM --- 

88 # log q(z|x): -½ ∑[logvar + (z−mean)² / exp(logvar)] 

89 log_q = -0.5 * (z_logvar + (z_sample - z_mean).pow(2).div(z_logvar.exp())) 

90 # log p(z): mixture prior 

91 log_p = self.log_p(z_sample, rep_mean, rep_logvar, w, sum_up=sum_up) 

92 

93 # per‐sample KL 

94 kld = log_q - log_p 

95 kl_term = (kld * data_weights).mean() if data_weights is not None else kld.mean() 

96 

97 loss = rec_err + self.beta * kl_term 

98 return loss, rec_err, kl_term 

99# class InformationBottleneckLoss(nn.Module): 

100# """ 

101# Loss = reconstruction_error + beta * KL[q(z|x) || p(z)] 

102# Where p(z) is a mixture over representative_z (means/logvars), 

103# weighted by representative_weights(idle_input). 

104# """ 

105# def __init__(self, beta: float = 1.0, eps: float = 1e-8): 

106# super().__init__() 

107# self.beta = beta 

108# self.eps = eps 

109# self.device = "cuda" if torch.cuda.is_available() else "cpu" 

110 

111# def log_p( 

112# self, 

113# z: torch.Tensor, # [B, n_cvs] 

114# rep_mean: torch.Tensor, # [k, n_cvs] 

115# rep_logvar: torch.Tensor, # [k, n_cvs] 

116# w: torch.Tensor, # [k, 1] 

117# ) -> torch.Tensor: 

118# """ 

119# Compute log p(z) under the mixture prior. 

120 

121# Args: 

122# z: (Tensor[B, n_cvs]) latent samples 

123# rep_mean: (Tensor[k, n_cvs]) mixture means 

124# rep_logvar: (Tensor[k, n_cvs]) mixture log-variances 

125# w: (Tensor[k, 1]) mixture weights 

126 

127# Returns: 

128# Tensor[B] 

129# """ 

130# batch_size, n_cvs = z.shape 

131# k = rep_mean.shape[0] 

132 

133# # Expand dimensions for broadcasting 

134# z_expand = z.unsqueeze(1) # [B, 1, n_cvs] 

135# mu = rep_mean.unsqueeze(0) # [1, k, n_cvs] 

136# lv = rep_logvar.unsqueeze(0) # [1, k, n_cvs] 

137 

138# var = torch.exp(lv) 

139 

140# # Log-probability per dimension per component 

141# log_prob_per_dim = -0.5 * (math.log(2 * math.pi) + lv + ((z_expand - mu) ** 2) / var) # [B, k, n_cvs] 

142 

143# # Sum over dimensions to get log_prob per component 

144# log_prob_comp = log_prob_per_dim.sum(dim=2) # [B, k] 

145 

146# # Add log-weights 

147# log_w = torch.log(w + self.eps).squeeze(-1) # [k] 

148# log_prob_comp += log_w.unsqueeze(0) # [B, k] 

149 

150# # Marginalize over components via logsumexp 

151# log_p = torch.logsumexp(log_prob_comp, dim=1) # [B] 

152 

153# return log_p 

154 

155# def forward( 

156# self, 

157# data_targets: torch.Tensor, # [B, C_out] 

158# outputs: torch.Tensor, # [B, C_out], log-probs 

159# z_sample: torch.Tensor, # [B, n_cvs] 

160# z_mean: torch.Tensor, # [B, n_cvs] 

161# z_logvar: torch.Tensor, # [B, n_cvs] 

162# rep_mean: torch.Tensor, # [k, n_cvs] 

163# rep_logvar: torch.Tensor, # [k, n_cvs] 

164# w: torch.Tensor, # [k, 1] 

165# data_weights: torch.Tensor = None, # [B] or None 

166# ): 

167# """ 

168# Computes: 

169# rec_err = E_q[-log p(x|z)] 

170# kld = E_q[ log q(z|x) - log p(z) ] 

171# Returns: 

172# loss, rec_err (scalar), kl_term (scalar) 

173# """ 

174# # Reconstruction error: cross-entropy per sample 

175# ce = torch.sum(-data_targets * outputs, dim=1) # [B] 

176# rec_err = (ce * data_weights).mean() if data_weights is not None else ce.mean() 

177 

178# # KL term: log q(z|x) - log p(z) 

179# # log q(z|x): full multivariate diagonal Gaussian log-prob 

180# log_q_per_dim = -0.5 * (math.log(2 * math.pi) + z_logvar + ((z_sample - z_mean) ** 2) / z_logvar.exp()) # [B, n_cvs] 

181# log_q = log_q_per_dim.sum(dim=1) # [B] 

182 

183# # log p(z): mixture prior 

184# log_p = self.log_p(z_sample, rep_mean, rep_logvar, w) # [B] 

185 

186# # Per-sample KL 

187# kld = log_q - log_p # [B] 

188# kl_term = (kld * data_weights).mean() if data_weights is not None else kld.mean() 

189 

190# loss = rec_err + self.beta * kl_term 

191# return loss, rec_err, kl_term