Coverage for biobb_pytorch/mdae/apply_mdae.py: 84%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-21 09:06 +0000

1#!/usr/bin/env python3 

2 

3"""Module containing the ApplyMDAE class and the command line interface.""" 

4 

5import argparse 

6import time 

7from typing import Optional 

8 

9import numpy as np 

10import torch 

11import torch.utils.data 

12from biobb_common.configuration import settings 

13from biobb_common.generic.biobb_object import BiobbObject 

14from biobb_common.tools import file_utils as fu 

15from biobb_common.tools.file_utils import launchlogger 

16 

17from biobb_pytorch.mdae.common import ( 

18 execute_model, 

19 format_time, 

20 human_readable_file_size, 

21 ndarray_denormalization, 

22 ndarray_normalization, 

23) 

24from biobb_pytorch.mdae.mdae import MDAE 

25 

26 

27class ApplyMDAE(BiobbObject): 

28 """ 

29 | biobb_pytorch ApplyMDAE 

30 | Apply a Molecular Dynamics AutoEncoder (MDAE) PyTorch model. 

31 | Apply a Molecular Dynamics AutoEncoder (MDAE) PyTorch model, the resulting denoised molecular dynamics or the reduced the dimensionality of molecular dynamics data can be used to analyze the dynamic properties of the system. 

32 

33 Args: 

34 input_data_npy_path (str): Path to the input data file. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/data/mdae/train_mdae_traj.npy>`_. Accepted formats: npy (edam:format_4003). 

35 input_model_pth_path (str): Path to the input model file. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_output_model.pth>`_. Accepted formats: pth (edam:format_2333). 

36 output_reconstructed_data_npy_path (str): Path to the output reconstructed data file. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_output_reconstructed_data.npy>`_. Accepted formats: npy (edam:format_4003). 

37 output_latent_space_npy_path (str) (Optional): Path to the reduced dimensionality file. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_output_latent_space.npy>`_. Accepted formats: npy (edam:format_4003). 

38 properties (dict - Python dictionary object containing the tool parameters, not input/output files): 

39 * **batch_size** (*int*) - (1) number of samples/frames per batch. 

40 * **latent_dimensions** (*int*) - (2) min dimensionality of the latent space. 

41 * **num_layers** (*int*) - (4) number of layers in the encoder/decoder (4 to encode and 4 to decode). 

42 * **input_dimensions** (*int*) - (None) input dimensions by default it should be the number of features in the input data (number of atoms * 3 corresponding to x, y, z coordinates). 

43 * **output_dimensions** (*int*) - (None) output dimensions by default it should be the number of features in the input data (number of atoms * 3 corresponding to x, y, z coordinates). 

44 

45 Examples: 

46 This is a use case of how to use the building block from Python:: 

47 

48 from biobb_pytorch.mdae.apply_mdae import ApplyMDAE 

49 prop = { 

50 'latent_dimensions': 2, 

51 'num_layers': 4 

52 } 

53 ApplyMDAE(input_data_npy_path='/path/to/myInputData.npy', 

54 output_reconstructed_data_npy_path='/path/to/newReconstructedData.npz', 

55 input_model_pth_path='/path/to/oldModel.pth', 

56 properties=prop).launch() 

57 

58 Info: 

59 * wrapped_software: 

60 * name: PyTorch 

61 * version: >=1.6.0 

62 * license: BSD 3-Clause 

63 * ontology: 

64 * name: EDAM 

65 * schema: http://edamontology.org/EDAM.owl 

66 """ 

67 

68 def __init__( 

69 self, 

70 input_data_npy_path: str, 

71 input_model_pth_path: str, 

72 output_reconstructed_data_npy_path: str, 

73 output_latent_space_npy_path: Optional[str] = None, 

74 properties: Optional[dict] = None, 

75 **kwargs, 

76 ) -> None: 

77 properties = properties or {} 

78 

79 # Call parent class constructor 

80 super().__init__(properties) 

81 self.locals_var_dict = locals().copy() 

82 

83 # Input/Output files 

84 self.io_dict = { 

85 "in": { 

86 "input_data_npy_path": input_data_npy_path, 

87 "input_model_pth_path": input_model_pth_path, 

88 }, 

89 "out": { 

90 "output_reconstructed_data_npy_path": output_reconstructed_data_npy_path, 

91 "output_latent_space_npy_path": output_latent_space_npy_path, 

92 }, 

93 } 

94 

95 # Properties specific for BB 

96 self.batch_size: int = int( 

97 properties.get("batch_size", 1) 

98 ) # number of samples/frames per batch 

99 self.latent_dimensions: int = int( 

100 properties.get("latent_dimensions", 2) 

101 ) # min dimensionality of the latent space 

102 self.num_layers: int = int( 

103 properties.get("num_layers", 4) 

104 ) # number of layers in the encoder/decoder (4 to encode and 4 to decode) 

105 

106 # Input data section 

107 input_raw_data = np.load(self.io_dict["in"]["input_data_npy_path"]) 

108 # Reshape the input data to be a 2D array and normalization 

109 input_reshaped_data: np.ndarray = np.reshape( 

110 input_raw_data, 

111 (len(input_raw_data), input_raw_data.shape[1] * input_raw_data.shape[2]), 

112 ) 

113 # Normalization of the input data 

114 self.input_data_max_values: np.ndarray = np.max(input_reshaped_data, axis=0) 

115 self.input_data_min_values: np.ndarray = np.min(input_reshaped_data, axis=0) 

116 input_data: np.ndarray = ndarray_normalization( 

117 input_reshaped_data, 

118 max_values=self.input_data_max_values, 

119 min_values=self.input_data_min_values, 

120 ) 

121 self.input_dimensions: int = ( 

122 int(properties["input_dimensions"]) 

123 if properties.get("input_dimensions") 

124 else input_data.shape[1] 

125 ) # input dimensions by default it should be the number of features in the input data (number of atoms * 3 corresponding to x, y, z coordinates) 

126 self.output_dimensions: int = ( 

127 int(properties["output_dimensions"]) 

128 if properties.get("output_dimensions") 

129 else self.input_dimensions 

130 ) # output dimensions by default it should be the number of features in the input data (number of atoms * 3 corresponding to x, y, z coordinates) 

131 

132 # Check the properties 

133 self.check_properties(properties) 

134 self.check_arguments() 

135 

136 data_tensor = torch.FloatTensor(input_data) 

137 tensor_dataset = torch.utils.data.TensorDataset(data_tensor) 

138 self.data_loader = torch.utils.data.DataLoader( 

139 tensor_dataset, batch_size=self.batch_size, shuffle=False 

140 ) 

141 self.model = MDAE( 

142 input_dimensions=self.input_dimensions, 

143 num_layers=self.num_layers, 

144 latent_dimensions=self.latent_dimensions, 

145 ) 

146 self.model.load_state_dict( 

147 torch.load( 

148 self.io_dict["in"]["input_model_pth_path"], 

149 map_location=self.model.device, 

150 ) 

151 ) 

152 

153 @launchlogger 

154 def launch(self) -> int: 

155 """Execute the :class:`ApplyMDAE <mdae.apply_mdae.ApplyMDAE>` object.""" 

156 

157 # Setup Biobb 

158 if self.check_restart(): 

159 return 0 

160 

161 self.stage_files() 

162 

163 fu.log( 

164 f"Applying MDAE model reducing dimensionality from {self.input_dimensions} to {self.latent_dimensions} and reconstructing.", 

165 self.out_log, 

166 ) 

167 latent_space, reconstructed_data = self.apply_model(self.data_loader) 

168 denormalized_reconstructed_data = ndarray_denormalization( 

169 reconstructed_data, self.input_data_max_values, self.input_data_min_values 

170 ) 

171 reshaped_reconstructed_data = np.reshape( 

172 denormalized_reconstructed_data, 

173 (len(denormalized_reconstructed_data), -1, 3), 

174 ) 

175 np.save( 

176 self.stage_io_dict["out"]["output_reconstructed_data_npy_path"], 

177 np.array(reshaped_reconstructed_data), 

178 ) 

179 fu.log( 

180 f'Saving reconstructed data to: {self.stage_io_dict["out"]["output_reconstructed_data_npy_path"]}', 

181 self.out_log, 

182 ) 

183 fu.log( 

184 f' File size: {human_readable_file_size(self.stage_io_dict["out"]["output_reconstructed_data_npy_path"])}', 

185 self.out_log, 

186 ) 

187 

188 if self.stage_io_dict["out"].get("output_latent_space_npy_path"): 

189 np.save( 

190 self.stage_io_dict["out"]["output_latent_space_npy_path"], 

191 np.array(latent_space), 

192 ) 

193 fu.log( 

194 f'Saving latent space to: {self.stage_io_dict["out"]["output_latent_space_npy_path"]}', 

195 self.out_log, 

196 ) 

197 fu.log( 

198 f' File size: {human_readable_file_size(self.stage_io_dict["out"]["output_latent_space_npy_path"])}', 

199 self.out_log, 

200 ) 

201 

202 # Copy files to host 

203 self.copy_to_host() 

204 

205 # Remove temporal files 

206 self.remove_tmp_files() 

207 

208 self.check_arguments(output_files_created=True, raise_exception=False) 

209 return 0 

210 

211 def apply_model( 

212 self, dataloader: torch.utils.data.DataLoader 

213 ) -> tuple[np.ndarray, np.ndarray]: 

214 self.model.to(self.model.device) 

215 start_time: float = time.time() 

216 fu.log("Applying model:", self.out_log) 

217 fu.log(f" Device: {self.model.device}", self.out_log) 

218 fu.log( 

219 f" Input file: {self.stage_io_dict['in']['input_data_npy_path']}", 

220 self.out_log, 

221 ) 

222 fu.log( 

223 f" File size: {human_readable_file_size(self.stage_io_dict['in']['input_data_npy_path'])}", 

224 self.out_log, 

225 ) 

226 fu.log( 

227 f" Number of atoms: {int(len(next(iter(dataloader))[0][0])/3)}", 

228 self.out_log, 

229 ) 

230 fu.log( 

231 f" Number of frames: {int(len(dataloader) * (dataloader.batch_size or 1))}", 

232 self.out_log, 

233 ) # type: ignore 

234 fu.log(f" Batch size: {self.batch_size}", self.out_log) 

235 fu.log(f" Number of layers: {self.num_layers}", self.out_log) 

236 fu.log(f" Input dimensions: {self.input_dimensions}", self.out_log) 

237 fu.log(f" Latent dimensions: {self.latent_dimensions}", self.out_log) 

238 

239 execution_tuple = execute_model( 

240 self.model, dataloader, self.input_dimensions, self.latent_dimensions 

241 )[1:] 

242 

243 fu.log( 

244 f" Execution time: {format_time(time.time() - start_time)}", self.out_log 

245 ) 

246 return execution_tuple 

247 

248 

249def applyMDAE( 

250 input_data_npy_path: str, 

251 input_model_pth_path: str, 

252 output_reconstructed_data_npy_path: str, 

253 output_latent_space_npy_path: Optional[str] = None, 

254 properties: Optional[dict] = None, 

255 **kwargs, 

256) -> int: 

257 """Execute the :class:`ApplyMDAE <mdae.apply_mdae.ApplyMDAE>` class and 

258 execute the :meth:`launch() <mdae.apply_mdae.ApplyMDAE.launch>` method.""" 

259 

260 return ApplyMDAE( 

261 input_data_npy_path=input_data_npy_path, 

262 input_model_pth_path=input_model_pth_path, 

263 output_reconstructed_data_npy_path=output_reconstructed_data_npy_path, 

264 output_latent_space_npy_path=output_latent_space_npy_path, 

265 properties=properties, 

266 **kwargs, 

267 ).launch() 

268 

269 

270def main(): 

271 """Command line execution of this building block. Please check the command line documentation.""" 

272 parser = argparse.ArgumentParser( 

273 description="Apply a Molecular Dynamics AutoEncoder (MDAE) PyTorch model.", 

274 formatter_class=lambda prog: argparse.RawTextHelpFormatter(prog, width=99999), 

275 ) 

276 parser.add_argument( 

277 "-c", 

278 "--config", 

279 required=False, 

280 help="This file can be a YAML file, JSON file or JSON string", 

281 ) 

282 

283 # Specific args of each building block 

284 required_args = parser.add_argument_group("required arguments") 

285 

286 required_args.add_argument( 

287 "--input_data_npy_path", required=True, help="Path to the input data file." 

288 ) 

289 required_args.add_argument( 

290 "--input_model_pth_path", required=True, help="Path to the input model file." 

291 ) 

292 required_args.add_argument( 

293 "--output_reconstructed_data_npy_path", 

294 required=True, 

295 help="Path to the output reconstructed data file.", 

296 ) 

297 parser.add_argument( 

298 "--output_latent_space_npy_path", 

299 required=False, 

300 help="Path to the reduced dimensionality file.", 

301 ) 

302 parser.add_argument( 

303 "--properties", 

304 required=False, 

305 help="Additional properties for the MDAE object.", 

306 ) 

307 args = parser.parse_args() 

308 config = args.config if args.config else None 

309 properties = settings.ConfReader(config=config).get_prop_dic() 

310 

311 applyMDAE( 

312 input_data_npy_path=args.input_data_npy_path, 

313 input_model_pth_path=args.input_model_pth_path, 

314 output_reconstructed_data_npy_path=args.output_reconstructed_data_npy_path, 

315 output_latent_space_npy_path=args.output_latent_space_npy_path, 

316 properties=properties, 

317 ) 

318 

319 

320if __name__ == "__main__": 

321 main()