Coverage for biobb_pytorch / mdae / models / molearn.py: 20%
179 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import numpy as np
5from biobb_pytorch.mdae.loss.utils.torch_protein_energy import TorchProteinEnergy
6from biobb_pytorch.mdae.loss.physics_loss import PhysicsLoss
7import lightning.pytorch as pl
8from mlcolvar.cvs import BaseCV
11def index_points(point_clouds, index):
12 '''
13 Given a batch of tensor and index, select sub-tensor.
15 :param points_clouds: input points data, [B, N, C]
16 :param index: sample index data, [B, N, k]
17 :return: indexed points data, [B, N, k, C]
18 '''
19 device = point_clouds.device
20 batch_size = point_clouds.shape[0]
21 view_shape = list(index.shape)
22 view_shape[1:] = [1] * (len(view_shape) - 1)
23 repeat_shape = list(index.shape)
24 repeat_shape[0] = 1
25 batch_indices = torch.arange(batch_size, dtype=torch.long, device=device).view(view_shape).repeat(repeat_shape)
26 new_points = point_clouds[batch_indices, index, :]
27 return new_points
30def knn(x, k):
31 '''
32 K nearest neighborhood.
34 :param x: a tensor with size of (B, C, N)
35 :param k: the number of nearest neighborhoods
36 :return: indices of the k nearest neighborhoods with size of (B, N, k)
37 '''
38 inner = -2 * torch.matmul(x.transpose(2, 1), x) # (B, N, N)
39 xx = torch.sum(x ** 2, dim=1, keepdim=True) # (B, 1, N)
40 pairwise_distance = -xx - inner - xx.transpose(2, 1) # (B, 1, N), (B, N, N), (B, N, 1) -> (B, N, N)
42 idx = pairwise_distance.topk(k=k, dim=-1)[1] # (B, N, k)
43 return idx
46class GraphLayer(nn.Module):
47 '''
48 Graph layer.
49 in_channel: it depends on the input of this network.
50 out_channel: given by ourselves.
51 '''
53 def __init__(self, in_channel, out_channel, k=16):
54 super(GraphLayer, self).__init__()
55 self.k = k
56 self.conv = nn.Conv1d(in_channel, out_channel, 1)
57 self.bn = nn.BatchNorm1d(out_channel)
59 def forward(self, x):
60 '''
61 :param x: tensor with size of (B, C, N)
62 '''
63 # KNN
64 knn_idx = knn(x, k=self.k) # (B, N, k)
65 knn_x = index_points(x.permute(0, 2, 1), knn_idx) # (B, N, k, C)
67 # Local Max Pooling
68 x = torch.max(knn_x, dim=2)[0].permute(0, 2, 1) # (B, N, C)
70 # Feature Map
71 x = F.relu(self.bn(self.conv(x)))
72 return x
75class Encoder(nn.Module):
76 '''
77 Graph based encoder
78 '''
80 def __init__(self, latent_dimension=2, **kwargs):
81 super(Encoder, self).__init__()
82 self.latent_dimension = latent_dimension
83 self.conv1 = nn.Conv1d(12, 64, 1)
84 self.conv2 = nn.Conv1d(64, 64, 1)
85 self.conv3 = nn.Conv1d(64, 64, 1)
87 self.bn1 = nn.BatchNorm1d(64)
88 self.bn2 = nn.BatchNorm1d(64)
89 self.bn3 = nn.BatchNorm1d(64)
91 self.graph_layer1 = GraphLayer(in_channel=64, out_channel=128, k=16)
92 self.graph_layer2 = GraphLayer(in_channel=128, out_channel=1024, k=16)
94 self.conv4 = nn.Conv1d(1024, 512, 1)
95 self.bn4 = nn.BatchNorm1d(512)
96 self.conv5 = nn.Conv1d(512, latent_dimension, 1)
98 def forward(self, x):
100 b, c, n = x.size()
102 # get the covariances, reshape and concatenate with x
103 knn_idx = knn(x, k=16)
104 knn_x = index_points(x.permute(0, 2, 1), knn_idx) # (B, N, 16, 3)
105 mean = torch.mean(knn_x, dim=2, keepdim=True)
106 knn_x = knn_x - mean
107 covariances = torch.matmul(knn_x.transpose(2, 3), knn_x).view(b, n, -1).permute(0, 2, 1)
108 x = torch.cat([x, covariances], dim=1) # (B, 12, N)
110 # three layer MLP
111 x = F.relu(self.bn1(self.conv1(x)))
112 x = F.relu(self.bn2(self.conv2(x)))
113 x = F.relu(self.bn3(self.conv3(x)))
115 # two consecutive graph layers
116 x = self.graph_layer1(x)
117 x = self.graph_layer2(x)
119 x = self.bn4(self.conv4(x))
121 x = torch.max(x, dim=-1)[0].unsqueeze(-1)
123 x = self.conv5(x)
124 return x
127class FoldingLayer(nn.Module):
128 '''
129 The folding operation of FoldingNet
130 '''
132 def __init__(self, in_channel: int, out_channels: list):
133 super(FoldingLayer, self).__init__()
135 layers = []
136 for oc in out_channels[:-1]:
137 conv = nn.Conv1d(in_channel, oc, 3, 1, 1)
138 bn = nn.BatchNorm1d(oc)
139 active = nn.ReLU(inplace=True)
140 layers.extend([conv, bn, active])
141 in_channel = oc
142 out_layer = nn.Conv1d(in_channel, out_channels[-1], 3, 1, 1)
143 layers.append(out_layer)
145 self.layers = nn.Sequential(*layers)
147 def forward(self, *args):
148 """
149 :param grids: reshaped 2D grids or intermediam reconstructed point clouds
150 """
151 # concatenate
152 # try:
153 # x = torch.cat([*args], dim=1)
154 # except:
155 # for arg in args:
156 # print(arg.shape)
157 # raise
158 x = torch.cat([*args], dim=1)
159 # shared mlp
160 x = self.layers(x)
162 return x
165class Decoder_Layer(nn.Module):
166 '''
167 Decoder Module of FoldingNet
168 '''
170 def __init__(self, in_features, out_features, in_channel, out_channel, **kwargs):
171 super(Decoder_Layer, self).__init__()
173 # Sample the grids in 2D space
174 # xx = np.linspace(-0.3, 0.3, 45, dtype=np.float32)
175 # yy = np.linspace(-0.3, 0.3, 45, dtype=np.float32)
176 # self.grid = np.meshgrid(xx, yy) # (2, 45, 45)
177 self.out_features = out_features
178 self.grid = torch.linspace(-0.5, 0.5, out_features).view(1, -1)
179 # reshape
180 # self.grid = torch.Tensor(self.grid).view(2, -1) # (2, 45, 45) -> (2, 45 * 45)
181 assert out_features % in_features == 0
182 self.m = out_features // in_features
184 self.fold1 = FoldingLayer(in_channel + 1, [512, 512, out_channel])
185 self.fold2 = FoldingLayer(in_channel + out_channel + 1, [512, 512, out_channel])
187 def forward(self, x):
188 '''
189 :param x: (B, C)
190 '''
191 batch_size = x.shape[0]
193 # repeat grid for batch operation
194 grid = self.grid.to(x.device) # (2, 45 * 45)
195 grid = grid.unsqueeze(0).repeat(batch_size, 1, 1) # (B, 2, 45 * 45)
197 # repeat codewords
198 x = x.repeat_interleave(self.m, dim=-1) # (B, 512, 45 * 45)
200 # two folding operations
201 recon1 = self.fold1(grid, x)
202 recon2 = recon1 + self.fold2(grid, x, recon1)
204 return recon2
207class Decoder(nn.Module):
208 '''
209 Decoder Module of FoldingNet
210 '''
212 def __init__(self, out_features, latent_dimension=2, **kwargs):
213 super(Decoder, self).__init__()
214 self.latent_dimension = latent_dimension
216 # Sample the grids in 2D space
217 # xx = np.linspace(-0.3, 0.3, 45, dtype=np.float32)
218 # yy = np.linspace(-0.3, 0.3, 45, dtype=np.float32)
219 # self.grid = np.meshgrid(xx, yy) # (2, 45, 45)
221 start_out = (out_features // 128) + 1
223 self.out_features = out_features
225 self.layer1 = Decoder_Layer(1, start_out, latent_dimension, 3 * 128)
226 self.layer2 = Decoder_Layer(start_out, start_out * 8, 3 * 128, 3 * 16)
227 self.layer3 = Decoder_Layer(start_out * 8, start_out * 32, 3 * 16, 3 * 4)
228 self.layer4 = Decoder_Layer(start_out * 32, start_out * 128, 3 * 4, 3)
230 def forward(self, x):
231 '''
232 x: (B, C)
233 '''
234 x = x.view(-1, self.latent_dimension, 1)
235 x = self.layer1(x)
236 x = self.layer2(x)
237 x = self.layer3(x)
238 x = self.layer4(x)
240 return x
243class CNNAutoEncoder(BaseCV, pl.LightningModule):
244 '''
245 Autoencoder architecture derived from FoldingNet.
246 '''
248 BLOCKS = ["norm_in", "encoder", "decoder"]
250 def __init__(self, n_features, n_cvs, encoder_layers, decoder_layers, options=None, **kwargs):
252 super().__init__(in_features=n_features, out_features=n_cvs, **kwargs)
254 options = self.parse_options(options)
256 if "norm_in" in options and options["norm_in"] is not None:
257 self.mean = options["norm_in"]["stats"]["parametric"][0]
258 self.std = options["norm_in"]["stats"]["parametric"][1]
259 # self.norm_in = Normalization(self.in_features, mean=self.mean, range=self.std)
261 top = options["norm_in"]["stats"]["topology"][0]
262 self.in_features = top.n_atoms
263 self.out_features = n_features
265 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
267 x0_coords = torch.tensor(top.xyz[0]).permute(1, 0)
269 atominfo = []
270 for i in top.topology.atoms:
271 atominfo.append([i.name, i.residue.name, i.residue.index + 1])
272 atominfo = np.array(atominfo, dtype=object)
274 # if trainer == "Torch":
275 self.protein_energy = TorchProteinEnergy(x0_coords,
276 atominfo,
277 device=device,
278 method='roll')
280 # if options["trainer"] == "OpenMM":
282 # from loss.utils.openmm_thread import openmm_energy_setup, openmm_energy, soft_xml_script
284 # protein_energy_setup = openmm_energy_setup(top)
285 # atominfo = protein_energy_setup.get_atominfo()
286 # self.mol = protein_energy_setup.mol_dataframe()
288 # atoms_selected = list(set([atom.name for atom in top.topology.atoms]))
290 # openmm_kwargs = options["Trainer"]["OpenMM"]
292 # if xml_file is None and openmm_kwargs.get("soft_NB", None):
293 # print("using soft nonbonded forces by default")
294 # import uuid
295 # random_string = str(uuid.uuid4())[:8]
296 # tmp_filename = f"soft_nonbonded_{random_string()}.xml"
297 # with open(tmp_filename, "w") as f:
298 # f.write(soft_xml_script)
299 # xml_file = ["amber14-all.xml", tmp_filename]
300 # kwargs["remove_NB"] = True
301 # elif xml_file is None:
302 # xml_file = ["amber14-all.xml"]
303 # self.start_physics_at = openmm_kwargs.get("start_physics_at", 10)
304 # self.psf = physics_scaling_factor
305 # if openmm_kwargs.get("clamp", False):
306 # clamp_kwargs = dict(max=openmm_kwargs.get("clamp_threshold",1e8),
307 # min=-openmm_kwargs.get("clamp_threshold", 1e8))
308 # else:
309 # clamp_kwargs = None
311 # self.protein_energy = openmm_energy(self.mol,
312 # self.std,
313 # clamp=clamp_kwargs,
314 # platform="CUDA" if self.device == torch.device("cuda") else "Reference",
315 # atoms=atoms_selected,
316 # xml_file=xml_file,
317 # **kwargs)
318 # os.remove(tmp_filename)
320 self.loss_fn = PhysicsLoss(stats=None,
321 protein_energy=self.protein_energy,
322 physics_scaling_factor=0.1)
324 self.encoder = Encoder(self.out_features, **kwargs)
325 self.decoder = Decoder(self.in_features, self.out_features, **kwargs)
327 def decode(self, x):
328 x = self.decoder(x)
329 return x
331 def encode_decode(self, batch):
332 z = self.forward_cv(batch)
333 x = self.decoder(z)[:, :, : batch.size(2)]
334 interpolated = self.decode_interpolation(x, z)[:, :, : batch.size(2)]
335 return z, x, interpolated
337 def forward_cv(self, x):
338 if self.norm_in is not None:
339 x = (x - self.mean) / self.std
340 z = self.encoder(x)
341 return z
343 def training_step(self, train_batch, batch_idx):
345 x = train_batch["data"]
346 x = x.view(-1, 3, self.in_features)
348 if "target" in train_batch:
349 x_ref = train_batch["target"]
350 else:
351 x_ref = x
353 _, x_hat, xhat_interpolated = self.encode_decode(x_ref)
355 xhat_interpolated = xhat_interpolated * self.std + self.mean
357 mse_loss, physics_loss, scale = self.loss_fn(x_ref, x_hat, xhat_interpolated)
358 loss = mse_loss + scale * physics_loss['physics_loss']
360 name = "train" if self.training else "valid"
361 self.log(f"{name}_loss", loss, on_epoch=True, on_step=True, prog_bar=True, logger=True)
362 self.log(f"{name}_mse_loss", mse_loss, on_epoch=True, on_step=True, prog_bar=True, logger=True)
363 self.log(f"{name}_physics_loss", physics_loss['physics_loss'], on_epoch=True, on_step=True, prog_bar=True, logger=True)
364 self.log(f"{name}_bond_energy", physics_loss['bond_energy'], on_epoch=True, on_step=True, prog_bar=True, logger=True)
365 self.log(f"{name}_angle_energy", physics_loss['angle_energy'], on_epoch=True, on_step=True, prog_bar=True, logger=True)
366 self.log(f"{name}_torsion_energy", physics_loss['torsion_energy'], on_epoch=True, on_step=True, prog_bar=True, logger=True)
368 return loss
370 def decode_interpolation(self, batch, latent):
371 """
372 Decode a latent vector to a protein structure.
374 :param torch.Tensor latent: Latent vector to decode.
375 :return: Decoded protein structure.
376 :rtype: torch.Tensor
377 """
379 alpha = torch.rand(int(len(batch) // 2), 1, 1).type_as(latent)
380 latent_interpolated = (1 - alpha) * latent[:-1:2] + alpha * latent[1::2]
381 decoded_interpolation = self.decoder(latent_interpolated) * self.std + self.mean
382 return decoded_interpolation
384 @property
385 def example_input_array(self):
386 # return a dummy tensor of shape [1, C, N] on the correct device
387 return torch.zeros(1, 3, self.in_features)
389 # def configure_optimizers(self):
390 # optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
391 # return optimizer