Coverage for biobb_pytorch / mdae / models / nn / utils.py: 38%
52 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import torch
2import torch.nn.functional as F
3from typing import Union
6class Shifted_Softplus(torch.nn.Softplus):
7 """Element-wise softplus function shifted as to pass from the origin."""
9 def __init__(self, beta=1, threshold=20):
10 super(Shifted_Softplus, self).__init__(beta, threshold)
12 def forward(self, input):
13 sp0 = F.softplus(torch.zeros(1), self.beta, self.threshold).item()
14 return F.softplus(input, self.beta, self.threshold) - sp0
17class Custom_Sigmoid(torch.nn.Module):
18 def __init__(self, p=3):
19 super(Custom_Sigmoid, self).__init__()
20 self.p = p
22 def forward(self, input):
23 return 1 / (1 + torch.exp(-self.p * (input)))
26def get_activation(activation: str):
27 """Return activation module given string."""
28 activ = None
29 if activation == "relu":
30 activ = torch.nn.ReLU(True)
31 elif activation == "elu":
32 activ = torch.nn.ELU(True)
33 elif activation == "tanh":
34 activ = torch.nn.Tanh()
35 elif activation == "softplus":
36 activ = torch.nn.Softplus()
37 elif activation == "shifted_softplus":
38 activ = Shifted_Softplus()
39 elif activation == "sigmoid":
40 activ = torch.nn.Sigmoid()
41 elif activation == "logsoftmax":
42 activ = torch.nn.LogSoftmax(dim=1)
43 elif activation == "linear":
44 print("WARNING: no activation selected")
45 elif activation is None:
46 pass
47 else:
48 raise ValueError(
49 f"Unknown activation: {activation}. options: 'relu','elu','tanh','softplus','shifted_softplus','logsoftmax','linear'. "
50 )
51 return activ
54def parse_nn_options(options: Union[str, list], n_layers: int, last_layer_activation: Union[bool, str]):
55 """Parse args per layer of the NN.
57 If a single value is given, repeat options to all layers but for the output one,
58 unless ``last_layer_activation is True``, in which case the option is repeated
59 also for the output layer.
60 """
61 if last_layer_activation is False:
62 last_layer_activation = None
64 # If an iterable is given cheeck that its length matches the number of NN layers
65 if hasattr(options, "__iter__") and not isinstance(options, str):
66 if len(options) != n_layers:
67 raise ValueError(
68 f"Length of options: {options} ({len(options)} should be equal to number of layers ({n_layers}))."
69 )
70 options_list = options
72 # if a single value is given, repeat options to all layers but for the output one
73 else:
74 if last_layer_activation:
75 if isinstance(last_layer_activation, str):
76 options_list = [options for _ in range(n_layers - 1)]
77 options_list.append(last_layer_activation)
78 else:
79 options_list = [options for _ in range(n_layers)]
80 else:
81 options_list = [options for _ in range(n_layers - 1)]
82 options_list.append(None)
84 return options_list