Coverage for biobb_pytorch / mdae / models / nn / feedforward.py: 84%

31 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1#!/usr/bin/env python 

2 

3# ----------------------------------------------------------------------------- 

4# feedforward.py 

5# 

6# from the mlcolvar repository 

7# ----------------------------------------------------------------------------- 

8 

9# ============================================================================= 

10# MODULE DOCSTRING 

11# ============================================================================= 

12 

13""" 

14Variational Autoencoder collective variable. 

15""" 

16 

17__all__ = ["FeedForward"] 

18 

19 

20# ============================================================================= 

21# GLOBAL IMPORTS 

22# ============================================================================= 

23 

24from typing import Optional, Union 

25 

26import torch 

27import lightning 

28from .utils import get_activation, parse_nn_options 

29 

30# ============================================================================= 

31# STANDARD FEED FORWARD 

32# ============================================================================= 

33 

34 

35class FeedForward(lightning.LightningModule): 

36 """Define a feedforward neural network given the list of layers. 

37 

38 Optionally dropout and batchnorm can be applied (the order is activation -> dropout -> batchnorm). 

39 """ 

40 

41 def __init__( 

42 self, 

43 layers: list, 

44 activation: Union[str, list] = None, 

45 dropout: Optional[Union[float, list]] = None, 

46 batchnorm: Union[bool, list] = False, 

47 last_layer_activation: bool = False, 

48 **kwargs, 

49 ): 

50 """Constructor. 

51 

52 Parameters 

53 ---------- 

54 layers : list 

55 Number of neurons per layer. 

56 activation : string or list[str], optional 

57 Add activation function (options: relu, tanh, elu, linear). If a 

58 ``list``, this must have length ``len(layers)-1``, and ``activation[i]`` 

59 controls whether to add the activation to the ``i``-layer. 

60 dropout : float or list[float], optional 

61 Add dropout with this probability after each layer. If a ``list``, 

62 this must have length ``len(layers)-1``, and ``dropout[i]`` specifies 

63 the the dropout probability for the ``i``-th layer. 

64 batchnorm : bool or list[bool], optional 

65 Add batchnorm after each layer. If a ``list``, this must have 

66 length ``len(layers)-1``, and ``batchnorm[i]`` controls whether to 

67 add the batchnorm to the ``i``-th layer. 

68 last_layer_activation : bool, optional 

69 If ``True`` and activation, dropout, and batchnorm are added also to 

70 the output layer when ``activation``, ``dropout``, or ``batchnorm`` 

71 (i.e., they are not lists). Otherwise, the output layer will be linear. 

72 This option is ignored for the arguments among ``activation``, ``dropout``, 

73 and ``batchnorm`` that are passed as lists. 

74 **kwargs: 

75 Optional arguments passed to torch.nn.Module 

76 """ 

77 super().__init__(**kwargs) 

78 

79 # Parse layers 

80 if not isinstance(layers[0], int): 

81 raise TypeError("layers should be a list-type of integers.") 

82 

83 # Parse options per each hidden layer 

84 n_layers = len(layers) - 1 

85 

86 # -- activation 

87 activation_list = parse_nn_options(activation, n_layers, last_layer_activation) 

88 if isinstance(last_layer_activation, str): 

89 last_layer_activation = None 

90 # -- dropout 

91 dropout_list = parse_nn_options(dropout, n_layers, last_layer_activation) 

92 # -- batchnorm 

93 batchnorm_list = parse_nn_options(batchnorm, n_layers, last_layer_activation) 

94 

95 # Create network 

96 modules = [] 

97 for i in range(len(layers) - 1): 

98 modules.append(torch.nn.Linear(layers[i], layers[i + 1])) 

99 activ, drop, norm = activation_list[i], dropout_list[i], batchnorm_list[i] 

100 

101 if activ is not None: 

102 modules.append(get_activation(activ)) 

103 

104 if drop is not None: 

105 modules.append(torch.nn.Dropout(p=drop)) 

106 

107 if norm: 

108 modules.append(torch.nn.BatchNorm1d(layers[i + 1])) 

109 

110 # store model and attributes 

111 self.nn = torch.nn.Sequential(*modules) 

112 self.in_features = layers[0] 

113 self.out_features = layers[-1] 

114 

115 # def extra_repr(self) -> str: 

116 # repr = f"in_features={self.in_features}, out_features={self.out_features}" 

117 # return repr 

118 

119 def forward(self, x: torch.Tensor) -> torch.Tensor: 

120 return self.nn(x)