Coverage for biobb_pytorch / mdae / evaluate_model.py: 83%
94 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import os
2import torch
3from torch.utils.data import DataLoader
4from biobb_common.tools.file_utils import launchlogger
5from biobb_common.tools import file_utils as fu
6from biobb_pytorch.mdae.utils.log_utils import get_size
7from biobb_common.generic.biobb_object import BiobbObject
8from mlcolvar.data import DictDataset
9import numpy as np
12class EvaluateModel(BiobbObject):
13 """
14 | biobb_pytorch EvaluateModel
15 | Evaluate a Molecular Dynamics AutoEncoder (MDAE) PyTorch model.
16 | Evaluates a PyTorch autoencoder from the given properties.
18 Args:
19 input_model_pth_path (str): Path to the trained model file. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333).
20 input_dataset_pt_path (str): Path to the input dataset file (.pt) to evaluate on. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pt>`_. Accepted formats: pt (edam:format_2333).
21 output_results_npz_path (str): Path to the output evaluation results file (compressed NumPy archive). File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_results.npz>`_. Accepted formats: npz (edam:format_2333).
22 properties (dict - Python dictionary object containing the tool parameters, not input/output files):
23 * **Dataset** (*dict*) - ({}) mlcolvar DictDataset / DataLoader options (e.g. batch_size, shuffle).
25 Examples:
26 This example shows how to use the EvaluateModel class to evaluate a PyTorch autoencoder model::
28 from biobb_pytorch.mdae.evaluate_model import evaluateModel
30 input_model_pth_path='input_model.pth'
31 input_dataset_pt_path='input_dataset.pt'
32 output_results_npz_path='output_results.npz'
34 prop={
35 'Dataset': {
36 'batch_size': 32
37 }
38 }
40 evaluateModel(input_model_pth_path=input_model.pth,
41 input_dataset_pt_path=input_dataset.pt,
42 output_results_npz_path=output_results.npz,
43 properties=prop)
45 Info:
46 * wrapped_software:
47 * name: PyTorch
48 * version: >=1.6.0
49 * license: BSD 3-Clause
50 * ontology:
51 * name: EDAM
52 * schema: http://edamontology.org/EDAM.owl
53 """
55 def __init__(
56 self,
57 input_model_pth_path: str,
58 input_dataset_pt_path: str,
59 output_results_npz_path: str,
60 properties: dict,
61 **kwargs,
62 ) -> None:
64 properties = properties or {}
66 super().__init__(properties)
68 self.input_model_pth_path = input_model_pth_path
69 self.input_dataset_pt_path = input_dataset_pt_path
70 self.output_results_npz_path = output_results_npz_path
71 self.properties = properties.copy()
72 self.locals_var_dict = locals().copy()
74 # Input/Output files
75 self.io_dict = {
76 "in": {
77 "input_model_pth_path": input_model_pth_path,
78 "input_dataset_pt_path": input_dataset_pt_path,
79 },
80 "out": {
81 "output_results_npz_path": output_results_npz_path,
82 },
83 }
85 self.Dataset = self.properties.get('Dataset', {})
86 self.results = None
88 # Check the properties
89 self.check_properties(properties)
90 self.check_arguments()
92 def load_model(self):
93 return torch.load(self.io_dict["in"]["input_model_pth_path"],
94 weights_only=False)
96 def load_dataset(self):
97 dataset = torch.load(self.io_dict["in"]["input_dataset_pt_path"],
98 weights_only=False)
99 return DictDataset(dataset)
101 def create_dataloader(self, dataset):
102 ds_cfg = self.properties['Dataset']
103 return DataLoader(
104 dataset,
105 batch_size=ds_cfg.get('batch_size', 16),
106 shuffle=ds_cfg.get('shuffle', False),
107 )
109 def evaluate_full_model(self, model, dataloader):
110 """Evaluate the model on the data, computing average loss and collecting output variables."""
112 output_variables = model.eval_variables
113 all_results = []
114 all_losses = []
115 result_dict = {}
117 model.eval()
118 with torch.no_grad():
119 for batch_idx, batch in enumerate(dataloader):
120 result = model.evaluate_model(batch, batch_idx)
121 # Note: Consider replacing with model.validation_step(batch, batch_idx) or
122 # model.loss_fn(model(batch['data']), batch['data']) for eval-specific loss
123 batch_loss = model.training_step(batch, batch_idx)
124 all_results.append(result)
125 all_losses.append(batch_loss.item()) # Use .item() to get scalar
127 # After all batches, collect per variable (assuming result is list/tuple of tensors)
128 for i, var in enumerate(output_variables):
129 var_results = [res[i] for res in all_results]
130 result_dict[var] = torch.cat(var_results) if var_results else torch.tensor([]) # Concat if tensors
132 # Average loss (use np.mean for simplicity with list of scalars)
133 avg_loss = np.mean(all_losses) if all_losses else 0.0
135 # Add to dictionary
136 result_dict['eval_loss'] = avg_loss
138 # Optional: Convert tensors to NumPy arrays for saving to .npz
139 for key in result_dict:
140 if torch.is_tensor(result_dict[key]):
141 result_dict[key] = result_dict[key].numpy()
143 return result_dict
145 def evaluate_encoder(self, model, dataloader):
146 """Evaluate the encoder part of the model."""
147 model.eval()
148 with torch.no_grad():
149 all_z = []
150 for batch in dataloader:
151 z = model.forward_cv(batch['data'])
152 all_z.append(z)
153 return torch.cat(all_z, dim=0) if all_z else torch.tensor([])
155 def evaluate_decoder(self, model, dataloader):
156 """Evaluate the decoder part of the model."""
157 model.eval()
158 with torch.no_grad():
159 all_reconstructions = []
160 for batch in dataloader:
161 reconstructions = model.decode(batch['data'])
162 all_reconstructions.append(reconstructions)
163 return torch.cat(all_reconstructions, dim=0) if all_reconstructions else torch.tensor([])
165 @launchlogger
166 def launch(self) -> int:
167 """
168 Execute the :class:`EvaluateModel` class and its `.launch()` method.
169 """
171 fu.log('## BioBB Model Evaluator ##', self.out_log)
173 # Setup Biobb
174 if self.check_restart():
175 return 0
177 self.stage_files()
179 # Start Pipeline
181 # load the model
182 fu.log(f'Load model from {os.path.abspath(self.io_dict["in"]["input_model_pth_path"])}', self.out_log)
183 model = self.load_model()
185 # load the dataset
186 fu.log(f'Load dataset from {os.path.abspath(self.io_dict["in"]["input_dataset_pt_path"])}', self.out_log)
187 dataset = self.load_dataset()
189 # create the dataloader
190 fu.log('Start evaluating...', self.out_log)
191 dataloader = self.create_dataloader(dataset)
193 # evaluate the model
194 results = self.evaluate_full_model(model, dataloader)
196 # Save the results
197 np.savez_compressed(self.io_dict["out"]["output_results_npz_path"], **results)
198 fu.log(f'Evaluation Results saved to {os.path.abspath(self.io_dict["out"]["output_results_npz_path"])}', self.out_log)
199 fu.log(f'File size: {get_size(self.io_dict["out"]["output_results_npz_path"])}', self.out_log)
201 # Copy files to host
202 self.copy_to_host()
204 # Remove temporal files
205 self.remove_tmp_files()
207 self.check_arguments(output_files_created=True, raise_exception=False)
209 return 0
212def evaluateModel(
213 properties: dict,
214 input_model_pth_path: str,
215 input_dataset_pt_path: str,
216 output_results_npz_path: str,
217 **kwargs,
218) -> int:
219 """Create the :class:`EvaluateModel <EvaluateModel.EvaluateModel>` class and
220 execute the :meth:`launch() <EvaluateModel.EvaluateModel.launch>` method."""
221 return EvaluateModel(**dict(locals())).launch()
224evaluateModel.__doc__ = EvaluateModel.__doc__
225main = EvaluateModel.get_main(evaluateModel, "Evaluate a Molecular Dynamics AutoEncoder (MDAE) PyTorch model.")
227if __name__ == "__main__":
228 main()