Coverage for biobb_pytorch / mdae / loss / committor_loss.py: 12%

175 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1#!/usr/bin/env python 

2 

3# ============================================================================= 

4# MODULE DOCSTRING 

5# ============================================================================= 

6 

7""" 

8Committor function Loss Function and Utils. 

9""" 

10 

11from mlcolvar.core.transform.descriptors.utils import sanitize_positions_shape 

12__all__ = ["CommittorLoss", "committor_loss", "SmartDerivatives", "compute_descriptors_derivatives"] 

13 

14# ============================================================================= 

15# GLOBAL IMPORTS 

16# ============================================================================= 

17 

18import torch 

19from typing import Optional 

20 

21# ============================================================================= 

22# LOSS FUNCTIONS 

23# ============================================================================= 

24 

25 

26class CommittorLoss(torch.nn.Module): 

27 """Compute a loss function based on Kolmogorov's variational principle for the determination of the committor function""" 

28 

29 def __init__(self, 

30 mass: torch.Tensor, 

31 alpha: float, 

32 cell: float = None, 

33 gamma: float = 10000, 

34 delta_f: float = 0, 

35 separate_boundary_dataset: bool = True, 

36 descriptors_derivatives: torch.nn.Module = None 

37 ): 

38 """Compute Kolmogorov's variational principle loss and impose boundary conditions on the metastable states 

39 

40 Parameters 

41 ---------- 

42 mass : torch.Tensor 

43 Atomic masses of the atoms in the system 

44 alpha : float 

45 Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B) 

46 cell : float, optional 

47 CUBIC cell size length, used to scale the positions from reduce coordinates to real coordinates, by default None 

48 gamma : float, optional 

49 Hyperparamer that scales the whole loss to avoid too small numbers, i.e. gamma*(loss_var + loss_bound), by default 10000 

50 delta_f : float, optional 

51 Delta free energy between A (label 0) and B (label 1), units is kBT, by default 0. 

52 State B is supposed to be higher in energy. 

53 separate_boundary_dataset : bool, optional 

54 Switch to exculde boundary condition labeled data from the variational loss, by default True 

55 descriptors_derivatives : torch.nn.Module, optional 

56 `SmartDerivatives` object to save memory and time when using descriptors. 

57 See also mlcolvar.core.loss.committor_loss.SmartDerivatives 

58 

59 """ 

60 super().__init__() 

61 self.register_buffer("mass", mass) 

62 self.alpha = alpha 

63 self.cell = cell 

64 self.gamma = gamma 

65 self.delta_f = delta_f 

66 self.descriptors_derivatives = descriptors_derivatives 

67 self.separate_boundary_dataset = separate_boundary_dataset 

68 

69 def forward( 

70 self, x: torch.Tensor, q: torch.Tensor, labels: torch.Tensor, w: torch.Tensor, create_graph: bool = True 

71 ) -> torch.Tensor: 

72 return committor_loss(x=x, 

73 q=q, 

74 labels=labels, 

75 w=w, 

76 mass=self.mass, 

77 alpha=self.alpha, 

78 gamma=self.gamma, 

79 delta_f=self.delta_f, 

80 create_graph=create_graph, 

81 cell=self.cell, 

82 separate_boundary_dataset=self.separate_boundary_dataset, 

83 descriptors_derivatives=self.descriptors_derivatives 

84 ) 

85 

86 

87def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 

88 """Broadcast util, from torch_scatter""" 

89 if dim < 0: 

90 dim = other.dim() + dim 

91 if src.dim() == 1: 

92 for _ in range(0, dim): 

93 src = src.unsqueeze(0) 

94 for _ in range(src.dim(), other.dim()): 

95 src = src.unsqueeze(-1) 

96 src = src.expand(other.size()) 

97 return src 

98 

99 

100def scatter_sum(src: torch.Tensor, 

101 index: torch.Tensor, 

102 dim: int = -1, 

103 out: Optional[torch.Tensor] = None, 

104 dim_size: Optional[int] = None) -> torch.Tensor: 

105 """Scatter sum function, from torch_scatter module (https://github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py)""" 

106 index = broadcast(index, src, dim) 

107 if out is None: 

108 size = list(src.size()) 

109 if dim_size is not None: 

110 size[dim] = dim_size 

111 elif index.numel() == 0: 

112 size[dim] = 0 

113 else: 

114 size[dim] = int(index.max()) + 1 

115 out = torch.zeros(size, dtype=src.dtype, device=src.device) 

116 return out.scatter_add_(dim, index, src) 

117 else: 

118 return out.scatter_add_(dim, index, src) 

119 

120 

121def committor_loss(x: torch.Tensor, 

122 q: torch.Tensor, 

123 labels: torch.Tensor, 

124 w: torch.Tensor, 

125 mass: torch.Tensor, 

126 alpha: float, 

127 gamma: float = 10000, 

128 delta_f: float = 0, 

129 create_graph: bool = True, 

130 cell: float = None, 

131 separate_boundary_dataset: bool = True, 

132 descriptors_derivatives: torch.nn.Module = None 

133 ): 

134 """Compute variational loss for committor optimization with boundary conditions 

135 

136 Parameters 

137 ---------- 

138 x : torch.Tensor 

139 Input of the NN 

140 q : torch.Tensor 

141 Committor quess q(x), it is the output of NN 

142 labels : torch.Tensor 

143 Labels for states, A and B states for boundary conditions 

144 w : torch.Tensor 

145 Reweighing factors to Boltzmann distribution. This should depend on the simulation in which the data were collected. 

146 mass : torch.Tensor 

147 List of masses of all the atoms we are using, for each atom we need to repeat three times for x,y,z. 

148 Can be created using `committor.utils.initialize_committor_masses` 

149 alpha : float 

150 Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B) 

151 gamma : float 

152 Hyperparamer that scales the whole loss to avoid too small numbers, i.e. gamma*(loss_var + loss_bound) 

153 By default 10000 

154 delta_f : float 

155 Delta free energy between A (label 0) and B (label 1), units is kBT, by default 0. 

156 create_graph : bool 

157 Make loss backwardable, deactivate for validation to save memory, default True 

158 cell : float 

159 CUBIC cell size length, used to scale the positions from reduce coordinates to real coordinates, default None 

160 separate_boundary_dataset : bool, optional 

161 Switch to exculde boundary condition labeled data from the variational loss, by default True 

162 descriptors_derivatives : torch.nn.Module, optional 

163 `SmartDerivatives` object to save memory and time when using descriptors. 

164 See also mlcolvar.core.loss.committor_loss.SmartDerivatives 

165 

166 Returns 

167 ------- 

168 loss : torch.Tensor 

169 Loss value. 

170 gamma*loss_var : torch.Tensor 

171 The variational loss term 

172 gamma*alpha*loss_A : torch.Tensor 

173 The boundary loss term on basin A 

174 gamma*alpha*loss_B : torch.Tensor 

175 The boundary loss term on basin B 

176 """ 

177 # inherit right device 

178 device = x.device 

179 

180 mass = mass.to(device) 

181 

182 # Create masks to access different states data 

183 mask_A = torch.nonzero(labels.squeeze() == 0, as_tuple=True) 

184 mask_B = torch.nonzero(labels.squeeze() == 1, as_tuple=True) 

185 if separate_boundary_dataset: 

186 mask_var = torch.nonzero(labels.squeeze() > 1, as_tuple=True) 

187 else: 

188 mask_var = torch.ones(len(x), dtype=torch.bool) 

189 

190 # Update weights of basin B using the information on the delta_f 

191 delta_f = torch.Tensor([delta_f]) 

192 if delta_f < 0: # B higher in energy --> A-B < 0 

193 w[mask_B] = w[mask_B] * torch.exp(delta_f.to(device)) 

194 elif delta_f > 0: # A higher in energy --> A-B > 0 

195 w[mask_A] = w[mask_A] * torch.exp(-delta_f.to(device)) 

196 

197 # VARIATIONAL PRINICIPLE LOSS 

198 # Each loss contribution is scaled by the number of samples 

199 

200 # We need the gradient of q(x) 

201 grad_outputs = torch.ones_like(q[mask_var]) 

202 grad = torch.autograd.grad(q[mask_var], x, grad_outputs=grad_outputs, retain_graph=True, create_graph=create_graph)[0] 

203 grad = grad[mask_var] 

204 

205 # TODO this fixes cell size issue 

206 if cell is not None: 

207 grad = grad / cell 

208 

209 if descriptors_derivatives is not None: 

210 grad_square = descriptors_derivatives(grad) 

211 else: 

212 # we get the square of grad(q) and we multiply by the weight 

213 grad_square = torch.pow(grad, 2) 

214 

215 # we sanitize the shapes of mass and weights tensors 

216 # mass should have size [1, n_atoms*spatial_dims] 

217 mass = mass.unsqueeze(0) 

218 # weights should have size [n_batch, 1] 

219 w = w.unsqueeze(-1) 

220 

221 grad_square = torch.sum((grad_square * (1 / mass)), axis=1, keepdim=True) 

222 grad_square = grad_square * w[mask_var] 

223 

224 # variational contribution to loss: we sum over the batch 

225 loss_var = torch.mean(grad_square) 

226 

227 # boundary conditions 

228 q_A = q[mask_A] 

229 q_B = q[mask_B] 

230 loss_A = torch.mean(torch.pow(q_A, 2)) 

231 loss_B = torch.mean(torch.pow((q_B - 1), 2)) 

232 

233 loss = gamma * (loss_var + alpha * (loss_A + loss_B)) 

234 

235 # TODO maybe there is no need to detach them for logging 

236 return loss, gamma * loss_var.detach(), alpha * gamma * loss_A.detach(), alpha * gamma * loss_B.detach() 

237 

238 

239class SmartDerivatives(torch.nn.Module): 

240 """ 

241 Utils to compute efficently (time and memory wise) the derivatives of q wrt some input descriptors. 

242 Rather than computing explicitly the derivatives wrt the positions, we compute those wrt the descriptors (right input) 

243 and multiply them by the matrix of the derivatives of the descriptors wrt the positions (left input). 

244 """ 

245 

246 def __init__(self, 

247 der_desc_wrt_pos: torch.Tensor, 

248 n_atoms: int, 

249 setup_device: str = 'cpu' 

250 ): 

251 """Initialize the fixed matrices for smart derivatives, i.e. matrix of derivatives of descriptors wrt positions. 

252 The derivatives wrt positions are recovered by multiplying the derivatives of q wrt the descriptors (right input, computed at each epoch) 

253 by the non-zero elements of the derivatives of the descriptors wrt the positions (left input, compute once at the beginning on the whole dataset). 

254 The multiplication are done using scatte functions and keepoing track of the indeces of the batches, descriptors, atoms and dimensions. 

255 

256 NB. It should be used with only training set and single batch with shuffle and random_split disabled. 

257 

258 Parameters 

259 ---------- 

260 der_desc_wrt_pos : torch.Tensor 

261 Tensor containing the derivatives of the descriptors wrt the atomic positions 

262 n_atoms : int 

263 Number of atoms in the systems, all the atoms should be used in at least one of the descriptors 

264 setup_device : str 

265 Device on which to perform the expensive calculations. Either 'cpu' or 'cuda', by default 'cpu' 

266 """ 

267 super().__init__() 

268 self.batch_size = len(der_desc_wrt_pos) 

269 self.n_atoms = n_atoms 

270 

271 # setup the fixed part of the computation, i.e. left input and indeces for the scatter 

272 self.left, self.mat_ind, self.scatter_indeces = self._setup_left(der_desc_wrt_pos, setup_device=setup_device) 

273 

274 def _setup_left(self, left_input: torch.Tensor, setup_device: str = 'cpu'): 

275 """Setup the fixed part of the computation, i.e. left input""" 

276 # all the setup should be done on the CPU by defualt 

277 left_input = left_input.to(torch.device(setup_device)) 

278 

279 # the indeces in mat_ind are: batch, atom, descriptor and dimension 

280 left, mat_ind = self._create_nonzero_left(left_input) 

281 

282 # it is possible that some atoms are not involved in anything 

283 n_effective_atoms = len(torch.unique(mat_ind[1])) 

284 if n_effective_atoms < self.n_atoms: 

285 raise ValueError(f"Some of the input atoms are useless LOL. The not used atom IDs are : {[i for i in range(self.n_atoms) if i not in torch.unique(mat_ind[1]).numpy()]} ") 

286 

287 scatter_indeces = self._get_scatter_indices(batch_ind=mat_ind[0], atom_ind=mat_ind[1], dim_ind=mat_ind[3]) 

288 return left, mat_ind, scatter_indeces 

289 

290 def _create_nonzero_left(self, x): 

291 """Find the indeces of the non-zero elements of the left input 

292 """ 

293 # find indeces of nonzero entries of d_dist_d_x 

294 mat_ind = x.nonzero(as_tuple=True) 

295 

296 # flatten matrix --> big nonzero vector 

297 x = x.ravel() 

298 # find indeces of nonzero entries of the flattened matrix 

299 vec_ind = x.nonzero(as_tuple=True) 

300 

301 # create vector with the nonzero entries only 

302 x_vec = x[vec_ind[0].long()] 

303 

304 # del(vec_ind) 

305 return x_vec, mat_ind 

306 

307 def _get_scatter_indices(self, batch_ind, atom_ind, dim_ind): 

308 """Compute the general indices to map the long vector of nonzero derivatives to the right atom, dimension and descriptor also in the case of non homogenoeus input. 

309 We need to gather the derivatives with respect to the same atom coming from different descriptors to obtain the total gradient. 

310 """ 

311 # ====================================== INITIAL NOTE ====================================== 

312 # in the comment there's the example of the distances in a 3 atoms system with 4 batches 

313 # i.e. 3desc*3*atom*3dim*2pairs*4batch = 72 values needs to be mappped to 3atoms*3dims*4batch = 36 

314 

315 # Ref_idx: tensor([ 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, 7, 8, 

316 # 9, 10, 11, 9, 10, 11, 12, 13, 14, 12, 13, 14, 15, 16, 17, 15, 16, 17, 

317 # 18, 19, 20, 18, 19, 20, 21, 22, 23, 21, 22, 23, 24, 25, 26, 24, 25, 26, 

318 # 27, 28, 29, 27, 28, 29, 30, 31, 32, 30, 31, 32, 33, 34, 35, 33, 34, 35]) 

319 # ========================================================================================== 

320 

321 # these would be the indeces in the case of uniform batches and number of atom/descriptor dependence 

322 # it just repeats the atom index in a cycle 

323 # e.g. [0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, 7, 8, 0, 1, 2, 0, 1, 2, 

324 # 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, 7, 8, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 

325 # 6, 7, 8, 6, 7, 8, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, 7, 8] 

326 not_shifted_indeces = atom_ind * 3 + dim_ind 

327 

328 # get the number of elements in each batch 

329 # e.g. [17, 18, 18, 18] 

330 batch_elements = scatter_sum(torch.ones_like(batch_ind), batch_ind) 

331 batch_elements[0] -= 1 # to make the later indexing consistent 

332 

333 # compute the pointer idxs to the beginning of each batch by summing the number of elements in each batch 

334 # e.g. [ 0., 17., 35., 53.] NB. These are indeces! 

335 batch_pointers = torch.Tensor([batch_elements[:i].sum() for i in range(len(batch_elements))]) 

336 del (batch_elements) 

337 

338 # number of entries in the scattered vector before each batch 

339 # e.g. [ 0., 9., 18., 27.] 

340 markers = not_shifted_indeces[batch_pointers.long()] # highest not_shifted index for each batch 

341 del (not_shifted_indeces) 

342 del (batch_pointers) 

343 cumulative_markers = torch.Tensor([markers[:i + 1].sum() for i in range(len(markers))]).to(batch_ind.device) # stupid sum of indeces 

344 del (markers) 

345 cumulative_markers += torch.unique(batch_ind) # markers correction by the number of batches 

346 

347 # get the index shift in the scattered vector based on the batch 

348 # e.g. [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9, 

349 # 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 

350 # 18, 18, 18, 18, 18, 18, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27 ] 

351 batch_shift = torch.gather(cumulative_markers, 0, batch_ind) 

352 del (cumulative_markers) 

353 

354 # finally compute the scatter indeces by including also their shift due to the batch 

355 shifted_indeces = atom_ind * 3 + dim_ind + batch_shift 

356 

357 return shifted_indeces 

358 

359 def forward(self, x: torch.Tensor): 

360 # ensure device consistency 

361 left = self.left.to(x.device) 

362 

363 # get the vector with the derivatives of q wrt the descriptors 

364 right = self._create_right(x=x, batch_ind=self.mat_ind[0], des_ind=self.mat_ind[2]) 

365 

366 # do element-wise product 

367 src = left * right 

368 

369 # compute square modulus 

370 out = self._compute_square_modulus(x=src, indeces=self.scatter_indeces, n_atoms=self.n_atoms, batch_size=self.batch_size) 

371 

372 return out 

373 

374 def _create_right(self, x: torch.Tensor, batch_ind: torch.Tensor, des_ind: torch.Tensor): 

375 # keep only the non zero elements of right input 

376 desc_vec = x[batch_ind, des_ind] 

377 return desc_vec 

378 

379 def _compute_square_modulus(self, x: torch.Tensor, indeces: torch.Tensor, n_atoms: int, batch_size: torch.Tensor): 

380 indeces = indeces.long().to(x.device) 

381 

382 # this sums the elements of x according to the indeces, this way we get the contributions of different descriptors to the same atom 

383 out = scatter_sum(x, indeces.long()) 

384 # now make the square 

385 out = out.pow(2) 

386 # reshape, this needs to have the correct number of atoms as we need to mulply it by the mass vector later 

387 out = out.reshape((batch_size, n_atoms * 3)) 

388 return out 

389 

390 

391def compute_descriptors_derivatives(dataset, descriptor_function, n_atoms, separate_boundary_dataset=True): 

392 """Compute the derivatives of a set of descriptors wrt input positions in a dataset for committor optimization 

393 

394 Parameters 

395 ---------- 

396 dataset : 

397 DictDataset with the positions under the 'data' key 

398 descriptor_function : torch.nn.Module 

399 Transform module for the computation of the descriptors 

400 n_atoms : int 

401 Number of atoms in the system 

402 separate_boundary_dataset : bool, optional 

403 Switch to exculde boundary condition labeled data from the variational loss, by default True 

404 

405 Returns 

406 ------- 

407 desc : torch.Tensor 

408 Computed descriptors 

409 d_desc_d_pos : torch.Tensor 

410 Derivatives of desc wrt to pos 

411 """ 

412 pos = dataset['data'] 

413 labels = dataset['labels'] 

414 pos = sanitize_positions_shape(pos=pos, n_atoms=n_atoms)[0] 

415 pos.requires_grad = True 

416 

417 desc = descriptor_function(pos) 

418 if separate_boundary_dataset: 

419 mask_var = torch.nonzero(labels.squeeze() > 1, as_tuple=True)[0] 

420 der_desc = desc[mask_var] 

421 if len(der_desc) == 0: 

422 raise (ValueError('No points left after separating boundary and variational datasets. \n If you are using only unbiased data set separate_boundary_dataset=False here and in Committor or don\'t use SmartDerivatives!!')) 

423 else: 

424 der_desc = desc 

425 

426 # compute derivatives of descriptors wrt positions, loop over the number of decriptors 

427 aux = [] 

428 for i in range(len(der_desc[0])): 

429 aux_der = torch.autograd.grad(der_desc[:, i], pos, grad_outputs=torch.ones_like(der_desc[:, i]), retain_graph=True)[0] 

430 if separate_boundary_dataset: 

431 aux_der = aux_der[mask_var] 

432 aux.append(aux_der) 

433 

434 d_desc_d_pos = torch.stack(aux, axis=2) 

435 return pos, desc, d_desc_d_pos.squeeze(-1) 

436 

437 

438def test_smart_derivatives(): 

439 from mlcolvar.core.transform import PairwiseDistances 

440 from mlcolvar.core.nn import FeedForward 

441 from mlcolvar.data import DictDataset 

442 

443 # compute some descriptors from positions --> distances 

444 n_atoms = 10 

445 pos = torch.Tensor([[1.4970, 1.3861, -0.0273, -1.4933, 1.5070, -0.1133, -1.4473, -1.4193, 

446 -0.0553, 1.4940, 1.4990, -0.2403, 1.4780, -1.4173, -0.3363, -1.4243, 

447 -1.4093, -0.4293, 1.3530, -1.4313, -0.4183, 1.3060, 1.4750, -0.4333, 

448 1.2970, -1.3233, -0.4643, 1.1670, -1.3253, -0.5354]]) 

449 pos = pos.repeat(4, 1) 

450 labels = torch.arange(0, 4) 

451 

452 dataset = DictDataset({'data': pos, 'labels': labels}) 

453 

454 cell = torch.Tensor([3.0233]) 

455 ref_distances = torch.Tensor([[0.1521, 0.2335, 0.2412, 0.3798, 0.4733, 0.4649, 0.4575, 0.5741, 0.6815, 

456 0.1220, 0.1323, 0.2495, 0.3407, 0.3627, 0.3919, 0.4634, 0.5885, 0.2280, 

457 0.2976, 0.3748, 0.4262, 0.4821, 0.5043, 0.6376, 0.1447, 0.2449, 0.2454, 

458 0.2705, 0.3597, 0.4833, 0.1528, 0.1502, 0.2370, 0.2408, 0.3805, 0.2472, 

459 0.3243, 0.3159, 0.4527, 0.1270, 0.1301, 0.2440, 0.2273, 0.2819, 0.1482]]) 

460 ref_distances = ref_distances.repeat(4, 1) 

461 

462 ComputeDescriptors = PairwiseDistances(n_atoms=n_atoms, 

463 PBC=True, 

464 cell=cell, 

465 scaled_coords=False) 

466 

467 for separate_boundary_dataset in [False, True]: 

468 if separate_boundary_dataset: 

469 mask = [labels > 1] 

470 else: 

471 mask = torch.ones_like(labels, dtype=torch.bool) 

472 

473 pos, desc, d_desc_d_x = compute_descriptors_derivatives(dataset=dataset, 

474 descriptor_function=ComputeDescriptors, 

475 n_atoms=n_atoms, 

476 separate_boundary_dataset=separate_boundary_dataset) 

477 

478 assert (torch.allclose(desc, ref_distances, atol=1e-3)) 

479 

480 # apply simple NN 

481 NN = FeedForward(layers=[45, 2, 1]) 

482 out = NN(desc) 

483 

484 # compute derivatives of out wrt input 

485 d_out_d_x = torch.autograd.grad(out, pos, grad_outputs=torch.ones_like(out), retain_graph=True, create_graph=True)[0] 

486 # compute derivatives of out wrt descriptors 

487 d_out_d_d = torch.autograd.grad(out, desc, grad_outputs=torch.ones_like(out), retain_graph=True, create_graph=True)[0] 

488 ref = torch.einsum('badx,bd->bax ', d_desc_d_x, d_out_d_d[mask]) 

489 ref = ref.pow(2).sum(dim=(-2, -1)) 

490 

491 Ref = d_out_d_x[mask].pow(2).sum(dim=(-2, -1)) 

492 

493 # apply smart derivatives 

494 smart_derivatives = SmartDerivatives(d_desc_d_x, n_atoms=n_atoms) 

495 right_input = d_out_d_d.squeeze(-1) 

496 smart_out = smart_derivatives(right_input).sum(dim=1) 

497 

498 # do checks 

499 assert (torch.allclose(smart_out, ref)) 

500 assert (torch.allclose(smart_out, Ref)) 

501 

502 smart_out.sum().backward() 

503 

504 

505if __name__ == "__main__": 

506 test_smart_derivatives()