Coverage for biobb_pytorch / mdae / models / nn / utils.py: 38%

52 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import torch 

2import torch.nn.functional as F 

3from typing import Union 

4 

5 

6class Shifted_Softplus(torch.nn.Softplus): 

7 """Element-wise softplus function shifted as to pass from the origin.""" 

8 

9 def __init__(self, beta=1, threshold=20): 

10 super(Shifted_Softplus, self).__init__(beta, threshold) 

11 

12 def forward(self, input): 

13 sp0 = F.softplus(torch.zeros(1), self.beta, self.threshold).item() 

14 return F.softplus(input, self.beta, self.threshold) - sp0 

15 

16 

17class Custom_Sigmoid(torch.nn.Module): 

18 def __init__(self, p=3): 

19 super(Custom_Sigmoid, self).__init__() 

20 self.p = p 

21 

22 def forward(self, input): 

23 return 1 / (1 + torch.exp(-self.p * (input))) 

24 

25 

26def get_activation(activation: str): 

27 """Return activation module given string.""" 

28 activ = None 

29 if activation == "relu": 

30 activ = torch.nn.ReLU(True) 

31 elif activation == "elu": 

32 activ = torch.nn.ELU(True) 

33 elif activation == "tanh": 

34 activ = torch.nn.Tanh() 

35 elif activation == "softplus": 

36 activ = torch.nn.Softplus() 

37 elif activation == "shifted_softplus": 

38 activ = Shifted_Softplus() 

39 elif activation == "sigmoid": 

40 activ = torch.nn.Sigmoid() 

41 elif activation == "logsoftmax": 

42 activ = torch.nn.LogSoftmax(dim=1) 

43 elif activation == "linear": 

44 print("WARNING: no activation selected") 

45 elif activation is None: 

46 pass 

47 else: 

48 raise ValueError( 

49 f"Unknown activation: {activation}. options: 'relu','elu','tanh','softplus','shifted_softplus','logsoftmax','linear'. " 

50 ) 

51 return activ 

52 

53 

54def parse_nn_options(options: Union[str, list], n_layers: int, last_layer_activation: Union[bool, str]): 

55 """Parse args per layer of the NN. 

56 

57 If a single value is given, repeat options to all layers but for the output one, 

58 unless ``last_layer_activation is True``, in which case the option is repeated 

59 also for the output layer. 

60 """ 

61 if last_layer_activation is False: 

62 last_layer_activation = None 

63 

64 # If an iterable is given cheeck that its length matches the number of NN layers 

65 if hasattr(options, "__iter__") and not isinstance(options, str): 

66 if len(options) != n_layers: 

67 raise ValueError( 

68 f"Length of options: {options} ({len(options)} should be equal to number of layers ({n_layers}))." 

69 ) 

70 options_list = options 

71 

72 # if a single value is given, repeat options to all layers but for the output one 

73 else: 

74 if last_layer_activation: 

75 if isinstance(last_layer_activation, str): 

76 options_list = [options for _ in range(n_layers - 1)] 

77 options_list.append(last_layer_activation) 

78 else: 

79 options_list = [options for _ in range(n_layers)] 

80 else: 

81 options_list = [options for _ in range(n_layers - 1)] 

82 options_list.append(None) 

83 

84 return options_list