Coverage for biobb_pytorch/mdae/train_mdae.py: 90%

188 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-21 09:06 +0000

1#!/usr/bin/env python3 

2 

3"""Module containing the TrainMDAE class and the command line interface.""" 

4 

5import argparse 

6import time 

7from pathlib import Path 

8from typing import Optional 

9 

10import numpy as np 

11import torch 

12import torch.utils.data 

13from biobb_common.configuration import settings 

14from biobb_common.generic.biobb_object import BiobbObject 

15from biobb_common.tools import file_utils as fu 

16from biobb_common.tools.file_utils import launchlogger 

17from torch.optim.adam import Adam 

18from torch.optim.optimizer import Optimizer 

19 

20from biobb_pytorch.mdae.common import ( 

21 execute_model, 

22 format_time, 

23 get_loss_function, 

24 get_optimizer_function, 

25 human_readable_file_size, 

26 ndarray_denormalization, 

27 ndarray_normalization, 

28) 

29from biobb_pytorch.mdae.mdae import MDAE 

30 

31 

32class TrainMDAE(BiobbObject): 

33 """ 

34 | biobb_pytorch TrainMDAE 

35 | Train a Molecular Dynamics AutoEncoder (MDAE) PyTorch model. 

36 | Train a Molecular Dynamics AutoEncoder (MDAE) PyTorch model, the resulting Auto-associative Neural Network (AANN) can be applied to reduce the dimensionality of molecular dynamics data and analyze the dynamic properties of the system. 

37 

38 Args: 

39 input_train_npy_path (str): Path to the input train data file. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/data/mdae/train_mdae_traj.npy>`_. Accepted formats: npy (edam:format_4003). 

40 output_model_pth_path (str): Path to the output model file. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_output_model.pth>`_. Accepted formats: pth (edam:format_2333). 

41 input_model_pth_path (str) (Optional): Path to the input model file. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_output_model.pth>`_. Accepted formats: pth (edam:format_2333). 

42 output_train_data_npz_path (str) (Optional): Path to the output train data file. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_output_train_data.npz>`_. Accepted formats: npz (edam:format_4003). 

43 output_performance_npz_path (str) (Optional): Path to the output performance file. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_output_performance.npz>`_. Accepted formats: npz (edam:format_4003). 

44 properties (dict - Python dictionary object containing the tool parameters, not input/output files): 

45 * **latent_dimensions** (*int*) - (2) min dimensionality of the latent space. 

46 * **num_layers** (*int*) - (4) number of layers in the encoder/decoder (4 to encode and 4 to decode). 

47 * **num_epochs** (*int*) - (100) number of epochs (iterations of whole dataset) for training. 

48 * **lr** (*float*) - (0.0001) learning rate. 

49 * **lr_step_size** (*int*) - (100) Period of learning rate decay. 

50 * **gamma** (*float*) - (0.1) Multiplicative factor of learning rate decay. 

51 * **checkpoint_interval** (*int*) - (25) number of epochs interval to save model checkpoints o 0 to disable. 

52 * **output_checkpoint_prefix** (*str*) - ("checkpoint_epoch") prefix for the checkpoint files. 

53 * **partition** (*float*) - (0.8) 0.8 = 80% partition of the data for training and validation. 

54 * **batch_size** (*int*) - (1) number of samples/frames per batch. 

55 * **log_interval** (*int*) - (10) number of epochs interval to log the training progress. 

56 * **input_dimensions** (*int*) - (None) input dimensions by default it should be the number of features in the input data (number of atoms * 3 corresponding to x, y, z coordinates). 

57 * **output_dimensions** (*int*) - (None) output dimensions by default it should be the number of features in the input data (number of atoms * 3 corresponding to x, y, z coordinates). 

58 * **loss_function** (*str*) - ("MSELoss") Loss function to be used. Values: MSELoss, L1Loss, SmoothL1Loss, BCELoss, BCEWithLogitsLoss, CrossEntropyLoss, CTCLoss, NLLLoss, KLDivLoss, PoissonNLLLoss, NLLLoss2d, CosineEmbeddingLoss, HingeEmbeddingLoss, MarginRankingLoss, MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, TripletMarginLoss, HuberLoss, SoftMarginLoss, MultiLabelSoftMarginLoss, CosineEmbeddingLoss, MultiMarginLoss, TripletMarginLoss, MarginRankingLoss, HingeEmbeddingLoss, CTCLoss, NLLLoss, PoissonNLLLoss, KLDivLoss, CrossEntropyLoss, BCEWithLogitsLoss, BCELoss, SmoothL1Loss, L1Loss, MSELoss. 

59 * **optimizer** (*str*) - ("Adam") Optimizer algorithm to be used. Values: Adadelta, Adagrad, Adam, AdamW, SparseAdam, Adamax, ASGD, LBFGS, RMSprop, Rprop, SGD. 

60 * **seed** (*int*) - (None) Random seed for reproducibility. 

61 

62 Examples: 

63 This is a use case of how to use the building block from Python:: 

64 

65 from biobb_pytorch.mdae.train_mdae import trainMDAE 

66 

67 prop = { 

68 'latent_dimensions': 2, 

69 'num_layers': 4, 

70 'num_epochs': 100, 

71 'lr': 0.0001, 

72 'checkpoint_interval': 25, 

73 'partition': 0.8, 

74 'batch_size': 1, 

75 'log_interval': 10, 

76 'input_dimensions': 3, 

77 'output_dimensions': 3, 

78 'loss_function': 'MSELoss', 

79 'optimizer': 'Adam' 

80 } 

81 

82 trainMDAE(input_train_npy_path='/path/to/myInputData.npy', 

83 output_model_pth_path='/path/to/newModel.pth', 

84 input_model_pth_path='/path/to/oldModel.pth', 

85 output_train_data_npz_path='/path/to/newTrainData.npz', 

86 output_performance_npz_path='/path/to/newPerformance.npz', 

87 properties=prop) 

88 

89 Info: 

90 * wrapped_software: 

91 * name: PyTorch 

92 * version: >=1.6.0 

93 * license: BSD 3-Clause 

94 * ontology: 

95 * name: EDAM 

96 * schema: http://edamontology.org/EDAM.owl 

97 """ 

98 

99 def __init__( 

100 self, 

101 input_train_npy_path: str, 

102 output_model_pth_path: str, 

103 input_model_pth_path: Optional[str] = None, 

104 output_train_data_npz_path: Optional[ 

105 str 

106 ] = None, # npz of train_losses, valid_losses 

107 output_performance_npz_path: Optional[ 

108 str 

109 ] = None, # npz of evaluate_losses, latent_space, reconstructed_data 

110 properties: Optional[dict] = None, 

111 **kwargs, 

112 ) -> None: 

113 properties = properties or {} 

114 

115 # Call parent class constructor 

116 super().__init__(properties) 

117 self.locals_var_dict = locals().copy() 

118 

119 # Input/Output files 

120 self.io_dict = { 

121 "in": { 

122 "input_train_npy_path": input_train_npy_path, 

123 "input_model_pth_path": input_model_pth_path, 

124 }, 

125 "out": { 

126 "output_model_pth_path": output_model_pth_path, 

127 "output_train_data_npz_path": output_train_data_npz_path, 

128 "output_performance_npz_path": output_performance_npz_path, 

129 }, 

130 } 

131 

132 # Properties specific for BB 

133 self.latent_dimensions: int = int( 

134 properties.get("latent_dimensions", 2) 

135 ) # min dimensionality of the latent space 

136 self.num_layers: int = int( 

137 properties.get("num_layers", 4) 

138 ) # number of layers in the encoder/decoder (4 to encode and 4 to decode) 

139 self.num_epochs: int = int( 

140 properties.get("num_epochs", 100) 

141 ) # number of epochs (iterations of whole dataset) for training 

142 self.lr: float = float(properties.get("lr", 0.0001)) # learning rate 

143 self.lr_step_size: int = int( 

144 properties.get("lr_step_size", 100) 

145 ) # Period of learning rate decay 

146 self.gamma: float = float( 

147 properties.get("gamma", 0.1) 

148 ) # Multiplicative factor of learning rate decay 

149 self.checkpoint_interval: int = int( 

150 properties.get("checkpoint_interval", 25) 

151 ) # number of epochs interval to save model checkpoints o 0 to disable 

152 self.output_checkpoint_prefix: str = properties.get( 

153 "output_checkpoint_prefix", "checkpoint_epoch_" 

154 ) # prefix for the checkpoint files, 

155 self.partition: float = float( 

156 properties.get("partition", 0.8) 

157 ) # 0.8 = 80% partition of the data for training and validation 

158 self.seed: Optional[int] = ( 

159 int(properties.get("seed", "42")) if properties.get("seed", None) else None 

160 ) # Random seed for reproducibility 

161 self.batch_size: int = int( 

162 properties.get("batch_size", 1) 

163 ) # number of samples/frames per batch 

164 self.log_interval: int = int( 

165 properties.get("log_interval", 10) 

166 ) # number of epochs interval to log the training progress 

167 

168 # Input data section 

169 input_raw_data = np.load(self.io_dict["in"]["input_train_npy_path"]) 

170 # Reshape the input data to be a 2D array and normalization 

171 input_train_reshaped_data: np.ndarray = np.reshape( 

172 input_raw_data, 

173 (len(input_raw_data), input_raw_data.shape[1] * input_raw_data.shape[2]), 

174 ) 

175 # Normalization of the input data 

176 self.input_train_data_max_values: np.ndarray = np.max( 

177 input_train_reshaped_data, axis=0 

178 ) 

179 self.input_train_data_min_values: np.ndarray = np.min( 

180 input_train_reshaped_data, axis=0 

181 ) 

182 input_train_data: np.ndarray = ndarray_normalization( 

183 input_train_reshaped_data, 

184 max_values=self.input_train_data_max_values, 

185 min_values=self.input_train_data_min_values, 

186 ) 

187 

188 self.input_dimensions: int = ( 

189 int(properties["input_dimensions"]) 

190 if properties.get("input_dimensions") 

191 else input_train_data.shape[1] 

192 ) # input dimensions by default it should be the number of features in the input data (number of atoms * 3 corresponding to x, y, z coordinates) 

193 self.output_dimensions: int = ( 

194 int(properties["output_dimensions"]) 

195 if properties.get("output_dimensions") 

196 else self.input_dimensions 

197 ) # output dimensions by default it should be the number of features in the input data (number of atoms * 3 corresponding to x, y, z coordinates) 

198 

199 # Check the properties 

200 self.check_properties(properties) 

201 self.check_arguments() 

202 

203 # Select the data for training and validation steps 

204 index_train_data = int(self.partition * input_train_data.shape[0]) 

205 index_validation_data = int((1 - self.partition) * input_train_data.shape[0]) 

206 train_tensor = torch.FloatTensor(input_train_data[:index_train_data, :]) 

207 validation_tensor = torch.FloatTensor( 

208 input_train_data[-index_validation_data:, :] 

209 ) 

210 performance_tensor = torch.FloatTensor(input_train_data) 

211 train_dataset = torch.utils.data.TensorDataset(train_tensor) 

212 validation_dataset = torch.utils.data.TensorDataset(validation_tensor) 

213 performance_dataset = torch.utils.data.TensorDataset(performance_tensor) 

214 

215 # Seed 

216 if self.seed: 

217 torch.manual_seed(self.seed) 

218 np.random.seed(self.seed) 

219 if torch.cuda.is_available(): 

220 torch.cuda.manual_seed_all(self.seed) 

221 

222 self.train_dataloader: torch.utils.data.DataLoader = ( 

223 torch.utils.data.DataLoader( 

224 dataset=train_dataset, 

225 batch_size=self.batch_size, 

226 drop_last=True, 

227 shuffle=True, 

228 ) 

229 ) 

230 self.validation_dataloader: torch.utils.data.DataLoader = ( 

231 torch.utils.data.DataLoader( 

232 dataset=validation_dataset, 

233 batch_size=self.batch_size, 

234 drop_last=True, 

235 shuffle=False, 

236 ) 

237 ) 

238 self.performance_dataloader: torch.utils.data.DataLoader = ( 

239 torch.utils.data.DataLoader( 

240 dataset=performance_dataset, 

241 batch_size=self.batch_size, 

242 drop_last=False, 

243 shuffle=False, 

244 ) 

245 ) 

246 

247 # Create the model 

248 self.model = MDAE( 

249 input_dimensions=self.input_dimensions, 

250 num_layers=self.num_layers, 

251 latent_dimensions=self.latent_dimensions, 

252 ) 

253 if self.io_dict["in"]["input_model_pth_path"]: 

254 self.model.load_state_dict( 

255 torch.load( 

256 self.io_dict["in"]["input_model_pth_path"], 

257 map_location=self.model.device, 

258 ) 

259 ) 

260 

261 # Define loss function and optimizer algorithm 

262 loss_function_str: str = properties.get("loss_function", "") 

263 try: 

264 self.loss_function: torch.nn.modules.loss._Loss = get_loss_function( 

265 loss_function_str 

266 )() 

267 fu.log(f"Using loss function: {self.loss_function}", self.out_log) 

268 except ValueError: 

269 fu.log(f"Invalid loss function: {loss_function_str}", self.out_log) 

270 fu.log("Using default loss function: MSELoss", self.out_log) 

271 self.loss_function = torch.nn.MSELoss() 

272 

273 optimizer_str: str = properties.get("optimizer", "") 

274 try: 

275 self.optimizer = get_optimizer_function(optimizer_str)( 

276 self.model.parameters(), lr=self.lr 

277 ) 

278 fu.log(f"Using optimizer: {self.optimizer}", self.out_log) 

279 except ValueError: 

280 fu.log(f"Invalid optimizer: {optimizer_str}", self.out_log) 

281 self.optimizer = Adam(self.model.parameters(), lr=self.lr) 

282 

283 @launchlogger 

284 def launch(self) -> int: 

285 """Execute the :class:`TrainMDAE <mdae.train_mdae.TrainMDAE>` object.""" 

286 

287 # Setup Biobb 

288 if self.check_restart(): 

289 return 0 

290 

291 self.stage_files() 

292 

293 # Train the model 

294 train_losses, validation_losses, best_model, best_model_epoch = ( 

295 self.train_model() 

296 ) 

297 if self.stage_io_dict["out"].get("output_train_data_npz_path"): 

298 np.savez( 

299 self.stage_io_dict["out"]["output_train_data_npz_path"], 

300 train_losses=np.array(train_losses), 

301 validation_losses=np.array(validation_losses), 

302 ) 

303 fu.log( 

304 f'Saving train data to: {self.stage_io_dict["out"]["output_train_data_npz_path"]}', 

305 self.out_log, 

306 ) 

307 fu.log( 

308 f' File size: {human_readable_file_size(self.stage_io_dict["out"]["output_train_data_npz_path"])}', 

309 self.out_log, 

310 ) 

311 

312 # Evaluate the model 

313 if self.stage_io_dict["out"].get("output_performance_npz_path"): 

314 evaluate_losses, latent_space, reconstructed_data = self.evaluate_model( 

315 self.performance_dataloader, self.loss_function 

316 ) 

317 denormalized_reconstructed_data = ndarray_denormalization( 

318 reconstructed_data, 

319 self.input_train_data_max_values, 

320 self.input_train_data_min_values, 

321 ) 

322 reshaped_reconstructed_data = np.reshape( 

323 denormalized_reconstructed_data, 

324 (len(denormalized_reconstructed_data), -1, 3), 

325 ) 

326 np.savez( 

327 self.stage_io_dict["out"]["output_performance_npz_path"], 

328 evaluate_losses=np.array(evaluate_losses), 

329 latent_space=np.array(latent_space), 

330 denormalized_reconstructed_data=np.array(reshaped_reconstructed_data), 

331 ) 

332 fu.log( 

333 f'Saving evaluation data to: {self.stage_io_dict["out"]["output_performance_npz_path"]}', 

334 self.out_log, 

335 ) 

336 fu.log( 

337 f' File size: {human_readable_file_size(self.stage_io_dict["out"]["output_performance_npz_path"])}', 

338 self.out_log, 

339 ) 

340 

341 # Save the model 

342 torch.save(best_model, self.stage_io_dict["out"]["output_model_pth_path"]) 

343 fu.log( 

344 f'Saving best model to: {self.stage_io_dict["out"]["output_model_pth_path"]}', 

345 self.out_log, 

346 ) 

347 fu.log(f" Best model epoch: {best_model_epoch}", self.out_log) 

348 fu.log( 

349 f' File size: {human_readable_file_size(self.stage_io_dict["out"]["output_model_pth_path"])}', 

350 self.out_log, 

351 ) 

352 

353 # Copy files to host 

354 self.copy_to_host() 

355 

356 # Remove temporal files 

357 self.remove_tmp_files() 

358 

359 self.check_arguments(output_files_created=True, raise_exception=False) 

360 return 0 

361 

362 def train_model(self) -> tuple[list[float], list[float], dict, int]: 

363 self.model.to(self.model.device) 

364 train_losses: list[float] = [] 

365 validation_losses: list[float] = [] 

366 best_valid_loss: float = float("inf") # Initialize best valid loss to infinity 

367 

368 start_time: float = time.time() 

369 fu.log("Start Training:", self.out_log) 

370 fu.log(f" Device: {self.model.device}", self.out_log) 

371 fu.log( 

372 f" Train input file: {self.stage_io_dict['in']['input_train_npy_path']}", 

373 self.out_log, 

374 ) 

375 fu.log( 

376 f" File size: {human_readable_file_size(self.stage_io_dict['in']['input_train_npy_path'])}", 

377 self.out_log, 

378 ) 

379 fu.log( 

380 f" Number of atoms: {int(len(next(iter(self.train_dataloader))[0][0])/3)}", 

381 self.out_log, 

382 ) 

383 fu.log( 

384 f" Number of frames for training: {len(self.train_dataloader)*self.train_dataloader.batch_size} Total number of frames: {int((len(self.train_dataloader)*self.train_dataloader.batch_size)/self.partition) if self.partition is not None else 'Unknown'}", 

385 self.out_log, 

386 ) # type: ignore 

387 fu.log(f" Number of epochs: {self.num_epochs}", self.out_log) 

388 fu.log(f" Partition: {self.partition}", self.out_log) 

389 fu.log(f" Batch size: {self.batch_size}", self.out_log) 

390 fu.log(f" Learning rate: {self.lr}", self.out_log) 

391 fu.log(f" Learning rate step size: {self.lr_step_size}", self.out_log) 

392 fu.log(f" Learning rate gamma: {self.gamma}", self.out_log) 

393 fu.log(f" Number of layers: {self.num_layers}", self.out_log) 

394 fu.log(f" Input dimensions: {self.input_dimensions}", self.out_log) 

395 fu.log(f" Latent dimensions: {self.latent_dimensions}", self.out_log) 

396 fu.log( 

397 f" Loss function: {str(self.loss_function).split('(')[0]}", self.out_log 

398 ) 

399 fu.log(f" Optimizer: {str(self.optimizer).split('(')[0]}", self.out_log) 

400 fu.log(f" Seed: {self.seed}", self.out_log) 

401 fu.log(f" Checkpoint interval: {self.checkpoint_interval}", self.out_log) 

402 fu.log(f" Log interval: {self.log_interval}\n", self.out_log) 

403 

404 scheduler = torch.optim.lr_scheduler.StepLR( 

405 self.optimizer, step_size=self.lr_step_size, gamma=self.gamma 

406 ) 

407 for epoch_index in range(self.num_epochs): 

408 loop_start_time: float = time.time() 

409 

410 # Training & validation step 

411 avg_train_loss, avg_validation_loss = self.training_step( 

412 self.train_dataloader, self.optimizer, self.loss_function 

413 ) 

414 train_losses.append(avg_train_loss) 

415 validation_losses.append(avg_validation_loss) 

416 

417 # Logging 

418 if self.log_interval and ( 

419 epoch_index % self.log_interval == 0 

420 or epoch_index == self.num_epochs - 1 

421 ): 

422 epoch_time: float = time.time() - loop_start_time 

423 fu.log( 

424 f'{"Epoch":>4} {epoch_index+1}/{self.num_epochs}, Train Loss: {avg_train_loss:.3f}, Validation Loss: {avg_validation_loss:.3f}, LR: {scheduler.get_last_lr()[0]:.5f}, Duration: {format_time(epoch_time)}, ETA: {format_time((self.num_epochs-(epoch_index+1))*epoch_time)}', 

425 self.out_log, 

426 ) 

427 loop_start_time = time.time() 

428 

429 # Save checkpoint 

430 if self.checkpoint_interval and ( 

431 epoch_index % self.checkpoint_interval == 0 

432 or epoch_index == self.num_epochs - 1 

433 ): 

434 checkpoint_path = str( 

435 Path(self.stage_io_dict.get("unique_dir", "")).joinpath( 

436 f"{self.output_checkpoint_prefix}_{epoch_index}.pth" 

437 ) 

438 ) 

439 fu.log(f'{"Saving: ":>4} {checkpoint_path}', self.out_log) 

440 torch.save(self.model.state_dict(), checkpoint_path) 

441 

442 # Update learning rate 

443 scheduler.step() 

444 

445 # Save best model 

446 if avg_validation_loss < best_valid_loss: 

447 best_valid_loss = avg_validation_loss 

448 best_model: dict = self.model.state_dict() 

449 best_model_epoch: int = epoch_index 

450 

451 fu.log( 

452 f"End Training, total time: {format_time((time.time() - start_time))}", 

453 self.out_log, 

454 ) 

455 

456 return train_losses, validation_losses, best_model, best_model_epoch 

457 

458 def training_step( 

459 self, 

460 dataloader: torch.utils.data.DataLoader, 

461 optimizer: Optimizer, 

462 loss_function: torch.nn.modules.loss._Loss, 

463 ) -> tuple[float, float]: 

464 self.model.train() 

465 train_losses: list[float] = [] 

466 for data in dataloader: 

467 data = data[0].to(self.model.device) 

468 _, output = self.model(data) 

469 loss = loss_function(output, data) 

470 optimizer.zero_grad() 

471 loss.backward() 

472 optimizer.step() 

473 train_losses.append(loss.item()) 

474 

475 self.model.eval() 

476 valid_losses: list[float] = [] 

477 with torch.no_grad(): 

478 for data in dataloader: 

479 data = data[0].to(self.model.device) 

480 _, output = self.model(data) 

481 loss = loss_function(output, data) 

482 valid_losses.append(loss.item()) 

483 

484 return float(np.mean(train_losses)), float( 

485 torch.mean(torch.tensor(valid_losses)) 

486 ) 

487 

488 def evaluate_model( 

489 self, 

490 dataloader: torch.utils.data.DataLoader, 

491 loss_function: torch.nn.modules.loss._Loss, 

492 ) -> tuple[float, np.ndarray, np.ndarray]: 

493 return execute_model( 

494 self.model, 

495 dataloader, 

496 self.input_dimensions, 

497 self.latent_dimensions, 

498 loss_function, 

499 ) 

500 

501 

502def trainMDAE( 

503 input_train_npy_path: str, 

504 output_model_pth_path: str, 

505 input_model_pth_path: Optional[str] = None, 

506 output_train_data_npz_path: Optional[str] = None, 

507 output_performance_npz_path: Optional[str] = None, 

508 properties: Optional[dict] = None, 

509 **kwargs, 

510) -> int: 

511 """Execute the :class:`TrainMDAE <mdae.train_mdae.TrainMDAE>` class and 

512 execute the :meth:`launch() <mdae.train_mdae.TrainMDAE.launch>` method.""" 

513 

514 return TrainMDAE( 

515 input_train_npy_path=input_train_npy_path, 

516 output_model_pth_path=output_model_pth_path, 

517 input_model_pth_path=input_model_pth_path, 

518 output_train_data_npz_path=output_train_data_npz_path, 

519 output_performance_npz_path=output_performance_npz_path, 

520 properties=properties, 

521 **kwargs, 

522 ).launch() 

523 

524 

525def main(): 

526 """Command line execution of this building block. Please check the command line documentation.""" 

527 parser = argparse.ArgumentParser( 

528 description="Train a Molecular Dynamics AutoEncoder (MDAE) PyTorch model.", 

529 formatter_class=lambda prog: argparse.RawTextHelpFormatter(prog, width=99999), 

530 ) 

531 parser.add_argument( 

532 "-c", 

533 "--config", 

534 required=False, 

535 help="This file can be a YAML file, JSON file or JSON string", 

536 ) 

537 

538 # Specific args of each building block 

539 required_args = parser.add_argument_group("required arguments") 

540 required_args.add_argument( 

541 "--input_train_npy_path", 

542 required=True, 

543 help="Path to the input train data file. Accepted formats: npy.", 

544 ) 

545 required_args.add_argument( 

546 "--output_model_pth_path", 

547 required=True, 

548 help="Path to the output model file. Accepted formats: pth.", 

549 ) 

550 parser.add_argument( 

551 "--input_model_pth_path", 

552 required=False, 

553 help="Path to the input model file. Accepted formats: pth.", 

554 ) 

555 parser.add_argument( 

556 "--output_train_data_npz_path", 

557 required=False, 

558 help="Path to the output train data file. Accepted formats: npz.", 

559 ) 

560 parser.add_argument( 

561 "--output_performance_npz_path", 

562 required=False, 

563 help="Path to the output performance file. Accepted formats: npz.", 

564 ) 

565 parser.add_argument( 

566 "--properties", 

567 required=False, 

568 help="Additional properties for the MDAE object.", 

569 ) 

570 args = parser.parse_args() 

571 config = args.config if args.config else None 

572 properties = settings.ConfReader(config=config).get_prop_dic() 

573 

574 trainMDAE( 

575 input_train_npy_path=args.input_train_npy_path, 

576 output_model_pth_path=args.output_model_pth_path, 

577 input_model_pth_path=args.input_model_pth_path, 

578 output_train_data_npz_path=args.output_train_data_npz_path, 

579 output_performance_npz_path=args.output_performance_npz_path, 

580 properties=properties, 

581 ) 

582 

583 

584if __name__ == "__main__": 

585 main()