Coverage for biobb_pytorch/mdae/common.py: 95%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-28 11:48 +0000

1"""Common functions for package biobb_pytorch.models""" 

2 

3from pathlib import Path 

4from typing import Callable, Optional, Union 

5 

6import numpy as np 

7import torch 

8 

9 

10def ndarray_normalization( 

11 ndarray: np.ndarray, max_values: np.ndarray, min_values: np.ndarray 

12) -> np.ndarray: 

13 """ 

14 Normalize an ndarray along a specified axis. 

15 

16 Args: 

17 ndarray (np.ndarray): The input ndarray to be normalized. 

18 max_values (np.ndarray): The maximum values for normalization. 

19 min_values (np.ndarray): The minimum values for normalization. 

20 

21 Returns: 

22 np.ndarray: The normalized ndarray. 

23 """ 

24 return (ndarray - min_values) / (max_values - min_values) 

25 

26 

27def ndarray_denormalization( 

28 normalized_ndarray: np.ndarray, max_values: np.ndarray, min_values: np.ndarray 

29) -> np.ndarray: 

30 """ 

31 Denormalizes a normalized ndarray using the given max and min values. 

32 

33 Args: 

34 normalized_ndarray (np.ndarray): The normalized ndarray to be denormalized. 

35 max_values (np.ndarray): The maximum value used for normalization. 

36 min_values (np.ndarray): The minimum value used for normalization. 

37 

38 Returns: 

39 np.ndarray: The denormalized ndarray. 

40 """ 

41 return normalized_ndarray * (max_values - min_values) + min_values 

42 

43 

44def get_loss_function(loss_function: str) -> Callable: 

45 """ 

46 Get the loss function from the given string. 

47 

48 Args: 

49 loss_function (str): The loss function to be used. 

50 

51 Returns: 

52 Callable: The loss function. 

53 """ 

54 loss_function_dict = dict( 

55 filter(lambda pair: pair[0].endswith("Loss"), vars(torch.nn).items()) 

56 ) 

57 try: 

58 return loss_function_dict[loss_function] 

59 except KeyError: 

60 raise ValueError(f"Invalid loss function: {loss_function}") 

61 

62 

63def get_optimizer_function(optimizer_function: str) -> Callable: 

64 """ 

65 Get the optimizer function from the given string. 

66 

67 Args: 

68 optimizer_function (str): The optimizer function to be used. 

69 

70 Returns: 

71 Callable: The optimizer function. 

72 """ 

73 optimizer_function_dict = dict( 

74 filter(lambda pair: not pair[0].startswith("_"), vars(torch.optim).items()) 

75 ) 

76 try: 

77 return optimizer_function_dict[optimizer_function] 

78 except KeyError: 

79 raise ValueError(f"Invalid optimizer function: {optimizer_function}") 

80 

81 

82def execute_model( 

83 model: torch.nn.Module, 

84 dataloader: torch.utils.data.DataLoader, 

85 input_dimensions: int, 

86 latent_dimensions: int, 

87 loss_function: Optional[torch.nn.modules.loss._Loss] = None, 

88) -> tuple[float, np.ndarray, np.ndarray]: 

89 model.eval() 

90 losses: list[float] = [] 

91 z_list: list[float] = [] 

92 x_hat_list: list[float] = [] 

93 with torch.no_grad(): 

94 for data in dataloader: 

95 data = data[0].to(model.device) 

96 latent, output = model(data) 

97 if loss_function: 

98 loss = loss_function(output, data) 

99 losses.append(loss.item()) 

100 z_list.append(latent.cpu().numpy()) 

101 x_hat_list.append(output.cpu().numpy()) 

102 loss = float(np.mean(losses)) if losses else -1.0 

103 latent_space: np.ndarray = np.reshape( 

104 np.concatenate(z_list, axis=0), (-1, latent_dimensions) 

105 ) 

106 reconstructed_data: np.ndarray = np.reshape( 

107 np.concatenate(x_hat_list, axis=0), (-1, input_dimensions) 

108 ) 

109 return loss, latent_space, reconstructed_data 

110 

111 

112def format_time(seconds: Union[float, int]) -> str: 

113 """Converts time in seconds to a string of the format 'HH:MM:SS'.""" 

114 hours, remainder = divmod(seconds, 3600) 

115 minutes, seconds = divmod(remainder, 60) 

116 if hours: 

117 return "{:02}h {:02}m {:02}s".format(int(hours), int(minutes), int(seconds)) 

118 elif minutes: 

119 return "{:02}m {:02}s".format(int(minutes), int(seconds)) 

120 else: 

121 return "{:02}s".format(int(seconds)) 

122 

123 

124def human_readable_file_size(file_path: Union[str, Path]) -> str: 

125 """Get the size of a file and return it in a human-readable format.""" 

126 file_path = Path(file_path) # Ensure file_path is a Path object 

127 size_in_bytes: float = file_path.stat().st_size 

128 units = ["Bytes", "KB", "MB", "GB", "PB"] 

129 for unit in units: 

130 if size_in_bytes < 1024: 

131 return f"{size_in_bytes:.2f} {unit}" 

132 size_in_bytes /= 1024 

133 return f"{size_in_bytes:.2f} {unit}"