Coverage for biobb_pytorch/mdae/common.py: 95%
55 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-28 11:48 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-28 11:48 +0000
1"""Common functions for package biobb_pytorch.models"""
3from pathlib import Path
4from typing import Callable, Optional, Union
6import numpy as np
7import torch
10def ndarray_normalization(
11 ndarray: np.ndarray, max_values: np.ndarray, min_values: np.ndarray
12) -> np.ndarray:
13 """
14 Normalize an ndarray along a specified axis.
16 Args:
17 ndarray (np.ndarray): The input ndarray to be normalized.
18 max_values (np.ndarray): The maximum values for normalization.
19 min_values (np.ndarray): The minimum values for normalization.
21 Returns:
22 np.ndarray: The normalized ndarray.
23 """
24 return (ndarray - min_values) / (max_values - min_values)
27def ndarray_denormalization(
28 normalized_ndarray: np.ndarray, max_values: np.ndarray, min_values: np.ndarray
29) -> np.ndarray:
30 """
31 Denormalizes a normalized ndarray using the given max and min values.
33 Args:
34 normalized_ndarray (np.ndarray): The normalized ndarray to be denormalized.
35 max_values (np.ndarray): The maximum value used for normalization.
36 min_values (np.ndarray): The minimum value used for normalization.
38 Returns:
39 np.ndarray: The denormalized ndarray.
40 """
41 return normalized_ndarray * (max_values - min_values) + min_values
44def get_loss_function(loss_function: str) -> Callable:
45 """
46 Get the loss function from the given string.
48 Args:
49 loss_function (str): The loss function to be used.
51 Returns:
52 Callable: The loss function.
53 """
54 loss_function_dict = dict(
55 filter(lambda pair: pair[0].endswith("Loss"), vars(torch.nn).items())
56 )
57 try:
58 return loss_function_dict[loss_function]
59 except KeyError:
60 raise ValueError(f"Invalid loss function: {loss_function}")
63def get_optimizer_function(optimizer_function: str) -> Callable:
64 """
65 Get the optimizer function from the given string.
67 Args:
68 optimizer_function (str): The optimizer function to be used.
70 Returns:
71 Callable: The optimizer function.
72 """
73 optimizer_function_dict = dict(
74 filter(lambda pair: not pair[0].startswith("_"), vars(torch.optim).items())
75 )
76 try:
77 return optimizer_function_dict[optimizer_function]
78 except KeyError:
79 raise ValueError(f"Invalid optimizer function: {optimizer_function}")
82def execute_model(
83 model: torch.nn.Module,
84 dataloader: torch.utils.data.DataLoader,
85 input_dimensions: int,
86 latent_dimensions: int,
87 loss_function: Optional[torch.nn.modules.loss._Loss] = None,
88) -> tuple[float, np.ndarray, np.ndarray]:
89 model.eval()
90 losses: list[float] = []
91 z_list: list[float] = []
92 x_hat_list: list[float] = []
93 with torch.no_grad():
94 for data in dataloader:
95 data = data[0].to(model.device)
96 latent, output = model(data)
97 if loss_function:
98 loss = loss_function(output, data)
99 losses.append(loss.item())
100 z_list.append(latent.cpu().numpy())
101 x_hat_list.append(output.cpu().numpy())
102 loss = float(np.mean(losses)) if losses else -1.0
103 latent_space: np.ndarray = np.reshape(
104 np.concatenate(z_list, axis=0), (-1, latent_dimensions)
105 )
106 reconstructed_data: np.ndarray = np.reshape(
107 np.concatenate(x_hat_list, axis=0), (-1, input_dimensions)
108 )
109 return loss, latent_space, reconstructed_data
112def format_time(seconds: Union[float, int]) -> str:
113 """Converts time in seconds to a string of the format 'HH:MM:SS'."""
114 hours, remainder = divmod(seconds, 3600)
115 minutes, seconds = divmod(remainder, 60)
116 if hours:
117 return "{:02}h {:02}m {:02}s".format(int(hours), int(minutes), int(seconds))
118 elif minutes:
119 return "{:02}m {:02}s".format(int(minutes), int(seconds))
120 else:
121 return "{:02}s".format(int(seconds))
124def human_readable_file_size(file_path: Union[str, Path]) -> str:
125 """Get the size of a file and return it in a human-readable format."""
126 file_path = Path(file_path) # Ensure file_path is a Path object
127 size_in_bytes: float = file_path.stat().st_size
128 units = ["Bytes", "KB", "MB", "GB", "PB"]
129 for unit in units:
130 if size_in_bytes < 1024:
131 return f"{size_in_bytes:.2f} {unit}"
132 size_in_bytes /= 1024
133 return f"{size_in_bytes:.2f} {unit}"