Coverage for biobb_pytorch / mdae / explainability / LRP.py: 89%

91 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import torch 

2import numpy as np 

3import os 

4from typing import Optional 

5from biobb_common.tools.file_utils import launchlogger 

6from biobb_common.tools import file_utils as fu 

7from biobb_pytorch.mdae.utils.log_utils import get_size 

8from biobb_common.generic.biobb_object import BiobbObject 

9from torch.utils.data import DataLoader 

10from mlcolvar.data import DictDataset 

11from biobb_pytorch.mdae.explainability.layerwise_relevance_prop import lrp_encoder 

12 

13 

14class LRP(BiobbObject): 

15 """ 

16 | biobb_pytorch LRP 

17 | Performs Layer-wise Relevance Propagation on a trained autoencoder encoder. 

18 | Performs Layer-wise Relevance Propagation on a trained autoencoder encoder. 

19 

20 Args: 

21 input_model_pth_path (str): Path to the trained model file whose encoder is analyzed. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333). 

22 input_dataset_pt_path (str): Path to the input dataset file (.pt) used for computing relevance scores. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pt>`_. Accepted formats: pt (edam:format_2333). 

23 output_results_npz_path (str) (Optional): Path to the output results file containing relevance scores (compressed NumPy archive). File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_results.npz>`_. Accepted formats: npz (edam:format_2333). 

24 properties (dict - Python dictionary object containing the tool parameters, not input/output files): 

25 * **Dataset** (*dict*) - ({}) Dataset/DataLoader options (e.g. batch_size and optional indices to subset the dataset). 

26 

27 Examples: 

28 This example shows how to use the LRP class to perform Layer-wise Relevance Propagation:: 

29 

30 from biobb_pytorch.mdae.explainability import relevancePropagation 

31 

32 input_model_pth_path='input_model.pth' 

33 input_dataset_pt_path='input_dataset.pt' 

34 output_results_npz_path='output_results.npz' 

35 

36 prop={ 

37 'Dataset': { 

38 'batch_size': 32 

39 } 

40 } 

41 

42 LRP(input_model_pth_path=input_model_pth_path, 

43 input_dataset_pt_path=input_dataset_pt_path, 

44 output_results_npz_path=None, 

45 properties=prop) 

46 

47 

48 Info: 

49 * wrapped_software: 

50 * name: PyTorch 

51 * version: >=1.6.0 

52 * license: BSD 3-Clause 

53 * ontology: 

54 * name: EDAM 

55 * schema: http://edamontology.org/EDAM.owl 

56 """ 

57 

58 def __init__( 

59 self, 

60 input_model_pth_path: str, 

61 input_dataset_pt_path: str, 

62 output_results_npz_path: Optional[str] = None, 

63 properties: dict = None, 

64 **kwargs, 

65 ) -> None: 

66 

67 properties = properties or {} 

68 

69 super().__init__(properties) 

70 

71 self.input_model_pth_path = input_model_pth_path 

72 self.input_dataset_pt_path = input_dataset_pt_path 

73 self.output_results_npz_path = output_results_npz_path 

74 self.properties = properties.copy() 

75 self.locals_var_dict = locals().copy() 

76 

77 # Input/Output files 

78 self.io_dict = { 

79 "in": { 

80 "input_model_pth_path": input_model_pth_path, 

81 "input_dataset_pt_path": input_dataset_pt_path, 

82 }, 

83 "out": {}, 

84 } 

85 

86 if output_results_npz_path: 

87 self.io_dict["out"]["output_results_npz_path"] = output_results_npz_path 

88 

89 self.Dataset = self.properties.get('Dataset', {}) 

90 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

91 self.results = None 

92 

93 # Check the properties 

94 self.check_properties(properties) 

95 self.check_arguments() 

96 

97 def load_model(self): 

98 return torch.load(self.io_dict["in"]["input_model_pth_path"], weights_only=False) 

99 

100 def mask_idx(self, dataset: dict, indices: np.ndarray) -> dict: 

101 """ 

102 Mask the dataset (dict) for all keys. 

103 """ 

104 for key in dataset.keys(): 

105 dataset[key] = dataset[key][indices] 

106 return dataset 

107 

108 def load_dataset(self): 

109 dataset = torch.load(self.io_dict["in"]["input_dataset_pt_path"], weights_only=False) 

110 

111 if self.Dataset.get('indices', None): 

112 if isinstance(self.Dataset['indices'], list): 

113 indices = np.array(self.Dataset['indices']) 

114 elif isinstance(self.Dataset['indices'], np.ndarray): 

115 indices = self.Dataset['indices'] 

116 dataset = self.mask_idx(dataset, indices) 

117 

118 return DictDataset(dataset) 

119 

120 def create_dataloader(self, dataset): 

121 ds_cfg = self.properties['Dataset'] 

122 return DataLoader( 

123 dataset, 

124 batch_size=ds_cfg.get('batch_size', 16), 

125 shuffle=False 

126 ) 

127 

128 def compute_global_importance(self, model, dataloader, latent_index=None): 

129 all_R0 = [] 

130 for batch in dataloader: 

131 X_batch = batch['data'].to(self.device) # Assuming DictDataset with 'data' key 

132 R0 = lrp_encoder(model, X_batch, latent_index=latent_index) 

133 all_R0.append(R0.cpu()) # Move to CPU to save GPU memory 

134 R0_all = torch.cat(all_R0, dim=0) # [total_samples, in_dim] 

135 

136 # Reshape assuming features grouped by 3 (e.g., coordinates); adjust if needed 

137 num_features = R0_all.size(1) // 3 

138 R0_all = R0_all.reshape(-1, num_features, 3) 

139 R0_mean = R0_all.mean(dim=2) # [total_samples, num_features] 

140 

141 global_importance = R0_mean.abs().mean(dim=0) # [num_features] 

142 global_importance_raw = global_importance.detach().numpy() 

143 

144 # Normalize 

145 min_val = global_importance_raw.min() 

146 max_val = global_importance_raw.max() 

147 global_range = max_val - min_val + 1e-10 # Avoid division by zero 

148 global_importance_norm = (global_importance_raw - min_val) / global_range 

149 

150 return { 

151 "global_importance": global_importance_norm, 

152 "global_importance_raw": global_importance_raw, 

153 } 

154 

155 @launchlogger 

156 def launch(self) -> int: 

157 """ 

158 Execute the :class:`LRP` class and its `.launch()` method. 

159 """ 

160 

161 fu.log('## BioBB Layer-wise Relevance Propagation ##', self.out_log) 

162 

163 # Setup Biobb 

164 if self.check_restart(): 

165 return 0 

166 

167 self.stage_files() 

168 

169 # load the model 

170 fu.log(f'Load model from {os.path.abspath(self.io_dict["in"]["input_model_pth_path"])}', self.out_log) 

171 model = self.load_model() 

172 

173 # load the dataset 

174 fu.log(f'Load dataset from {os.path.abspath(self.io_dict["in"]["input_dataset_pt_path"])}', self.out_log) 

175 dataset = self.load_dataset() 

176 

177 # create the dataloader 

178 fu.log('Start LRP analysis...', self.out_log) 

179 dataloader = self.create_dataloader(dataset) 

180 

181 # Compute LRP 

182 self.results = self.compute_global_importance(model, dataloader, latent_index=None) 

183 

184 # Save the results if path provided 

185 if self.output_results_npz_path: 

186 np.savez_compressed(self.io_dict["out"]["output_results_npz_path"], **self.results) 

187 fu.log(f'Results saved to {os.path.abspath(self.io_dict["out"]["output_results_npz_path"])}', self.out_log) 

188 fu.log(f'File size: {get_size(self.io_dict["out"]["output_results_npz_path"])}', self.out_log) 

189 

190 # Copy files to host 

191 self.copy_to_host() 

192 

193 # Remove temporal files 

194 self.remove_tmp_files() 

195 

196 output_created = bool(self.output_results_npz_path) 

197 self.check_arguments(output_files_created=output_created, raise_exception=False) 

198 

199 return 0 

200 

201 

202def relevance_propagation( 

203 properties: dict, 

204 input_model_pth_path: str, 

205 input_dataset_pt_path: str, 

206 output_results_npz_path: Optional[str] = None, 

207 **kwargs, 

208) -> int: 

209 """Create the :class:`LRP <LRP>` class and 

210 execute the :meth:`launch() <LRP.launch>` method.""" 

211 return LRP(**dict(locals())).launch() 

212 

213 

214relevance_propagation.__doc__ = LRP.__doc__ 

215main = LRP.get_main(relevance_propagation, "Performs Layer-wise Relevance Propagation on a trained autoencoder encoder.") 

216 

217 

218if __name__ == "__main__": 

219 main()