Coverage for biobb_pytorch / mdae / models / ae.py: 85%

68 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1# -------------------------------------------------------------------------------------- 

2# autoencoder.py 

3# 

4# from the mlcolvar repository 

5# https://github.com/mlcolvar/mlcolvar 

6# Copyright (c) 2023 Luigi Bonati, Enrico Trizio, Andrea Rizzi & Michele Parrinello 

7# Licensed under the MIT License (see project LICENSE file for full text) 

8# -------------------------------------------------------------------------------------- 

9 

10import torch 

11import lightning.pytorch as pl 

12from mlcolvar.cvs import BaseCV 

13from biobb_pytorch.mdae.models.nn.feedforward import FeedForward 

14from biobb_pytorch.mdae.featurization.normalization import Normalization 

15from mlcolvar.core.transform.utils import Inverse 

16from biobb_pytorch.mdae.loss import MSELoss 

17 

18__all__ = ["AutoEncoder"] 

19 

20 

21class AutoEncoder(BaseCV, pl.LightningModule): 

22 """AutoEncoding Collective Variable. 

23 It is composed by a first neural network (encoder) which projects 

24 the input data into a latent space (the CVs). Then a second network (decoder) takes 

25 the CVs and tries to reconstruct the input data based on them. It is an unsupervised learning approach, 

26 typically used when no labels are available. This CV is inspired by [1]_. 

27 

28 Furthermore, it can also be used lo learn a representation which can be used not to reconstruct the data but 

29 to predict, e.g. future configurations. 

30 

31 **Data**: for training it requires a DictDataset with the key 'data' and optionally 'weights' to reweight the 

32 data as done in [2]_. If a 'target' key is present this will be used as reference for the output of the decoder, 

33 otherway this will be compared with the input 'data'. This feature can be used to train a time-lagged autoencoder [3]_ 

34 where the task is not to reconstruct the input but the output at a later step. 

35 

36 **Loss**: reconstruction loss (MSELoss) 

37 

38 References 

39 ---------- 

40 .. [1] W. Chen and A. L. Ferguson, “ Molecular enhanced sampling with autoencoders: On-the-fly collective 

41 variable discovery and accelerated free energy landscape exploration,” JCC 39, 2079–2102 (2018) 

42 .. [2] Z. Belkacemi, P. Gkeka, T. Lelièvre, and G. Stoltz, “ Chasing collective variables using autoencoders and biased 

43 trajectories,” JCTC 18, 59–78 (2022) 

44 .. [3] C. Wehmeyer and F. Noé, “Time-lagged autoencoders: Deep learning of slow collective variables for molecular 

45 kinetics,” JCP 148, 241703 (2018). 

46 

47 See also 

48 -------- 

49 mlcolvar.core.loss.MSELoss 

50 (weighted) Mean Squared Error (MSE) loss function. 

51 """ 

52 

53 BLOCKS = ["norm_in", "encoder", "decoder"] 

54 

55 def __init__( 

56 self, 

57 n_features: int, 

58 n_cvs: int, 

59 encoder_layers: list, 

60 decoder_layers: list = None, 

61 options: dict = None, 

62 **kwargs, 

63 ): 

64 """ 

65 Define a CV defined as the output layer of the encoder of an autoencoder model (latent space). 

66 The decoder part is used only during the training for the reconstruction loss. 

67 By default a module standardizing the inputs is also used. 

68 

69 Parameters 

70 ---------- 

71 encoder_layers : list 

72 Number of neurons per layer of the encoder 

73 decoder_layers : list, optional 

74 Number of neurons per layer of the decoder, by default None 

75 If not set it takes automaically the reversed architecture of the encoder 

76 options : dict[str,Any], optional 

77 Options for the building blocks of the model, by default None. 

78 Available blocks: ['norm_in', 'encoder','decoder']. 

79 Set 'block_name' = None or False to turn off that block 

80 """ 

81 super().__init__( 

82 in_features=n_features, out_features=n_cvs, **kwargs 

83 ) 

84 

85 # ======= LOSS ======= 

86 # Reconstruction (MSE) loss 

87 self.loss_fn = MSELoss() 

88 

89 # ======= OPTIONS ======= 

90 # parse and sanitize 

91 options = self.parse_options(options) 

92 

93 # if decoder is not given reverse the encoder 

94 if decoder_layers is None: 

95 decoder_layers = encoder_layers[::-1] 

96 

97 # ======= BLOCKS ======= 

98 

99 # initialize norm_in 

100 o = "norm_in" 

101 if (options[o] is not False) and (options[o] is not None): 

102 self.norm_in = Normalization(self.in_features, **options[o]) 

103 

104 # initialize encoder 

105 o = "encoder" 

106 self.encoder = FeedForward([n_features] + encoder_layers + [n_cvs], **options[o]) 

107 

108 # initialize decoder 

109 o = "decoder" 

110 self.decoder = FeedForward([n_cvs] + decoder_layers + [n_features], **options[o]) 

111 

112 self.eval_variables = ["xhat", "z"] 

113 

114 def forward_cv(self, x: torch.Tensor) -> torch.Tensor: 

115 """Evaluate the CV without pre or post/processing modules.""" 

116 if self.norm_in is not None: 

117 x = self.norm_in(x) 

118 x = self.encoder(x) 

119 return x 

120 

121 def decode(self, z: torch.Tensor) -> torch.Tensor: 

122 """Decode the latent space into the original input space.""" 

123 x = self.decoder(z) 

124 if self.norm_in is not None: 

125 x = self.norm_in.inverse(x) 

126 return x 

127 

128 def encode_decode(self, x: torch.Tensor) -> torch.Tensor: 

129 """Pass the inputs through both the encoder and the decoder networks.""" 

130 x = self.forward_cv(x) 

131 x = self.decoder(x) 

132 if self.norm_in is not None: 

133 x = self.norm_in.inverse(x) 

134 return x 

135 

136 def evaluate_model(self, batch, batch_idx=None): 

137 """Evaluate the model on the data, computing the reconstruction loss.""" 

138 

139 x = batch['data'] 

140 z = self.forward_cv(x) 

141 x_hat = self.decoder(z) 

142 

143 if self.norm_in is not None: 

144 x_hat = self.norm_in.inverse(x_hat) 

145 

146 return x_hat, z 

147 

148 def training_step(self, train_batch, batch_idx): 

149 """Compute and return the training loss and record metrics.""" 

150 # =================get data=================== 

151 x = train_batch["data"] 

152 loss_kwargs = {} 

153 if "weights" in train_batch: 

154 loss_kwargs["weights"] = train_batch["weights"] 

155 # =================forward==================== 

156 x_hat = self.encode_decode(x) 

157 # ===================loss===================== 

158 if "target" in train_batch: 

159 x_ref = train_batch["target"] 

160 else: 

161 x_ref = x 

162 

163 # if self.norm_in is not None: 

164 # x_ref = self.norm_in(x_ref) 

165 

166 loss = self.loss_fn(x_hat, x_ref, **loss_kwargs) 

167 

168 # ====================log===================== 

169 name = "train" if self.training else "valid" 

170 self.log(f"{name}_loss", loss, on_epoch=True, on_step=True, prog_bar=True, logger=True) 

171 return loss 

172 

173 def get_decoder(self, return_normalization=False): 

174 """Return a torch model with the decoder and optionally the normalization inverse""" 

175 if return_normalization: 

176 if self.norm_in is not None: 

177 inv_norm = Inverse(module=self.norm_in) 

178 decoder_model = torch.nn.Sequential(*[self.decoder, inv_norm]) 

179 else: 

180 raise ValueError( 

181 "return_normalization is set to True but self.norm_in is None" 

182 ) 

183 else: 

184 decoder_model = self.decoder 

185 return decoder_model