Coverage for biobb_pytorch/mdae/mdae.py: 100%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-21 09:06 +0000

1"""Module containing the MDAutoEncoder class and the command line interface.""" 

2 

3import torch 

4 

5 

6class MDAE(torch.nn.Module): 

7 def __init__( 

8 self, 

9 input_dimensions: int, 

10 num_layers: int, 

11 latent_dimensions: int, 

12 dropout: float = 0.0, 

13 leaky_relu: float = 0.0, 

14 ): 

15 super().__init__() 

16 self.input_dimensions: int = input_dimensions 

17 self.num_layers: int = num_layers 

18 self.latent_dimensions: int = latent_dimensions 

19 self.device: torch.device = torch.device( 

20 "cuda" if torch.cuda.is_available() else "cpu" 

21 ) 

22 self.delta: int = int((input_dimensions - latent_dimensions) / num_layers) 

23 self.dropout: float = dropout 

24 self.leaky_relu: float = leaky_relu 

25 

26 # Encoder 

27 encoder: list = [] 

28 nunits: int = self.input_dimensions 

29 for _ in range(self.num_layers - 1): 

30 encoder.append(torch.nn.Linear(nunits, nunits - self.delta)) 

31 # encoder.append(torch.nn.ReLU()) 

32 encoder.append(torch.nn.LeakyReLU(self.leaky_relu)) 

33 encoder.append(torch.nn.Dropout(self.dropout)) 

34 nunits = nunits - self.delta 

35 self.encoder = torch.nn.Sequential(*encoder) 

36 

37 # Latent Space 

38 self.lv = torch.nn.Sequential( 

39 torch.nn.Linear(nunits, self.latent_dimensions), torch.nn.Sigmoid() 

40 ) 

41 

42 # Decoder 

43 decoder: list = [] 

44 nunits = self.latent_dimensions 

45 for _ in range(self.num_layers - 1): 

46 decoder.append(torch.nn.Linear(nunits, nunits + self.delta)) 

47 # decoder.append(torch.nn.ReLU()) 

48 decoder.append(torch.nn.LeakyReLU(self.leaky_relu)) 

49 decoder.append(torch.nn.Dropout(self.dropout)) 

50 nunits = nunits + self.delta 

51 self.decoder = torch.nn.Sequential(*decoder) 

52 

53 # Output 

54 self.output_layer = torch.nn.Sequential( 

55 torch.nn.Linear(nunits, input_dimensions), torch.nn.Sigmoid() 

56 ) 

57 

58 def forward(self, x): 

59 encoded = self.encoder(x) 

60 latent_space = self.lv(encoded) 

61 decoded = self.decoder(latent_space) 

62 output = self.output_layer(decoded) 

63 return latent_space, output