1 import torch
2 import torch.nn as nn
3 import lightning.pytorch as pl
4 from mlcolvar.cvs import BaseCV
5 from biobb_pytorch.mdae.featurization.normalization import Normalization
6 from mlcolvar.core.transform.utils import Inverse
7 from biobb_pytorch.mdae.models.nn.feedforward import FeedForward
8 from biobb_pytorch.mdae.loss import ELBOGaussianMixtureLoss
9
10
11 __all__ = ["GaussianMixtureVariationalAutoEncoder"]
12
13
14 class GaussianMixtureVariationalAutoEncoder(BaseCV, pl.LightningModule):
15 """Gaussian Mixture Variational AutoEncoder Collective Variable.
16 This class implements a Gaussian Mixture Variational AutoEncoder (GMVAE) for
17 collective variable (CV) learning. The GMVAE is a generative model that combines
18 the principles of Gaussian Mixture Models (GMM) and Variational Autoencoders (VAE).
19 It learns a latent representation of the input data by modeling it as a mixture
20 of Gaussians, where each Gaussian corresponds to a different cluster in the data.
21 The model consists of an encoder that maps the input data to a latent space,
22 and a decoder that reconstructs the input data from the latent representation.
23 The GMVAE is trained using a variational inference approach, where the model
24 learns to maximize the evidence lower bound (ELBO) on the data likelihood.
25 The ELBO consists of two terms: the reconstruction loss and the KL divergence
26 between the learned latent distribution and a prior distribution (usually a
27 standard normal distribution). The GMVAE can be used for various tasks such as
28 clustering, dimensionality reduction, and generative modeling.
29 The model is designed to work with PyTorch and PyTorch Lightning, making it easy
30 to integrate into existing workflows and leverage GPU acceleration.
31 Parameters
32 ----------
33 k : int
34 The number of clusters in the Gaussian Mixture Model.
35 n_cvs : int
36 The dimension of the CV or, equivalently, the dimension of the latent
37 space of the autoencoder.
38 n_features : int
39 The dimension of the input data.
40 r_nent : float
41 The weight for the entropy regularization term.
42 qy_dims : list
43 The dimensions of the layers in the encoder for the cluster assignment.
44 qz_dims : list
45 The dimensions of the layers in the encoder for the latent variable.
46 pz_dims : list
47 The dimensions of the layers in the decoder for the latent variable.
48 px_dims : list
49 The dimensions of the layers in the decoder for the reconstruction.
50 options : dict, optional
51 Additional options for the model, such as normalization and dropout rates.
52 """
53
54 BLOCKS = ["norm_in", "encoder", "decoder", "k"]
55
56 def __init__(self, n_features, n_cvs, encoder_layers, decoder_layers, options=None, **kwargs):
57 super().__init__(in_features=n_features, out_features=n_cvs, **kwargs)
58
59 options = self.parse_options(options)
60
61 if "norm_in" in options and options["norm_in"] is not None:
62 self.norm_in = Normalization(self.in_features, **options["norm_in"])
63
64 self.k = options["k"]
65 self.r_nent = options.get('loss_function', {}).get("r_nent", 0.5)
66
67 qy_dims = encoder_layers["qy_dims"]
68 qz_dims = encoder_layers["qz_dims"]
69 pz_dims = decoder_layers["pz_dims"]
70 px_dims = decoder_layers["px_dims"]
71
72 self.loss_fn = ELBOGaussianMixtureLoss(r_nent=self.r_nent, k=self.k)
73
74 self.encoder = nn.ModuleDict()
75 self.decoder = nn.ModuleDict()
76
77 self.encoder['y_transform'] = nn.Linear(self.k, self.k)
78
79 self.encoder['qy_nn'] = FeedForward([n_features] + qy_dims + [self.k], **options["encoder"]['qy_nn'])
80
81 self.encoder['qz_nn'] = FeedForward([n_features + self.k] + qz_dims, **options["encoder"]['qz_nn'])
82 self.encoder['zm_layer'] = nn.Linear(qz_dims[-1], n_cvs)
83 self.encoder['zv_layer'] = nn.Linear(qz_dims[-1], n_cvs)
84
85 self.decoder['pz_nn'] = FeedForward([self.k] + pz_dims, **options["decoder"]['pz_nn'])
86 self.decoder['zm_prior_layer'] = nn.Linear(pz_dims[-1], n_cvs)
87 self.decoder['zv_prior_layer'] = nn.Linear(pz_dims[-1], n_cvs)
88
89 self.decoder['px_nn'] = FeedForward([n_cvs] + px_dims, **options["decoder"]['px_nn'])
90 self.decoder['xm_layer'] = nn.Linear(px_dims[-1], n_features)
91 self.decoder['xv_layer'] = nn.Linear(px_dims[-1], n_features)
92
93 self.eval_variables = ["xhat", "z", "qy"]
94
95 @staticmethod
96 def log_normal(x, mu, var, eps=1e-10):
97 return -0.5 * torch.sum(torch.log(torch.tensor(2.0) * torch.pi) + (x - mu).pow(2) / var + var.log(), dim=-1) # log probability of a normal (Gaussian) distribution
98
99 def loss_function(self, x, xm, xv, z, zm, zv, zm_prior, zv_prior):
100 return (
-
W504
Line break after binary operator
101 -self.log_normal(x, xm, xv) + # Reconstruction Loss
-
W504
Line break after binary operator
102 self.log_normal(z, zm, zv) - self.log_normal(z, zm_prior, zv_prior) - # Regularization Loss (KL Divergence)
103 torch.log(torch.tensor(1 / self.k, device=x.device)) # Entropy Regularization
104 )
105
106 def encode_decode(self, x):
107
108 if self.norm_in is not None:
109 data = self.norm_in(x)
110
111 qy_logit = self.encoder['qy_nn'](data)
112
113 y_ = torch.zeros([data.shape[0], self.k]).to(data.device)
114
115 zm_list, zv_list, z_list = [], [], []
116 xm_list, xv_list, x_list = [], [], []
117 zm_prior_list, zv_prior_list = [], []
118
119 for i in range(self.k):
120 # One-hot y
121 y = y_ + torch.eye(self.k).to(data.device)[i]
122
123 # Qz
124 h0 = self.encoder['y_transform'](y)
125 xy = torch.cat([data, h0], dim=1)
126 qz_logit = self.encoder['qz_nn'](xy)
127 zm = self.encoder['zm_layer'](qz_logit)
128 zv = torch.nn.functional.softplus(self.encoder['zv_layer'](qz_logit))
129 noise = torch.randn_like(torch.sqrt(zv))
130 z_sample = zm + noise * zv
131
132 zm_list.append(zm)
133 zv_list.append(zv)
134 z_list.append(z_sample)
135
136 # Pz (prior)
137 pz_logit = self.decoder['pz_nn'](y)
138 zm_prior = self.decoder['zm_prior_layer'](pz_logit)
139 zv_prior = torch.nn.functional.softplus(self.decoder['zv_prior_layer'](pz_logit))
140 noise = torch.randn_like(torch.sqrt(zv_prior))
141 z_prior_sample = zm_prior + noise * zv_prior
142
143 zm_prior_list.append(zm_prior)
144 zv_prior_list.append(zv_prior)
145
146 # Px
147 px_logit = self.decoder['px_nn'](z_prior_sample)
148 xm = self.decoder['xm_layer'](px_logit)
149 xv = torch.nn.functional.softplus(self.decoder['xv_layer'](px_logit))
150 noise = torch.randn_like(torch.sqrt(xv))
151 x_sample = xm + noise * xv
152
153 xm_list.append(xm)
154 xv_list.append(xv)
155 x_list.append(x_sample)
156
157 return (
158 data, qy_logit, xm_list, xv_list,
159 z_list, zm_list, zv_list,
160 zm_prior_list, zv_prior_list
161 )
162
163 def evaluate_model(self, batch, batch_idx):
164 """Evaluate the model on the data, computing average loss."""
165
166 x = batch['data']
167
168 if self.norm_in is not None:
169 data = self.norm_in(x)
170
171 qy_logit = self.encoder['qy_nn'](data)
172 qy = torch.softmax(qy_logit, dim=1)
173
174 y_ = torch.zeros([data.shape[0], self.k]).to(data.device)
175
176 zm_list, zv_list, z_list = [], [], []
177 xm_list, xv_list, x_list = [], [], []
178 zm_prior_list, zv_prior_list = [], []
179
180 for i in range(self.k):
181 # One-hot y
182 y = y_ + torch.eye(self.k).to(data.device)[i]
183
184 # Qz
185 h0 = self.encoder['y_transform'](y)
186 xy = torch.cat([data, h0], dim=1)
187 qz_logit = self.encoder['qz_nn'](xy)
188 zm = self.encoder['zm_layer'](qz_logit)
189 zv = torch.nn.functional.softplus(self.encoder['zv_layer'](qz_logit))
190 noise = torch.randn_like(torch.sqrt(zv))
191 z_sample = zm + noise * zv
192
193 zm_list.append(zm)
194 zv_list.append(zv)
195 z_list.append(z_sample)
196
197 # Pz (prior)
198 pz_logit = self.decoder['pz_nn'](y)
199 zm_prior = self.decoder['zm_prior_layer'](pz_logit)
200 zv_prior = torch.nn.functional.softplus(self.decoder['zv_prior_layer'](pz_logit))
201 noise = torch.randn_like(torch.sqrt(zv_prior))
202 z_prior_sample = zm_prior + noise * zv_prior
203
204 zm_prior_list.append(zm_prior)
205 zv_prior_list.append(zv_prior)
206
207 # Px
208 px_logit = self.decoder['px_nn'](z_prior_sample)
209 xm = self.decoder['xm_layer'](px_logit)
210 xv = torch.nn.functional.softplus(self.decoder['xv_layer'](px_logit))
211 noise = torch.randn_like(torch.sqrt(xv))
212 x_sample = xm + noise * xv
213
214 xm_list.append(xm)
215 xv_list.append(xv)
216 x_list.append(x_sample)
217
218 xhat = torch.sum(qy.unsqueeze(-1) * torch.stack(x_list, dim=1), dim=1)
219
220 if self.norm_in is not None:
221 xhat = self.norm_in.inverse(xhat)
222
223 z = torch.sum(qy.unsqueeze(-1) * torch.stack(z_list, dim=1), dim=1)
224
225 return xhat, z, qy
226
227 def decode(self, z):
228 """
229 Reconstruct x' from aggregated z
230 """
231 if z.dim() == 1:
232 z = z.unsqueeze(0)
233
234 px_logit = self.decoder['px_nn'](z)
235 xm = self.decoder['xm_layer'](px_logit)
236 xv = torch.nn.functional.softplus(self.decoder['xv_layer'](px_logit))
237 noise = torch.randn_like(torch.sqrt(xv))
238 x = xm + noise * xv
239
240 if self.norm_in is not None:
241 x = self.norm_in.inverse(x)
242
243 return x
244
245 def forward_cv(self, x):
246
247 if self.norm_in is not None:
248 x = self.norm_in(x)
249
250 qy_logit = self.encoder['qy_nn'](x)
251 qy = torch.softmax(qy_logit, dim=1)
252
253 y_ = torch.zeros([x.shape[0], self.k]).to(x.device)
254
255 zm_list, zv_list, z_list = [], [], []
256
257 for i in range(self.k):
258 # One-hot y
259 y = y_ + torch.eye(self.k).to(x.device)[i]
260
261 # Qz
262 h0 = self.encoder['y_transform'](y)
263 xy = torch.cat([x, h0], dim=1)
264 qz_logit = self.encoder['qz_nn'](xy)
265 zm = self.encoder['zm_layer'](qz_logit)
266 zv = torch.nn.functional.softplus(self.encoder['zv_layer'](qz_logit))
267 noise = torch.randn_like(torch.sqrt(zv))
268 z_sample = zm + noise * zv
269
270 zm_list.append(zm)
271 zv_list.append(zv)
272 z_list.append(z_sample)
273
274 Z = torch.stack(z_list, dim=1)
275 a = torch.sum(qy.unsqueeze(-1) * Z, dim=1)
276
277 return a
278
279 def training_step(self, train_batch, batch_idx):
280
281 x = train_batch["data"]
282
283 if "target" in train_batch:
284 x_ref = train_batch["target"]
285 else:
286 x_ref = x
287
288 data, qy_logit, xm_list, xv_list, z_list, zm_list, zv_list, zm_prior_list, zv_prior_list = self.encode_decode(x_ref)
289
290 batch_loss, nent = self.loss_fn(data,
291 qy_logit,
292 xm_list, xv_list,
293 z_list, zm_list, zv_list,
294 zm_prior_list, zv_prior_list)
295
296 loss = batch_loss.mean()
297 ce_loss = nent.mean()
298
299 name = "train" if self.training else "valid"
300 self.log(f"{name}_loss", loss, on_epoch=True, on_step=True, prog_bar=True, logger=True)
301 self.log(f"{name}_cross_entropy", ce_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
302
303 return loss
304
305 def get_decoder(self, return_normalization=False):
306 """Return a torch model with the decoder and optionally the normalization inverse"""
307 if return_normalization:
308 if self.norm_in is not None:
309 inv_norm = Inverse(module=self.norm_in)
310 decoder_model = torch.nn.Sequential(*[self.decoder, inv_norm])
311 else:
312 raise ValueError(
313 "return_normalization is set to True but self.norm_in is None"
314 )
315 else:
316 decoder_model = self.decoder
317
318 return decoder_model
319
320
321 # # Example of usage:
322
323 # # Define dimensions
324 # n_features = 1551 # Input dimension
325 # n_clusters = 5 # Output dimension for Qy
326 # n_cvs = 3 # Latent dimension (CVs)
327 # r_nent = 0.5 # Weight for the entropy regularization term.
328
329 # # Encoder sizes
330 # qy_dims = [32]
331 # qz_dims = [16, 16]
332
333 # # Decoder sizes
334 # pz_dims = [16, 16]
335 # px_dims = [128]
336
337 # options = {
338 # "norm_in": {
339 # "mode": "mean_std"
340 # },
341 # "optimizer": {
342 # "lr": 1e-4
343 # }
344 # }
345
346 # # Instantiate your GMVAECV
347 # model = GaussianMixtureVariationalAutoEncoder(k=n_clusters,
348 # n_cvs=n_cvs,
349 # n_features=n_features,
350 # r_nent=r_nent,
351 # qy_dims=qy_dims,
352 # qz_dims=qz_dims,
353 # pz_dims=pz_dims,
354 # px_dims=px_dims,
355 # options=options)