Coverage for biobb_pytorch / mdae / decode_model.py: 97%

64 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import torch 

2from torch.utils.data import DataLoader 

3import os 

4from biobb_common.tools.file_utils import launchlogger 

5from biobb_common.tools import file_utils as fu 

6from biobb_pytorch.mdae.utils.log_utils import get_size 

7from biobb_common.generic.biobb_object import BiobbObject 

8import numpy as np 

9 

10 

11class EvaluateDecoder(BiobbObject): 

12 """ 

13 | biobb_pytorch evaluateDecoder 

14 | Evaluates a PyTorch autoencoder from the given properties. 

15 | Evaluates a PyTorch autoencoder from the given properties. 

16 

17 Args: 

18 input_model_pth_path (str): Path to the trained model file whose decoder will be used. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333). 

19 input_dataset_npy_path (str): Path to the input latent variables file in NumPy format (e.g. encoded 'z'). File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.npy>`_. Accepted formats: npy (edam:format_2333). 

20 output_results_npz_path (str): Path to the output reconstructed data file (compressed NumPy archive, typically containing 'xhat'). File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_results.npz>`_. Accepted formats: npz (edam:format_2333). 

21 properties (dict - Python dictionary object containing the tool parameters, not input/output files): 

22 * **Dataset** (*dict*) - ({}) DataLoader options (e.g. batch_size, shuffle) for batching the latent variables. 

23 

24 Examples: 

25 This example shows how to use the EvaluateDecoder class to evaluate a PyTorch autoencoder model:: 

26 

27 from biobb_pytorch.mdae.decode_model import evaluateDecoder 

28 

29 input_model_pth_path='input_model.pth' 

30 input_dataset_npy_path='input_dataset.npy' 

31 output_results_npz_path='output_results.npz' 

32 

33 prop={ 

34 'Dataset': { 

35 'batch_size': 32 

36 } 

37 } 

38 

39 evaluateDecoder(input_model_pth_path=input_model.pth, 

40 input_dataset_npy_path=input_dataset.npy, 

41 output_results_npz_path=output_results.npz, 

42 properties=prop) 

43 

44 Info: 

45 * wrapped_software: 

46 * name: PyTorch 

47 * version: >=1.6.0 

48 * license: BSD 3-Clause 

49 * ontology: 

50 * name: EDAM 

51 * schema: http://edamontology.org/EDAM.owl 

52 """ 

53 

54 def __init__( 

55 self, 

56 input_model_pth_path: str, 

57 input_dataset_npy_path: str, 

58 output_results_npz_path: str, 

59 properties: dict, 

60 **kwargs, 

61 ) -> None: 

62 

63 properties = properties or {} 

64 

65 super().__init__(properties) 

66 

67 self.input_model_pth_path = input_model_pth_path 

68 self.input_dataset_npy_path = input_dataset_npy_path 

69 self.output_results_npz_path = output_results_npz_path 

70 self.properties = properties.copy() 

71 self.locals_var_dict = locals().copy() 

72 

73 # Input/Output files 

74 self.io_dict = { 

75 "in": { 

76 "input_model_pth_path": input_model_pth_path, 

77 "input_dataset_npy_path": input_dataset_npy_path, 

78 }, 

79 "out": { 

80 "output_results_npz_path": output_results_npz_path, 

81 }, 

82 } 

83 

84 self.Dataset = self.properties.get('Dataset', {}) 

85 self.results = None 

86 

87 # Check the properties 

88 self.check_properties(properties) 

89 self.check_arguments() 

90 

91 def load_model(self): 

92 return torch.load(self.io_dict["in"]["input_model_pth_path"], 

93 weights_only=False) 

94 

95 def load_dataset(self): 

96 dataset = torch.tensor(np.load(self.io_dict["in"]["input_dataset_npy_path"])) 

97 return dataset.float() 

98 

99 def create_dataloader(self, dataset): 

100 ds_cfg = self.properties['Dataset'] 

101 return DataLoader( 

102 dataset, 

103 batch_size=ds_cfg.get('batch_size', 16), 

104 shuffle=ds_cfg.get('shuffle', False), 

105 ) 

106 

107 def evaluate_decoder(self, model, dataloader): 

108 """Evaluate the decoder part of the model.""" 

109 model.eval() 

110 with torch.no_grad(): 

111 all_reconstructions = [] 

112 for batch in dataloader: 

113 z = model.decode(batch) 

114 all_reconstructions.append(z) 

115 return {"xhat": torch.cat(all_reconstructions, dim=0)} 

116 

117 @launchlogger 

118 def launch(self) -> int: 

119 """ 

120 Execute the :class:`EvaluateDecoder` class and its `.launch()` method. 

121 """ 

122 

123 fu.log('## BioBB Model Evaluator ##', self.out_log) 

124 

125 # Setup Biobb 

126 if self.check_restart(): 

127 return 0 

128 

129 self.stage_files() 

130 

131 # Start Pipeline 

132 

133 # load the model 

134 fu.log(f'Load model from {os.path.abspath(self.io_dict["in"]["input_model_pth_path"])}', self.out_log) 

135 model = self.load_model() 

136 

137 # load the dataset 

138 fu.log(f'Load dataset from {os.path.abspath(self.io_dict["in"]["input_dataset_npy_path"])}', self.out_log) 

139 dataset = self.load_dataset() 

140 

141 # create the dataloader 

142 fu.log('Start evaluating...', self.out_log) 

143 dataloader = self.create_dataloader(dataset) 

144 

145 # evaluate the model 

146 results = self.evaluate_decoder(model, dataloader) 

147 

148 # Save the results 

149 np.savez_compressed(self.io_dict["out"]["output_results_npz_path"], **results) 

150 fu.log(f'Evaluation Results saved to {os.path.abspath(self.io_dict["out"]["output_results_npz_path"])}', self.out_log) 

151 fu.log(f'File size: {get_size(self.io_dict["out"]["output_results_npz_path"])}', self.out_log) 

152 

153 # Copy files to host 

154 self.copy_to_host() 

155 

156 # Remove temporal files 

157 self.remove_tmp_files() 

158 

159 self.check_arguments(output_files_created=True, raise_exception=False) 

160 

161 return 0 

162 

163 

164def evaluateDecoder( 

165 properties: dict, 

166 input_model_pth_path: str, 

167 input_dataset_npy_path: str, 

168 output_results_npz_path: str, 

169 **kwargs, 

170) -> int: 

171 """Create the :class:`EvaluateDecoder <EvaluateDecoder.EvaluateDecoder>` class and 

172 execute the :meth:`launch() <EvaluateDecoder.evaluateDecoder.launch>` method.""" 

173 return EvaluateDecoder(**dict(locals())).launch() 

174 

175 

176evaluateDecoder.__doc__ = EvaluateDecoder.__doc__ 

177main = EvaluateDecoder.get_main(evaluateDecoder, "Evaluates a PyTorch autoencoder from the given properties.") 

178 

179if __name__ == "__main__": 

180 main()