Coverage for biobb_pytorch / mdae / models / vae.py: 58%
77 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1# --------------------------------------------------------------------------------------
2# vae.py
3#
4# from the mlcolvar repository
5# https://github.com/mlcolvar/mlcolvar
6# Copyright (c) 2023 Luigi Bonati, Enrico Trizio, Andrea Rizzi & Michele Parrinello
7# Licensed under the MIT License (see project LICENSE file for full text)
8# --------------------------------------------------------------------------------------
10from typing import Optional, Tuple
11import torch
12import lightning.pytorch as pl
13from mlcolvar.cvs import BaseCV
14from biobb_pytorch.mdae.models.nn.feedforward import FeedForward
15from biobb_pytorch.mdae.featurization.normalization import Normalization
16from mlcolvar.core.transform.utils import Inverse
17from biobb_pytorch.mdae.loss import ELBOGaussiansLoss
19__all__ = ["VariationalAutoEncoder"]
22class VariationalAutoEncoder(BaseCV, pl.LightningModule):
24 """Variational AutoEncoder Collective Variable.
26 At training time, the encoder outputs a mean and a variance for each CV
27 defining a Gaussian distribution associated to the input. One sample is
28 drawn from this Gaussian, and it goes through the decoder. Then the ELBO
29 loss is minimized. The ELBO sums the MSE of the reconstruction and the KL
30 divergence between the generated Gaussian and a N(0, 1) Gaussian.
32 At evaluation time, the encoder's output mean is used as the CV, while the
33 variance output and the decoder are ignored.
35 **Data**: for training, it requires a DictDataset with the key ``'data'`` and
36 optionally ``'weights'``. If a 'target' key is present this will be used as reference
37 for the output of the decoder, otherway this will be compared with the input 'data'.
38 This feature can be used to train (variational) time-lagged autoencoders like in [1]_.
40 **Loss**: Evidence Lower BOund (ELBO)
42 References
43 ----------
44 .. [1] C. X. Hernández, H. K. Wayment-Steele, M. M. Sultan, B. E. Husic, and V. S. Pande,
45 “Variational encoding of complex dynamics,” Physical Review E 97, 062412 (2018).
47 See also
48 --------
49 mlcolvar.core.loss.ELBOLoss
50 Evidence Lower BOund loss function
51 """
53 BLOCKS = ["norm_in", "encoder", "decoder"]
55 def __init__(
56 self,
57 n_features: int,
58 n_cvs: int,
59 encoder_layers: list,
60 decoder_layers: Optional[list] = None,
61 options: Optional[dict] = None,
62 **kwargs,
63 ):
64 """
65 Variational autoencoder constructor. Initializes two neural network modules
66 (encoder and decoder). By default a module standardizing the inputs is also used.
68 Parameters
69 ----------
70 n_cvs : int
71 The dimension of the CV or, equivalently, the dimension of the latent
72 space of the autoencoder.
73 encoder_layers : list
74 Number of neurons per layer of the encoder up to the last hidden layer.
75 The size of the output layer is instead specified with ``n_cvs``
76 decoder_layers : list, optional
77 Number of neurons per layer of the decoder, except for the input layer
78 which is specified by ``n_cvs``. If ``None`` (default), it takes automatically
79 the reversed architecture of the encoder.
80 options : dict[str, Any], optional
81 Options for the building blocks of the model, by default ``None``.
82 Available blocks are: ``'norm_in'``, ``'encoder'``, and ``'decoder'``.
83 Set ``'block_name' = None`` or ``False`` to turn off a block. Encoder
84 and decoder cannot be turned off.
85 """
86 super().__init__(in_features=n_features, out_features=n_cvs, **kwargs)
88 # ======= LOSS =======
89 # ELBO loss function when latent space and reconstruction distributions are Gaussians.
90 self.loss_fn = ELBOGaussiansLoss()
92 # ======= OPTIONS =======
93 # parse and sanitize
94 options = self.parse_options(options)
96 # if decoder is not given reverse the encoder
97 if decoder_layers is None:
98 decoder_layers = encoder_layers[::-1]
100 # ======= BLOCKS =======
102 # initialize norm_in
103 o = "norm_in"
104 if (options[o] is not False) and (options[o] is not None):
105 self.norm_in = Normalization(self.in_features, **options[o])
107 # initialize encoder
108 # The encoder outputs two values for each CV representing mean and std.
109 o = "encoder"
111 # Note: The FeedForward implementing the encoder by default needs to have also
112 # the nonlinearity (and eventually dropout/batchnorm) also for the output
113 # layer since we'll have two separate linear layers for the mean and variance.
114 if "last_layer_activation" not in options[o]:
115 options[o]["last_layer_activation"] = True
117 self.encoder = FeedForward([n_features] + encoder_layers, **options[o])
118 self.mean_nn = torch.nn.Linear(
119 in_features=encoder_layers[-1], out_features=n_cvs
120 )
121 self.log_var_nn = torch.nn.Linear(
122 in_features=encoder_layers[-1], out_features=n_cvs
123 )
125 # initialize encoder
126 o = "decoder"
127 self.decoder = FeedForward([n_cvs] + decoder_layers + [n_features], **options[o])
129 self.eval_variables = ["xhat", "z", "z_mean", "z_logvar"]
131 def forward_cv(self, x: torch.Tensor) -> torch.Tensor:
132 """Compute the value of the CV from preprocessed input.
134 Return the mean output (ignoring the variance output) of the encoder
135 after (optionally) applying the normalization to the input.
137 Parameters
138 ----------
139 x : torch.Tensor
140 Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The
141 input descriptors of the CV after preprocessing.
143 Returns
144 -------
145 cv : torch.Tensor
146 Shape ``(n_batches, n_cvs)``. The CVs, i.e., the mean output of the
147 encoder (the variance output is discarded).
148 """
149 if self.norm_in is not None:
150 x = self.norm_in(x)
151 x = self.encoder(x)
153 # Take only the means and ignore the log variances.
154 return self.mean_nn(x)
156 def encode_decode(
157 self, x: torch.Tensor
158 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
159 """Run a pass of encoding + decoding.
161 The function applies the normalizing to the inputs and its reverse on
162 the output.
164 Parameters
165 ----------
166 x : torch.Tensor
167 Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The
168 input descriptors of the CV after preprocessing.
170 Returns
171 -------
172 mean : torch.Tensor
173 Shape ``(n_batches, n_cvs)`` of ``(n_cvs,)``. The mean of the
174 Gaussian distribution associated to the input in latent space.
175 log_variance : torch.Tensor
176 Shape ``(n_batches, n_cvs)`` of ``(n_cvs,)``. The logarithm of the
177 variance of the Gaussian distribution associated to the input in
178 latent space.
179 x_hat : torch.Tensor
180 Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The
181 reconstructed descriptors.
182 """
183 # Normalize inputs.
184 if self.norm_in is not None:
185 x = self.norm_in(x)
187 # Encode input into a Gaussian distribution.
188 x = self.encoder(x)
189 mean, log_variance = self.mean_nn(x), self.log_var_nn(x)
191 # Sample from the Gaussian distribution in latent space.
192 std = torch.exp(log_variance / 2)
193 z = torch.distributions.Normal(mean, std).rsample()
195 # Decode sample.
196 x_hat = self.decoder(z)
198 # if self.norm_in is not None:
199 # x_hat = self.norm_in.inverse(x_hat)
201 return z, mean, log_variance, x_hat
203 def evaluate_model(self, batch, batch_idx):
204 """Evaluate the model on the data, computing average loss."""
206 x = batch['data']
208 if 'target' in batch:
209 x_ref = batch['target']
210 if self.norm_in is not None:
211 x_ref = self.norm_in(x_ref)
212 else:
213 x_ref = x
215 z, mean, log_variance, x_hat = self.encode_decode(x)
217 if self.norm_in is not None:
218 x_hat = self.norm_in.inverse(x_hat)
220 return x_hat, z, mean, log_variance
222 def training_step(self, train_batch, batch_idx):
223 """Single training step performed by the PyTorch Lightning Trainer."""
224 x = train_batch["data"]
225 loss_kwargs = {}
226 if "weights" in train_batch:
227 loss_kwargs["weights"] = train_batch["weights"]
229 # Encode/decode.
230 z, mean, log_variance, x_hat = self.encode_decode(x)
232 # Reference output (compare with a 'target' key if any, otherwise with input 'data')
233 if "target" in train_batch:
234 x_ref = train_batch["target"]
235 else:
236 x_ref = x
238 if self.norm_in is not None:
239 x_ref = self.norm_in(x_ref)
241 # Loss function.
242 loss = self.loss_fn(x_ref, x_hat, mean, log_variance, **loss_kwargs)
244 # Log.
245 name = "train" if self.training else "valid"
246 self.log(f"{name}_loss", loss, on_epoch=True, on_step=True, prog_bar=True, logger=True)
248 return loss
250 def get_decoder(self, return_normalization=False):
251 """Return a torch model with the decoder and optionally the normalization inverse"""
252 if return_normalization:
253 if self.norm_in is not None:
254 inv_norm = Inverse(module=self.norm_in)
255 decoder_model = torch.nn.Sequential(*[self.decoder, inv_norm])
256 else:
257 raise ValueError(
258 "return_normalization is set to True but self.norm_in is None"
259 )
260 else:
261 decoder_model = self.decoder
263 return decoder_model