Coverage for biobb_pytorch / mdae / loss / eigvals.py: 61%
36 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1#!/usr/bin/env python
3# =============================================================================
4# MODULE DOCSTRING
5# =============================================================================
7"""
8Reduce eigenvalues loss.
9"""
11__all__ = ["ReduceEigenvaluesLoss", "reduce_eigenvalues_loss"]
14# =============================================================================
15# GLOBAL IMPORTS
16# =============================================================================
18import torch
21# =============================================================================
22# LOSS FUNCTIONS
23# =============================================================================
26class ReduceEigenvaluesLoss(torch.nn.Module):
27 """Calculate a monotonic function f(x) of the eigenvalues, by default the sum.
29 By default it returns -f(x) to be used as loss function to maximize
30 eigenvalues in gradient descent schemes.
32 The following reduce functions are implemented:
33 - sum : sum_i (lambda_i)
34 - sum2 : sum_i (lambda_i)**2
35 - gap : (lambda_1-lambda_2)
36 - its : sum_i (1/log(lambda_i))
37 - single : (lambda_i)
38 - single2 : (lambda_i)**2
40 """
42 def __init__(
43 self,
44 mode: str = "sum",
45 n_eig: int = 0,
46 invert_sign: bool = True,
47 ):
48 """Constructor.
50 Parameters
51 ----------
52 mode : str, optional
53 Function of the eigenvalues to optimize (see notes). Default is ``'sum'``.
54 n_eig: int, optional
55 Number of eigenvalues to include in the loss (default: 0 --> all).
56 In case of ``'single'`` and ``'single2'`` is used to specify which
57 eigenvalue to use.
58 invert_sign: bool, optional
59 Whether to return the opposite of the function (in order to be minimized
60 with GD methods). Default is ``True``.
61 """
62 super().__init__()
63 self.mode = mode
64 self.n_eig = n_eig
65 self.invert_sign = invert_sign
67 def forward(self, evals: torch.Tensor) -> torch.Tensor:
68 """Compute the loss.
70 Parameters
71 ----------
72 evals : torch.Tensor
73 Shape ``(n_batches, n_eigenvalues)``. Eigenvalues.
75 Returns
76 -------
77 loss : torch.Tensor
78 """
79 return reduce_eigenvalues_loss(evals, self.mode, self.n_eig, self.invert_sign)
82def reduce_eigenvalues_loss(
83 evals: torch.Tensor,
84 mode: str = "sum",
85 n_eig: int = 0,
86 invert_sign: bool = True,
87) -> torch.Tensor:
88 """Calculate a monotonic function f(x) of the eigenvalues, by default the sum.
90 By default it returns -f(x) to be used as loss function to maximize
91 eigenvalues in gradient descent schemes.
93 Parameters
94 ----------
95 evals : torch.Tensor
96 Shape ``(n_batches, n_eigenvalues)``. Eigenvalues.
97 mode : str, optional
98 Function of the eigenvalues to optimize (see notes). Default is ``'sum'``.
99 n_eig: int, optional
100 Number of eigenvalues to include in the loss (default: 0 --> all).
101 In case of ``'single'`` and ``'single2'`` is used to specify which
102 eigenvalue to use.
103 invert_sign: bool, optional
104 Whether to return the opposite of the function (in order to be minimized
105 with GD methods). Default is ``True``.
107 Notes
108 -----
109 The following functions are implemented:
110 - sum : sum_i (lambda_i)
111 - sum2 : sum_i (lambda_i)**2
112 - gap : (lambda_1-lambda_2)
113 - its : sum_i (1/log(lambda_i))
114 - single : (lambda_i)
115 - single2 : (lambda_i)**2
117 Returns
118 -------
119 loss : torch.Tensor (scalar)
120 Loss value.
121 """
123 # check if n_eig is given and
124 if (n_eig > 0) & (len(evals) < n_eig):
125 raise ValueError("n_eig must be lower than the number of eigenvalues.")
126 elif n_eig == 0:
127 if (mode == "single") | (mode == "single2"):
128 raise ValueError("n_eig must be specified when using single or single2.")
129 else:
130 n_eig = len(evals)
132 loss = None
134 if mode == "sum":
135 loss = torch.sum(evals[:n_eig])
136 elif mode == "sum2":
137 g_lambda = torch.pow(evals, 2)
138 loss = torch.sum(g_lambda[:n_eig])
139 elif mode == "gap":
140 loss = evals[0] - evals[1]
141 elif mode == "its":
142 g_lambda = 1 / torch.log(evals)
143 loss = torch.sum(g_lambda[:n_eig])
144 elif mode == "single":
145 loss = evals[n_eig - 1]
146 elif mode == "single2":
147 loss = torch.pow(evals[n_eig - 1], 2)
148 else:
149 raise ValueError(
150 f"unknown mode : {mode}. options: 'sum','sum2','gap','single','its'."
151 )
153 if invert_sign:
154 loss *= -1
156 return loss