Coverage for biobb_pytorch / mdae / models / gnnae.py: 100%

0 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1# import torch 

2# import lightning 

3# from mlcolvar.cvs import BaseCV 

4# from torch_geometric.nn import GCNConv, global_mean_pool 

5# from mlcolvar.core.transform.utils import Inverse 

6# from biobb_pytorch.mdae.featurization.normalization import Normalization 

7# from biobb_pytorch.mdae.loss import MSELoss 

8 

9# __all__ = ["GNNAutoEncoder"] 

10 

11# class GNNEncoder(torch.nn.Module): 

12# def __init__(self, in_channels, hidden_dim, latent_dim): 

13# super().__init__() 

14# self.conv1 = GCNConv(in_channels, hidden_dim) 

15# self.conv2 = GCNConv(hidden_dim, hidden_dim) 

16# self.fc_mu = torch.nn.Linear(hidden_dim, latent_dim) 

17 

18# def forward(self, x, edge_index, batch): 

19# x = torch.relu(self.conv1(x, edge_index)) 

20# x = torch.relu(self.conv2(x, edge_index)) 

21# x = global_mean_pool(x, batch) 

22# z = self.fc_mu(x) 

23# return z 

24 

25# class GNNDecoder(torch.nn.Module): 

26# def __init__(self, latent_dim, hidden_dim, out_features): 

27# super().__init__() 

28# self.mlp = torch.nn.Sequential( 

29# torch.nn.Linear(latent_dim, hidden_dim), 

30# torch.nn.ReLU(), 

31# torch.nn.Linear(hidden_dim, hidden_dim), 

32# torch.nn.ReLU() 

33# ) 

34# self.out_layer = torch.nn.Linear(hidden_dim, out_features) 

35 

36# def forward(self, z, batch): 

37# z_expanded = z[batch] 

38# x_rec = self.out_layer(self.mlp(z_expanded)) 

39# return x_rec 

40 

41# class GNNAutoEncoder(BaseCV, lightning.LightningModule): 

42# BLOCKS = ["norm_in", "encoder", "decoder"] 

43 

44# def __init__( 

45# self, 

46# n_cvs: int, 

47# encoder_layers: list, 

48# decoder_layers: list = None, 

49# edge_index: torch.Tensor = None, 

50# options: dict = None, 

51# **kwargs, 

52# ): 

53# super().__init__(in_features=encoder_layers[0], out_features=n_cvs, **kwargs) 

54 

55# self.loss_fn = MSELoss() 

56# options = self.parse_options(options) 

57 

58# self.edge_index = edge_index 

59# hidden_dim = encoder_layers[1] if len(encoder_layers) > 1 else 64 

60 

61# if decoder_layers is None: 

62# decoder_layers = encoder_layers[::-1] 

63 

64# if options.get("norm_in", True): 

65# self.norm_in = Normalization(self.in_features, **options.get("norm_in", {})) 

66# else: 

67# self.norm_in = None 

68 

69# self.encoder = GNNEncoder(in_channels=self.in_features, hidden_dim=hidden_dim, latent_dim=n_cvs) 

70# self.decoder = GNNDecoder(latent_dim=n_cvs, hidden_dim=hidden_dim, out_features=self.in_features) 

71 

72# def forward_cv(self, x): 

73# # x: [B, F], F = in_features 

74# if self.norm_in is not None: 

75# x = self.norm_in(x) 

76# return self.encoder(x, self.edge_index, torch.zeros(x.shape[0], dtype=torch.long, device=x.device)) 

77 

78# def encode_decode(self, x): 

79# B, F = x.shape 

80# if self.norm_in is not None: 

81# x = self.norm_in(x) 

82# # Simulate each frame as one node in a graph 

83# z = self.encoder(x, self.edge_index, torch.zeros(B, dtype=torch.long, device=x.device)) 

84# x_rec = self.decoder(z, torch.arange(B, device=x.device)) 

85# if self.norm_in is not None: 

86# x_rec = self.norm_in.inverse(x_rec) 

87# return x_rec 

88 

89# def training_step(self, batch, batch_idx): 

90# x = batch["data"] 

91# target = batch.get("target", x) 

92 

93# x_hat = self.encode_decode(x) 

94# loss = self.loss_fn(x_hat, target) 

95 

96# name = "train" if self.training else "valid" 

97# self.log(f"{name}_loss", loss, on_step=True, on_epoch=True, prog_bar=True) 

98# return loss 

99 

100# def get_decoder(self, return_normalization=False): 

101# if return_normalization: 

102# if self.norm_in is not None: 

103# inv_norm = Inverse(module=self.norm_in) 

104# decoder_model = torch.nn.Sequential(self.decoder, inv_norm) 

105# else: 

106# raise ValueError("return_normalization is set to True but self.norm_in is None") 

107# else: 

108# decoder_model = self.decoder 

109# return decoder_model 

110 

111# @property 

112# def example_input_array(self): 

113# return None