Coverage for biobb_pytorch / mdae / models / molearn.py: 20%

179 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import torch 

2import torch.nn as nn 

3import torch.nn.functional as F 

4import numpy as np 

5from biobb_pytorch.mdae.loss.utils.torch_protein_energy import TorchProteinEnergy 

6from biobb_pytorch.mdae.loss.physics_loss import PhysicsLoss 

7import lightning.pytorch as pl 

8from mlcolvar.cvs import BaseCV 

9 

10 

11def index_points(point_clouds, index): 

12 ''' 

13 Given a batch of tensor and index, select sub-tensor. 

14 

15 :param points_clouds: input points data, [B, N, C] 

16 :param index: sample index data, [B, N, k] 

17 :return: indexed points data, [B, N, k, C] 

18 ''' 

19 device = point_clouds.device 

20 batch_size = point_clouds.shape[0] 

21 view_shape = list(index.shape) 

22 view_shape[1:] = [1] * (len(view_shape) - 1) 

23 repeat_shape = list(index.shape) 

24 repeat_shape[0] = 1 

25 batch_indices = torch.arange(batch_size, dtype=torch.long, device=device).view(view_shape).repeat(repeat_shape) 

26 new_points = point_clouds[batch_indices, index, :] 

27 return new_points 

28 

29 

30def knn(x, k): 

31 ''' 

32 K nearest neighborhood. 

33 

34 :param x: a tensor with size of (B, C, N) 

35 :param k: the number of nearest neighborhoods 

36 :return: indices of the k nearest neighborhoods with size of (B, N, k) 

37 ''' 

38 inner = -2 * torch.matmul(x.transpose(2, 1), x) # (B, N, N) 

39 xx = torch.sum(x ** 2, dim=1, keepdim=True) # (B, 1, N) 

40 pairwise_distance = -xx - inner - xx.transpose(2, 1) # (B, 1, N), (B, N, N), (B, N, 1) -> (B, N, N) 

41 

42 idx = pairwise_distance.topk(k=k, dim=-1)[1] # (B, N, k) 

43 return idx 

44 

45 

46class GraphLayer(nn.Module): 

47 ''' 

48 Graph layer. 

49 in_channel: it depends on the input of this network. 

50 out_channel: given by ourselves. 

51 ''' 

52 

53 def __init__(self, in_channel, out_channel, k=16): 

54 super(GraphLayer, self).__init__() 

55 self.k = k 

56 self.conv = nn.Conv1d(in_channel, out_channel, 1) 

57 self.bn = nn.BatchNorm1d(out_channel) 

58 

59 def forward(self, x): 

60 ''' 

61 :param x: tensor with size of (B, C, N) 

62 ''' 

63 # KNN 

64 knn_idx = knn(x, k=self.k) # (B, N, k) 

65 knn_x = index_points(x.permute(0, 2, 1), knn_idx) # (B, N, k, C) 

66 

67 # Local Max Pooling 

68 x = torch.max(knn_x, dim=2)[0].permute(0, 2, 1) # (B, N, C) 

69 

70 # Feature Map 

71 x = F.relu(self.bn(self.conv(x))) 

72 return x 

73 

74 

75class Encoder(nn.Module): 

76 ''' 

77 Graph based encoder 

78 ''' 

79 

80 def __init__(self, latent_dimension=2, **kwargs): 

81 super(Encoder, self).__init__() 

82 self.latent_dimension = latent_dimension 

83 self.conv1 = nn.Conv1d(12, 64, 1) 

84 self.conv2 = nn.Conv1d(64, 64, 1) 

85 self.conv3 = nn.Conv1d(64, 64, 1) 

86 

87 self.bn1 = nn.BatchNorm1d(64) 

88 self.bn2 = nn.BatchNorm1d(64) 

89 self.bn3 = nn.BatchNorm1d(64) 

90 

91 self.graph_layer1 = GraphLayer(in_channel=64, out_channel=128, k=16) 

92 self.graph_layer2 = GraphLayer(in_channel=128, out_channel=1024, k=16) 

93 

94 self.conv4 = nn.Conv1d(1024, 512, 1) 

95 self.bn4 = nn.BatchNorm1d(512) 

96 self.conv5 = nn.Conv1d(512, latent_dimension, 1) 

97 

98 def forward(self, x): 

99 

100 b, c, n = x.size() 

101 

102 # get the covariances, reshape and concatenate with x 

103 knn_idx = knn(x, k=16) 

104 knn_x = index_points(x.permute(0, 2, 1), knn_idx) # (B, N, 16, 3) 

105 mean = torch.mean(knn_x, dim=2, keepdim=True) 

106 knn_x = knn_x - mean 

107 covariances = torch.matmul(knn_x.transpose(2, 3), knn_x).view(b, n, -1).permute(0, 2, 1) 

108 x = torch.cat([x, covariances], dim=1) # (B, 12, N) 

109 

110 # three layer MLP 

111 x = F.relu(self.bn1(self.conv1(x))) 

112 x = F.relu(self.bn2(self.conv2(x))) 

113 x = F.relu(self.bn3(self.conv3(x))) 

114 

115 # two consecutive graph layers 

116 x = self.graph_layer1(x) 

117 x = self.graph_layer2(x) 

118 

119 x = self.bn4(self.conv4(x)) 

120 

121 x = torch.max(x, dim=-1)[0].unsqueeze(-1) 

122 

123 x = self.conv5(x) 

124 return x 

125 

126 

127class FoldingLayer(nn.Module): 

128 ''' 

129 The folding operation of FoldingNet 

130 ''' 

131 

132 def __init__(self, in_channel: int, out_channels: list): 

133 super(FoldingLayer, self).__init__() 

134 

135 layers = [] 

136 for oc in out_channels[:-1]: 

137 conv = nn.Conv1d(in_channel, oc, 3, 1, 1) 

138 bn = nn.BatchNorm1d(oc) 

139 active = nn.ReLU(inplace=True) 

140 layers.extend([conv, bn, active]) 

141 in_channel = oc 

142 out_layer = nn.Conv1d(in_channel, out_channels[-1], 3, 1, 1) 

143 layers.append(out_layer) 

144 

145 self.layers = nn.Sequential(*layers) 

146 

147 def forward(self, *args): 

148 """ 

149 :param grids: reshaped 2D grids or intermediam reconstructed point clouds 

150 """ 

151 # concatenate 

152 # try: 

153 # x = torch.cat([*args], dim=1) 

154 # except: 

155 # for arg in args: 

156 # print(arg.shape) 

157 # raise 

158 x = torch.cat([*args], dim=1) 

159 # shared mlp 

160 x = self.layers(x) 

161 

162 return x 

163 

164 

165class Decoder_Layer(nn.Module): 

166 ''' 

167 Decoder Module of FoldingNet 

168 ''' 

169 

170 def __init__(self, in_features, out_features, in_channel, out_channel, **kwargs): 

171 super(Decoder_Layer, self).__init__() 

172 

173 # Sample the grids in 2D space 

174 # xx = np.linspace(-0.3, 0.3, 45, dtype=np.float32) 

175 # yy = np.linspace(-0.3, 0.3, 45, dtype=np.float32) 

176 # self.grid = np.meshgrid(xx, yy) # (2, 45, 45) 

177 self.out_features = out_features 

178 self.grid = torch.linspace(-0.5, 0.5, out_features).view(1, -1) 

179 # reshape 

180 # self.grid = torch.Tensor(self.grid).view(2, -1) # (2, 45, 45) -> (2, 45 * 45) 

181 assert out_features % in_features == 0 

182 self.m = out_features // in_features 

183 

184 self.fold1 = FoldingLayer(in_channel + 1, [512, 512, out_channel]) 

185 self.fold2 = FoldingLayer(in_channel + out_channel + 1, [512, 512, out_channel]) 

186 

187 def forward(self, x): 

188 ''' 

189 :param x: (B, C) 

190 ''' 

191 batch_size = x.shape[0] 

192 

193 # repeat grid for batch operation 

194 grid = self.grid.to(x.device) # (2, 45 * 45) 

195 grid = grid.unsqueeze(0).repeat(batch_size, 1, 1) # (B, 2, 45 * 45) 

196 

197 # repeat codewords 

198 x = x.repeat_interleave(self.m, dim=-1) # (B, 512, 45 * 45) 

199 

200 # two folding operations 

201 recon1 = self.fold1(grid, x) 

202 recon2 = recon1 + self.fold2(grid, x, recon1) 

203 

204 return recon2 

205 

206 

207class Decoder(nn.Module): 

208 ''' 

209 Decoder Module of FoldingNet 

210 ''' 

211 

212 def __init__(self, out_features, latent_dimension=2, **kwargs): 

213 super(Decoder, self).__init__() 

214 self.latent_dimension = latent_dimension 

215 

216 # Sample the grids in 2D space 

217 # xx = np.linspace(-0.3, 0.3, 45, dtype=np.float32) 

218 # yy = np.linspace(-0.3, 0.3, 45, dtype=np.float32) 

219 # self.grid = np.meshgrid(xx, yy) # (2, 45, 45) 

220 

221 start_out = (out_features // 128) + 1 

222 

223 self.out_features = out_features 

224 

225 self.layer1 = Decoder_Layer(1, start_out, latent_dimension, 3 * 128) 

226 self.layer2 = Decoder_Layer(start_out, start_out * 8, 3 * 128, 3 * 16) 

227 self.layer3 = Decoder_Layer(start_out * 8, start_out * 32, 3 * 16, 3 * 4) 

228 self.layer4 = Decoder_Layer(start_out * 32, start_out * 128, 3 * 4, 3) 

229 

230 def forward(self, x): 

231 ''' 

232 x: (B, C) 

233 ''' 

234 x = x.view(-1, self.latent_dimension, 1) 

235 x = self.layer1(x) 

236 x = self.layer2(x) 

237 x = self.layer3(x) 

238 x = self.layer4(x) 

239 

240 return x 

241 

242 

243class CNNAutoEncoder(BaseCV, pl.LightningModule): 

244 ''' 

245 Autoencoder architecture derived from FoldingNet. 

246 ''' 

247 

248 BLOCKS = ["norm_in", "encoder", "decoder"] 

249 

250 def __init__(self, n_features, n_cvs, encoder_layers, decoder_layers, options=None, **kwargs): 

251 

252 super().__init__(in_features=n_features, out_features=n_cvs, **kwargs) 

253 

254 options = self.parse_options(options) 

255 

256 if "norm_in" in options and options["norm_in"] is not None: 

257 self.mean = options["norm_in"]["stats"]["parametric"][0] 

258 self.std = options["norm_in"]["stats"]["parametric"][1] 

259 # self.norm_in = Normalization(self.in_features, mean=self.mean, range=self.std) 

260 

261 top = options["norm_in"]["stats"]["topology"][0] 

262 self.in_features = top.n_atoms 

263 self.out_features = n_features 

264 

265 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

266 

267 x0_coords = torch.tensor(top.xyz[0]).permute(1, 0) 

268 

269 atominfo = [] 

270 for i in top.topology.atoms: 

271 atominfo.append([i.name, i.residue.name, i.residue.index + 1]) 

272 atominfo = np.array(atominfo, dtype=object) 

273 

274 # if trainer == "Torch": 

275 self.protein_energy = TorchProteinEnergy(x0_coords, 

276 atominfo, 

277 device=device, 

278 method='roll') 

279 

280 # if options["trainer"] == "OpenMM": 

281 

282 # from loss.utils.openmm_thread import openmm_energy_setup, openmm_energy, soft_xml_script 

283 

284 # protein_energy_setup = openmm_energy_setup(top) 

285 # atominfo = protein_energy_setup.get_atominfo() 

286 # self.mol = protein_energy_setup.mol_dataframe() 

287 

288 # atoms_selected = list(set([atom.name for atom in top.topology.atoms])) 

289 

290 # openmm_kwargs = options["Trainer"]["OpenMM"] 

291 

292 # if xml_file is None and openmm_kwargs.get("soft_NB", None): 

293 # print("using soft nonbonded forces by default") 

294 # import uuid 

295 # random_string = str(uuid.uuid4())[:8] 

296 # tmp_filename = f"soft_nonbonded_{random_string()}.xml" 

297 # with open(tmp_filename, "w") as f: 

298 # f.write(soft_xml_script) 

299 # xml_file = ["amber14-all.xml", tmp_filename] 

300 # kwargs["remove_NB"] = True 

301 # elif xml_file is None: 

302 # xml_file = ["amber14-all.xml"] 

303 # self.start_physics_at = openmm_kwargs.get("start_physics_at", 10) 

304 # self.psf = physics_scaling_factor 

305 # if openmm_kwargs.get("clamp", False): 

306 # clamp_kwargs = dict(max=openmm_kwargs.get("clamp_threshold",1e8), 

307 # min=-openmm_kwargs.get("clamp_threshold", 1e8)) 

308 # else: 

309 # clamp_kwargs = None 

310 

311 # self.protein_energy = openmm_energy(self.mol, 

312 # self.std, 

313 # clamp=clamp_kwargs, 

314 # platform="CUDA" if self.device == torch.device("cuda") else "Reference", 

315 # atoms=atoms_selected, 

316 # xml_file=xml_file, 

317 # **kwargs) 

318 # os.remove(tmp_filename) 

319 

320 self.loss_fn = PhysicsLoss(stats=None, 

321 protein_energy=self.protein_energy, 

322 physics_scaling_factor=0.1) 

323 

324 self.encoder = Encoder(self.out_features, **kwargs) 

325 self.decoder = Decoder(self.in_features, self.out_features, **kwargs) 

326 

327 def decode(self, x): 

328 x = self.decoder(x) 

329 return x 

330 

331 def encode_decode(self, batch): 

332 z = self.forward_cv(batch) 

333 x = self.decoder(z)[:, :, : batch.size(2)] 

334 interpolated = self.decode_interpolation(x, z)[:, :, : batch.size(2)] 

335 return z, x, interpolated 

336 

337 def forward_cv(self, x): 

338 if self.norm_in is not None: 

339 x = (x - self.mean) / self.std 

340 z = self.encoder(x) 

341 return z 

342 

343 def training_step(self, train_batch, batch_idx): 

344 

345 x = train_batch["data"] 

346 x = x.view(-1, 3, self.in_features) 

347 

348 if "target" in train_batch: 

349 x_ref = train_batch["target"] 

350 else: 

351 x_ref = x 

352 

353 _, x_hat, xhat_interpolated = self.encode_decode(x_ref) 

354 

355 xhat_interpolated = xhat_interpolated * self.std + self.mean 

356 

357 mse_loss, physics_loss, scale = self.loss_fn(x_ref, x_hat, xhat_interpolated) 

358 loss = mse_loss + scale * physics_loss['physics_loss'] 

359 

360 name = "train" if self.training else "valid" 

361 self.log(f"{name}_loss", loss, on_epoch=True, on_step=True, prog_bar=True, logger=True) 

362 self.log(f"{name}_mse_loss", mse_loss, on_epoch=True, on_step=True, prog_bar=True, logger=True) 

363 self.log(f"{name}_physics_loss", physics_loss['physics_loss'], on_epoch=True, on_step=True, prog_bar=True, logger=True) 

364 self.log(f"{name}_bond_energy", physics_loss['bond_energy'], on_epoch=True, on_step=True, prog_bar=True, logger=True) 

365 self.log(f"{name}_angle_energy", physics_loss['angle_energy'], on_epoch=True, on_step=True, prog_bar=True, logger=True) 

366 self.log(f"{name}_torsion_energy", physics_loss['torsion_energy'], on_epoch=True, on_step=True, prog_bar=True, logger=True) 

367 

368 return loss 

369 

370 def decode_interpolation(self, batch, latent): 

371 """ 

372 Decode a latent vector to a protein structure. 

373 

374 :param torch.Tensor latent: Latent vector to decode. 

375 :return: Decoded protein structure. 

376 :rtype: torch.Tensor 

377 """ 

378 

379 alpha = torch.rand(int(len(batch) // 2), 1, 1).type_as(latent) 

380 latent_interpolated = (1 - alpha) * latent[:-1:2] + alpha * latent[1::2] 

381 decoded_interpolation = self.decoder(latent_interpolated) * self.std + self.mean 

382 return decoded_interpolation 

383 

384 @property 

385 def example_input_array(self): 

386 # return a dummy tensor of shape [1, C, N] on the correct device 

387 return torch.zeros(1, 3, self.in_features) 

388 

389 # def configure_optimizers(self): 

390 # optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4) 

391 # return optimizer