Coverage for biobb_pytorch / mdae / models / spib.py: 39%

175 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1""" 

2SPIB: A deep learning-based framework to learn RCs 

3from MD trajectories. Code maintained by Dedi. 

4 

5Read and cite the following when using this method: 

6https://aip.scitation.org/doi/abs/10.1063/5.0038198 

7""" 

8 

9# -------------------- 

10# Model 

11# -------------------- 

12 

13import torch 

14import torch.nn as nn 

15import torch.nn.functional as F 

16import lightning.pytorch as pl 

17from mlcolvar.cvs import BaseCV 

18from biobb_pytorch.mdae.models.nn.feedforward import FeedForward 

19from biobb_pytorch.mdae.loss import InformationBottleneckLoss 

20from typing import Optional 

21 

22__all__ = ["SPIB"] 

23 

24 

25class SPIB(BaseCV, pl.LightningModule): 

26 BLOCKS = ["norm_in", "encoder", "decoder", "k", 

27 "UpdateLabel", "beta", "threshold", "patience", "refinements", 

28 "learning_rate", "lr_step_size", "lr_gamma"] 

29 

30 def __init__(self, n_features, n_cvs, encoder_layers, decoder_layers, options=None, **kwargs): 

31 super().__init__(in_features=n_features, out_features=n_cvs, **kwargs) 

32 

33 options = self.parse_options(options) 

34 

35 self._n_cvs = n_cvs 

36 self.k = options.get("k", 2) 

37 self.output_dim = n_features 

38 

39 self.learning_rate = options.get("optimizer", {}).get("lr", 0.001) 

40 self.lr_step_size = options.get("optimizer", {}).get("step_size", 10) 

41 self.lr_gamma = options.get("optimizer", {}).get("gamma", 0.1) 

42 

43 self.beta = 0.01 

44 self.threshold = options.get("threshold", 0.01) 

45 self.patience = options.get("patience", 10) 

46 self.refinements = options.get("refinements", 5) 

47 

48 self.update_times = 0 

49 self.unchanged_epochs = 0 

50 self.state_population0 = None 

51 self.eps = 1e-10 

52 

53 # Representative inputs 

54 self.representative_inputs = torch.eye( 

55 self.k, self.output_dim, device=self.device, requires_grad=False 

56 ) 

57 self.idle_input = torch.eye( 

58 self.k, self.k, device=self.device, requires_grad=False 

59 ) 

60 self.representative_weights = nn.Sequential( 

61 nn.Linear(self.k, 1, bias=False), 

62 nn.Softmax(dim=0) 

63 ) 

64 nn.init.ones_(self.representative_weights[0].weight) 

65 

66 # Encoder / Decoder 

67 o = "encoder" 

68 self.encoder = FeedForward([n_features] + encoder_layers, **options[o]) 

69 self.encoder_mean = torch.nn.Linear( 

70 in_features=encoder_layers[-1], out_features=n_cvs 

71 ) 

72 self.encoder_logvar = torch.nn.Linear( 

73 in_features=encoder_layers[-1], out_features=n_cvs 

74 ) 

75 

76 o = "decoder" 

77 self.decoder = FeedForward([n_cvs] + decoder_layers + [n_features], **options[o]) 

78 

79 # IB loss 

80 self.loss_fn = InformationBottleneckLoss(beta=self.beta, eps=self.eps) 

81 

82 self.eval_variables = ["xhat", "z", "mu", "logvar", "labels"] 

83 

84 def encode(self, inputs: torch.Tensor): 

85 h = self.encoder(inputs) 

86 mu = self.encoder_mean(h) 

87 logvar = -10 * F.sigmoid(self.encoder_logvar(h)) 

88 return mu, logvar 

89 

90 def decode(self, z: torch.Tensor): 

91 return F.log_softmax(self.decoder(z), dim=1) 

92 

93 def forward_cv(self, x: torch.Tensor) -> torch.Tensor: 

94 if self.norm_in is not None: 

95 x = self.norm_in(x) 

96 mu, logvar = self.encode(x) 

97 z = self.reparameterize(mu, logvar) 

98 return z 

99 

100 def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor): 

101 std = torch.exp(0.5 * logvar) 

102 return mu + torch.randn_like(std) * std 

103 

104 def encode_decode(self, x: torch.Tensor): 

105 flat = x.view(x.size(0), -1) 

106 if self.norm_in is not None: 

107 flat = self.norm_in(flat) 

108 mu, logvar = self.encode(flat) 

109 z = self.reparameterize(mu, logvar) 

110 x_hat = self.decode(z) 

111 if self.norm_in is not None: 

112 x_hat = self.norm_in.inverse(x_hat) 

113 return x_hat, z, mu, logvar 

114 

115 def evaluate_model(self, batch, batch_idx): 

116 

117 xhat, z, mu, logvar = self.encode_decode(batch['data']) 

118 

119 pred = xhat.exp() 

120 labels = pred.argmax(1) 

121 

122 return xhat, z, mu, logvar, labels 

123 

124 def configure_optimizers(self): 

125 opt = torch.optim.Adam(self.parameters(), lr=self.learning_rate) 

126 sch = torch.optim.lr_scheduler.StepLR(opt, step_size=self.lr_step_size, gamma=self.lr_gamma) 

127 return {'optimizer': opt, 'lr_scheduler': {'scheduler': sch, 'interval': 'epoch'}} 

128 

129 @torch.no_grad() 

130 def update_labels(self, inputs: torch.Tensor) -> Optional[torch.Tensor]: 

131 if not self.UpdateLabel: 

132 return None 

133 loader = self.trainer.datamodule.train_dataloader() 

134 bs = loader.batch_size 

135 labels = [] 

136 for i in range(0, len(inputs), bs): 

137 batch = inputs[i:i + bs].to(self.device) 

138 mu, _ = self.encode(batch) 

139 logp = self.decode(mu) 

140 labels.append(logp.exp()) 

141 preds = torch.cat(labels, dim=0) 

142 idx = preds.argmax(dim=1) 

143 return F.one_hot(idx, num_classes=self.k) 

144 

145 @torch.no_grad() 

146 def get_representative_z(self): 

147 return self.encode(self.representative_inputs) 

148 

149 def reset_representative(self, rep_inputs: torch.Tensor): 

150 self.representative_inputs = rep_inputs.detach().clone() 

151 dim = rep_inputs.size(0) 

152 self.idle_input = torch.eye(dim, dim, device=self.device, requires_grad=False) 

153 self.representative_weights = nn.Sequential( 

154 nn.Linear(dim, 1, bias=False), nn.Softmax(dim=0) 

155 ) 

156 nn.init.ones_(self.representative_weights[0].weight) 

157 

158 def training_step(self, batch, batch_idx): 

159 x, y = batch['data'], batch['labels'] 

160 w_batch = batch.get('weights', None) 

161 preds, z, mu, logvar = self.encode_decode(x) 

162 rep_mu, rep_logvar = self.get_representative_z() 

163 w = self.representative_weights(self.idle_input) 

164 loss, recon_err, kl = self.loss_fn( 

165 y.to(self.device), preds, z, mu, logvar, rep_mu, rep_logvar, w, w_batch 

166 ) 

167 name = "train" if self.training else "valid" 

168 self.log(f'{name}_loss', loss, on_step=True, on_epoch=True, prog_bar=True) 

169 self.log(f'{name}_recon', recon_err, on_step=True, on_epoch=True, prog_bar=True) 

170 self.log(f'{name}_kl', kl, on_step=True, on_epoch=True, prog_bar=True) 

171 return loss 

172 

173 @torch.no_grad() 

174 def on_train_epoch_start(self): 

175 if self.trainer.current_epoch == 0: 

176 ds = self.trainer.datamodule.train_dataloader().dataset 

177 self.state_population0 = ds['labels'].float().mean(dim=0) 

178 self.representative_inputs = torch.eye( 

179 self.k, self.output_dim, device=self.device, requires_grad=False 

180 ) 

181 

182 @torch.no_grad() 

183 def on_train_epoch_end(self): 

184 ds = self.trainer.datamodule.train_dataloader().dataset 

185 new_labels = self.update_labels(ds['target']) 

186 if new_labels is None: 

187 return 

188 state_pop = new_labels.float().mean(dim=0) 

189 delta = torch.norm(state_pop - self.state_population0) 

190 self.log('state_population_change', delta) 

191 self.state_population0 = state_pop 

192 if delta < self.threshold: 

193 self.unchanged_epochs += 1 

194 if self.unchanged_epochs > self.patience: 

195 if torch.sum(state_pop > 0) < 2: 

196 self.trainer.should_stop = True 

197 return 

198 if self.UpdateLabel and self.update_times < self.refinements: 

199 self.update_times += 1 

200 self.unchanged_epochs = 0 

201 ds['labels'] = new_labels 

202 reps = self.estimate_representative_inputs( 

203 ds['data'], getattr(ds, 'weights', None) 

204 ).to(self.device) 

205 self.reset_representative(reps) 

206 self.log(f'refinement_{self.update_times}', 1) 

207 # Force reload of the DataLoader to reflect updated labels 

208 loop = self.trainer.fit_loop 

209 loop._combined_loader = None 

210 loop.setup_data() 

211 else: 

212 self.trainer.should_stop = True 

213 else: 

214 self.unchanged_epochs = 0 

215 

216 @torch.no_grad() 

217 def estimate_representative_inputs( 

218 self, inputs: torch.Tensor, bias: Optional[torch.Tensor] = None 

219 ) -> torch.Tensor: 

220 self.eval() 

221 N = len(inputs) 

222 bs = self.trainer.datamodule.train_dataloader().batch_size 

223 preds = [] 

224 for i in range(0, N, bs): 

225 batch = inputs[i:i + bs].to(self.device) 

226 mu, _ = self.encode(batch) 

227 logp = self.decoder(mu) 

228 preds.append(logp.exp()) 

229 preds = torch.cat(preds, dim=0) 

230 labels = F.one_hot(preds.argmax(dim=1), num_classes=self.k) 

231 

232 if bias is None: 

233 bias = torch.ones(N, device=self.device) 

234 

235 data_shape = inputs.shape[1:] 

236 state_sums = torch.zeros(self.k, *data_shape, device=self.device) 

237 state_counts = torch.zeros(self.k, device=self.device) 

238 

239 for state in range(self.k): 

240 mask = labels[:, state] == 1 

241 if mask.any(): 

242 weights_expanded = bias[mask].view(-1, *([1] * len(data_shape))) 

243 weighted_inputs = inputs[mask] * weights_expanded 

244 state_sums[state] += weighted_inputs.sum(dim=0) 

245 state_counts[state] += bias[mask].sum() 

246 

247 reps = torch.zeros(self.k, *data_shape, device=self.device) 

248 for state in range(self.k): 

249 if state_counts[state] > 0: 

250 reps[state] = state_sums[state] / state_counts[state] 

251 

252 return reps