Coverage for biobb_pytorch / mdae / models / spib.py: 39%
175 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1"""
2SPIB: A deep learning-based framework to learn RCs
3from MD trajectories. Code maintained by Dedi.
5Read and cite the following when using this method:
6https://aip.scitation.org/doi/abs/10.1063/5.0038198
7"""
9# --------------------
10# Model
11# --------------------
13import torch
14import torch.nn as nn
15import torch.nn.functional as F
16import lightning.pytorch as pl
17from mlcolvar.cvs import BaseCV
18from biobb_pytorch.mdae.models.nn.feedforward import FeedForward
19from biobb_pytorch.mdae.loss import InformationBottleneckLoss
20from typing import Optional
22__all__ = ["SPIB"]
25class SPIB(BaseCV, pl.LightningModule):
26 BLOCKS = ["norm_in", "encoder", "decoder", "k",
27 "UpdateLabel", "beta", "threshold", "patience", "refinements",
28 "learning_rate", "lr_step_size", "lr_gamma"]
30 def __init__(self, n_features, n_cvs, encoder_layers, decoder_layers, options=None, **kwargs):
31 super().__init__(in_features=n_features, out_features=n_cvs, **kwargs)
33 options = self.parse_options(options)
35 self._n_cvs = n_cvs
36 self.k = options.get("k", 2)
37 self.output_dim = n_features
39 self.learning_rate = options.get("optimizer", {}).get("lr", 0.001)
40 self.lr_step_size = options.get("optimizer", {}).get("step_size", 10)
41 self.lr_gamma = options.get("optimizer", {}).get("gamma", 0.1)
43 self.beta = 0.01
44 self.threshold = options.get("threshold", 0.01)
45 self.patience = options.get("patience", 10)
46 self.refinements = options.get("refinements", 5)
48 self.update_times = 0
49 self.unchanged_epochs = 0
50 self.state_population0 = None
51 self.eps = 1e-10
53 # Representative inputs
54 self.representative_inputs = torch.eye(
55 self.k, self.output_dim, device=self.device, requires_grad=False
56 )
57 self.idle_input = torch.eye(
58 self.k, self.k, device=self.device, requires_grad=False
59 )
60 self.representative_weights = nn.Sequential(
61 nn.Linear(self.k, 1, bias=False),
62 nn.Softmax(dim=0)
63 )
64 nn.init.ones_(self.representative_weights[0].weight)
66 # Encoder / Decoder
67 o = "encoder"
68 self.encoder = FeedForward([n_features] + encoder_layers, **options[o])
69 self.encoder_mean = torch.nn.Linear(
70 in_features=encoder_layers[-1], out_features=n_cvs
71 )
72 self.encoder_logvar = torch.nn.Linear(
73 in_features=encoder_layers[-1], out_features=n_cvs
74 )
76 o = "decoder"
77 self.decoder = FeedForward([n_cvs] + decoder_layers + [n_features], **options[o])
79 # IB loss
80 self.loss_fn = InformationBottleneckLoss(beta=self.beta, eps=self.eps)
82 self.eval_variables = ["xhat", "z", "mu", "logvar", "labels"]
84 def encode(self, inputs: torch.Tensor):
85 h = self.encoder(inputs)
86 mu = self.encoder_mean(h)
87 logvar = -10 * F.sigmoid(self.encoder_logvar(h))
88 return mu, logvar
90 def decode(self, z: torch.Tensor):
91 return F.log_softmax(self.decoder(z), dim=1)
93 def forward_cv(self, x: torch.Tensor) -> torch.Tensor:
94 if self.norm_in is not None:
95 x = self.norm_in(x)
96 mu, logvar = self.encode(x)
97 z = self.reparameterize(mu, logvar)
98 return z
100 def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor):
101 std = torch.exp(0.5 * logvar)
102 return mu + torch.randn_like(std) * std
104 def encode_decode(self, x: torch.Tensor):
105 flat = x.view(x.size(0), -1)
106 if self.norm_in is not None:
107 flat = self.norm_in(flat)
108 mu, logvar = self.encode(flat)
109 z = self.reparameterize(mu, logvar)
110 x_hat = self.decode(z)
111 if self.norm_in is not None:
112 x_hat = self.norm_in.inverse(x_hat)
113 return x_hat, z, mu, logvar
115 def evaluate_model(self, batch, batch_idx):
117 xhat, z, mu, logvar = self.encode_decode(batch['data'])
119 pred = xhat.exp()
120 labels = pred.argmax(1)
122 return xhat, z, mu, logvar, labels
124 def configure_optimizers(self):
125 opt = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
126 sch = torch.optim.lr_scheduler.StepLR(opt, step_size=self.lr_step_size, gamma=self.lr_gamma)
127 return {'optimizer': opt, 'lr_scheduler': {'scheduler': sch, 'interval': 'epoch'}}
129 @torch.no_grad()
130 def update_labels(self, inputs: torch.Tensor) -> Optional[torch.Tensor]:
131 if not self.UpdateLabel:
132 return None
133 loader = self.trainer.datamodule.train_dataloader()
134 bs = loader.batch_size
135 labels = []
136 for i in range(0, len(inputs), bs):
137 batch = inputs[i:i + bs].to(self.device)
138 mu, _ = self.encode(batch)
139 logp = self.decode(mu)
140 labels.append(logp.exp())
141 preds = torch.cat(labels, dim=0)
142 idx = preds.argmax(dim=1)
143 return F.one_hot(idx, num_classes=self.k)
145 @torch.no_grad()
146 def get_representative_z(self):
147 return self.encode(self.representative_inputs)
149 def reset_representative(self, rep_inputs: torch.Tensor):
150 self.representative_inputs = rep_inputs.detach().clone()
151 dim = rep_inputs.size(0)
152 self.idle_input = torch.eye(dim, dim, device=self.device, requires_grad=False)
153 self.representative_weights = nn.Sequential(
154 nn.Linear(dim, 1, bias=False), nn.Softmax(dim=0)
155 )
156 nn.init.ones_(self.representative_weights[0].weight)
158 def training_step(self, batch, batch_idx):
159 x, y = batch['data'], batch['labels']
160 w_batch = batch.get('weights', None)
161 preds, z, mu, logvar = self.encode_decode(x)
162 rep_mu, rep_logvar = self.get_representative_z()
163 w = self.representative_weights(self.idle_input)
164 loss, recon_err, kl = self.loss_fn(
165 y.to(self.device), preds, z, mu, logvar, rep_mu, rep_logvar, w, w_batch
166 )
167 name = "train" if self.training else "valid"
168 self.log(f'{name}_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
169 self.log(f'{name}_recon', recon_err, on_step=True, on_epoch=True, prog_bar=True)
170 self.log(f'{name}_kl', kl, on_step=True, on_epoch=True, prog_bar=True)
171 return loss
173 @torch.no_grad()
174 def on_train_epoch_start(self):
175 if self.trainer.current_epoch == 0:
176 ds = self.trainer.datamodule.train_dataloader().dataset
177 self.state_population0 = ds['labels'].float().mean(dim=0)
178 self.representative_inputs = torch.eye(
179 self.k, self.output_dim, device=self.device, requires_grad=False
180 )
182 @torch.no_grad()
183 def on_train_epoch_end(self):
184 ds = self.trainer.datamodule.train_dataloader().dataset
185 new_labels = self.update_labels(ds['target'])
186 if new_labels is None:
187 return
188 state_pop = new_labels.float().mean(dim=0)
189 delta = torch.norm(state_pop - self.state_population0)
190 self.log('state_population_change', delta)
191 self.state_population0 = state_pop
192 if delta < self.threshold:
193 self.unchanged_epochs += 1
194 if self.unchanged_epochs > self.patience:
195 if torch.sum(state_pop > 0) < 2:
196 self.trainer.should_stop = True
197 return
198 if self.UpdateLabel and self.update_times < self.refinements:
199 self.update_times += 1
200 self.unchanged_epochs = 0
201 ds['labels'] = new_labels
202 reps = self.estimate_representative_inputs(
203 ds['data'], getattr(ds, 'weights', None)
204 ).to(self.device)
205 self.reset_representative(reps)
206 self.log(f'refinement_{self.update_times}', 1)
207 # Force reload of the DataLoader to reflect updated labels
208 loop = self.trainer.fit_loop
209 loop._combined_loader = None
210 loop.setup_data()
211 else:
212 self.trainer.should_stop = True
213 else:
214 self.unchanged_epochs = 0
216 @torch.no_grad()
217 def estimate_representative_inputs(
218 self, inputs: torch.Tensor, bias: Optional[torch.Tensor] = None
219 ) -> torch.Tensor:
220 self.eval()
221 N = len(inputs)
222 bs = self.trainer.datamodule.train_dataloader().batch_size
223 preds = []
224 for i in range(0, N, bs):
225 batch = inputs[i:i + bs].to(self.device)
226 mu, _ = self.encode(batch)
227 logp = self.decoder(mu)
228 preds.append(logp.exp())
229 preds = torch.cat(preds, dim=0)
230 labels = F.one_hot(preds.argmax(dim=1), num_classes=self.k)
232 if bias is None:
233 bias = torch.ones(N, device=self.device)
235 data_shape = inputs.shape[1:]
236 state_sums = torch.zeros(self.k, *data_shape, device=self.device)
237 state_counts = torch.zeros(self.k, device=self.device)
239 for state in range(self.k):
240 mask = labels[:, state] == 1
241 if mask.any():
242 weights_expanded = bias[mask].view(-1, *([1] * len(data_shape)))
243 weighted_inputs = inputs[mask] * weights_expanded
244 state_sums[state] += weighted_inputs.sum(dim=0)
245 state_counts[state] += bias[mask].sum()
247 reps = torch.zeros(self.k, *data_shape, device=self.device)
248 for state in range(self.k):
249 if state_counts[state] > 0:
250 reps[state] = state_sums[state] / state_counts[state]
252 return reps