Coverage for biobb_pytorch / mdae / loss / mse.py: 100%
18 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1#!/usr/bin/env python
3# =============================================================================
4# MODULE DOCSTRING
5# =============================================================================
7"""
8(Weighted) Mean Squared Error (MSE) loss function.
9"""
11__all__ = ["MSELoss", "mse_loss"]
14# =============================================================================
15# GLOBAL IMPORTS
16# =============================================================================
18from typing import Optional
20import torch
23# =============================================================================
24# LOSS FUNCTIONS
25# =============================================================================
28class MSELoss(torch.nn.Module):
29 """(Weighted) Mean Square Error"""
31 def forward(
32 self,
33 input: torch.Tensor,
34 target: torch.Tensor,
35 weights: Optional[torch.Tensor] = None,
36 ) -> torch.Tensor:
37 """Compute the value of the loss function."""
38 return mse_loss(input, target, weights)
41def mse_loss(
42 input: torch.Tensor, target: torch.Tensor, weights: Optional[torch.Tensor] = None
43) -> torch.Tensor:
44 """(Weighted) Mean Square Error
46 Parameters
47 ----------
48 input : torch.Tensor
49 prediction
50 target : torch.Tensor
51 reference
52 weights : torch.Tensor, optional
53 sample weights, by default None
55 Returns
56 -------
57 loss: torch.Tensor
58 loss function
59 """
60 # reshape in the correct format (batch, size)
61 if input.ndim == 1:
62 input = input.unsqueeze(1)
63 if target.ndim == 1:
64 target = target.unsqueeze(1)
65 # take the different
66 diff = input - target
67 # weight them
68 if weights is not None:
69 if weights.ndim == 1:
70 weights = weights.unsqueeze(1)
71 loss = (diff * weights).square().mean()
72 else:
73 loss = diff.square().mean()
74 return loss