Coverage for biobb_pytorch / mdae / evaluate_model.py: 83%

94 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import os 

2import torch 

3from torch.utils.data import DataLoader 

4from biobb_common.tools.file_utils import launchlogger 

5from biobb_common.tools import file_utils as fu 

6from biobb_pytorch.mdae.utils.log_utils import get_size 

7from biobb_common.generic.biobb_object import BiobbObject 

8from mlcolvar.data import DictDataset 

9import numpy as np 

10 

11 

12class EvaluateModel(BiobbObject): 

13 """ 

14 | biobb_pytorch EvaluateModel 

15 | Evaluate a Molecular Dynamics AutoEncoder (MDAE) PyTorch model. 

16 | Evaluates a PyTorch autoencoder from the given properties. 

17 

18 Args: 

19 input_model_pth_path (str): Path to the trained model file. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333). 

20 input_dataset_pt_path (str): Path to the input dataset file (.pt) to evaluate on. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pt>`_. Accepted formats: pt (edam:format_2333). 

21 output_results_npz_path (str): Path to the output evaluation results file (compressed NumPy archive). File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_results.npz>`_. Accepted formats: npz (edam:format_2333). 

22 properties (dict - Python dictionary object containing the tool parameters, not input/output files): 

23 * **Dataset** (*dict*) - ({}) mlcolvar DictDataset / DataLoader options (e.g. batch_size, shuffle). 

24 

25 Examples: 

26 This example shows how to use the EvaluateModel class to evaluate a PyTorch autoencoder model:: 

27 

28 from biobb_pytorch.mdae.evaluate_model import evaluateModel 

29 

30 input_model_pth_path='input_model.pth' 

31 input_dataset_pt_path='input_dataset.pt' 

32 output_results_npz_path='output_results.npz' 

33 

34 prop={ 

35 'Dataset': { 

36 'batch_size': 32 

37 } 

38 } 

39 

40 evaluateModel(input_model_pth_path=input_model.pth, 

41 input_dataset_pt_path=input_dataset.pt, 

42 output_results_npz_path=output_results.npz, 

43 properties=prop) 

44 

45 Info: 

46 * wrapped_software: 

47 * name: PyTorch 

48 * version: >=1.6.0 

49 * license: BSD 3-Clause 

50 * ontology: 

51 * name: EDAM 

52 * schema: http://edamontology.org/EDAM.owl 

53 """ 

54 

55 def __init__( 

56 self, 

57 input_model_pth_path: str, 

58 input_dataset_pt_path: str, 

59 output_results_npz_path: str, 

60 properties: dict, 

61 **kwargs, 

62 ) -> None: 

63 

64 properties = properties or {} 

65 

66 super().__init__(properties) 

67 

68 self.input_model_pth_path = input_model_pth_path 

69 self.input_dataset_pt_path = input_dataset_pt_path 

70 self.output_results_npz_path = output_results_npz_path 

71 self.properties = properties.copy() 

72 self.locals_var_dict = locals().copy() 

73 

74 # Input/Output files 

75 self.io_dict = { 

76 "in": { 

77 "input_model_pth_path": input_model_pth_path, 

78 "input_dataset_pt_path": input_dataset_pt_path, 

79 }, 

80 "out": { 

81 "output_results_npz_path": output_results_npz_path, 

82 }, 

83 } 

84 

85 self.Dataset = self.properties.get('Dataset', {}) 

86 self.results = None 

87 

88 # Check the properties 

89 self.check_properties(properties) 

90 self.check_arguments() 

91 

92 def load_model(self): 

93 return torch.load(self.io_dict["in"]["input_model_pth_path"], 

94 weights_only=False) 

95 

96 def load_dataset(self): 

97 dataset = torch.load(self.io_dict["in"]["input_dataset_pt_path"], 

98 weights_only=False) 

99 return DictDataset(dataset) 

100 

101 def create_dataloader(self, dataset): 

102 ds_cfg = self.properties['Dataset'] 

103 return DataLoader( 

104 dataset, 

105 batch_size=ds_cfg.get('batch_size', 16), 

106 shuffle=ds_cfg.get('shuffle', False), 

107 ) 

108 

109 def evaluate_full_model(self, model, dataloader): 

110 """Evaluate the model on the data, computing average loss and collecting output variables.""" 

111 

112 output_variables = model.eval_variables 

113 all_results = [] 

114 all_losses = [] 

115 result_dict = {} 

116 

117 model.eval() 

118 with torch.no_grad(): 

119 for batch_idx, batch in enumerate(dataloader): 

120 result = model.evaluate_model(batch, batch_idx) 

121 # Note: Consider replacing with model.validation_step(batch, batch_idx) or 

122 # model.loss_fn(model(batch['data']), batch['data']) for eval-specific loss 

123 batch_loss = model.training_step(batch, batch_idx) 

124 all_results.append(result) 

125 all_losses.append(batch_loss.item()) # Use .item() to get scalar 

126 

127 # After all batches, collect per variable (assuming result is list/tuple of tensors) 

128 for i, var in enumerate(output_variables): 

129 var_results = [res[i] for res in all_results] 

130 result_dict[var] = torch.cat(var_results) if var_results else torch.tensor([]) # Concat if tensors 

131 

132 # Average loss (use np.mean for simplicity with list of scalars) 

133 avg_loss = np.mean(all_losses) if all_losses else 0.0 

134 

135 # Add to dictionary 

136 result_dict['eval_loss'] = avg_loss 

137 

138 # Optional: Convert tensors to NumPy arrays for saving to .npz 

139 for key in result_dict: 

140 if torch.is_tensor(result_dict[key]): 

141 result_dict[key] = result_dict[key].numpy() 

142 

143 return result_dict 

144 

145 def evaluate_encoder(self, model, dataloader): 

146 """Evaluate the encoder part of the model.""" 

147 model.eval() 

148 with torch.no_grad(): 

149 all_z = [] 

150 for batch in dataloader: 

151 z = model.forward_cv(batch['data']) 

152 all_z.append(z) 

153 return torch.cat(all_z, dim=0) if all_z else torch.tensor([]) 

154 

155 def evaluate_decoder(self, model, dataloader): 

156 """Evaluate the decoder part of the model.""" 

157 model.eval() 

158 with torch.no_grad(): 

159 all_reconstructions = [] 

160 for batch in dataloader: 

161 reconstructions = model.decode(batch['data']) 

162 all_reconstructions.append(reconstructions) 

163 return torch.cat(all_reconstructions, dim=0) if all_reconstructions else torch.tensor([]) 

164 

165 @launchlogger 

166 def launch(self) -> int: 

167 """ 

168 Execute the :class:`EvaluateModel` class and its `.launch()` method. 

169 """ 

170 

171 fu.log('## BioBB Model Evaluator ##', self.out_log) 

172 

173 # Setup Biobb 

174 if self.check_restart(): 

175 return 0 

176 

177 self.stage_files() 

178 

179 # Start Pipeline 

180 

181 # load the model 

182 fu.log(f'Load model from {os.path.abspath(self.io_dict["in"]["input_model_pth_path"])}', self.out_log) 

183 model = self.load_model() 

184 

185 # load the dataset 

186 fu.log(f'Load dataset from {os.path.abspath(self.io_dict["in"]["input_dataset_pt_path"])}', self.out_log) 

187 dataset = self.load_dataset() 

188 

189 # create the dataloader 

190 fu.log('Start evaluating...', self.out_log) 

191 dataloader = self.create_dataloader(dataset) 

192 

193 # evaluate the model 

194 results = self.evaluate_full_model(model, dataloader) 

195 

196 # Save the results 

197 np.savez_compressed(self.io_dict["out"]["output_results_npz_path"], **results) 

198 fu.log(f'Evaluation Results saved to {os.path.abspath(self.io_dict["out"]["output_results_npz_path"])}', self.out_log) 

199 fu.log(f'File size: {get_size(self.io_dict["out"]["output_results_npz_path"])}', self.out_log) 

200 

201 # Copy files to host 

202 self.copy_to_host() 

203 

204 # Remove temporal files 

205 self.remove_tmp_files() 

206 

207 self.check_arguments(output_files_created=True, raise_exception=False) 

208 

209 return 0 

210 

211 

212def evaluateModel( 

213 properties: dict, 

214 input_model_pth_path: str, 

215 input_dataset_pt_path: str, 

216 output_results_npz_path: str, 

217 **kwargs, 

218) -> int: 

219 """Create the :class:`EvaluateModel <EvaluateModel.EvaluateModel>` class and 

220 execute the :meth:`launch() <EvaluateModel.EvaluateModel.launch>` method.""" 

221 return EvaluateModel(**dict(locals())).launch() 

222 

223 

224evaluateModel.__doc__ = EvaluateModel.__doc__ 

225main = EvaluateModel.get_main(evaluateModel, "Evaluate a Molecular Dynamics AutoEncoder (MDAE) PyTorch model.") 

226 

227if __name__ == "__main__": 

228 main()