Coverage for biobb_pytorch / mdae / models / vae.py: 58%

77 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1# -------------------------------------------------------------------------------------- 

2# vae.py 

3# 

4# from the mlcolvar repository 

5# https://github.com/mlcolvar/mlcolvar 

6# Copyright (c) 2023 Luigi Bonati, Enrico Trizio, Andrea Rizzi & Michele Parrinello 

7# Licensed under the MIT License (see project LICENSE file for full text) 

8# -------------------------------------------------------------------------------------- 

9 

10from typing import Optional, Tuple 

11import torch 

12import lightning.pytorch as pl 

13from mlcolvar.cvs import BaseCV 

14from biobb_pytorch.mdae.models.nn.feedforward import FeedForward 

15from biobb_pytorch.mdae.featurization.normalization import Normalization 

16from mlcolvar.core.transform.utils import Inverse 

17from biobb_pytorch.mdae.loss import ELBOGaussiansLoss 

18 

19__all__ = ["VariationalAutoEncoder"] 

20 

21 

22class VariationalAutoEncoder(BaseCV, pl.LightningModule): 

23 

24 """Variational AutoEncoder Collective Variable. 

25 

26 At training time, the encoder outputs a mean and a variance for each CV 

27 defining a Gaussian distribution associated to the input. One sample is 

28 drawn from this Gaussian, and it goes through the decoder. Then the ELBO 

29 loss is minimized. The ELBO sums the MSE of the reconstruction and the KL 

30 divergence between the generated Gaussian and a N(0, 1) Gaussian. 

31 

32 At evaluation time, the encoder's output mean is used as the CV, while the 

33 variance output and the decoder are ignored. 

34 

35 **Data**: for training, it requires a DictDataset with the key ``'data'`` and 

36 optionally ``'weights'``. If a 'target' key is present this will be used as reference 

37 for the output of the decoder, otherway this will be compared with the input 'data'. 

38 This feature can be used to train (variational) time-lagged autoencoders like in [1]_. 

39 

40 **Loss**: Evidence Lower BOund (ELBO) 

41 

42 References 

43 ---------- 

44 .. [1] C. X. Hernández, H. K. Wayment-Steele, M. M. Sultan, B. E. Husic, and V. S. Pande, 

45 “Variational encoding of complex dynamics,” Physical Review E 97, 062412 (2018). 

46 

47 See also 

48 -------- 

49 mlcolvar.core.loss.ELBOLoss 

50 Evidence Lower BOund loss function 

51 """ 

52 

53 BLOCKS = ["norm_in", "encoder", "decoder"] 

54 

55 def __init__( 

56 self, 

57 n_features: int, 

58 n_cvs: int, 

59 encoder_layers: list, 

60 decoder_layers: Optional[list] = None, 

61 options: Optional[dict] = None, 

62 **kwargs, 

63 ): 

64 """ 

65 Variational autoencoder constructor. Initializes two neural network modules 

66 (encoder and decoder). By default a module standardizing the inputs is also used. 

67 

68 Parameters 

69 ---------- 

70 n_cvs : int 

71 The dimension of the CV or, equivalently, the dimension of the latent 

72 space of the autoencoder. 

73 encoder_layers : list 

74 Number of neurons per layer of the encoder up to the last hidden layer. 

75 The size of the output layer is instead specified with ``n_cvs`` 

76 decoder_layers : list, optional 

77 Number of neurons per layer of the decoder, except for the input layer 

78 which is specified by ``n_cvs``. If ``None`` (default), it takes automatically 

79 the reversed architecture of the encoder. 

80 options : dict[str, Any], optional 

81 Options for the building blocks of the model, by default ``None``. 

82 Available blocks are: ``'norm_in'``, ``'encoder'``, and ``'decoder'``. 

83 Set ``'block_name' = None`` or ``False`` to turn off a block. Encoder 

84 and decoder cannot be turned off. 

85 """ 

86 super().__init__(in_features=n_features, out_features=n_cvs, **kwargs) 

87 

88 # ======= LOSS ======= 

89 # ELBO loss function when latent space and reconstruction distributions are Gaussians. 

90 self.loss_fn = ELBOGaussiansLoss() 

91 

92 # ======= OPTIONS ======= 

93 # parse and sanitize 

94 options = self.parse_options(options) 

95 

96 # if decoder is not given reverse the encoder 

97 if decoder_layers is None: 

98 decoder_layers = encoder_layers[::-1] 

99 

100 # ======= BLOCKS ======= 

101 

102 # initialize norm_in 

103 o = "norm_in" 

104 if (options[o] is not False) and (options[o] is not None): 

105 self.norm_in = Normalization(self.in_features, **options[o]) 

106 

107 # initialize encoder 

108 # The encoder outputs two values for each CV representing mean and std. 

109 o = "encoder" 

110 

111 # Note: The FeedForward implementing the encoder by default needs to have also 

112 # the nonlinearity (and eventually dropout/batchnorm) also for the output 

113 # layer since we'll have two separate linear layers for the mean and variance. 

114 if "last_layer_activation" not in options[o]: 

115 options[o]["last_layer_activation"] = True 

116 

117 self.encoder = FeedForward([n_features] + encoder_layers, **options[o]) 

118 self.mean_nn = torch.nn.Linear( 

119 in_features=encoder_layers[-1], out_features=n_cvs 

120 ) 

121 self.log_var_nn = torch.nn.Linear( 

122 in_features=encoder_layers[-1], out_features=n_cvs 

123 ) 

124 

125 # initialize encoder 

126 o = "decoder" 

127 self.decoder = FeedForward([n_cvs] + decoder_layers + [n_features], **options[o]) 

128 

129 self.eval_variables = ["xhat", "z", "z_mean", "z_logvar"] 

130 

131 def forward_cv(self, x: torch.Tensor) -> torch.Tensor: 

132 """Compute the value of the CV from preprocessed input. 

133 

134 Return the mean output (ignoring the variance output) of the encoder 

135 after (optionally) applying the normalization to the input. 

136 

137 Parameters 

138 ---------- 

139 x : torch.Tensor 

140 Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The 

141 input descriptors of the CV after preprocessing. 

142 

143 Returns 

144 ------- 

145 cv : torch.Tensor 

146 Shape ``(n_batches, n_cvs)``. The CVs, i.e., the mean output of the 

147 encoder (the variance output is discarded). 

148 """ 

149 if self.norm_in is not None: 

150 x = self.norm_in(x) 

151 x = self.encoder(x) 

152 

153 # Take only the means and ignore the log variances. 

154 return self.mean_nn(x) 

155 

156 def encode_decode( 

157 self, x: torch.Tensor 

158 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

159 """Run a pass of encoding + decoding. 

160 

161 The function applies the normalizing to the inputs and its reverse on 

162 the output. 

163 

164 Parameters 

165 ---------- 

166 x : torch.Tensor 

167 Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The 

168 input descriptors of the CV after preprocessing. 

169 

170 Returns 

171 ------- 

172 mean : torch.Tensor 

173 Shape ``(n_batches, n_cvs)`` of ``(n_cvs,)``. The mean of the 

174 Gaussian distribution associated to the input in latent space. 

175 log_variance : torch.Tensor 

176 Shape ``(n_batches, n_cvs)`` of ``(n_cvs,)``. The logarithm of the 

177 variance of the Gaussian distribution associated to the input in 

178 latent space. 

179 x_hat : torch.Tensor 

180 Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The 

181 reconstructed descriptors. 

182 """ 

183 # Normalize inputs. 

184 if self.norm_in is not None: 

185 x = self.norm_in(x) 

186 

187 # Encode input into a Gaussian distribution. 

188 x = self.encoder(x) 

189 mean, log_variance = self.mean_nn(x), self.log_var_nn(x) 

190 

191 # Sample from the Gaussian distribution in latent space. 

192 std = torch.exp(log_variance / 2) 

193 z = torch.distributions.Normal(mean, std).rsample() 

194 

195 # Decode sample. 

196 x_hat = self.decoder(z) 

197 

198 # if self.norm_in is not None: 

199 # x_hat = self.norm_in.inverse(x_hat) 

200 

201 return z, mean, log_variance, x_hat 

202 

203 def evaluate_model(self, batch, batch_idx): 

204 """Evaluate the model on the data, computing average loss.""" 

205 

206 x = batch['data'] 

207 

208 if 'target' in batch: 

209 x_ref = batch['target'] 

210 if self.norm_in is not None: 

211 x_ref = self.norm_in(x_ref) 

212 else: 

213 x_ref = x 

214 

215 z, mean, log_variance, x_hat = self.encode_decode(x) 

216 

217 if self.norm_in is not None: 

218 x_hat = self.norm_in.inverse(x_hat) 

219 

220 return x_hat, z, mean, log_variance 

221 

222 def training_step(self, train_batch, batch_idx): 

223 """Single training step performed by the PyTorch Lightning Trainer.""" 

224 x = train_batch["data"] 

225 loss_kwargs = {} 

226 if "weights" in train_batch: 

227 loss_kwargs["weights"] = train_batch["weights"] 

228 

229 # Encode/decode. 

230 z, mean, log_variance, x_hat = self.encode_decode(x) 

231 

232 # Reference output (compare with a 'target' key if any, otherwise with input 'data') 

233 if "target" in train_batch: 

234 x_ref = train_batch["target"] 

235 else: 

236 x_ref = x 

237 

238 if self.norm_in is not None: 

239 x_ref = self.norm_in(x_ref) 

240 

241 # Loss function. 

242 loss = self.loss_fn(x_ref, x_hat, mean, log_variance, **loss_kwargs) 

243 

244 # Log. 

245 name = "train" if self.training else "valid" 

246 self.log(f"{name}_loss", loss, on_epoch=True, on_step=True, prog_bar=True, logger=True) 

247 

248 return loss 

249 

250 def get_decoder(self, return_normalization=False): 

251 """Return a torch model with the decoder and optionally the normalization inverse""" 

252 if return_normalization: 

253 if self.norm_in is not None: 

254 inv_norm = Inverse(module=self.norm_in) 

255 decoder_model = torch.nn.Sequential(*[self.decoder, inv_norm]) 

256 else: 

257 raise ValueError( 

258 "return_normalization is set to True but self.norm_in is None" 

259 ) 

260 else: 

261 decoder_model = self.decoder 

262 

263 return decoder_model