Coverage for biobb_pytorch / mdae / models / gmvae.py: 26%
174 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import torch
2import torch.nn as nn
3import lightning.pytorch as pl
4from mlcolvar.cvs import BaseCV
5from biobb_pytorch.mdae.featurization.normalization import Normalization
6from mlcolvar.core.transform.utils import Inverse
7from biobb_pytorch.mdae.models.nn.feedforward import FeedForward
8from biobb_pytorch.mdae.loss import ELBOGaussianMixtureLoss
11__all__ = ["GaussianMixtureVariationalAutoEncoder"]
14class GaussianMixtureVariationalAutoEncoder(BaseCV, pl.LightningModule):
15 """Gaussian Mixture Variational AutoEncoder Collective Variable.
16 This class implements a Gaussian Mixture Variational AutoEncoder (GMVAE) for
17 collective variable (CV) learning. The GMVAE is a generative model that combines
18 the principles of Gaussian Mixture Models (GMM) and Variational Autoencoders (VAE).
19 It learns a latent representation of the input data by modeling it as a mixture
20 of Gaussians, where each Gaussian corresponds to a different cluster in the data.
21 The model consists of an encoder that maps the input data to a latent space,
22 and a decoder that reconstructs the input data from the latent representation.
23 The GMVAE is trained using a variational inference approach, where the model
24 learns to maximize the evidence lower bound (ELBO) on the data likelihood.
25 The ELBO consists of two terms: the reconstruction loss and the KL divergence
26 between the learned latent distribution and a prior distribution (usually a
27 standard normal distribution). The GMVAE can be used for various tasks such as
28 clustering, dimensionality reduction, and generative modeling.
29 The model is designed to work with PyTorch and PyTorch Lightning, making it easy
30 to integrate into existing workflows and leverage GPU acceleration.
31 Parameters
32 ----------
33 k : int
34 The number of clusters in the Gaussian Mixture Model.
35 n_cvs : int
36 The dimension of the CV or, equivalently, the dimension of the latent
37 space of the autoencoder.
38 n_features : int
39 The dimension of the input data.
40 r_nent : float
41 The weight for the entropy regularization term.
42 qy_dims : list
43 The dimensions of the layers in the encoder for the cluster assignment.
44 qz_dims : list
45 The dimensions of the layers in the encoder for the latent variable.
46 pz_dims : list
47 The dimensions of the layers in the decoder for the latent variable.
48 px_dims : list
49 The dimensions of the layers in the decoder for the reconstruction.
50 options : dict, optional
51 Additional options for the model, such as normalization and dropout rates.
52 """
54 BLOCKS = ["norm_in", "encoder", "decoder", "k"]
56 def __init__(self, n_features, n_cvs, encoder_layers, decoder_layers, options=None, **kwargs):
57 super().__init__(in_features=n_features, out_features=n_cvs, **kwargs)
59 options = self.parse_options(options)
61 if "norm_in" in options and options["norm_in"] is not None:
62 self.norm_in = Normalization(self.in_features, **options["norm_in"])
64 self.k = options["k"]
65 self.r_nent = options.get('loss_function', {}).get("r_nent", 0.5)
67 qy_dims = encoder_layers["qy_dims"]
68 qz_dims = encoder_layers["qz_dims"]
69 pz_dims = decoder_layers["pz_dims"]
70 px_dims = decoder_layers["px_dims"]
72 self.loss_fn = ELBOGaussianMixtureLoss(r_nent=self.r_nent, k=self.k)
74 self.encoder = nn.ModuleDict()
75 self.decoder = nn.ModuleDict()
77 self.encoder['y_transform'] = nn.Linear(self.k, self.k)
79 self.encoder['qy_nn'] = FeedForward([n_features] + qy_dims + [self.k], **options["encoder"]['qy_nn'])
81 self.encoder['qz_nn'] = FeedForward([n_features + self.k] + qz_dims, **options["encoder"]['qz_nn'])
82 self.encoder['zm_layer'] = nn.Linear(qz_dims[-1], n_cvs)
83 self.encoder['zv_layer'] = nn.Linear(qz_dims[-1], n_cvs)
85 self.decoder['pz_nn'] = FeedForward([self.k] + pz_dims, **options["decoder"]['pz_nn'])
86 self.decoder['zm_prior_layer'] = nn.Linear(pz_dims[-1], n_cvs)
87 self.decoder['zv_prior_layer'] = nn.Linear(pz_dims[-1], n_cvs)
89 self.decoder['px_nn'] = FeedForward([n_cvs] + px_dims, **options["decoder"]['px_nn'])
90 self.decoder['xm_layer'] = nn.Linear(px_dims[-1], n_features)
91 self.decoder['xv_layer'] = nn.Linear(px_dims[-1], n_features)
93 self.eval_variables = ["xhat", "z", "qy"]
95 @staticmethod
96 def log_normal(x, mu, var, eps=1e-10):
97 return -0.5 * torch.sum(torch.log(torch.tensor(2.0) * torch.pi) + (x - mu).pow(2) / var + var.log(), dim=-1) # log probability of a normal (Gaussian) distribution
99 def loss_function(self, x, xm, xv, z, zm, zv, zm_prior, zv_prior):
100 return (
101 -self.log_normal(x, xm, xv) + # Reconstruction Loss
102 self.log_normal(z, zm, zv) - self.log_normal(z, zm_prior, zv_prior) - # Regularization Loss (KL Divergence)
103 torch.log(torch.tensor(1 / self.k, device=x.device)) # Entropy Regularization
104 )
106 def encode_decode(self, x):
108 if self.norm_in is not None:
109 data = self.norm_in(x)
111 qy_logit = self.encoder['qy_nn'](data)
113 y_ = torch.zeros([data.shape[0], self.k]).to(data.device)
115 zm_list, zv_list, z_list = [], [], []
116 xm_list, xv_list, x_list = [], [], []
117 zm_prior_list, zv_prior_list = [], []
119 for i in range(self.k):
120 # One-hot y
121 y = y_ + torch.eye(self.k).to(data.device)[i]
123 # Qz
124 h0 = self.encoder['y_transform'](y)
125 xy = torch.cat([data, h0], dim=1)
126 qz_logit = self.encoder['qz_nn'](xy)
127 zm = self.encoder['zm_layer'](qz_logit)
128 zv = torch.nn.functional.softplus(self.encoder['zv_layer'](qz_logit))
129 noise = torch.randn_like(torch.sqrt(zv))
130 z_sample = zm + noise * zv
132 zm_list.append(zm)
133 zv_list.append(zv)
134 z_list.append(z_sample)
136 # Pz (prior)
137 pz_logit = self.decoder['pz_nn'](y)
138 zm_prior = self.decoder['zm_prior_layer'](pz_logit)
139 zv_prior = torch.nn.functional.softplus(self.decoder['zv_prior_layer'](pz_logit))
140 noise = torch.randn_like(torch.sqrt(zv_prior))
141 z_prior_sample = zm_prior + noise * zv_prior
143 zm_prior_list.append(zm_prior)
144 zv_prior_list.append(zv_prior)
146 # Px
147 px_logit = self.decoder['px_nn'](z_prior_sample)
148 xm = self.decoder['xm_layer'](px_logit)
149 xv = torch.nn.functional.softplus(self.decoder['xv_layer'](px_logit))
150 noise = torch.randn_like(torch.sqrt(xv))
151 x_sample = xm + noise * xv
153 xm_list.append(xm)
154 xv_list.append(xv)
155 x_list.append(x_sample)
157 return (
158 data, qy_logit, xm_list, xv_list,
159 z_list, zm_list, zv_list,
160 zm_prior_list, zv_prior_list
161 )
163 def evaluate_model(self, batch, batch_idx):
164 """Evaluate the model on the data, computing average loss."""
166 x = batch['data']
168 if self.norm_in is not None:
169 data = self.norm_in(x)
171 qy_logit = self.encoder['qy_nn'](data)
172 qy = torch.softmax(qy_logit, dim=1)
174 y_ = torch.zeros([data.shape[0], self.k]).to(data.device)
176 zm_list, zv_list, z_list = [], [], []
177 xm_list, xv_list, x_list = [], [], []
178 zm_prior_list, zv_prior_list = [], []
180 for i in range(self.k):
181 # One-hot y
182 y = y_ + torch.eye(self.k).to(data.device)[i]
184 # Qz
185 h0 = self.encoder['y_transform'](y)
186 xy = torch.cat([data, h0], dim=1)
187 qz_logit = self.encoder['qz_nn'](xy)
188 zm = self.encoder['zm_layer'](qz_logit)
189 zv = torch.nn.functional.softplus(self.encoder['zv_layer'](qz_logit))
190 noise = torch.randn_like(torch.sqrt(zv))
191 z_sample = zm + noise * zv
193 zm_list.append(zm)
194 zv_list.append(zv)
195 z_list.append(z_sample)
197 # Pz (prior)
198 pz_logit = self.decoder['pz_nn'](y)
199 zm_prior = self.decoder['zm_prior_layer'](pz_logit)
200 zv_prior = torch.nn.functional.softplus(self.decoder['zv_prior_layer'](pz_logit))
201 noise = torch.randn_like(torch.sqrt(zv_prior))
202 z_prior_sample = zm_prior + noise * zv_prior
204 zm_prior_list.append(zm_prior)
205 zv_prior_list.append(zv_prior)
207 # Px
208 px_logit = self.decoder['px_nn'](z_prior_sample)
209 xm = self.decoder['xm_layer'](px_logit)
210 xv = torch.nn.functional.softplus(self.decoder['xv_layer'](px_logit))
211 noise = torch.randn_like(torch.sqrt(xv))
212 x_sample = xm + noise * xv
214 xm_list.append(xm)
215 xv_list.append(xv)
216 x_list.append(x_sample)
218 xhat = torch.sum(qy.unsqueeze(-1) * torch.stack(x_list, dim=1), dim=1)
220 if self.norm_in is not None:
221 xhat = self.norm_in.inverse(xhat)
223 z = torch.sum(qy.unsqueeze(-1) * torch.stack(z_list, dim=1), dim=1)
225 return xhat, z, qy
227 def decode(self, z):
228 """
229 Reconstruct x' from aggregated z
230 """
231 if z.dim() == 1:
232 z = z.unsqueeze(0)
234 px_logit = self.decoder['px_nn'](z)
235 xm = self.decoder['xm_layer'](px_logit)
236 xv = torch.nn.functional.softplus(self.decoder['xv_layer'](px_logit))
237 noise = torch.randn_like(torch.sqrt(xv))
238 x = xm + noise * xv
240 if self.norm_in is not None:
241 x = self.norm_in.inverse(x)
243 return x
245 def forward_cv(self, x):
247 if self.norm_in is not None:
248 x = self.norm_in(x)
250 qy_logit = self.encoder['qy_nn'](x)
251 qy = torch.softmax(qy_logit, dim=1)
253 y_ = torch.zeros([x.shape[0], self.k]).to(x.device)
255 zm_list, zv_list, z_list = [], [], []
257 for i in range(self.k):
258 # One-hot y
259 y = y_ + torch.eye(self.k).to(x.device)[i]
261 # Qz
262 h0 = self.encoder['y_transform'](y)
263 xy = torch.cat([x, h0], dim=1)
264 qz_logit = self.encoder['qz_nn'](xy)
265 zm = self.encoder['zm_layer'](qz_logit)
266 zv = torch.nn.functional.softplus(self.encoder['zv_layer'](qz_logit))
267 noise = torch.randn_like(torch.sqrt(zv))
268 z_sample = zm + noise * zv
270 zm_list.append(zm)
271 zv_list.append(zv)
272 z_list.append(z_sample)
274 Z = torch.stack(z_list, dim=1)
275 a = torch.sum(qy.unsqueeze(-1) * Z, dim=1)
277 return a
279 def training_step(self, train_batch, batch_idx):
281 x = train_batch["data"]
283 if "target" in train_batch:
284 x_ref = train_batch["target"]
285 else:
286 x_ref = x
288 data, qy_logit, xm_list, xv_list, z_list, zm_list, zv_list, zm_prior_list, zv_prior_list = self.encode_decode(x_ref)
290 batch_loss, nent = self.loss_fn(data,
291 qy_logit,
292 xm_list, xv_list,
293 z_list, zm_list, zv_list,
294 zm_prior_list, zv_prior_list)
296 loss = batch_loss.mean()
297 ce_loss = nent.mean()
299 name = "train" if self.training else "valid"
300 self.log(f"{name}_loss", loss, on_epoch=True, on_step=True, prog_bar=True, logger=True)
301 self.log(f"{name}_cross_entropy", ce_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
303 return loss
305 def get_decoder(self, return_normalization=False):
306 """Return a torch model with the decoder and optionally the normalization inverse"""
307 if return_normalization:
308 if self.norm_in is not None:
309 inv_norm = Inverse(module=self.norm_in)
310 decoder_model = torch.nn.Sequential(*[self.decoder, inv_norm])
311 else:
312 raise ValueError(
313 "return_normalization is set to True but self.norm_in is None"
314 )
315 else:
316 decoder_model = self.decoder
318 return decoder_model
321# # Example of usage:
323# # Define dimensions
324# n_features = 1551 # Input dimension
325# n_clusters = 5 # Output dimension for Qy
326# n_cvs = 3 # Latent dimension (CVs)
327# r_nent = 0.5 # Weight for the entropy regularization term.
329# # Encoder sizes
330# qy_dims = [32]
331# qz_dims = [16, 16]
333# # Decoder sizes
334# pz_dims = [16, 16]
335# px_dims = [128]
337# options = {
338# "norm_in": {
339# "mode": "mean_std"
340# },
341# "optimizer": {
342# "lr": 1e-4
343# }
344# }
346# # Instantiate your GMVAECV
347# model = GaussianMixtureVariationalAutoEncoder(k=n_clusters,
348# n_cvs=n_cvs,
349# n_features=n_features,
350# r_nent=r_nent,
351# qy_dims=qy_dims,
352# qz_dims=qz_dims,
353# pz_dims=pz_dims,
354# px_dims=px_dims,
355# options=options)