Coverage for biobb_pytorch / mdae / loss / mse.py: 100%

18 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1#!/usr/bin/env python 

2 

3# ============================================================================= 

4# MODULE DOCSTRING 

5# ============================================================================= 

6 

7""" 

8(Weighted) Mean Squared Error (MSE) loss function. 

9""" 

10 

11__all__ = ["MSELoss", "mse_loss"] 

12 

13 

14# ============================================================================= 

15# GLOBAL IMPORTS 

16# ============================================================================= 

17 

18from typing import Optional 

19 

20import torch 

21 

22 

23# ============================================================================= 

24# LOSS FUNCTIONS 

25# ============================================================================= 

26 

27 

28class MSELoss(torch.nn.Module): 

29 """(Weighted) Mean Square Error""" 

30 

31 def forward( 

32 self, 

33 input: torch.Tensor, 

34 target: torch.Tensor, 

35 weights: Optional[torch.Tensor] = None, 

36 ) -> torch.Tensor: 

37 """Compute the value of the loss function.""" 

38 return mse_loss(input, target, weights) 

39 

40 

41def mse_loss( 

42 input: torch.Tensor, target: torch.Tensor, weights: Optional[torch.Tensor] = None 

43) -> torch.Tensor: 

44 """(Weighted) Mean Square Error 

45 

46 Parameters 

47 ---------- 

48 input : torch.Tensor 

49 prediction 

50 target : torch.Tensor 

51 reference 

52 weights : torch.Tensor, optional 

53 sample weights, by default None 

54 

55 Returns 

56 ------- 

57 loss: torch.Tensor 

58 loss function 

59 """ 

60 # reshape in the correct format (batch, size) 

61 if input.ndim == 1: 

62 input = input.unsqueeze(1) 

63 if target.ndim == 1: 

64 target = target.unsqueeze(1) 

65 # take the different 

66 diff = input - target 

67 # weight them 

68 if weights is not None: 

69 if weights.ndim == 1: 

70 weights = weights.unsqueeze(1) 

71 loss = (diff * weights).square().mean() 

72 else: 

73 loss = diff.square().mean() 

74 return loss