Coverage for biobb_pytorch / mdae / loss / tda_loss.py: 18%
45 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1#!/usr/bin/env python
3# =============================================================================
4# MODULE DOCSTRING
5# =============================================================================
7"""
8Target Discriminant Analysis Loss Function.
9"""
11__all__ = ["TDALoss", "tda_loss"]
14# =============================================================================
15# GLOBAL IMPORTS
16# =============================================================================
18from typing import Union
19from warnings import warn
21import torch
24# =============================================================================
25# LOSS FUNCTIONS
26# =============================================================================
29class TDALoss(torch.nn.Module):
30 """Compute a loss function as the distance from a simple Gaussian target distribution."""
32 def __init__(
33 self,
34 n_states: int,
35 target_centers: Union[list, torch.Tensor],
36 target_sigmas: Union[list, torch.Tensor],
37 alpha: float = 1,
38 beta: float = 100,
39 ):
40 """Constructor.
42 Parameters
43 ----------
44 n_states : int
45 Number of states. The integer labels are expected to be in between 0
46 and ``n_states-1``.
47 target_centers : list or torch.Tensor
48 Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets.
49 target_sigmas : list or torch.Tensor
50 Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets.
51 alpha : float, optional
52 Centers_loss component prefactor, by default 1.
53 beta : float, optional
54 Sigmas loss compontent prefactor, by default 100.
55 """
56 super().__init__()
57 self.n_states = n_states
58 if not isinstance(target_centers, torch.Tensor):
59 target_centers = torch.Tensor(target_centers)
60 if not isinstance(target_sigmas, torch.Tensor):
61 target_sigmas = torch.Tensor(target_sigmas)
62 self.register_buffer("target_centers", target_centers)
63 self.register_buffer("target_sigmas", target_sigmas)
64 self.alpha = alpha
65 self.beta = beta
67 def forward(
68 self, H: torch.Tensor, labels: torch.Tensor, return_loss_terms: bool = False
69 ) -> torch.Tensor:
70 """Compute the value of the loss function.
72 Parameters
73 ----------
74 H : torch.Tensor
75 Shape ``(n_batches, n_features)``. Output of the NN.
76 labels : torch.Tensor
77 Shape ``(n_batches,)``. Labels of the dataset.
78 return_loss_terms : bool, optional
79 If ``True``, the loss terms associated to the center and standard
80 deviations of the target Gaussians are returned as well. Default
81 is ``False``.
83 Returns
84 -------
85 loss : torch.Tensor
86 Loss value.
87 loss_centers : torch.Tensor, optional
88 Only returned if ``return_loss_terms is True``. The value of the
89 loss term associated to the centers of the target Gaussians.
90 loss_sigmas : torch.Tensor, optional
91 Only returned if ``return_loss_terms is True``. The value of the
92 loss term associated to the standard deviations of the target Gaussians.
93 """
94 return tda_loss(
95 H,
96 labels,
97 self.n_states,
98 self.target_centers,
99 self.target_sigmas,
100 self.alpha,
101 self.beta,
102 return_loss_terms,
103 )
106def tda_loss(
107 H: torch.Tensor,
108 labels: torch.Tensor,
109 n_states: int,
110 target_centers: Union[list, torch.Tensor],
111 target_sigmas: Union[list, torch.Tensor],
112 alpha: float = 1,
113 beta: float = 100,
114 return_loss_terms: bool = False,
115) -> torch.Tensor:
116 """
117 Compute a loss function as the distance from a simple Gaussian target distribution.
119 Parameters
120 ----------
121 H : torch.Tensor
122 Shape ``(n_batches, n_cvs)``. Output of the NN.
123 labels : torch.Tensor
124 Shape ``(n_batches,)``. Labels of the dataset.
125 n_states : int
126 The integer labels are expected to be in between 0 and ``n_states-1``.
127 target_centers : list or torch.Tensor
128 Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets.
129 target_sigmas : list or torch.Tensor
130 Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets.
131 alpha : float, optional
132 Centers_loss component prefactor, by default 1.
133 beta : float, optional
134 Sigmas loss compontent prefactor, by default 100.
135 return_loss_terms : bool, optional
136 If ``True``, the loss terms associated to the center and standard deviations
137 of the target Gaussians are returned as well. Default is ``False``.
139 Returns
140 -------
141 loss : torch.Tensor
142 Loss value.
143 loss_centers : torch.Tensor, optional
144 Only returned if ``return_loss_terms is True``. The value of the loss
145 term associated to the centers of the target Gaussians.
146 loss_sigmas : torch.Tensor, optional
147 Only returned if ``return_loss_terms is True``. The value of the loss
148 term associated to the standard deviations of the target Gaussians.
149 """
150 if not isinstance(target_centers, torch.Tensor):
151 target_centers = torch.Tensor(target_centers)
152 if not isinstance(target_sigmas, torch.Tensor):
153 target_sigmas = torch.Tensor(target_sigmas)
155 device = H.device
156 target_centers = target_centers.to(device)
157 target_sigmas = target_sigmas.to(device)
158 loss_centers = torch.zeros_like(target_centers, device=device)
159 loss_sigmas = torch.zeros_like(target_sigmas, device=device)
161 for i in range(n_states):
162 # check which elements belong to class i
163 if not (labels == i).any():
164 raise ValueError(
165 f"State {i} was not represented in this batch! Either use bigger batch_size or a more equilibrated dataset composition!"
166 )
167 else:
168 H_red = H[torch.nonzero(labels == i, as_tuple=True)]
170 # compute mean and standard deviation over the class i
171 mu = torch.mean(H_red, 0)
172 if len(torch.nonzero(labels == i)) == 1:
173 warn(
174 f"There is only one sample for state {i} in this batch! Std is set to 0, this may affect the training! Either use bigger batch_size or a more equilibrated dataset composition!"
175 )
176 sigma = 0
177 else:
178 sigma = torch.std(H_red, 0)
180 # compute loss function contributes for class i
181 loss_centers[i] = alpha * (mu - target_centers[i]).pow(2)
182 loss_sigmas[i] = beta * (sigma - target_sigmas[i]).pow(2)
184 # get total model loss
185 loss_centers = torch.sum(loss_centers)
186 loss_sigmas = torch.sum(loss_sigmas)
187 loss = loss_centers + loss_sigmas
189 if return_loss_terms:
190 return loss, loss_centers, loss_sigmas
191 return loss