Coverage for biobb_pytorch / mdae / explainability / LRP.py: 89%
91 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import torch
2import numpy as np
3import os
4from typing import Optional
5from biobb_common.tools.file_utils import launchlogger
6from biobb_common.tools import file_utils as fu
7from biobb_pytorch.mdae.utils.log_utils import get_size
8from biobb_common.generic.biobb_object import BiobbObject
9from torch.utils.data import DataLoader
10from mlcolvar.data import DictDataset
11from biobb_pytorch.mdae.explainability.layerwise_relevance_prop import lrp_encoder
14class LRP(BiobbObject):
15 """
16 | biobb_pytorch LRP
17 | Performs Layer-wise Relevance Propagation on a trained autoencoder encoder.
18 | Performs Layer-wise Relevance Propagation on a trained autoencoder encoder.
20 Args:
21 input_model_pth_path (str): Path to the trained model file whose encoder is analyzed. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333).
22 input_dataset_pt_path (str): Path to the input dataset file (.pt) used for computing relevance scores. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pt>`_. Accepted formats: pt (edam:format_2333).
23 output_results_npz_path (str) (Optional): Path to the output results file containing relevance scores (compressed NumPy archive). File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_results.npz>`_. Accepted formats: npz (edam:format_2333).
24 properties (dict - Python dictionary object containing the tool parameters, not input/output files):
25 * **Dataset** (*dict*) - ({}) Dataset/DataLoader options (e.g. batch_size and optional indices to subset the dataset).
27 Examples:
28 This example shows how to use the LRP class to perform Layer-wise Relevance Propagation::
30 from biobb_pytorch.mdae.explainability import relevancePropagation
32 input_model_pth_path='input_model.pth'
33 input_dataset_pt_path='input_dataset.pt'
34 output_results_npz_path='output_results.npz'
36 prop={
37 'Dataset': {
38 'batch_size': 32
39 }
40 }
42 LRP(input_model_pth_path=input_model_pth_path,
43 input_dataset_pt_path=input_dataset_pt_path,
44 output_results_npz_path=None,
45 properties=prop)
48 Info:
49 * wrapped_software:
50 * name: PyTorch
51 * version: >=1.6.0
52 * license: BSD 3-Clause
53 * ontology:
54 * name: EDAM
55 * schema: http://edamontology.org/EDAM.owl
56 """
58 def __init__(
59 self,
60 input_model_pth_path: str,
61 input_dataset_pt_path: str,
62 output_results_npz_path: Optional[str] = None,
63 properties: dict = None,
64 **kwargs,
65 ) -> None:
67 properties = properties or {}
69 super().__init__(properties)
71 self.input_model_pth_path = input_model_pth_path
72 self.input_dataset_pt_path = input_dataset_pt_path
73 self.output_results_npz_path = output_results_npz_path
74 self.properties = properties.copy()
75 self.locals_var_dict = locals().copy()
77 # Input/Output files
78 self.io_dict = {
79 "in": {
80 "input_model_pth_path": input_model_pth_path,
81 "input_dataset_pt_path": input_dataset_pt_path,
82 },
83 "out": {},
84 }
86 if output_results_npz_path:
87 self.io_dict["out"]["output_results_npz_path"] = output_results_npz_path
89 self.Dataset = self.properties.get('Dataset', {})
90 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91 self.results = None
93 # Check the properties
94 self.check_properties(properties)
95 self.check_arguments()
97 def load_model(self):
98 return torch.load(self.io_dict["in"]["input_model_pth_path"], weights_only=False)
100 def mask_idx(self, dataset: dict, indices: np.ndarray) -> dict:
101 """
102 Mask the dataset (dict) for all keys.
103 """
104 for key in dataset.keys():
105 dataset[key] = dataset[key][indices]
106 return dataset
108 def load_dataset(self):
109 dataset = torch.load(self.io_dict["in"]["input_dataset_pt_path"], weights_only=False)
111 if self.Dataset.get('indices', None):
112 if isinstance(self.Dataset['indices'], list):
113 indices = np.array(self.Dataset['indices'])
114 elif isinstance(self.Dataset['indices'], np.ndarray):
115 indices = self.Dataset['indices']
116 dataset = self.mask_idx(dataset, indices)
118 return DictDataset(dataset)
120 def create_dataloader(self, dataset):
121 ds_cfg = self.properties['Dataset']
122 return DataLoader(
123 dataset,
124 batch_size=ds_cfg.get('batch_size', 16),
125 shuffle=False
126 )
128 def compute_global_importance(self, model, dataloader, latent_index=None):
129 all_R0 = []
130 for batch in dataloader:
131 X_batch = batch['data'].to(self.device) # Assuming DictDataset with 'data' key
132 R0 = lrp_encoder(model, X_batch, latent_index=latent_index)
133 all_R0.append(R0.cpu()) # Move to CPU to save GPU memory
134 R0_all = torch.cat(all_R0, dim=0) # [total_samples, in_dim]
136 # Reshape assuming features grouped by 3 (e.g., coordinates); adjust if needed
137 num_features = R0_all.size(1) // 3
138 R0_all = R0_all.reshape(-1, num_features, 3)
139 R0_mean = R0_all.mean(dim=2) # [total_samples, num_features]
141 global_importance = R0_mean.abs().mean(dim=0) # [num_features]
142 global_importance_raw = global_importance.detach().numpy()
144 # Normalize
145 min_val = global_importance_raw.min()
146 max_val = global_importance_raw.max()
147 global_range = max_val - min_val + 1e-10 # Avoid division by zero
148 global_importance_norm = (global_importance_raw - min_val) / global_range
150 return {
151 "global_importance": global_importance_norm,
152 "global_importance_raw": global_importance_raw,
153 }
155 @launchlogger
156 def launch(self) -> int:
157 """
158 Execute the :class:`LRP` class and its `.launch()` method.
159 """
161 fu.log('## BioBB Layer-wise Relevance Propagation ##', self.out_log)
163 # Setup Biobb
164 if self.check_restart():
165 return 0
167 self.stage_files()
169 # load the model
170 fu.log(f'Load model from {os.path.abspath(self.io_dict["in"]["input_model_pth_path"])}', self.out_log)
171 model = self.load_model()
173 # load the dataset
174 fu.log(f'Load dataset from {os.path.abspath(self.io_dict["in"]["input_dataset_pt_path"])}', self.out_log)
175 dataset = self.load_dataset()
177 # create the dataloader
178 fu.log('Start LRP analysis...', self.out_log)
179 dataloader = self.create_dataloader(dataset)
181 # Compute LRP
182 self.results = self.compute_global_importance(model, dataloader, latent_index=None)
184 # Save the results if path provided
185 if self.output_results_npz_path:
186 np.savez_compressed(self.io_dict["out"]["output_results_npz_path"], **self.results)
187 fu.log(f'Results saved to {os.path.abspath(self.io_dict["out"]["output_results_npz_path"])}', self.out_log)
188 fu.log(f'File size: {get_size(self.io_dict["out"]["output_results_npz_path"])}', self.out_log)
190 # Copy files to host
191 self.copy_to_host()
193 # Remove temporal files
194 self.remove_tmp_files()
196 output_created = bool(self.output_results_npz_path)
197 self.check_arguments(output_files_created=output_created, raise_exception=False)
199 return 0
202def relevance_propagation(
203 properties: dict,
204 input_model_pth_path: str,
205 input_dataset_pt_path: str,
206 output_results_npz_path: Optional[str] = None,
207 **kwargs,
208) -> int:
209 """Create the :class:`LRP <LRP>` class and
210 execute the :meth:`launch() <LRP.launch>` method."""
211 return LRP(**dict(locals())).launch()
214relevance_propagation.__doc__ = LRP.__doc__
215main = LRP.get_main(relevance_propagation, "Performs Layer-wise Relevance Propagation on a trained autoencoder encoder.")
218if __name__ == "__main__":
219 main()