Coverage for biobb_pytorch / mdae / models / gmvae.py: 26%

174 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import torch 

2import torch.nn as nn 

3import lightning.pytorch as pl 

4from mlcolvar.cvs import BaseCV 

5from biobb_pytorch.mdae.featurization.normalization import Normalization 

6from mlcolvar.core.transform.utils import Inverse 

7from biobb_pytorch.mdae.models.nn.feedforward import FeedForward 

8from biobb_pytorch.mdae.loss import ELBOGaussianMixtureLoss 

9 

10 

11__all__ = ["GaussianMixtureVariationalAutoEncoder"] 

12 

13 

14class GaussianMixtureVariationalAutoEncoder(BaseCV, pl.LightningModule): 

15 """Gaussian Mixture Variational AutoEncoder Collective Variable. 

16 This class implements a Gaussian Mixture Variational AutoEncoder (GMVAE) for 

17 collective variable (CV) learning. The GMVAE is a generative model that combines 

18 the principles of Gaussian Mixture Models (GMM) and Variational Autoencoders (VAE). 

19 It learns a latent representation of the input data by modeling it as a mixture 

20 of Gaussians, where each Gaussian corresponds to a different cluster in the data. 

21 The model consists of an encoder that maps the input data to a latent space, 

22 and a decoder that reconstructs the input data from the latent representation. 

23 The GMVAE is trained using a variational inference approach, where the model 

24 learns to maximize the evidence lower bound (ELBO) on the data likelihood. 

25 The ELBO consists of two terms: the reconstruction loss and the KL divergence 

26 between the learned latent distribution and a prior distribution (usually a 

27 standard normal distribution). The GMVAE can be used for various tasks such as 

28 clustering, dimensionality reduction, and generative modeling. 

29 The model is designed to work with PyTorch and PyTorch Lightning, making it easy 

30 to integrate into existing workflows and leverage GPU acceleration. 

31 Parameters 

32 ---------- 

33 k : int 

34 The number of clusters in the Gaussian Mixture Model. 

35 n_cvs : int 

36 The dimension of the CV or, equivalently, the dimension of the latent 

37 space of the autoencoder. 

38 n_features : int 

39 The dimension of the input data. 

40 r_nent : float 

41 The weight for the entropy regularization term. 

42 qy_dims : list 

43 The dimensions of the layers in the encoder for the cluster assignment. 

44 qz_dims : list 

45 The dimensions of the layers in the encoder for the latent variable. 

46 pz_dims : list 

47 The dimensions of the layers in the decoder for the latent variable. 

48 px_dims : list 

49 The dimensions of the layers in the decoder for the reconstruction. 

50 options : dict, optional 

51 Additional options for the model, such as normalization and dropout rates. 

52 """ 

53 

54 BLOCKS = ["norm_in", "encoder", "decoder", "k"] 

55 

56 def __init__(self, n_features, n_cvs, encoder_layers, decoder_layers, options=None, **kwargs): 

57 super().__init__(in_features=n_features, out_features=n_cvs, **kwargs) 

58 

59 options = self.parse_options(options) 

60 

61 if "norm_in" in options and options["norm_in"] is not None: 

62 self.norm_in = Normalization(self.in_features, **options["norm_in"]) 

63 

64 self.k = options["k"] 

65 self.r_nent = options.get('loss_function', {}).get("r_nent", 0.5) 

66 

67 qy_dims = encoder_layers["qy_dims"] 

68 qz_dims = encoder_layers["qz_dims"] 

69 pz_dims = decoder_layers["pz_dims"] 

70 px_dims = decoder_layers["px_dims"] 

71 

72 self.loss_fn = ELBOGaussianMixtureLoss(r_nent=self.r_nent, k=self.k) 

73 

74 self.encoder = nn.ModuleDict() 

75 self.decoder = nn.ModuleDict() 

76 

77 self.encoder['y_transform'] = nn.Linear(self.k, self.k) 

78 

79 self.encoder['qy_nn'] = FeedForward([n_features] + qy_dims + [self.k], **options["encoder"]['qy_nn']) 

80 

81 self.encoder['qz_nn'] = FeedForward([n_features + self.k] + qz_dims, **options["encoder"]['qz_nn']) 

82 self.encoder['zm_layer'] = nn.Linear(qz_dims[-1], n_cvs) 

83 self.encoder['zv_layer'] = nn.Linear(qz_dims[-1], n_cvs) 

84 

85 self.decoder['pz_nn'] = FeedForward([self.k] + pz_dims, **options["decoder"]['pz_nn']) 

86 self.decoder['zm_prior_layer'] = nn.Linear(pz_dims[-1], n_cvs) 

87 self.decoder['zv_prior_layer'] = nn.Linear(pz_dims[-1], n_cvs) 

88 

89 self.decoder['px_nn'] = FeedForward([n_cvs] + px_dims, **options["decoder"]['px_nn']) 

90 self.decoder['xm_layer'] = nn.Linear(px_dims[-1], n_features) 

91 self.decoder['xv_layer'] = nn.Linear(px_dims[-1], n_features) 

92 

93 self.eval_variables = ["xhat", "z", "qy"] 

94 

95 @staticmethod 

96 def log_normal(x, mu, var, eps=1e-10): 

97 return -0.5 * torch.sum(torch.log(torch.tensor(2.0) * torch.pi) + (x - mu).pow(2) / var + var.log(), dim=-1) # log probability of a normal (Gaussian) distribution 

98 

99 def loss_function(self, x, xm, xv, z, zm, zv, zm_prior, zv_prior): 

100 return ( 

101 -self.log_normal(x, xm, xv) + # Reconstruction Loss 

102 self.log_normal(z, zm, zv) - self.log_normal(z, zm_prior, zv_prior) - # Regularization Loss (KL Divergence) 

103 torch.log(torch.tensor(1 / self.k, device=x.device)) # Entropy Regularization 

104 ) 

105 

106 def encode_decode(self, x): 

107 

108 if self.norm_in is not None: 

109 data = self.norm_in(x) 

110 

111 qy_logit = self.encoder['qy_nn'](data) 

112 

113 y_ = torch.zeros([data.shape[0], self.k]).to(data.device) 

114 

115 zm_list, zv_list, z_list = [], [], [] 

116 xm_list, xv_list, x_list = [], [], [] 

117 zm_prior_list, zv_prior_list = [], [] 

118 

119 for i in range(self.k): 

120 # One-hot y 

121 y = y_ + torch.eye(self.k).to(data.device)[i] 

122 

123 # Qz 

124 h0 = self.encoder['y_transform'](y) 

125 xy = torch.cat([data, h0], dim=1) 

126 qz_logit = self.encoder['qz_nn'](xy) 

127 zm = self.encoder['zm_layer'](qz_logit) 

128 zv = torch.nn.functional.softplus(self.encoder['zv_layer'](qz_logit)) 

129 noise = torch.randn_like(torch.sqrt(zv)) 

130 z_sample = zm + noise * zv 

131 

132 zm_list.append(zm) 

133 zv_list.append(zv) 

134 z_list.append(z_sample) 

135 

136 # Pz (prior) 

137 pz_logit = self.decoder['pz_nn'](y) 

138 zm_prior = self.decoder['zm_prior_layer'](pz_logit) 

139 zv_prior = torch.nn.functional.softplus(self.decoder['zv_prior_layer'](pz_logit)) 

140 noise = torch.randn_like(torch.sqrt(zv_prior)) 

141 z_prior_sample = zm_prior + noise * zv_prior 

142 

143 zm_prior_list.append(zm_prior) 

144 zv_prior_list.append(zv_prior) 

145 

146 # Px 

147 px_logit = self.decoder['px_nn'](z_prior_sample) 

148 xm = self.decoder['xm_layer'](px_logit) 

149 xv = torch.nn.functional.softplus(self.decoder['xv_layer'](px_logit)) 

150 noise = torch.randn_like(torch.sqrt(xv)) 

151 x_sample = xm + noise * xv 

152 

153 xm_list.append(xm) 

154 xv_list.append(xv) 

155 x_list.append(x_sample) 

156 

157 return ( 

158 data, qy_logit, xm_list, xv_list, 

159 z_list, zm_list, zv_list, 

160 zm_prior_list, zv_prior_list 

161 ) 

162 

163 def evaluate_model(self, batch, batch_idx): 

164 """Evaluate the model on the data, computing average loss.""" 

165 

166 x = batch['data'] 

167 

168 if self.norm_in is not None: 

169 data = self.norm_in(x) 

170 

171 qy_logit = self.encoder['qy_nn'](data) 

172 qy = torch.softmax(qy_logit, dim=1) 

173 

174 y_ = torch.zeros([data.shape[0], self.k]).to(data.device) 

175 

176 zm_list, zv_list, z_list = [], [], [] 

177 xm_list, xv_list, x_list = [], [], [] 

178 zm_prior_list, zv_prior_list = [], [] 

179 

180 for i in range(self.k): 

181 # One-hot y 

182 y = y_ + torch.eye(self.k).to(data.device)[i] 

183 

184 # Qz 

185 h0 = self.encoder['y_transform'](y) 

186 xy = torch.cat([data, h0], dim=1) 

187 qz_logit = self.encoder['qz_nn'](xy) 

188 zm = self.encoder['zm_layer'](qz_logit) 

189 zv = torch.nn.functional.softplus(self.encoder['zv_layer'](qz_logit)) 

190 noise = torch.randn_like(torch.sqrt(zv)) 

191 z_sample = zm + noise * zv 

192 

193 zm_list.append(zm) 

194 zv_list.append(zv) 

195 z_list.append(z_sample) 

196 

197 # Pz (prior) 

198 pz_logit = self.decoder['pz_nn'](y) 

199 zm_prior = self.decoder['zm_prior_layer'](pz_logit) 

200 zv_prior = torch.nn.functional.softplus(self.decoder['zv_prior_layer'](pz_logit)) 

201 noise = torch.randn_like(torch.sqrt(zv_prior)) 

202 z_prior_sample = zm_prior + noise * zv_prior 

203 

204 zm_prior_list.append(zm_prior) 

205 zv_prior_list.append(zv_prior) 

206 

207 # Px 

208 px_logit = self.decoder['px_nn'](z_prior_sample) 

209 xm = self.decoder['xm_layer'](px_logit) 

210 xv = torch.nn.functional.softplus(self.decoder['xv_layer'](px_logit)) 

211 noise = torch.randn_like(torch.sqrt(xv)) 

212 x_sample = xm + noise * xv 

213 

214 xm_list.append(xm) 

215 xv_list.append(xv) 

216 x_list.append(x_sample) 

217 

218 xhat = torch.sum(qy.unsqueeze(-1) * torch.stack(x_list, dim=1), dim=1) 

219 

220 if self.norm_in is not None: 

221 xhat = self.norm_in.inverse(xhat) 

222 

223 z = torch.sum(qy.unsqueeze(-1) * torch.stack(z_list, dim=1), dim=1) 

224 

225 return xhat, z, qy 

226 

227 def decode(self, z): 

228 """ 

229 Reconstruct x' from aggregated z 

230 """ 

231 if z.dim() == 1: 

232 z = z.unsqueeze(0) 

233 

234 px_logit = self.decoder['px_nn'](z) 

235 xm = self.decoder['xm_layer'](px_logit) 

236 xv = torch.nn.functional.softplus(self.decoder['xv_layer'](px_logit)) 

237 noise = torch.randn_like(torch.sqrt(xv)) 

238 x = xm + noise * xv 

239 

240 if self.norm_in is not None: 

241 x = self.norm_in.inverse(x) 

242 

243 return x 

244 

245 def forward_cv(self, x): 

246 

247 if self.norm_in is not None: 

248 x = self.norm_in(x) 

249 

250 qy_logit = self.encoder['qy_nn'](x) 

251 qy = torch.softmax(qy_logit, dim=1) 

252 

253 y_ = torch.zeros([x.shape[0], self.k]).to(x.device) 

254 

255 zm_list, zv_list, z_list = [], [], [] 

256 

257 for i in range(self.k): 

258 # One-hot y 

259 y = y_ + torch.eye(self.k).to(x.device)[i] 

260 

261 # Qz 

262 h0 = self.encoder['y_transform'](y) 

263 xy = torch.cat([x, h0], dim=1) 

264 qz_logit = self.encoder['qz_nn'](xy) 

265 zm = self.encoder['zm_layer'](qz_logit) 

266 zv = torch.nn.functional.softplus(self.encoder['zv_layer'](qz_logit)) 

267 noise = torch.randn_like(torch.sqrt(zv)) 

268 z_sample = zm + noise * zv 

269 

270 zm_list.append(zm) 

271 zv_list.append(zv) 

272 z_list.append(z_sample) 

273 

274 Z = torch.stack(z_list, dim=1) 

275 a = torch.sum(qy.unsqueeze(-1) * Z, dim=1) 

276 

277 return a 

278 

279 def training_step(self, train_batch, batch_idx): 

280 

281 x = train_batch["data"] 

282 

283 if "target" in train_batch: 

284 x_ref = train_batch["target"] 

285 else: 

286 x_ref = x 

287 

288 data, qy_logit, xm_list, xv_list, z_list, zm_list, zv_list, zm_prior_list, zv_prior_list = self.encode_decode(x_ref) 

289 

290 batch_loss, nent = self.loss_fn(data, 

291 qy_logit, 

292 xm_list, xv_list, 

293 z_list, zm_list, zv_list, 

294 zm_prior_list, zv_prior_list) 

295 

296 loss = batch_loss.mean() 

297 ce_loss = nent.mean() 

298 

299 name = "train" if self.training else "valid" 

300 self.log(f"{name}_loss", loss, on_epoch=True, on_step=True, prog_bar=True, logger=True) 

301 self.log(f"{name}_cross_entropy", ce_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 

302 

303 return loss 

304 

305 def get_decoder(self, return_normalization=False): 

306 """Return a torch model with the decoder and optionally the normalization inverse""" 

307 if return_normalization: 

308 if self.norm_in is not None: 

309 inv_norm = Inverse(module=self.norm_in) 

310 decoder_model = torch.nn.Sequential(*[self.decoder, inv_norm]) 

311 else: 

312 raise ValueError( 

313 "return_normalization is set to True but self.norm_in is None" 

314 ) 

315 else: 

316 decoder_model = self.decoder 

317 

318 return decoder_model 

319 

320 

321# # Example of usage: 

322 

323# # Define dimensions 

324# n_features = 1551 # Input dimension 

325# n_clusters = 5 # Output dimension for Qy 

326# n_cvs = 3 # Latent dimension (CVs) 

327# r_nent = 0.5 # Weight for the entropy regularization term. 

328 

329# # Encoder sizes 

330# qy_dims = [32] 

331# qz_dims = [16, 16] 

332 

333# # Decoder sizes 

334# pz_dims = [16, 16] 

335# px_dims = [128] 

336 

337# options = { 

338# "norm_in": { 

339# "mode": "mean_std" 

340# }, 

341# "optimizer": { 

342# "lr": 1e-4 

343# } 

344# } 

345 

346# # Instantiate your GMVAECV 

347# model = GaussianMixtureVariationalAutoEncoder(k=n_clusters, 

348# n_cvs=n_cvs, 

349# n_features=n_features, 

350# r_nent=r_nent, 

351# qy_dims=qy_dims, 

352# qz_dims=qz_dims, 

353# pz_dims=pz_dims, 

354# px_dims=px_dims, 

355# options=options)