Coverage for biobb_pytorch / mdae / encode_model.py: 97%

65 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import torch 

2from torch.utils.data import DataLoader 

3import os 

4from biobb_common.tools.file_utils import launchlogger 

5from biobb_common.tools import file_utils as fu 

6from biobb_pytorch.mdae.utils.log_utils import get_size 

7from biobb_common.generic.biobb_object import BiobbObject 

8from mlcolvar.data import DictDataset 

9import numpy as np 

10 

11 

12class EvaluateEncoder(BiobbObject): 

13 """ 

14 | biobb_pytorch EvaluateEncoder 

15 | Encode data with a Molecular Dynamics AutoEncoder (MDAE) model. 

16 | Evaluates a PyTorch autoencoder from the given properties. 

17 

18 Args: 

19 input_model_pth_path (str): Path to the trained model file whose encoder will be used. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333). 

20 input_dataset_pt_path (str): Path to the input dataset file (.pt) to encode. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pt>`_. Accepted formats: pt (edam:format_2333). 

21 output_results_npz_path (str): Path to the output latent-space results file (compressed NumPy archive, typically containing 'z'). File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_results.npz>`_. Accepted formats: npz (edam:format_2333). 

22 properties (dict - Python dictionary object containing the tool parameters, not input/output files): 

23 * **Dataset** (*dict*) - ({}) mlcolvar DictDataset / DataLoader options (e.g. batch_size, shuffle). 

24 

25 Examples: 

26 This example shows how to use the EvaluateEncoder class to evaluate a PyTorch autoencoder model:: 

27 

28 from biobb_pytorch.mdae.evaluate_model import evaluateEncoder 

29 

30 input_model_pth_path='input_model.pth' 

31 input_dataset_pt_path='input_dataset.npy' 

32 output_results_npz_path='output_results.npz' 

33 

34 prop={ 

35 'Dataset': { 

36 'batch_size': 32 

37 } 

38 } 

39 

40 evaluateEncoder(input_model_pth_path=input_model.pth, 

41 input_dataset_pt_path=input_dataset.npy, 

42 output_results_npz_path=output_results.npz, 

43 properties=prop) 

44 

45 Info: 

46 * wrapped_software: 

47 * name: PyTorch 

48 * version: >=1.6.0 

49 * license: BSD 3-Clause 

50 * ontology: 

51 * name: EDAM 

52 * schema: http://edamontology.org/EDAM.owl 

53 """ 

54 

55 def __init__( 

56 self, 

57 input_model_pth_path: str, 

58 input_dataset_pt_path: str, 

59 output_results_npz_path: str, 

60 properties: dict, 

61 **kwargs, 

62 ) -> None: 

63 

64 properties = properties or {} 

65 

66 super().__init__(properties) 

67 

68 self.input_model_pth_path = input_model_pth_path 

69 self.input_dataset_pt_path = input_dataset_pt_path 

70 self.output_results_npz_path = output_results_npz_path 

71 self.properties = properties.copy() 

72 self.locals_var_dict = locals().copy() 

73 

74 # Input/Output files 

75 self.io_dict = { 

76 "in": { 

77 "input_model_pth_path": input_model_pth_path, 

78 "input_dataset_pt_path": input_dataset_pt_path, 

79 }, 

80 "out": { 

81 "output_results_npz_path": output_results_npz_path, 

82 }, 

83 } 

84 

85 self.Dataset = self.properties.get('Dataset', {}) 

86 self.results = None 

87 

88 # Check the properties 

89 self.check_properties(properties) 

90 self.check_arguments() 

91 

92 def load_model(self): 

93 return torch.load(self.io_dict["in"]["input_model_pth_path"], 

94 weights_only=False) 

95 

96 def load_dataset(self): 

97 dataset = torch.load(self.io_dict["in"]["input_dataset_pt_path"], 

98 weights_only=False) 

99 return DictDataset(dataset) 

100 

101 def create_dataloader(self, dataset): 

102 ds_cfg = self.properties['Dataset'] 

103 return DataLoader( 

104 dataset, 

105 batch_size=ds_cfg.get('batch_size', 16), 

106 shuffle=ds_cfg.get('shuffle', False), 

107 ) 

108 

109 def evaluate_encoder(self, model, dataloader): 

110 """Evaluate the encoder part of the model.""" 

111 model.eval() 

112 with torch.no_grad(): 

113 z_all = [] 

114 for batch in dataloader: 

115 z = model.forward_cv(batch['data']) 

116 z_all.append(z) 

117 return {"z": torch.cat(z_all, dim=0)} 

118 

119 @launchlogger 

120 def launch(self) -> int: 

121 """ 

122 Execute the :class:`EvaluateEncoder` class and its `.launch()` method. 

123 """ 

124 

125 fu.log('## BioBB Model Evaluator ##', self.out_log) 

126 

127 # Setup Biobb 

128 if self.check_restart(): 

129 return 0 

130 

131 self.stage_files() 

132 

133 # Start Pipeline 

134 

135 # load the model 

136 fu.log(f'Load model from {os.path.abspath(self.io_dict["in"]["input_model_pth_path"])}', self.out_log) 

137 model = self.load_model() 

138 

139 # load the dataset 

140 fu.log(f'Load dataset from {os.path.abspath(self.io_dict["in"]["input_dataset_pt_path"])}', self.out_log) 

141 dataset = self.load_dataset() 

142 

143 # create the dataloader 

144 fu.log('Start evaluating...', self.out_log) 

145 dataloader = self.create_dataloader(dataset) 

146 

147 # evaluate the model 

148 results = self.evaluate_encoder(model, dataloader) 

149 

150 # Save the results 

151 np.savez_compressed(self.io_dict["out"]["output_results_npz_path"], **results) 

152 fu.log(f'Evaluation Results saved to {os.path.abspath(self.io_dict["out"]["output_results_npz_path"])}', self.out_log) 

153 fu.log(f'File size: {get_size(self.io_dict["out"]["output_results_npz_path"])}', self.out_log) 

154 

155 # Copy files to host 

156 self.copy_to_host() 

157 

158 # Remove temporal files 

159 self.remove_tmp_files() 

160 

161 self.check_arguments(output_files_created=True, raise_exception=False) 

162 

163 return 0 

164 

165 

166def evaluateEncoder( 

167 properties: dict, 

168 input_model_pth_path: str, 

169 input_dataset_pt_path: str, 

170 output_results_npz_path: str, 

171 **kwargs, 

172) -> int: 

173 """Create the :class:`EvaluateEncoder <EvaluateEncoder.EvaluateEncoder>` class and 

174 execute the :meth:`launch() <EvaluateEncoder.evaluateEncoder.launch>` method.""" 

175 return EvaluateEncoder(**dict(locals())).launch() 

176 

177 

178evaluateEncoder.__doc__ = EvaluateEncoder.__doc__ 

179main = EvaluateEncoder.get_main(evaluateEncoder, "Encode data with a Molecular Dynamics AutoEncoder (MDAE) model.") 

180 

181if __name__ == "__main__": 

182 main()