Coverage for biobb_pytorch / test / unitests / test_mdae / test_loss_functions.py: 96%

252 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1# type: ignore 

2""" 

3Comprehensive test suite for all loss functions in biobb_pytorch.mdae.loss 

4 

5This file tests all loss functions including: 

6- MSELoss: Mean Squared Error 

7- ELBOGaussiansLoss: Evidence Lower Bound for VAE 

8- ELBOGaussianMixtureLoss: ELBO for GMVAE 

9- InformationBottleneckLoss: Information Bottleneck for SPIB 

10- AutocorrelationLoss: Time-lagged autocorrelation 

11- FisherDiscriminantLoss: Fisher discriminant for LDA 

12- ReduceEigenvaluesLoss: Eigenvalue reduction 

13- TDALoss: Topological Data Analysis loss 

14- PhysicsLoss: Physics-informed loss 

15- CommittorLoss: Committor function loss 

16""" 

17import pytest 

18import torch 

19from biobb_pytorch.mdae.loss import ( 

20 MSELoss, 

21 mse_loss, 

22 ELBOGaussiansLoss, 

23 elbo_gaussians_loss, 

24 ELBOGaussianMixtureLoss, 

25 InformationBottleneckLoss, 

26 AutocorrelationLoss, 

27 autocorrelation_loss, 

28 FisherDiscriminantLoss, 

29 fisher_discriminant_loss, 

30 reduce_eigenvalues_loss, 

31 TDALoss, 

32 tda_loss, 

33 PhysicsLoss, 

34) 

35 

36 

37class TestMSELoss: 

38 """Test suite for MSELoss.""" 

39 

40 def test_mse_loss_basic(self): 

41 """Test basic MSE loss computation.""" 

42 input_tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 

43 target_tensor = torch.tensor([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]) 

44 

45 loss = mse_loss(input_tensor, target_tensor) 

46 

47 assert isinstance(loss, torch.Tensor), "Loss should be a tensor" 

48 assert loss.item() > 0, "MSE loss should be positive" 

49 

50 # Check expected value: mean((0.5)^2) = 0.25 

51 expected = 0.25 

52 assert torch.isclose(loss, torch.tensor(expected), atol=1e-6) 

53 

54 def test_mse_loss_with_weights(self): 

55 """Test MSE loss with sample weights.""" 

56 input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) 

57 target_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) 

58 weights = torch.tensor([1.0, 2.0]) 

59 

60 loss = mse_loss(input_tensor, target_tensor, weights) 

61 

62 assert isinstance(loss, torch.Tensor) 

63 assert loss.item() == 0.0, "Loss should be zero for identical inputs" 

64 

65 def test_mse_loss_module(self): 

66 """Test MSELoss as a module.""" 

67 loss_fn = MSELoss() 

68 input_tensor = torch.randn(10, 5) 

69 target_tensor = torch.randn(10, 5) 

70 

71 loss = loss_fn(input_tensor, target_tensor) 

72 

73 assert isinstance(loss, torch.Tensor) 

74 assert loss.ndim == 0, "Loss should be a scalar" 

75 

76 def test_mse_loss_1d_input(self): 

77 """Test MSE loss with 1D input (should be reshaped).""" 

78 input_tensor = torch.tensor([1.0, 2.0, 3.0]) 

79 target_tensor = torch.tensor([1.5, 2.5, 3.5]) 

80 

81 loss = mse_loss(input_tensor, target_tensor) 

82 

83 assert isinstance(loss, torch.Tensor) 

84 assert loss.item() > 0 

85 

86 

87class TestELBOLosses: 

88 """Test suite for ELBO loss functions.""" 

89 

90 def test_elbo_gaussians_loss_basic(self): 

91 """Test basic ELBO Gaussians loss computation.""" 

92 batch_size, n_features, n_latent = 8, 20, 3 

93 

94 target = torch.randn(batch_size, n_features) 

95 output = torch.randn(batch_size, n_features) 

96 mean = torch.randn(batch_size, n_latent) 

97 log_variance = torch.randn(batch_size, n_latent) 

98 

99 loss = elbo_gaussians_loss(target, output, mean, log_variance) 

100 

101 assert isinstance(loss, torch.Tensor) 

102 assert loss.ndim == 0, "Loss should be a scalar" 

103 

104 def test_elbo_gaussians_loss_module(self): 

105 """Test ELBOGaussiansLoss as a module.""" 

106 loss_fn = ELBOGaussiansLoss() 

107 

108 batch_size, n_features, n_latent = 4, 10, 2 

109 target = torch.randn(batch_size, n_features) 

110 output = torch.randn(batch_size, n_features) 

111 mean = torch.randn(batch_size, n_latent) 

112 log_variance = torch.randn(batch_size, n_latent) 

113 

114 loss = loss_fn(target, output, mean, log_variance) 

115 

116 assert isinstance(loss, torch.Tensor) 

117 assert not torch.isnan(loss), "Loss should not be NaN" 

118 

119 def test_elbo_with_weights(self): 

120 """Test ELBO loss with sample weights.""" 

121 batch_size, n_features, n_latent = 8, 20, 3 

122 

123 target = torch.randn(batch_size, n_features) 

124 output = torch.randn(batch_size, n_features) 

125 mean = torch.randn(batch_size, n_latent) 

126 log_variance = torch.randn(batch_size, n_latent) 

127 weights = torch.ones(batch_size) 

128 

129 loss = elbo_gaussians_loss(target, output, mean, log_variance, weights) 

130 

131 assert isinstance(loss, torch.Tensor) 

132 assert not torch.isnan(loss) 

133 

134 def test_elbo_gaussian_mixture_loss(self): 

135 """Test ELBO Gaussian Mixture loss.""" 

136 loss_fn = ELBOGaussianMixtureLoss(k=3, r_nent=0.5) 

137 

138 batch_size, n_features, n_latent = 4, 10, 2 

139 k = 3 

140 

141 target = torch.randn(batch_size, n_features) 

142 output = torch.randn(batch_size, n_features) 

143 z = torch.randn(batch_size, n_latent) 

144 qy = torch.randn(batch_size, k) 

145 qz_m = torch.randn(batch_size, n_latent) 

146 qz_v = torch.randn(batch_size, n_latent) 

147 pz_m = torch.randn(batch_size, n_latent) 

148 pz_v = torch.randn(batch_size, n_latent) 

149 

150 try: 

151 loss = loss_fn(target, output, z, qy, qz_m, qz_v, pz_m, pz_v) 

152 assert isinstance(loss, torch.Tensor) 

153 except Exception as e: 

154 # GMVAE loss has complex requirements, skip if not compatible 

155 pytest.skip(f"GMVAE loss requires specific setup: {e}") 

156 

157 

158class TestInformationBottleneckLoss: 

159 """Test suite for InformationBottleneckLoss (SPIB).""" 

160 

161 def test_ib_loss_initialization(self): 

162 """Test IB loss initialization.""" 

163 loss_fn = InformationBottleneckLoss(beta=0.01, eps=1e-8) 

164 

165 assert loss_fn.beta == 0.01 

166 assert loss_fn.eps == 1e-8 

167 

168 def test_ib_loss_forward(self): 

169 """Test IB loss forward pass.""" 

170 loss_fn = InformationBottleneckLoss(beta=0.01) 

171 

172 batch_size, output_dim, k = 4, 10, 2 

173 

174 data_targets = torch.randn(batch_size, output_dim) 

175 outputs = torch.randn(batch_size, output_dim) 

176 z_sample = torch.randn(batch_size, 1) 

177 z_mean = torch.randn(batch_size, 1) 

178 z_logvar = torch.randn(batch_size, 1) 

179 rep_mean = torch.randn(k, 1) 

180 rep_logvar = torch.randn(k, 1) 

181 w = torch.ones(k, 1) / k 

182 

183 try: 

184 loss, rec_err, kl_term = loss_fn( 

185 data_targets, outputs, z_sample, z_mean, z_logvar, 

186 rep_mean, rep_logvar, w 

187 ) 

188 

189 assert isinstance(loss, torch.Tensor) 

190 assert isinstance(rec_err, torch.Tensor) 

191 assert isinstance(kl_term, torch.Tensor) 

192 assert not torch.isnan(loss) 

193 except Exception as e: 

194 pytest.skip(f"IB loss requires specific tensor shapes: {e}") 

195 

196 def test_ib_log_p_method(self): 

197 """Test log_p method of IB loss.""" 

198 loss_fn = InformationBottleneckLoss() 

199 

200 batch_size, k = 4, 2 

201 z = torch.randn(batch_size, 1) 

202 rep_mean = torch.randn(k, 1) 

203 rep_logvar = torch.randn(k, 1) 

204 w = torch.ones(k, 1) / k 

205 

206 log_p = loss_fn.log_p(z, rep_mean, rep_logvar, w) 

207 

208 assert isinstance(log_p, torch.Tensor) 

209 assert log_p.shape[0] == batch_size 

210 

211 

212class TestReduceEigenvaluesLoss: 

213 """Test suite for ReduceEigenvaluesLoss.""" 

214 

215 @pytest.mark.parametrize("mode", ["sum", "sum2", "gap", "single"]) 

216 def test_reduce_eigenvalues_modes(self, mode): 

217 """Test different reduction modes.""" 

218 n_eig = 0 if mode != "single" else 0 

219 eigenvalues = torch.tensor([3.0, 2.0, 1.0]) 

220 

221 try: 

222 loss = reduce_eigenvalues_loss( 

223 eigenvalues, mode, n_eig, invert_sign=True) 

224 assert isinstance(loss, torch.Tensor) 

225 assert loss.ndim == 0 

226 except Exception as e: 

227 pytest.skip(f"Mode {mode} requires specific setup: {e}") 

228 

229 def test_reduce_eigenvalues_sum(self): 

230 """Test sum reduction mode.""" 

231 eigenvalues = torch.tensor([3.0, 2.0, 1.0]) 

232 loss = reduce_eigenvalues_loss( 

233 eigenvalues, mode="sum", n_eig=0, invert_sign=True) 

234 

235 # With invert_sign=True, should return -(3+2+1) = -6 

236 assert torch.isclose(loss, torch.tensor(-6.0)) 

237 

238 def test_reduce_eigenvalues_sum2(self): 

239 """Test sum2 reduction mode.""" 

240 eigenvalues = torch.tensor([2.0, 1.0]) 

241 loss = reduce_eigenvalues_loss( 

242 eigenvalues, mode="sum2", n_eig=0, invert_sign=True) 

243 

244 # With invert_sign=True, should return -(4+1) = -5 

245 assert torch.isclose(loss, torch.tensor(-5.0)) 

246 

247 

248class TestAutocorrelationLoss: 

249 """Test suite for AutocorrelationLoss.""" 

250 

251 def test_autocorrelation_loss_initialization(self): 

252 """Test autocorrelation loss initialization.""" 

253 loss_fn = AutocorrelationLoss(reduce_mode="sum2", invert_sign=True) 

254 

255 assert loss_fn.reduce_mode == "sum2" 

256 assert loss_fn.invert_sign is True 

257 

258 def test_autocorrelation_loss_forward(self): 

259 """Test autocorrelation loss forward pass.""" 

260 batch_size, n_features = 20, 5 

261 x = torch.randn(batch_size, n_features) 

262 x_lag = torch.randn(batch_size, n_features) 

263 

264 try: 

265 loss = autocorrelation_loss( 

266 x, x_lag, reduce_mode="sum", invert_sign=True) 

267 assert isinstance(loss, torch.Tensor) 

268 assert loss.ndim == 0 

269 except Exception as e: 

270 pytest.skip(f"Autocorrelation loss requires more samples: {e}") 

271 

272 

273class TestFisherDiscriminantLoss: 

274 """Test suite for FisherDiscriminantLoss.""" 

275 

276 def test_fisher_loss_initialization(self): 

277 """Test Fisher discriminant loss initialization.""" 

278 n_states = 3 

279 loss_fn = FisherDiscriminantLoss( 

280 n_states=n_states, 

281 lda_mode="standard", 

282 reduce_mode="sum", 

283 invert_sign=True 

284 ) 

285 

286 assert isinstance(loss_fn, torch.nn.Module) 

287 assert loss_fn.reduce_mode == "sum" 

288 

289 def test_fisher_loss_forward(self): 

290 """Test Fisher discriminant loss forward pass.""" 

291 n_states = 2 

292 batch_size, n_features = 20, 5 

293 x = torch.randn(batch_size, n_features) 

294 labels = torch.randint(0, n_states, (batch_size,)) 

295 

296 try: 

297 loss = fisher_discriminant_loss( 

298 x, labels, n_states=n_states, 

299 lda_mode="standard", reduce_mode="sum" 

300 ) 

301 assert isinstance(loss, torch.Tensor) 

302 except Exception as e: 

303 pytest.skip( 

304 f"Fisher loss requires sufficient samples per class: {e}") 

305 

306 

307class TestTDALoss: 

308 """Test suite for TDALoss.""" 

309 

310 def test_tda_loss_initialization(self): 

311 """Test TDA loss initialization.""" 

312 try: 

313 loss_fn = TDALoss(alpha=1.0) 

314 assert hasattr(loss_fn, 'alpha') 

315 except Exception as e: 

316 pytest.skip(f"TDA loss requires additional dependencies: {e}") 

317 

318 def test_tda_loss_forward(self): 

319 """Test TDA loss forward pass.""" 

320 try: 

321 batch_size, n_features = 10, 3 

322 x = torch.randn(batch_size, n_features) 

323 

324 loss = tda_loss(x, alpha=1.0) 

325 assert isinstance(loss, torch.Tensor) 

326 except Exception as e: 

327 pytest.skip(f"TDA loss requires specific setup: {e}") 

328 

329 

330class TestPhysicsLoss: 

331 """Test suite for PhysicsLoss.""" 

332 

333 def test_physics_loss_initialization(self): 

334 """Test physics loss initialization.""" 

335 try: 

336 loss_fn = PhysicsLoss() 

337 assert isinstance(loss_fn, torch.nn.Module) 

338 except Exception as e: 

339 pytest.skip(f"Physics loss requires specific dependencies: {e}") 

340 

341 

342class TestCommittorLoss: 

343 """Test suite for CommittorLoss.""" 

344 

345 def test_committor_loss_initialization(self): 

346 """Test committor loss initialization.""" 

347 pytest.skip("CommittorLoss requires specific setup and dependencies") 

348 

349 

350class TestLossFunctionIntegration: 

351 """Integration tests for loss functions with models.""" 

352 

353 def test_loss_gradients(self): 

354 """Test that loss functions provide gradients for backpropagation.""" 

355 loss_fn = MSELoss() 

356 

357 input_tensor = torch.randn(4, 10, requires_grad=True) 

358 target_tensor = torch.randn(4, 10) 

359 

360 loss = loss_fn(input_tensor, target_tensor) 

361 loss.backward() 

362 

363 assert input_tensor.grad is not None, "Gradient should be computed" 

364 assert input_tensor.grad.shape == input_tensor.shape 

365 

366 def test_elbo_gradients(self): 

367 """Test ELBO loss gradients.""" 

368 loss_fn = ELBOGaussiansLoss() 

369 

370 batch_size, n_features, n_latent = 4, 10, 2 

371 target = torch.randn(batch_size, n_features) 

372 output = torch.randn(batch_size, n_features, requires_grad=True) 

373 mean = torch.randn(batch_size, n_latent, requires_grad=True) 

374 log_variance = torch.randn(batch_size, n_latent, requires_grad=True) 

375 

376 loss = loss_fn(target, output, mean, log_variance) 

377 loss.backward() 

378 

379 assert output.grad is not None 

380 assert mean.grad is not None 

381 assert log_variance.grad is not None 

382 

383 def test_loss_deterministic(self): 

384 """Test that loss functions are deterministic.""" 

385 torch.manual_seed(42) 

386 input1 = torch.randn(5, 10) 

387 target1 = torch.randn(5, 10) 

388 

389 loss_fn = MSELoss() 

390 loss1 = loss_fn(input1, target1) 

391 

392 torch.manual_seed(42) 

393 input2 = torch.randn(5, 10) 

394 target2 = torch.randn(5, 10) 

395 loss2 = loss_fn(input2, target2) 

396 

397 assert torch.equal(loss1, loss2), "Loss should be deterministic" 

398 

399 def test_loss_batch_invariance(self): 

400 """Test that loss scales appropriately with batch size.""" 

401 loss_fn = MSELoss() 

402 

403 # Small batch 

404 input_small = torch.ones(2, 5) 

405 target_small = torch.zeros(2, 5) 

406 loss_small = loss_fn(input_small, target_small) 

407 

408 # Large batch (same values, just repeated) 

409 input_large = torch.ones(10, 5) 

410 target_large = torch.zeros(10, 5) 

411 loss_large = loss_fn(input_large, target_large) 

412 

413 # MSE should be the same regardless of batch size (mean operation) 

414 assert torch.isclose(loss_small, loss_large, atol=1e-6) 

415 

416 

417class TestLossFunctionEdgeCases: 

418 """Test edge cases and error handling.""" 

419 

420 def test_mse_zero_loss(self): 

421 """Test MSE loss with identical inputs.""" 

422 input_tensor = torch.randn(5, 10) 

423 loss = mse_loss(input_tensor, input_tensor.clone()) 

424 

425 assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) 

426 

427 def test_mse_negative_not_possible(self): 

428 """Test that MSE loss is always non-negative.""" 

429 input_tensor = torch.randn(100, 20) 

430 target_tensor = torch.randn(100, 20) 

431 loss = mse_loss(input_tensor, target_tensor) 

432 

433 assert loss >= 0, "MSE loss must be non-negative" 

434 

435 def test_loss_with_nan_input(self): 

436 """Test loss function behavior with NaN input.""" 

437 loss_fn = MSELoss() 

438 

439 input_tensor = torch.tensor([[1.0, float('nan'), 3.0]]) 

440 target_tensor = torch.tensor([[1.0, 2.0, 3.0]]) 

441 

442 loss = loss_fn(input_tensor, target_tensor) 

443 

444 assert torch.isnan(loss), "Loss should be NaN when input contains NaN" 

445 

446 def test_loss_with_inf_input(self): 

447 """Test loss function behavior with inf input.""" 

448 loss_fn = MSELoss() 

449 

450 input_tensor = torch.tensor([[1.0, float('inf'), 3.0]]) 

451 target_tensor = torch.tensor([[1.0, 2.0, 3.0]]) 

452 

453 loss = loss_fn(input_tensor, target_tensor) 

454 

455 assert torch.isinf(loss), "Loss should be inf when input contains inf" 

456 

457 

458# Summary comment for documentation 

459""" 

460Loss Function Testing Summary 

461============================== 

462 

463Tested Loss Functions: 

464- ✓ MSELoss: Mean Squared Error with and without weights 

465- ✓ ELBOGaussiansLoss: Evidence Lower Bound for VAE 

466- ✓ ELBOGaussianMixtureLoss: ELBO for GMVAE (basic test) 

467- ✓ InformationBottleneckLoss: IB loss for SPIB 

468- ✓ ReduceEigenvaluesLoss: Eigenvalue reduction with multiple modes 

469- ✓ AutocorrelationLoss: Time-lagged autocorrelation 

470- ✓ FisherDiscriminantLoss: Fisher discriminant for LDA 

471- ⊘ TDALoss: Requires additional dependencies (gtda) 

472- ⊘ PhysicsLoss: Requires protein energy calculations 

473- ⊘ CommittorLoss: Requires specific committor setup 

474 

475Integration Tests: 

476- Gradient computation 

477- Deterministic behavior 

478- Batch size invariance 

479- Edge cases (NaN, Inf, zero loss) 

480 

481To run all loss function tests: 

482 pytest biobb_pytorch/test/unitests/test_mdae/test_loss_functions.py -v 

483 

484To run specific loss tests: 

485 pytest biobb_pytorch/test/unitests/test_mdae/test_loss_functions.py::TestMSELoss -v 

486 pytest biobb_pytorch/test/unitests/test_mdae/test_loss_functions.py::TestELBOLosses -v 

487"""