Coverage for biobb_pytorch / mdae / models / gnnae.py: 100%
0 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1# import torch
2# import lightning
3# from mlcolvar.cvs import BaseCV
4# from torch_geometric.nn import GCNConv, global_mean_pool
5# from mlcolvar.core.transform.utils import Inverse
6# from biobb_pytorch.mdae.featurization.normalization import Normalization
7# from biobb_pytorch.mdae.loss import MSELoss
9# __all__ = ["GNNAutoEncoder"]
11# class GNNEncoder(torch.nn.Module):
12# def __init__(self, in_channels, hidden_dim, latent_dim):
13# super().__init__()
14# self.conv1 = GCNConv(in_channels, hidden_dim)
15# self.conv2 = GCNConv(hidden_dim, hidden_dim)
16# self.fc_mu = torch.nn.Linear(hidden_dim, latent_dim)
18# def forward(self, x, edge_index, batch):
19# x = torch.relu(self.conv1(x, edge_index))
20# x = torch.relu(self.conv2(x, edge_index))
21# x = global_mean_pool(x, batch)
22# z = self.fc_mu(x)
23# return z
25# class GNNDecoder(torch.nn.Module):
26# def __init__(self, latent_dim, hidden_dim, out_features):
27# super().__init__()
28# self.mlp = torch.nn.Sequential(
29# torch.nn.Linear(latent_dim, hidden_dim),
30# torch.nn.ReLU(),
31# torch.nn.Linear(hidden_dim, hidden_dim),
32# torch.nn.ReLU()
33# )
34# self.out_layer = torch.nn.Linear(hidden_dim, out_features)
36# def forward(self, z, batch):
37# z_expanded = z[batch]
38# x_rec = self.out_layer(self.mlp(z_expanded))
39# return x_rec
41# class GNNAutoEncoder(BaseCV, lightning.LightningModule):
42# BLOCKS = ["norm_in", "encoder", "decoder"]
44# def __init__(
45# self,
46# n_cvs: int,
47# encoder_layers: list,
48# decoder_layers: list = None,
49# edge_index: torch.Tensor = None,
50# options: dict = None,
51# **kwargs,
52# ):
53# super().__init__(in_features=encoder_layers[0], out_features=n_cvs, **kwargs)
55# self.loss_fn = MSELoss()
56# options = self.parse_options(options)
58# self.edge_index = edge_index
59# hidden_dim = encoder_layers[1] if len(encoder_layers) > 1 else 64
61# if decoder_layers is None:
62# decoder_layers = encoder_layers[::-1]
64# if options.get("norm_in", True):
65# self.norm_in = Normalization(self.in_features, **options.get("norm_in", {}))
66# else:
67# self.norm_in = None
69# self.encoder = GNNEncoder(in_channels=self.in_features, hidden_dim=hidden_dim, latent_dim=n_cvs)
70# self.decoder = GNNDecoder(latent_dim=n_cvs, hidden_dim=hidden_dim, out_features=self.in_features)
72# def forward_cv(self, x):
73# # x: [B, F], F = in_features
74# if self.norm_in is not None:
75# x = self.norm_in(x)
76# return self.encoder(x, self.edge_index, torch.zeros(x.shape[0], dtype=torch.long, device=x.device))
78# def encode_decode(self, x):
79# B, F = x.shape
80# if self.norm_in is not None:
81# x = self.norm_in(x)
82# # Simulate each frame as one node in a graph
83# z = self.encoder(x, self.edge_index, torch.zeros(B, dtype=torch.long, device=x.device))
84# x_rec = self.decoder(z, torch.arange(B, device=x.device))
85# if self.norm_in is not None:
86# x_rec = self.norm_in.inverse(x_rec)
87# return x_rec
89# def training_step(self, batch, batch_idx):
90# x = batch["data"]
91# target = batch.get("target", x)
93# x_hat = self.encode_decode(x)
94# loss = self.loss_fn(x_hat, target)
96# name = "train" if self.training else "valid"
97# self.log(f"{name}_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
98# return loss
100# def get_decoder(self, return_normalization=False):
101# if return_normalization:
102# if self.norm_in is not None:
103# inv_norm = Inverse(module=self.norm_in)
104# decoder_model = torch.nn.Sequential(self.decoder, inv_norm)
105# else:
106# raise ValueError("return_normalization is set to True but self.norm_in is None")
107# else:
108# decoder_model = self.decoder
109# return decoder_model
111# @property
112# def example_input_array(self):
113# return None