Coverage for biobb_pytorch / mdae / explainability / layerwise_relevance_prop.py: 36%
162 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
6def lrp_gmvae_single(model, x, latent_index, eps=1e-6):
7 device = x.device
8 batch_size = x.shape[0]
9 in_features = model.in_features
10 k = model.k
12 has_norm = model.norm_in is not None
13 if has_norm:
14 x_input = model.norm_in(x)
15 else:
16 x_input = x
18 qy_logit = model.encoder['qy_nn'](x_input)
19 qy = torch.softmax(qy_logit, dim=1)
21 y_ = torch.zeros(batch_size, k, device=device)
23 zm_list = []
24 intermediates_list = []
26 for i in range(k):
27 y = y_ + torch.eye(k, device=device)[i]
29 # y_transform
30 module = model.encoder['y_transform']
31 z_h0 = y @ module.weight.t() + module.bias
32 a_h0 = z_h0 # no activation
34 intermediates = [('linear', module, y, z_h0)]
36 xy = torch.cat([x_input, a_h0], dim=1)
38 a = xy
40 # qz_nn
41 for sub_module in model.encoder['qz_nn'].children():
42 if isinstance(sub_module, nn.Linear):
43 z = a @ sub_module.weight.t() + sub_module.bias
44 intermediates.append(('linear', sub_module, a, z))
45 a = z
46 elif isinstance(sub_module, nn.ReLU):
47 a = F.relu(a)
49 # zm_layer
50 module = model.encoder['zm_layer']
51 z = a @ module.weight.t() + module.bias
52 intermediates.append(('linear', module, a, z))
54 zm = z
55 zm_list.append(zm)
57 intermediates_list.append(intermediates)
59 zm = torch.stack(zm_list, dim=1) # [batch, k, n_cvs]
61 term_k = qy * zm[:, :, latent_index] # [batch, k]
63 selected_a = torch.sum(term_k, dim=1) # [batch]
65 R_a = selected_a.unsqueeze(1) # [batch,1]
67 sign_a = selected_a.sign().unsqueeze(1)
69 den = selected_a.unsqueeze(1) + eps * sign_a
71 R_term_k = R_a * (term_k.unsqueeze(2) / den) # [batch, k,1]
73 R0 = torch.zeros(batch_size, in_features, device=device)
75 for i in range(k):
76 R = R_term_k[:, i, :] # [batch,1]
78 intermediates = intermediates_list[i][::-1]
80 for op, module, a_prev, z_prev in intermediates:
81 if op == 'linear':
82 sign_z = z_prev.sign()
83 Z = z_prev + eps * sign_z
84 s = R / Z
85 c = s @ module.weight
86 R = a_prev * c
88 # R is now R for xy
89 R_xy = R
90 R_x_k = R_xy[:, :in_features]
91 R0 += R_x_k
93 if has_norm:
94 w = (1 / model.norm_in.range).view(1, -1).to(device)
95 b = (-model.norm_in.mean / model.norm_in.range).view(1, -1).to(device)
96 z = x * w + b
97 sign_z = z.sign()
98 Z = z + eps * sign_z
99 s = R0 / Z
100 c = s * w
101 R0 = x * c
103 return R0
106def lrp_encoder(
107 model: nn.Module,
108 x: torch.Tensor,
109 latent_index: int = None,
110 eps: float = 1e-6
111) -> torch.Tensor:
112 """Main LRP function that handles both full models and encoder modules."""
113 # Check if it's a GMVAE model
114 if model.__class__.__name__ == 'GaussianMixtureVariationalAutoEncoder':
115 if latent_index is None:
116 R0 = 0
117 for j in range(model.out_features):
118 R0 += lrp_gmvae_single(model, x, j, eps)
119 return R0
120 else:
121 return lrp_gmvae_single(model, x, latent_index, eps)
123 # Check if model has forward_cv method (full model with encoder)
124 if hasattr(model, 'forward_cv'):
125 # General case for MLP models with forward_cv
126 handles = []
127 layers = []
129 def collect_hook(module, inp, out):
130 layers.append(module)
132 for m in model.modules():
133 if isinstance(m, nn.Linear):
134 handle = m.register_forward_hook(collect_hook)
135 handles.append(handle)
137 model.eval()
138 with torch.no_grad():
139 _ = model.forward_cv(x[:1])
141 for h in handles:
142 h.remove()
144 if len(layers) == 0:
145 raise ValueError("No Linear layers found.")
147 for layer in layers:
148 if not isinstance(layer, nn.Linear):
149 raise ValueError("LRP only supported for Linear layers in general case.")
151 L = len(layers)
153 has_norm = hasattr(model, 'norm_in') and model.norm_in is not None
154 if has_norm:
155 input_to_encoder = model.norm_in(x)
156 else:
157 input_to_encoder = x
159 A = [input_to_encoder.clone()]
160 Z = [None] * (L + 1)
161 for layer_idx in range(L):
162 lin = layers[layer_idx]
163 z = A[layer_idx] @ lin.weight.t() + lin.bias
164 sign_z = z.sign()
165 Z[layer_idx + 1] = z + eps * sign_z
166 if layer_idx < L - 1:
167 a = F.relu(z)
168 else:
169 a = z
170 A.append(a)
172 zL = A[L]
173 if latent_index is None:
174 R = [None] * (L + 1)
175 R[L] = zL.sum(dim=1, keepdim=True)
176 else:
177 R = [None] * (L + 1)
178 R[L] = zL[:, [latent_index]]
180 for layer_idx in range(L - 1, -1, -1):
181 lin = layers[layer_idx]
182 s = R[layer_idx + 1] / Z[layer_idx + 1]
183 c = s @ lin.weight
184 R[layer_idx] = A[layer_idx] * c
186 R0 = R[0]
188 if has_norm:
189 w = (1 / model.norm_in.range).view(1, -1).to(x.device)
190 b = (-model.norm_in.mean / model.norm_in.range).view(1, -1).to(x.device)
191 z = x * w + b
192 sign_z = z.sign()
193 Z = z + eps * sign_z
194 s = R0 / Z
195 c = s * w
196 R0 = x * c
198 return R0
200 # Fall back to simple encoder version
201 return _lrp_encoder_simple(model, x, latent_index, eps)
204# Layer-wise Relevance Propagation (simplified version for encoder modules)
205def _lrp_encoder_simple(
206 encoder: nn.Module,
207 x: torch.Tensor,
208 latent_index: int = None,
209 eps: float = 1e-6
210) -> torch.Tensor:
211 """
212 Perform Layer‐Wise Relevance Propagation on `encoder` for input `x`.
214 Arguments:
215 encoder -- an nn.Module mapping [batch, in_dim] → [batch, latent_dim],
216 built from Linear + ReLU layers.
217 x -- input tensor, shape [batch, in_dim].
218 latent_index -- which coordinate of the latent vector to explain.
219 If None, we explain the sum over all latent dims.
220 eps -- stabilization term to avoid division by zero.
222 Returns:
223 R0 -- relevance at the input layer, shape [batch, in_dim].
224 R0[b, i] is “how important feature i was for the chosen
225 latent coordinate (or sum).”
226 """
227 device = x.device
229 # 1) Extract all Linear layers in execution order
230 layers = []
231 for module in encoder.modules():
232 if isinstance(module, nn.Linear):
233 layers.append(module)
234 L = len(layers)
236 # 2) FORWARD PASS: collect activations A[l] and pre-activations Z[l]
237 A = [x.clone().to(device)]
238 Z = [None] * (L + 1)
239 for layer_idx, lin in enumerate(layers):
240 z = A[layer_idx] @ lin.weight.t() + lin.bias # shape [batch, out_dim]
241 Z[layer_idx + 1] = z + eps # add eps for numerical stability
242 a = F.relu(z)
243 A.append(a)
245 # 3) INITIALIZE RELEVANCE at the top
246 # Let zL = A[L] be shape [batch, latent_dim]
247 zL = A[L]
248 if latent_index is None:
249 # explain the sum of all latent coords
250 R = [None] * (L + 1)
251 R[L] = zL.sum(dim=1, keepdim=True) # shape [batch, 1]
252 else:
253 # explain a single latent coordinate
254 R = [None] * (L + 1)
255 R[L] = zL[:, [latent_index]] # shape [batch, 1]
257 # 4) BACKWARD PASS (LRP) from layer L → 0
258 # At each step layer_idx, we have R[layer_idx+1] of shape [batch, out_dim].
259 # We want R[layer_idx] of shape [batch, in_dim].
260 for layer_idx in range(L - 1, -1, -1):
261 lin = layers[layer_idx]
262 w = lin.weight # shape [out_dim, in_dim]
263 # Z[layer_idx+1]: [batch, out_dim], R[layer_idx+1]: [batch, out_dim]
264 s = R[layer_idx + 1] / Z[layer_idx + 1] # [batch, out_dim]
265 c = s @ w # [batch, in_dim], since w is [out, in]
266 # multiply by the forward activation to zero‐out inactive neurons
267 R[layer_idx] = A[layer_idx] * c # [batch, in_dim]
269 # R[0] is now the relevance of each input feature
270 return R[0]
272# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
274# X_batch1 = torch.tensor(data[2000:4000]).to(device)
276# # explain the sum of *all* latent coords:
277# R0_sum = lrp_encoder(model.encoder, X_batch1, latent_index=None)
279# # To get a global feature ranking, average absolute relevance over the batch:
280# R0_sum = R0_sum.reshape(R0_sum.size(0), -1, 3)
282# R0_sum = R0_sum.mean(dim=2) # [in_dim]
283# global_importance = R0_sum.abs().mean(dim=0) # [in_dim]
284# global_importance = global_importance.cpu().detach().numpy()
286# # Normalize
287# global_importance = (global_importance - global_importance.min()) / (global_importance.max() - global_importance.min())