Coverage for biobb_pytorch / mdae / loss / committor_loss.py: 12%
175 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1#!/usr/bin/env python
3# =============================================================================
4# MODULE DOCSTRING
5# =============================================================================
7"""
8Committor function Loss Function and Utils.
9"""
11from mlcolvar.core.transform.descriptors.utils import sanitize_positions_shape
12__all__ = ["CommittorLoss", "committor_loss", "SmartDerivatives", "compute_descriptors_derivatives"]
14# =============================================================================
15# GLOBAL IMPORTS
16# =============================================================================
18import torch
19from typing import Optional
21# =============================================================================
22# LOSS FUNCTIONS
23# =============================================================================
26class CommittorLoss(torch.nn.Module):
27 """Compute a loss function based on Kolmogorov's variational principle for the determination of the committor function"""
29 def __init__(self,
30 mass: torch.Tensor,
31 alpha: float,
32 cell: float = None,
33 gamma: float = 10000,
34 delta_f: float = 0,
35 separate_boundary_dataset: bool = True,
36 descriptors_derivatives: torch.nn.Module = None
37 ):
38 """Compute Kolmogorov's variational principle loss and impose boundary conditions on the metastable states
40 Parameters
41 ----------
42 mass : torch.Tensor
43 Atomic masses of the atoms in the system
44 alpha : float
45 Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B)
46 cell : float, optional
47 CUBIC cell size length, used to scale the positions from reduce coordinates to real coordinates, by default None
48 gamma : float, optional
49 Hyperparamer that scales the whole loss to avoid too small numbers, i.e. gamma*(loss_var + loss_bound), by default 10000
50 delta_f : float, optional
51 Delta free energy between A (label 0) and B (label 1), units is kBT, by default 0.
52 State B is supposed to be higher in energy.
53 separate_boundary_dataset : bool, optional
54 Switch to exculde boundary condition labeled data from the variational loss, by default True
55 descriptors_derivatives : torch.nn.Module, optional
56 `SmartDerivatives` object to save memory and time when using descriptors.
57 See also mlcolvar.core.loss.committor_loss.SmartDerivatives
59 """
60 super().__init__()
61 self.register_buffer("mass", mass)
62 self.alpha = alpha
63 self.cell = cell
64 self.gamma = gamma
65 self.delta_f = delta_f
66 self.descriptors_derivatives = descriptors_derivatives
67 self.separate_boundary_dataset = separate_boundary_dataset
69 def forward(
70 self, x: torch.Tensor, q: torch.Tensor, labels: torch.Tensor, w: torch.Tensor, create_graph: bool = True
71 ) -> torch.Tensor:
72 return committor_loss(x=x,
73 q=q,
74 labels=labels,
75 w=w,
76 mass=self.mass,
77 alpha=self.alpha,
78 gamma=self.gamma,
79 delta_f=self.delta_f,
80 create_graph=create_graph,
81 cell=self.cell,
82 separate_boundary_dataset=self.separate_boundary_dataset,
83 descriptors_derivatives=self.descriptors_derivatives
84 )
87def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
88 """Broadcast util, from torch_scatter"""
89 if dim < 0:
90 dim = other.dim() + dim
91 if src.dim() == 1:
92 for _ in range(0, dim):
93 src = src.unsqueeze(0)
94 for _ in range(src.dim(), other.dim()):
95 src = src.unsqueeze(-1)
96 src = src.expand(other.size())
97 return src
100def scatter_sum(src: torch.Tensor,
101 index: torch.Tensor,
102 dim: int = -1,
103 out: Optional[torch.Tensor] = None,
104 dim_size: Optional[int] = None) -> torch.Tensor:
105 """Scatter sum function, from torch_scatter module (https://github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py)"""
106 index = broadcast(index, src, dim)
107 if out is None:
108 size = list(src.size())
109 if dim_size is not None:
110 size[dim] = dim_size
111 elif index.numel() == 0:
112 size[dim] = 0
113 else:
114 size[dim] = int(index.max()) + 1
115 out = torch.zeros(size, dtype=src.dtype, device=src.device)
116 return out.scatter_add_(dim, index, src)
117 else:
118 return out.scatter_add_(dim, index, src)
121def committor_loss(x: torch.Tensor,
122 q: torch.Tensor,
123 labels: torch.Tensor,
124 w: torch.Tensor,
125 mass: torch.Tensor,
126 alpha: float,
127 gamma: float = 10000,
128 delta_f: float = 0,
129 create_graph: bool = True,
130 cell: float = None,
131 separate_boundary_dataset: bool = True,
132 descriptors_derivatives: torch.nn.Module = None
133 ):
134 """Compute variational loss for committor optimization with boundary conditions
136 Parameters
137 ----------
138 x : torch.Tensor
139 Input of the NN
140 q : torch.Tensor
141 Committor quess q(x), it is the output of NN
142 labels : torch.Tensor
143 Labels for states, A and B states for boundary conditions
144 w : torch.Tensor
145 Reweighing factors to Boltzmann distribution. This should depend on the simulation in which the data were collected.
146 mass : torch.Tensor
147 List of masses of all the atoms we are using, for each atom we need to repeat three times for x,y,z.
148 Can be created using `committor.utils.initialize_committor_masses`
149 alpha : float
150 Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B)
151 gamma : float
152 Hyperparamer that scales the whole loss to avoid too small numbers, i.e. gamma*(loss_var + loss_bound)
153 By default 10000
154 delta_f : float
155 Delta free energy between A (label 0) and B (label 1), units is kBT, by default 0.
156 create_graph : bool
157 Make loss backwardable, deactivate for validation to save memory, default True
158 cell : float
159 CUBIC cell size length, used to scale the positions from reduce coordinates to real coordinates, default None
160 separate_boundary_dataset : bool, optional
161 Switch to exculde boundary condition labeled data from the variational loss, by default True
162 descriptors_derivatives : torch.nn.Module, optional
163 `SmartDerivatives` object to save memory and time when using descriptors.
164 See also mlcolvar.core.loss.committor_loss.SmartDerivatives
166 Returns
167 -------
168 loss : torch.Tensor
169 Loss value.
170 gamma*loss_var : torch.Tensor
171 The variational loss term
172 gamma*alpha*loss_A : torch.Tensor
173 The boundary loss term on basin A
174 gamma*alpha*loss_B : torch.Tensor
175 The boundary loss term on basin B
176 """
177 # inherit right device
178 device = x.device
180 mass = mass.to(device)
182 # Create masks to access different states data
183 mask_A = torch.nonzero(labels.squeeze() == 0, as_tuple=True)
184 mask_B = torch.nonzero(labels.squeeze() == 1, as_tuple=True)
185 if separate_boundary_dataset:
186 mask_var = torch.nonzero(labels.squeeze() > 1, as_tuple=True)
187 else:
188 mask_var = torch.ones(len(x), dtype=torch.bool)
190 # Update weights of basin B using the information on the delta_f
191 delta_f = torch.Tensor([delta_f])
192 if delta_f < 0: # B higher in energy --> A-B < 0
193 w[mask_B] = w[mask_B] * torch.exp(delta_f.to(device))
194 elif delta_f > 0: # A higher in energy --> A-B > 0
195 w[mask_A] = w[mask_A] * torch.exp(-delta_f.to(device))
197 # VARIATIONAL PRINICIPLE LOSS
198 # Each loss contribution is scaled by the number of samples
200 # We need the gradient of q(x)
201 grad_outputs = torch.ones_like(q[mask_var])
202 grad = torch.autograd.grad(q[mask_var], x, grad_outputs=grad_outputs, retain_graph=True, create_graph=create_graph)[0]
203 grad = grad[mask_var]
205 # TODO this fixes cell size issue
206 if cell is not None:
207 grad = grad / cell
209 if descriptors_derivatives is not None:
210 grad_square = descriptors_derivatives(grad)
211 else:
212 # we get the square of grad(q) and we multiply by the weight
213 grad_square = torch.pow(grad, 2)
215 # we sanitize the shapes of mass and weights tensors
216 # mass should have size [1, n_atoms*spatial_dims]
217 mass = mass.unsqueeze(0)
218 # weights should have size [n_batch, 1]
219 w = w.unsqueeze(-1)
221 grad_square = torch.sum((grad_square * (1 / mass)), axis=1, keepdim=True)
222 grad_square = grad_square * w[mask_var]
224 # variational contribution to loss: we sum over the batch
225 loss_var = torch.mean(grad_square)
227 # boundary conditions
228 q_A = q[mask_A]
229 q_B = q[mask_B]
230 loss_A = torch.mean(torch.pow(q_A, 2))
231 loss_B = torch.mean(torch.pow((q_B - 1), 2))
233 loss = gamma * (loss_var + alpha * (loss_A + loss_B))
235 # TODO maybe there is no need to detach them for logging
236 return loss, gamma * loss_var.detach(), alpha * gamma * loss_A.detach(), alpha * gamma * loss_B.detach()
239class SmartDerivatives(torch.nn.Module):
240 """
241 Utils to compute efficently (time and memory wise) the derivatives of q wrt some input descriptors.
242 Rather than computing explicitly the derivatives wrt the positions, we compute those wrt the descriptors (right input)
243 and multiply them by the matrix of the derivatives of the descriptors wrt the positions (left input).
244 """
246 def __init__(self,
247 der_desc_wrt_pos: torch.Tensor,
248 n_atoms: int,
249 setup_device: str = 'cpu'
250 ):
251 """Initialize the fixed matrices for smart derivatives, i.e. matrix of derivatives of descriptors wrt positions.
252 The derivatives wrt positions are recovered by multiplying the derivatives of q wrt the descriptors (right input, computed at each epoch)
253 by the non-zero elements of the derivatives of the descriptors wrt the positions (left input, compute once at the beginning on the whole dataset).
254 The multiplication are done using scatte functions and keepoing track of the indeces of the batches, descriptors, atoms and dimensions.
256 NB. It should be used with only training set and single batch with shuffle and random_split disabled.
258 Parameters
259 ----------
260 der_desc_wrt_pos : torch.Tensor
261 Tensor containing the derivatives of the descriptors wrt the atomic positions
262 n_atoms : int
263 Number of atoms in the systems, all the atoms should be used in at least one of the descriptors
264 setup_device : str
265 Device on which to perform the expensive calculations. Either 'cpu' or 'cuda', by default 'cpu'
266 """
267 super().__init__()
268 self.batch_size = len(der_desc_wrt_pos)
269 self.n_atoms = n_atoms
271 # setup the fixed part of the computation, i.e. left input and indeces for the scatter
272 self.left, self.mat_ind, self.scatter_indeces = self._setup_left(der_desc_wrt_pos, setup_device=setup_device)
274 def _setup_left(self, left_input: torch.Tensor, setup_device: str = 'cpu'):
275 """Setup the fixed part of the computation, i.e. left input"""
276 # all the setup should be done on the CPU by defualt
277 left_input = left_input.to(torch.device(setup_device))
279 # the indeces in mat_ind are: batch, atom, descriptor and dimension
280 left, mat_ind = self._create_nonzero_left(left_input)
282 # it is possible that some atoms are not involved in anything
283 n_effective_atoms = len(torch.unique(mat_ind[1]))
284 if n_effective_atoms < self.n_atoms:
285 raise ValueError(f"Some of the input atoms are useless LOL. The not used atom IDs are : {[i for i in range(self.n_atoms) if i not in torch.unique(mat_ind[1]).numpy()]} ")
287 scatter_indeces = self._get_scatter_indices(batch_ind=mat_ind[0], atom_ind=mat_ind[1], dim_ind=mat_ind[3])
288 return left, mat_ind, scatter_indeces
290 def _create_nonzero_left(self, x):
291 """Find the indeces of the non-zero elements of the left input
292 """
293 # find indeces of nonzero entries of d_dist_d_x
294 mat_ind = x.nonzero(as_tuple=True)
296 # flatten matrix --> big nonzero vector
297 x = x.ravel()
298 # find indeces of nonzero entries of the flattened matrix
299 vec_ind = x.nonzero(as_tuple=True)
301 # create vector with the nonzero entries only
302 x_vec = x[vec_ind[0].long()]
304 # del(vec_ind)
305 return x_vec, mat_ind
307 def _get_scatter_indices(self, batch_ind, atom_ind, dim_ind):
308 """Compute the general indices to map the long vector of nonzero derivatives to the right atom, dimension and descriptor also in the case of non homogenoeus input.
309 We need to gather the derivatives with respect to the same atom coming from different descriptors to obtain the total gradient.
310 """
311 # ====================================== INITIAL NOTE ======================================
312 # in the comment there's the example of the distances in a 3 atoms system with 4 batches
313 # i.e. 3desc*3*atom*3dim*2pairs*4batch = 72 values needs to be mappped to 3atoms*3dims*4batch = 36
315 # Ref_idx: tensor([ 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, 7, 8,
316 # 9, 10, 11, 9, 10, 11, 12, 13, 14, 12, 13, 14, 15, 16, 17, 15, 16, 17,
317 # 18, 19, 20, 18, 19, 20, 21, 22, 23, 21, 22, 23, 24, 25, 26, 24, 25, 26,
318 # 27, 28, 29, 27, 28, 29, 30, 31, 32, 30, 31, 32, 33, 34, 35, 33, 34, 35])
319 # ==========================================================================================
321 # these would be the indeces in the case of uniform batches and number of atom/descriptor dependence
322 # it just repeats the atom index in a cycle
323 # e.g. [0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, 7, 8, 0, 1, 2, 0, 1, 2,
324 # 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, 7, 8, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5,
325 # 6, 7, 8, 6, 7, 8, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, 7, 8]
326 not_shifted_indeces = atom_ind * 3 + dim_ind
328 # get the number of elements in each batch
329 # e.g. [17, 18, 18, 18]
330 batch_elements = scatter_sum(torch.ones_like(batch_ind), batch_ind)
331 batch_elements[0] -= 1 # to make the later indexing consistent
333 # compute the pointer idxs to the beginning of each batch by summing the number of elements in each batch
334 # e.g. [ 0., 17., 35., 53.] NB. These are indeces!
335 batch_pointers = torch.Tensor([batch_elements[:i].sum() for i in range(len(batch_elements))])
336 del (batch_elements)
338 # number of entries in the scattered vector before each batch
339 # e.g. [ 0., 9., 18., 27.]
340 markers = not_shifted_indeces[batch_pointers.long()] # highest not_shifted index for each batch
341 del (not_shifted_indeces)
342 del (batch_pointers)
343 cumulative_markers = torch.Tensor([markers[:i + 1].sum() for i in range(len(markers))]).to(batch_ind.device) # stupid sum of indeces
344 del (markers)
345 cumulative_markers += torch.unique(batch_ind) # markers correction by the number of batches
347 # get the index shift in the scattered vector based on the batch
348 # e.g. [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9,
349 # 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
350 # 18, 18, 18, 18, 18, 18, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27 ]
351 batch_shift = torch.gather(cumulative_markers, 0, batch_ind)
352 del (cumulative_markers)
354 # finally compute the scatter indeces by including also their shift due to the batch
355 shifted_indeces = atom_ind * 3 + dim_ind + batch_shift
357 return shifted_indeces
359 def forward(self, x: torch.Tensor):
360 # ensure device consistency
361 left = self.left.to(x.device)
363 # get the vector with the derivatives of q wrt the descriptors
364 right = self._create_right(x=x, batch_ind=self.mat_ind[0], des_ind=self.mat_ind[2])
366 # do element-wise product
367 src = left * right
369 # compute square modulus
370 out = self._compute_square_modulus(x=src, indeces=self.scatter_indeces, n_atoms=self.n_atoms, batch_size=self.batch_size)
372 return out
374 def _create_right(self, x: torch.Tensor, batch_ind: torch.Tensor, des_ind: torch.Tensor):
375 # keep only the non zero elements of right input
376 desc_vec = x[batch_ind, des_ind]
377 return desc_vec
379 def _compute_square_modulus(self, x: torch.Tensor, indeces: torch.Tensor, n_atoms: int, batch_size: torch.Tensor):
380 indeces = indeces.long().to(x.device)
382 # this sums the elements of x according to the indeces, this way we get the contributions of different descriptors to the same atom
383 out = scatter_sum(x, indeces.long())
384 # now make the square
385 out = out.pow(2)
386 # reshape, this needs to have the correct number of atoms as we need to mulply it by the mass vector later
387 out = out.reshape((batch_size, n_atoms * 3))
388 return out
391def compute_descriptors_derivatives(dataset, descriptor_function, n_atoms, separate_boundary_dataset=True):
392 """Compute the derivatives of a set of descriptors wrt input positions in a dataset for committor optimization
394 Parameters
395 ----------
396 dataset :
397 DictDataset with the positions under the 'data' key
398 descriptor_function : torch.nn.Module
399 Transform module for the computation of the descriptors
400 n_atoms : int
401 Number of atoms in the system
402 separate_boundary_dataset : bool, optional
403 Switch to exculde boundary condition labeled data from the variational loss, by default True
405 Returns
406 -------
407 desc : torch.Tensor
408 Computed descriptors
409 d_desc_d_pos : torch.Tensor
410 Derivatives of desc wrt to pos
411 """
412 pos = dataset['data']
413 labels = dataset['labels']
414 pos = sanitize_positions_shape(pos=pos, n_atoms=n_atoms)[0]
415 pos.requires_grad = True
417 desc = descriptor_function(pos)
418 if separate_boundary_dataset:
419 mask_var = torch.nonzero(labels.squeeze() > 1, as_tuple=True)[0]
420 der_desc = desc[mask_var]
421 if len(der_desc) == 0:
422 raise (ValueError('No points left after separating boundary and variational datasets. \n If you are using only unbiased data set separate_boundary_dataset=False here and in Committor or don\'t use SmartDerivatives!!'))
423 else:
424 der_desc = desc
426 # compute derivatives of descriptors wrt positions, loop over the number of decriptors
427 aux = []
428 for i in range(len(der_desc[0])):
429 aux_der = torch.autograd.grad(der_desc[:, i], pos, grad_outputs=torch.ones_like(der_desc[:, i]), retain_graph=True)[0]
430 if separate_boundary_dataset:
431 aux_der = aux_der[mask_var]
432 aux.append(aux_der)
434 d_desc_d_pos = torch.stack(aux, axis=2)
435 return pos, desc, d_desc_d_pos.squeeze(-1)
438def test_smart_derivatives():
439 from mlcolvar.core.transform import PairwiseDistances
440 from mlcolvar.core.nn import FeedForward
441 from mlcolvar.data import DictDataset
443 # compute some descriptors from positions --> distances
444 n_atoms = 10
445 pos = torch.Tensor([[1.4970, 1.3861, -0.0273, -1.4933, 1.5070, -0.1133, -1.4473, -1.4193,
446 -0.0553, 1.4940, 1.4990, -0.2403, 1.4780, -1.4173, -0.3363, -1.4243,
447 -1.4093, -0.4293, 1.3530, -1.4313, -0.4183, 1.3060, 1.4750, -0.4333,
448 1.2970, -1.3233, -0.4643, 1.1670, -1.3253, -0.5354]])
449 pos = pos.repeat(4, 1)
450 labels = torch.arange(0, 4)
452 dataset = DictDataset({'data': pos, 'labels': labels})
454 cell = torch.Tensor([3.0233])
455 ref_distances = torch.Tensor([[0.1521, 0.2335, 0.2412, 0.3798, 0.4733, 0.4649, 0.4575, 0.5741, 0.6815,
456 0.1220, 0.1323, 0.2495, 0.3407, 0.3627, 0.3919, 0.4634, 0.5885, 0.2280,
457 0.2976, 0.3748, 0.4262, 0.4821, 0.5043, 0.6376, 0.1447, 0.2449, 0.2454,
458 0.2705, 0.3597, 0.4833, 0.1528, 0.1502, 0.2370, 0.2408, 0.3805, 0.2472,
459 0.3243, 0.3159, 0.4527, 0.1270, 0.1301, 0.2440, 0.2273, 0.2819, 0.1482]])
460 ref_distances = ref_distances.repeat(4, 1)
462 ComputeDescriptors = PairwiseDistances(n_atoms=n_atoms,
463 PBC=True,
464 cell=cell,
465 scaled_coords=False)
467 for separate_boundary_dataset in [False, True]:
468 if separate_boundary_dataset:
469 mask = [labels > 1]
470 else:
471 mask = torch.ones_like(labels, dtype=torch.bool)
473 pos, desc, d_desc_d_x = compute_descriptors_derivatives(dataset=dataset,
474 descriptor_function=ComputeDescriptors,
475 n_atoms=n_atoms,
476 separate_boundary_dataset=separate_boundary_dataset)
478 assert (torch.allclose(desc, ref_distances, atol=1e-3))
480 # apply simple NN
481 NN = FeedForward(layers=[45, 2, 1])
482 out = NN(desc)
484 # compute derivatives of out wrt input
485 d_out_d_x = torch.autograd.grad(out, pos, grad_outputs=torch.ones_like(out), retain_graph=True, create_graph=True)[0]
486 # compute derivatives of out wrt descriptors
487 d_out_d_d = torch.autograd.grad(out, desc, grad_outputs=torch.ones_like(out), retain_graph=True, create_graph=True)[0]
488 ref = torch.einsum('badx,bd->bax ', d_desc_d_x, d_out_d_d[mask])
489 ref = ref.pow(2).sum(dim=(-2, -1))
491 Ref = d_out_d_x[mask].pow(2).sum(dim=(-2, -1))
493 # apply smart derivatives
494 smart_derivatives = SmartDerivatives(d_desc_d_x, n_atoms=n_atoms)
495 right_input = d_out_d_d.squeeze(-1)
496 smart_out = smart_derivatives(right_input).sum(dim=1)
498 # do checks
499 assert (torch.allclose(smart_out, ref))
500 assert (torch.allclose(smart_out, Ref))
502 smart_out.sum().backward()
505if __name__ == "__main__":
506 test_smart_derivatives()