Coverage for biobb_pytorch / mdae / decode_model.py: 97%
64 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import torch
2from torch.utils.data import DataLoader
3import os
4from biobb_common.tools.file_utils import launchlogger
5from biobb_common.tools import file_utils as fu
6from biobb_pytorch.mdae.utils.log_utils import get_size
7from biobb_common.generic.biobb_object import BiobbObject
8import numpy as np
11class EvaluateDecoder(BiobbObject):
12 """
13 | biobb_pytorch evaluateDecoder
14 | Evaluates a PyTorch autoencoder from the given properties.
15 | Evaluates a PyTorch autoencoder from the given properties.
17 Args:
18 input_model_pth_path (str): Path to the trained model file whose decoder will be used. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333).
19 input_dataset_npy_path (str): Path to the input latent variables file in NumPy format (e.g. encoded 'z'). File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.npy>`_. Accepted formats: npy (edam:format_2333).
20 output_results_npz_path (str): Path to the output reconstructed data file (compressed NumPy archive, typically containing 'xhat'). File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_results.npz>`_. Accepted formats: npz (edam:format_2333).
21 properties (dict - Python dictionary object containing the tool parameters, not input/output files):
22 * **Dataset** (*dict*) - ({}) DataLoader options (e.g. batch_size, shuffle) for batching the latent variables.
24 Examples:
25 This example shows how to use the EvaluateDecoder class to evaluate a PyTorch autoencoder model::
27 from biobb_pytorch.mdae.decode_model import evaluateDecoder
29 input_model_pth_path='input_model.pth'
30 input_dataset_npy_path='input_dataset.npy'
31 output_results_npz_path='output_results.npz'
33 prop={
34 'Dataset': {
35 'batch_size': 32
36 }
37 }
39 evaluateDecoder(input_model_pth_path=input_model.pth,
40 input_dataset_npy_path=input_dataset.npy,
41 output_results_npz_path=output_results.npz,
42 properties=prop)
44 Info:
45 * wrapped_software:
46 * name: PyTorch
47 * version: >=1.6.0
48 * license: BSD 3-Clause
49 * ontology:
50 * name: EDAM
51 * schema: http://edamontology.org/EDAM.owl
52 """
54 def __init__(
55 self,
56 input_model_pth_path: str,
57 input_dataset_npy_path: str,
58 output_results_npz_path: str,
59 properties: dict,
60 **kwargs,
61 ) -> None:
63 properties = properties or {}
65 super().__init__(properties)
67 self.input_model_pth_path = input_model_pth_path
68 self.input_dataset_npy_path = input_dataset_npy_path
69 self.output_results_npz_path = output_results_npz_path
70 self.properties = properties.copy()
71 self.locals_var_dict = locals().copy()
73 # Input/Output files
74 self.io_dict = {
75 "in": {
76 "input_model_pth_path": input_model_pth_path,
77 "input_dataset_npy_path": input_dataset_npy_path,
78 },
79 "out": {
80 "output_results_npz_path": output_results_npz_path,
81 },
82 }
84 self.Dataset = self.properties.get('Dataset', {})
85 self.results = None
87 # Check the properties
88 self.check_properties(properties)
89 self.check_arguments()
91 def load_model(self):
92 return torch.load(self.io_dict["in"]["input_model_pth_path"],
93 weights_only=False)
95 def load_dataset(self):
96 dataset = torch.tensor(np.load(self.io_dict["in"]["input_dataset_npy_path"]))
97 return dataset.float()
99 def create_dataloader(self, dataset):
100 ds_cfg = self.properties['Dataset']
101 return DataLoader(
102 dataset,
103 batch_size=ds_cfg.get('batch_size', 16),
104 shuffle=ds_cfg.get('shuffle', False),
105 )
107 def evaluate_decoder(self, model, dataloader):
108 """Evaluate the decoder part of the model."""
109 model.eval()
110 with torch.no_grad():
111 all_reconstructions = []
112 for batch in dataloader:
113 z = model.decode(batch)
114 all_reconstructions.append(z)
115 return {"xhat": torch.cat(all_reconstructions, dim=0)}
117 @launchlogger
118 def launch(self) -> int:
119 """
120 Execute the :class:`EvaluateDecoder` class and its `.launch()` method.
121 """
123 fu.log('## BioBB Model Evaluator ##', self.out_log)
125 # Setup Biobb
126 if self.check_restart():
127 return 0
129 self.stage_files()
131 # Start Pipeline
133 # load the model
134 fu.log(f'Load model from {os.path.abspath(self.io_dict["in"]["input_model_pth_path"])}', self.out_log)
135 model = self.load_model()
137 # load the dataset
138 fu.log(f'Load dataset from {os.path.abspath(self.io_dict["in"]["input_dataset_npy_path"])}', self.out_log)
139 dataset = self.load_dataset()
141 # create the dataloader
142 fu.log('Start evaluating...', self.out_log)
143 dataloader = self.create_dataloader(dataset)
145 # evaluate the model
146 results = self.evaluate_decoder(model, dataloader)
148 # Save the results
149 np.savez_compressed(self.io_dict["out"]["output_results_npz_path"], **results)
150 fu.log(f'Evaluation Results saved to {os.path.abspath(self.io_dict["out"]["output_results_npz_path"])}', self.out_log)
151 fu.log(f'File size: {get_size(self.io_dict["out"]["output_results_npz_path"])}', self.out_log)
153 # Copy files to host
154 self.copy_to_host()
156 # Remove temporal files
157 self.remove_tmp_files()
159 self.check_arguments(output_files_created=True, raise_exception=False)
161 return 0
164def evaluateDecoder(
165 properties: dict,
166 input_model_pth_path: str,
167 input_dataset_npy_path: str,
168 output_results_npz_path: str,
169 **kwargs,
170) -> int:
171 """Create the :class:`EvaluateDecoder <EvaluateDecoder.EvaluateDecoder>` class and
172 execute the :meth:`launch() <EvaluateDecoder.evaluateDecoder.launch>` method."""
173 return EvaluateDecoder(**dict(locals())).launch()
176evaluateDecoder.__doc__ = EvaluateDecoder.__doc__
177main = EvaluateDecoder.get_main(evaluateDecoder, "Evaluates a PyTorch autoencoder from the given properties.")
179if __name__ == "__main__":
180 main()