Coverage for biobb_pytorch / mdae / explainability / layerwise_relevance_prop.py: 36%

162 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import torch 

2import torch.nn as nn 

3import torch.nn.functional as F 

4 

5 

6def lrp_gmvae_single(model, x, latent_index, eps=1e-6): 

7 device = x.device 

8 batch_size = x.shape[0] 

9 in_features = model.in_features 

10 k = model.k 

11 

12 has_norm = model.norm_in is not None 

13 if has_norm: 

14 x_input = model.norm_in(x) 

15 else: 

16 x_input = x 

17 

18 qy_logit = model.encoder['qy_nn'](x_input) 

19 qy = torch.softmax(qy_logit, dim=1) 

20 

21 y_ = torch.zeros(batch_size, k, device=device) 

22 

23 zm_list = [] 

24 intermediates_list = [] 

25 

26 for i in range(k): 

27 y = y_ + torch.eye(k, device=device)[i] 

28 

29 # y_transform 

30 module = model.encoder['y_transform'] 

31 z_h0 = y @ module.weight.t() + module.bias 

32 a_h0 = z_h0 # no activation 

33 

34 intermediates = [('linear', module, y, z_h0)] 

35 

36 xy = torch.cat([x_input, a_h0], dim=1) 

37 

38 a = xy 

39 

40 # qz_nn 

41 for sub_module in model.encoder['qz_nn'].children(): 

42 if isinstance(sub_module, nn.Linear): 

43 z = a @ sub_module.weight.t() + sub_module.bias 

44 intermediates.append(('linear', sub_module, a, z)) 

45 a = z 

46 elif isinstance(sub_module, nn.ReLU): 

47 a = F.relu(a) 

48 

49 # zm_layer 

50 module = model.encoder['zm_layer'] 

51 z = a @ module.weight.t() + module.bias 

52 intermediates.append(('linear', module, a, z)) 

53 

54 zm = z 

55 zm_list.append(zm) 

56 

57 intermediates_list.append(intermediates) 

58 

59 zm = torch.stack(zm_list, dim=1) # [batch, k, n_cvs] 

60 

61 term_k = qy * zm[:, :, latent_index] # [batch, k] 

62 

63 selected_a = torch.sum(term_k, dim=1) # [batch] 

64 

65 R_a = selected_a.unsqueeze(1) # [batch,1] 

66 

67 sign_a = selected_a.sign().unsqueeze(1) 

68 

69 den = selected_a.unsqueeze(1) + eps * sign_a 

70 

71 R_term_k = R_a * (term_k.unsqueeze(2) / den) # [batch, k,1] 

72 

73 R0 = torch.zeros(batch_size, in_features, device=device) 

74 

75 for i in range(k): 

76 R = R_term_k[:, i, :] # [batch,1] 

77 

78 intermediates = intermediates_list[i][::-1] 

79 

80 for op, module, a_prev, z_prev in intermediates: 

81 if op == 'linear': 

82 sign_z = z_prev.sign() 

83 Z = z_prev + eps * sign_z 

84 s = R / Z 

85 c = s @ module.weight 

86 R = a_prev * c 

87 

88 # R is now R for xy 

89 R_xy = R 

90 R_x_k = R_xy[:, :in_features] 

91 R0 += R_x_k 

92 

93 if has_norm: 

94 w = (1 / model.norm_in.range).view(1, -1).to(device) 

95 b = (-model.norm_in.mean / model.norm_in.range).view(1, -1).to(device) 

96 z = x * w + b 

97 sign_z = z.sign() 

98 Z = z + eps * sign_z 

99 s = R0 / Z 

100 c = s * w 

101 R0 = x * c 

102 

103 return R0 

104 

105 

106def lrp_encoder( 

107 model: nn.Module, 

108 x: torch.Tensor, 

109 latent_index: int = None, 

110 eps: float = 1e-6 

111) -> torch.Tensor: 

112 """Main LRP function that handles both full models and encoder modules.""" 

113 # Check if it's a GMVAE model 

114 if model.__class__.__name__ == 'GaussianMixtureVariationalAutoEncoder': 

115 if latent_index is None: 

116 R0 = 0 

117 for j in range(model.out_features): 

118 R0 += lrp_gmvae_single(model, x, j, eps) 

119 return R0 

120 else: 

121 return lrp_gmvae_single(model, x, latent_index, eps) 

122 

123 # Check if model has forward_cv method (full model with encoder) 

124 if hasattr(model, 'forward_cv'): 

125 # General case for MLP models with forward_cv 

126 handles = [] 

127 layers = [] 

128 

129 def collect_hook(module, inp, out): 

130 layers.append(module) 

131 

132 for m in model.modules(): 

133 if isinstance(m, nn.Linear): 

134 handle = m.register_forward_hook(collect_hook) 

135 handles.append(handle) 

136 

137 model.eval() 

138 with torch.no_grad(): 

139 _ = model.forward_cv(x[:1]) 

140 

141 for h in handles: 

142 h.remove() 

143 

144 if len(layers) == 0: 

145 raise ValueError("No Linear layers found.") 

146 

147 for layer in layers: 

148 if not isinstance(layer, nn.Linear): 

149 raise ValueError("LRP only supported for Linear layers in general case.") 

150 

151 L = len(layers) 

152 

153 has_norm = hasattr(model, 'norm_in') and model.norm_in is not None 

154 if has_norm: 

155 input_to_encoder = model.norm_in(x) 

156 else: 

157 input_to_encoder = x 

158 

159 A = [input_to_encoder.clone()] 

160 Z = [None] * (L + 1) 

161 for layer_idx in range(L): 

162 lin = layers[layer_idx] 

163 z = A[layer_idx] @ lin.weight.t() + lin.bias 

164 sign_z = z.sign() 

165 Z[layer_idx + 1] = z + eps * sign_z 

166 if layer_idx < L - 1: 

167 a = F.relu(z) 

168 else: 

169 a = z 

170 A.append(a) 

171 

172 zL = A[L] 

173 if latent_index is None: 

174 R = [None] * (L + 1) 

175 R[L] = zL.sum(dim=1, keepdim=True) 

176 else: 

177 R = [None] * (L + 1) 

178 R[L] = zL[:, [latent_index]] 

179 

180 for layer_idx in range(L - 1, -1, -1): 

181 lin = layers[layer_idx] 

182 s = R[layer_idx + 1] / Z[layer_idx + 1] 

183 c = s @ lin.weight 

184 R[layer_idx] = A[layer_idx] * c 

185 

186 R0 = R[0] 

187 

188 if has_norm: 

189 w = (1 / model.norm_in.range).view(1, -1).to(x.device) 

190 b = (-model.norm_in.mean / model.norm_in.range).view(1, -1).to(x.device) 

191 z = x * w + b 

192 sign_z = z.sign() 

193 Z = z + eps * sign_z 

194 s = R0 / Z 

195 c = s * w 

196 R0 = x * c 

197 

198 return R0 

199 

200 # Fall back to simple encoder version 

201 return _lrp_encoder_simple(model, x, latent_index, eps) 

202 

203 

204# Layer-wise Relevance Propagation (simplified version for encoder modules) 

205def _lrp_encoder_simple( 

206 encoder: nn.Module, 

207 x: torch.Tensor, 

208 latent_index: int = None, 

209 eps: float = 1e-6 

210) -> torch.Tensor: 

211 """ 

212 Perform Layer‐Wise Relevance Propagation on `encoder` for input `x`. 

213 

214 Arguments: 

215 encoder -- an nn.Module mapping [batch, in_dim] → [batch, latent_dim], 

216 built from Linear + ReLU layers. 

217 x -- input tensor, shape [batch, in_dim]. 

218 latent_index -- which coordinate of the latent vector to explain. 

219 If None, we explain the sum over all latent dims. 

220 eps -- stabilization term to avoid division by zero. 

221 

222 Returns: 

223 R0 -- relevance at the input layer, shape [batch, in_dim]. 

224 R0[b, i] is “how important feature i was for the chosen 

225 latent coordinate (or sum).” 

226 """ 

227 device = x.device 

228 

229 # 1) Extract all Linear layers in execution order 

230 layers = [] 

231 for module in encoder.modules(): 

232 if isinstance(module, nn.Linear): 

233 layers.append(module) 

234 L = len(layers) 

235 

236 # 2) FORWARD PASS: collect activations A[l] and pre-activations Z[l] 

237 A = [x.clone().to(device)] 

238 Z = [None] * (L + 1) 

239 for layer_idx, lin in enumerate(layers): 

240 z = A[layer_idx] @ lin.weight.t() + lin.bias # shape [batch, out_dim] 

241 Z[layer_idx + 1] = z + eps # add eps for numerical stability 

242 a = F.relu(z) 

243 A.append(a) 

244 

245 # 3) INITIALIZE RELEVANCE at the top 

246 # Let zL = A[L] be shape [batch, latent_dim] 

247 zL = A[L] 

248 if latent_index is None: 

249 # explain the sum of all latent coords 

250 R = [None] * (L + 1) 

251 R[L] = zL.sum(dim=1, keepdim=True) # shape [batch, 1] 

252 else: 

253 # explain a single latent coordinate 

254 R = [None] * (L + 1) 

255 R[L] = zL[:, [latent_index]] # shape [batch, 1] 

256 

257 # 4) BACKWARD PASS (LRP) from layer L → 0 

258 # At each step layer_idx, we have R[layer_idx+1] of shape [batch, out_dim]. 

259 # We want R[layer_idx] of shape [batch, in_dim]. 

260 for layer_idx in range(L - 1, -1, -1): 

261 lin = layers[layer_idx] 

262 w = lin.weight # shape [out_dim, in_dim] 

263 # Z[layer_idx+1]: [batch, out_dim], R[layer_idx+1]: [batch, out_dim] 

264 s = R[layer_idx + 1] / Z[layer_idx + 1] # [batch, out_dim] 

265 c = s @ w # [batch, in_dim], since w is [out, in] 

266 # multiply by the forward activation to zero‐out inactive neurons 

267 R[layer_idx] = A[layer_idx] * c # [batch, in_dim] 

268 

269 # R[0] is now the relevance of each input feature 

270 return R[0] 

271 

272# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

273 

274# X_batch1 = torch.tensor(data[2000:4000]).to(device) 

275 

276# # explain the sum of *all* latent coords: 

277# R0_sum = lrp_encoder(model.encoder, X_batch1, latent_index=None) 

278 

279# # To get a global feature ranking, average absolute relevance over the batch: 

280# R0_sum = R0_sum.reshape(R0_sum.size(0), -1, 3) 

281 

282# R0_sum = R0_sum.mean(dim=2) # [in_dim] 

283# global_importance = R0_sum.abs().mean(dim=0) # [in_dim] 

284# global_importance = global_importance.cpu().detach().numpy() 

285 

286# # Normalize 

287# global_importance = (global_importance - global_importance.min()) / (global_importance.max() - global_importance.min())