Coverage for biobb_pytorch / mdae / models / ae.py: 85%
68 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1# --------------------------------------------------------------------------------------
2# autoencoder.py
3#
4# from the mlcolvar repository
5# https://github.com/mlcolvar/mlcolvar
6# Copyright (c) 2023 Luigi Bonati, Enrico Trizio, Andrea Rizzi & Michele Parrinello
7# Licensed under the MIT License (see project LICENSE file for full text)
8# --------------------------------------------------------------------------------------
10import torch
11import lightning.pytorch as pl
12from mlcolvar.cvs import BaseCV
13from biobb_pytorch.mdae.models.nn.feedforward import FeedForward
14from biobb_pytorch.mdae.featurization.normalization import Normalization
15from mlcolvar.core.transform.utils import Inverse
16from biobb_pytorch.mdae.loss import MSELoss
18__all__ = ["AutoEncoder"]
21class AutoEncoder(BaseCV, pl.LightningModule):
22 """AutoEncoding Collective Variable.
23 It is composed by a first neural network (encoder) which projects
24 the input data into a latent space (the CVs). Then a second network (decoder) takes
25 the CVs and tries to reconstruct the input data based on them. It is an unsupervised learning approach,
26 typically used when no labels are available. This CV is inspired by [1]_.
28 Furthermore, it can also be used lo learn a representation which can be used not to reconstruct the data but
29 to predict, e.g. future configurations.
31 **Data**: for training it requires a DictDataset with the key 'data' and optionally 'weights' to reweight the
32 data as done in [2]_. If a 'target' key is present this will be used as reference for the output of the decoder,
33 otherway this will be compared with the input 'data'. This feature can be used to train a time-lagged autoencoder [3]_
34 where the task is not to reconstruct the input but the output at a later step.
36 **Loss**: reconstruction loss (MSELoss)
38 References
39 ----------
40 .. [1] W. Chen and A. L. Ferguson, “ Molecular enhanced sampling with autoencoders: On-the-fly collective
41 variable discovery and accelerated free energy landscape exploration,” JCC 39, 2079–2102 (2018)
42 .. [2] Z. Belkacemi, P. Gkeka, T. Lelièvre, and G. Stoltz, “ Chasing collective variables using autoencoders and biased
43 trajectories,” JCTC 18, 59–78 (2022)
44 .. [3] C. Wehmeyer and F. Noé, “Time-lagged autoencoders: Deep learning of slow collective variables for molecular
45 kinetics,” JCP 148, 241703 (2018).
47 See also
48 --------
49 mlcolvar.core.loss.MSELoss
50 (weighted) Mean Squared Error (MSE) loss function.
51 """
53 BLOCKS = ["norm_in", "encoder", "decoder"]
55 def __init__(
56 self,
57 n_features: int,
58 n_cvs: int,
59 encoder_layers: list,
60 decoder_layers: list = None,
61 options: dict = None,
62 **kwargs,
63 ):
64 """
65 Define a CV defined as the output layer of the encoder of an autoencoder model (latent space).
66 The decoder part is used only during the training for the reconstruction loss.
67 By default a module standardizing the inputs is also used.
69 Parameters
70 ----------
71 encoder_layers : list
72 Number of neurons per layer of the encoder
73 decoder_layers : list, optional
74 Number of neurons per layer of the decoder, by default None
75 If not set it takes automaically the reversed architecture of the encoder
76 options : dict[str,Any], optional
77 Options for the building blocks of the model, by default None.
78 Available blocks: ['norm_in', 'encoder','decoder'].
79 Set 'block_name' = None or False to turn off that block
80 """
81 super().__init__(
82 in_features=n_features, out_features=n_cvs, **kwargs
83 )
85 # ======= LOSS =======
86 # Reconstruction (MSE) loss
87 self.loss_fn = MSELoss()
89 # ======= OPTIONS =======
90 # parse and sanitize
91 options = self.parse_options(options)
93 # if decoder is not given reverse the encoder
94 if decoder_layers is None:
95 decoder_layers = encoder_layers[::-1]
97 # ======= BLOCKS =======
99 # initialize norm_in
100 o = "norm_in"
101 if (options[o] is not False) and (options[o] is not None):
102 self.norm_in = Normalization(self.in_features, **options[o])
104 # initialize encoder
105 o = "encoder"
106 self.encoder = FeedForward([n_features] + encoder_layers + [n_cvs], **options[o])
108 # initialize decoder
109 o = "decoder"
110 self.decoder = FeedForward([n_cvs] + decoder_layers + [n_features], **options[o])
112 self.eval_variables = ["xhat", "z"]
114 def forward_cv(self, x: torch.Tensor) -> torch.Tensor:
115 """Evaluate the CV without pre or post/processing modules."""
116 if self.norm_in is not None:
117 x = self.norm_in(x)
118 x = self.encoder(x)
119 return x
121 def decode(self, z: torch.Tensor) -> torch.Tensor:
122 """Decode the latent space into the original input space."""
123 x = self.decoder(z)
124 if self.norm_in is not None:
125 x = self.norm_in.inverse(x)
126 return x
128 def encode_decode(self, x: torch.Tensor) -> torch.Tensor:
129 """Pass the inputs through both the encoder and the decoder networks."""
130 x = self.forward_cv(x)
131 x = self.decoder(x)
132 if self.norm_in is not None:
133 x = self.norm_in.inverse(x)
134 return x
136 def evaluate_model(self, batch, batch_idx=None):
137 """Evaluate the model on the data, computing the reconstruction loss."""
139 x = batch['data']
140 z = self.forward_cv(x)
141 x_hat = self.decoder(z)
143 if self.norm_in is not None:
144 x_hat = self.norm_in.inverse(x_hat)
146 return x_hat, z
148 def training_step(self, train_batch, batch_idx):
149 """Compute and return the training loss and record metrics."""
150 # =================get data===================
151 x = train_batch["data"]
152 loss_kwargs = {}
153 if "weights" in train_batch:
154 loss_kwargs["weights"] = train_batch["weights"]
155 # =================forward====================
156 x_hat = self.encode_decode(x)
157 # ===================loss=====================
158 if "target" in train_batch:
159 x_ref = train_batch["target"]
160 else:
161 x_ref = x
163 # if self.norm_in is not None:
164 # x_ref = self.norm_in(x_ref)
166 loss = self.loss_fn(x_hat, x_ref, **loss_kwargs)
168 # ====================log=====================
169 name = "train" if self.training else "valid"
170 self.log(f"{name}_loss", loss, on_epoch=True, on_step=True, prog_bar=True, logger=True)
171 return loss
173 def get_decoder(self, return_normalization=False):
174 """Return a torch model with the decoder and optionally the normalization inverse"""
175 if return_normalization:
176 if self.norm_in is not None:
177 inv_norm = Inverse(module=self.norm_in)
178 decoder_model = torch.nn.Sequential(*[self.decoder, inv_norm])
179 else:
180 raise ValueError(
181 "return_normalization is set to True but self.norm_in is None"
182 )
183 else:
184 decoder_model = self.decoder
185 return decoder_model