⬅ biobb_pytorch/mdae/models/gmvae.py source

1 import torch
2 import torch.nn as nn
3 import lightning.pytorch as pl
4 from mlcolvar.cvs import BaseCV
5 from biobb_pytorch.mdae.featurization.normalization import Normalization
6 from mlcolvar.core.transform.utils import Inverse
7 from biobb_pytorch.mdae.models.nn.feedforward import FeedForward
8 from biobb_pytorch.mdae.loss import ELBOGaussianMixtureLoss
9  
10  
11 __all__ = ["GaussianMixtureVariationalAutoEncoder"]
12  
13  
14 class GaussianMixtureVariationalAutoEncoder(BaseCV, pl.LightningModule):
15 """Gaussian Mixture Variational AutoEncoder Collective Variable.
16 This class implements a Gaussian Mixture Variational AutoEncoder (GMVAE) for
17 collective variable (CV) learning. The GMVAE is a generative model that combines
18 the principles of Gaussian Mixture Models (GMM) and Variational Autoencoders (VAE).
19 It learns a latent representation of the input data by modeling it as a mixture
20 of Gaussians, where each Gaussian corresponds to a different cluster in the data.
21 The model consists of an encoder that maps the input data to a latent space,
22 and a decoder that reconstructs the input data from the latent representation.
23 The GMVAE is trained using a variational inference approach, where the model
24 learns to maximize the evidence lower bound (ELBO) on the data likelihood.
25 The ELBO consists of two terms: the reconstruction loss and the KL divergence
26 between the learned latent distribution and a prior distribution (usually a
27 standard normal distribution). The GMVAE can be used for various tasks such as
28 clustering, dimensionality reduction, and generative modeling.
29 The model is designed to work with PyTorch and PyTorch Lightning, making it easy
30 to integrate into existing workflows and leverage GPU acceleration.
31 Parameters
32 ----------
33 k : int
34 The number of clusters in the Gaussian Mixture Model.
35 n_cvs : int
36 The dimension of the CV or, equivalently, the dimension of the latent
37 space of the autoencoder.
38 n_features : int
39 The dimension of the input data.
40 r_nent : float
41 The weight for the entropy regularization term.
42 qy_dims : list
43 The dimensions of the layers in the encoder for the cluster assignment.
44 qz_dims : list
45 The dimensions of the layers in the encoder for the latent variable.
46 pz_dims : list
47 The dimensions of the layers in the decoder for the latent variable.
48 px_dims : list
49 The dimensions of the layers in the decoder for the reconstruction.
50 options : dict, optional
51 Additional options for the model, such as normalization and dropout rates.
52 """
53  
54 BLOCKS = ["norm_in", "encoder", "decoder", "k"]
55  
56 def __init__(self, n_features, n_cvs, encoder_layers, decoder_layers, options=None, **kwargs):
57 super().__init__(in_features=n_features, out_features=n_cvs, **kwargs)
58  
59 options = self.parse_options(options)
60  
61 if "norm_in" in options and options["norm_in"] is not None:
62 self.norm_in = Normalization(self.in_features, **options["norm_in"])
63  
64 self.k = options["k"]
65 self.r_nent = options.get('loss_function', {}).get("r_nent", 0.5)
66  
67 qy_dims = encoder_layers["qy_dims"]
68 qz_dims = encoder_layers["qz_dims"]
69 pz_dims = decoder_layers["pz_dims"]
70 px_dims = decoder_layers["px_dims"]
71  
72 self.loss_fn = ELBOGaussianMixtureLoss(r_nent=self.r_nent, k=self.k)
73  
74 self.encoder = nn.ModuleDict()
75 self.decoder = nn.ModuleDict()
76  
77 self.encoder['y_transform'] = nn.Linear(self.k, self.k)
78  
79 self.encoder['qy_nn'] = FeedForward([n_features] + qy_dims + [self.k], **options["encoder"]['qy_nn'])
80  
81 self.encoder['qz_nn'] = FeedForward([n_features + self.k] + qz_dims, **options["encoder"]['qz_nn'])
82 self.encoder['zm_layer'] = nn.Linear(qz_dims[-1], n_cvs)
83 self.encoder['zv_layer'] = nn.Linear(qz_dims[-1], n_cvs)
84  
85 self.decoder['pz_nn'] = FeedForward([self.k] + pz_dims, **options["decoder"]['pz_nn'])
86 self.decoder['zm_prior_layer'] = nn.Linear(pz_dims[-1], n_cvs)
87 self.decoder['zv_prior_layer'] = nn.Linear(pz_dims[-1], n_cvs)
88  
89 self.decoder['px_nn'] = FeedForward([n_cvs] + px_dims, **options["decoder"]['px_nn'])
90 self.decoder['xm_layer'] = nn.Linear(px_dims[-1], n_features)
91 self.decoder['xv_layer'] = nn.Linear(px_dims[-1], n_features)
92  
93 self.eval_variables = ["xhat", "z", "qy"]
94  
95 @staticmethod
96 def log_normal(x, mu, var, eps=1e-10):
97 return -0.5 * torch.sum(torch.log(torch.tensor(2.0) * torch.pi) + (x - mu).pow(2) / var + var.log(), dim=-1) # log probability of a normal (Gaussian) distribution
98  
99 def loss_function(self, x, xm, xv, z, zm, zv, zm_prior, zv_prior):
100 return (
  • W504 Line break after binary operator
101 -self.log_normal(x, xm, xv) + # Reconstruction Loss
  • W504 Line break after binary operator
102 self.log_normal(z, zm, zv) - self.log_normal(z, zm_prior, zv_prior) - # Regularization Loss (KL Divergence)
103 torch.log(torch.tensor(1 / self.k, device=x.device)) # Entropy Regularization
104 )
105  
106 def encode_decode(self, x):
107  
108 if self.norm_in is not None:
109 data = self.norm_in(x)
110  
111 qy_logit = self.encoder['qy_nn'](data)
112  
113 y_ = torch.zeros([data.shape[0], self.k]).to(data.device)
114  
115 zm_list, zv_list, z_list = [], [], []
116 xm_list, xv_list, x_list = [], [], []
117 zm_prior_list, zv_prior_list = [], []
118  
119 for i in range(self.k):
120 # One-hot y
121 y = y_ + torch.eye(self.k).to(data.device)[i]
122  
123 # Qz
124 h0 = self.encoder['y_transform'](y)
125 xy = torch.cat([data, h0], dim=1)
126 qz_logit = self.encoder['qz_nn'](xy)
127 zm = self.encoder['zm_layer'](qz_logit)
128 zv = torch.nn.functional.softplus(self.encoder['zv_layer'](qz_logit))
129 noise = torch.randn_like(torch.sqrt(zv))
130 z_sample = zm + noise * zv
131  
132 zm_list.append(zm)
133 zv_list.append(zv)
134 z_list.append(z_sample)
135  
136 # Pz (prior)
137 pz_logit = self.decoder['pz_nn'](y)
138 zm_prior = self.decoder['zm_prior_layer'](pz_logit)
139 zv_prior = torch.nn.functional.softplus(self.decoder['zv_prior_layer'](pz_logit))
140 noise = torch.randn_like(torch.sqrt(zv_prior))
141 z_prior_sample = zm_prior + noise * zv_prior
142  
143 zm_prior_list.append(zm_prior)
144 zv_prior_list.append(zv_prior)
145  
146 # Px
147 px_logit = self.decoder['px_nn'](z_prior_sample)
148 xm = self.decoder['xm_layer'](px_logit)
149 xv = torch.nn.functional.softplus(self.decoder['xv_layer'](px_logit))
150 noise = torch.randn_like(torch.sqrt(xv))
151 x_sample = xm + noise * xv
152  
153 xm_list.append(xm)
154 xv_list.append(xv)
155 x_list.append(x_sample)
156  
157 return (
158 data, qy_logit, xm_list, xv_list,
159 z_list, zm_list, zv_list,
160 zm_prior_list, zv_prior_list
161 )
162  
163 def evaluate_model(self, batch, batch_idx):
164 """Evaluate the model on the data, computing average loss."""
165  
166 x = batch['data']
167  
168 if self.norm_in is not None:
169 data = self.norm_in(x)
170  
171 qy_logit = self.encoder['qy_nn'](data)
172 qy = torch.softmax(qy_logit, dim=1)
173  
174 y_ = torch.zeros([data.shape[0], self.k]).to(data.device)
175  
176 zm_list, zv_list, z_list = [], [], []
177 xm_list, xv_list, x_list = [], [], []
178 zm_prior_list, zv_prior_list = [], []
179  
180 for i in range(self.k):
181 # One-hot y
182 y = y_ + torch.eye(self.k).to(data.device)[i]
183  
184 # Qz
185 h0 = self.encoder['y_transform'](y)
186 xy = torch.cat([data, h0], dim=1)
187 qz_logit = self.encoder['qz_nn'](xy)
188 zm = self.encoder['zm_layer'](qz_logit)
189 zv = torch.nn.functional.softplus(self.encoder['zv_layer'](qz_logit))
190 noise = torch.randn_like(torch.sqrt(zv))
191 z_sample = zm + noise * zv
192  
193 zm_list.append(zm)
194 zv_list.append(zv)
195 z_list.append(z_sample)
196  
197 # Pz (prior)
198 pz_logit = self.decoder['pz_nn'](y)
199 zm_prior = self.decoder['zm_prior_layer'](pz_logit)
200 zv_prior = torch.nn.functional.softplus(self.decoder['zv_prior_layer'](pz_logit))
201 noise = torch.randn_like(torch.sqrt(zv_prior))
202 z_prior_sample = zm_prior + noise * zv_prior
203  
204 zm_prior_list.append(zm_prior)
205 zv_prior_list.append(zv_prior)
206  
207 # Px
208 px_logit = self.decoder['px_nn'](z_prior_sample)
209 xm = self.decoder['xm_layer'](px_logit)
210 xv = torch.nn.functional.softplus(self.decoder['xv_layer'](px_logit))
211 noise = torch.randn_like(torch.sqrt(xv))
212 x_sample = xm + noise * xv
213  
214 xm_list.append(xm)
215 xv_list.append(xv)
216 x_list.append(x_sample)
217  
218 xhat = torch.sum(qy.unsqueeze(-1) * torch.stack(x_list, dim=1), dim=1)
219  
220 if self.norm_in is not None:
221 xhat = self.norm_in.inverse(xhat)
222  
223 z = torch.sum(qy.unsqueeze(-1) * torch.stack(z_list, dim=1), dim=1)
224  
225 return xhat, z, qy
226  
227 def decode(self, z):
228 """
229 Reconstruct x' from aggregated z
230 """
231 if z.dim() == 1:
232 z = z.unsqueeze(0)
233  
234 px_logit = self.decoder['px_nn'](z)
235 xm = self.decoder['xm_layer'](px_logit)
236 xv = torch.nn.functional.softplus(self.decoder['xv_layer'](px_logit))
237 noise = torch.randn_like(torch.sqrt(xv))
238 x = xm + noise * xv
239  
240 if self.norm_in is not None:
241 x = self.norm_in.inverse(x)
242  
243 return x
244  
245 def forward_cv(self, x):
246  
247 if self.norm_in is not None:
248 x = self.norm_in(x)
249  
250 qy_logit = self.encoder['qy_nn'](x)
251 qy = torch.softmax(qy_logit, dim=1)
252  
253 y_ = torch.zeros([x.shape[0], self.k]).to(x.device)
254  
255 zm_list, zv_list, z_list = [], [], []
256  
257 for i in range(self.k):
258 # One-hot y
259 y = y_ + torch.eye(self.k).to(x.device)[i]
260  
261 # Qz
262 h0 = self.encoder['y_transform'](y)
263 xy = torch.cat([x, h0], dim=1)
264 qz_logit = self.encoder['qz_nn'](xy)
265 zm = self.encoder['zm_layer'](qz_logit)
266 zv = torch.nn.functional.softplus(self.encoder['zv_layer'](qz_logit))
267 noise = torch.randn_like(torch.sqrt(zv))
268 z_sample = zm + noise * zv
269  
270 zm_list.append(zm)
271 zv_list.append(zv)
272 z_list.append(z_sample)
273  
274 Z = torch.stack(z_list, dim=1)
275 a = torch.sum(qy.unsqueeze(-1) * Z, dim=1)
276  
277 return a
278  
279 def training_step(self, train_batch, batch_idx):
280  
281 x = train_batch["data"]
282  
283 if "target" in train_batch:
284 x_ref = train_batch["target"]
285 else:
286 x_ref = x
287  
288 data, qy_logit, xm_list, xv_list, z_list, zm_list, zv_list, zm_prior_list, zv_prior_list = self.encode_decode(x_ref)
289  
290 batch_loss, nent = self.loss_fn(data,
291 qy_logit,
292 xm_list, xv_list,
293 z_list, zm_list, zv_list,
294 zm_prior_list, zv_prior_list)
295  
296 loss = batch_loss.mean()
297 ce_loss = nent.mean()
298  
299 name = "train" if self.training else "valid"
300 self.log(f"{name}_loss", loss, on_epoch=True, on_step=True, prog_bar=True, logger=True)
301 self.log(f"{name}_cross_entropy", ce_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
302  
303 return loss
304  
305 def get_decoder(self, return_normalization=False):
306 """Return a torch model with the decoder and optionally the normalization inverse"""
307 if return_normalization:
308 if self.norm_in is not None:
309 inv_norm = Inverse(module=self.norm_in)
310 decoder_model = torch.nn.Sequential(*[self.decoder, inv_norm])
311 else:
312 raise ValueError(
313 "return_normalization is set to True but self.norm_in is None"
314 )
315 else:
316 decoder_model = self.decoder
317  
318 return decoder_model
319  
320  
321 # # Example of usage:
322  
323 # # Define dimensions
324 # n_features = 1551 # Input dimension
325 # n_clusters = 5 # Output dimension for Qy
326 # n_cvs = 3 # Latent dimension (CVs)
327 # r_nent = 0.5 # Weight for the entropy regularization term.
328  
329 # # Encoder sizes
330 # qy_dims = [32]
331 # qz_dims = [16, 16]
332  
333 # # Decoder sizes
334 # pz_dims = [16, 16]
335 # px_dims = [128]
336  
337 # options = {
338 # "norm_in": {
339 # "mode": "mean_std"
340 # },
341 # "optimizer": {
342 # "lr": 1e-4
343 # }
344 # }
345  
346 # # Instantiate your GMVAECV
347 # model = GaussianMixtureVariationalAutoEncoder(k=n_clusters,
348 # n_cvs=n_cvs,
349 # n_features=n_features,
350 # r_nent=r_nent,
351 # qy_dims=qy_dims,
352 # qz_dims=qz_dims,
353 # pz_dims=pz_dims,
354 # px_dims=px_dims,
355 # options=options)