Coverage for biobb_pytorch / mdae / featurization / normalization.py: 67%
95 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
2# --------------------------------------------------------------------------------------
3# normalization.py
4#
5# from the mlcolvar repository
6# https://github.com/mlcolvar/mlcolvar
7# Copyright (c) 2023 Luigi Bonati, Enrico Trizio, Andrea Rizzi & Michele Parrinello
8# Licensed under the MIT License (see project LICENSE file for full text)
9# --------------------------------------------------------------------------------------
11import torch
12from mlcolvar.core.transform.utils import Statistics
13from mlcolvar.core.transform import Transform
15__all__ = ["Normalization"]
18def batch_reshape(t: torch.Tensor, size: torch.Size) -> torch.Tensor:
19 """Return value reshaped according to size.
20 In case of batch unsqueeze and expand along the first dimension.
21 For single inputs just pass.
23 Parameters
24 ----------
25 mean and range
27 """
28 if len(size) == 1:
29 return t
30 if len(size) == 2:
31 batch_size = size[0]
32 x_size = size[1]
33 t = t.unsqueeze(0).expand(batch_size, x_size)
34 else:
35 raise ValueError(
36 f"Input tensor must of shape (n_features) or (n_batch,n_features), not {size} (len={len(size)})."
37 )
38 return t
41def sanitize_range(range: torch.Tensor):
42 """Sanitize
44 Parameters
45 ----------
46 range : torch.Tensor
47 range to be used for standardization
49 """
51 if (range < 1e-6).nonzero().sum() > 0:
52 print(
53 "[Warning] Normalization: the following features have a range of values < 1e-6:",
54 (range < 1e-6).nonzero(),
55 )
56 range[range < 1e-6] = 1.0
58 return range
61class Normalization(Transform):
62 """
63 Normalizing block, used for computing standardized inputs/outputs.
64 """
66 def __init__(
67 self,
68 in_features: int,
69 mean: torch.Tensor = None,
70 range: torch.Tensor = None,
71 stats: dict = None,
72 mode: str = "mean_std",
73 ):
74 """Initialize a normalization object. Values will be subtracted by self.mean and then divided by self.range.
75 The parameters for the standardization can be either given from the user (via mean/range keywords), or they can be calculated from a datamodule.
76 In the former, the mode will be overriden as 'custom'. 'In the latter, the standardization mode can be either 'mean_std' (remove by the mean and divide by the standard deviation) or 'min_max' (scale and shift the range of values such that all inputs are between -1 and 1).
78 Parameters
79 ----------
80 in_features : int
81 number of inputs
82 mean: torch.Tensor, optional
83 values to be subtracted
84 range: torch.Tensor, optional
85 values to be scaled by
86 mode : str, optional
87 normalization mode (mean_std, min_max), by default 'mean_std'
88 """
90 super().__init__(in_features=in_features, out_features=in_features)
92 # buffers containing mean and range for standardization
93 self.register_buffer("mean", torch.zeros(in_features))
94 self.register_buffer("range", torch.ones(in_features))
96 self.mode = mode
97 self.is_initialized = False
99 # set values based on args if provided
100 self.set_custom(mean, range)
101 if stats is not None:
102 self.set_from_stats(stats, mode=mode)
104 # save params
105 self.in_features = in_features
106 self.out_features = in_features
108 def extra_repr(self) -> str:
109 return f"in_features={self.in_features}, out_features={self.out_features}, mode={self.mode}"
111 def set_custom(self, mean: torch.Tensor = None, range: torch.Tensor = None):
112 """Set parameter of the normalization layer.
114 Parameters
115 ----------
116 mean : torch.Tensor
117 Value that will be removed.
118 range : torch.Tensor, optional
119 Value that will be divided for.
120 """
122 if mean is not None:
123 self.mean = mean
124 if range is not None:
125 self.range = sanitize_range(range)
127 if mean is not None or range is not None:
128 self.is_initialized = True
129 self.mode = "custom"
131 def set_from_stats(self, stats: dict, mode: str = None):
132 """Set parameters of the normalization layer based on a dictionary with statistics
134 Parameters
135 ----------
136 stats : dict or Statistics
137 dictionary with statistics
138 mode : str, optional
139 standardization mode ('mean_std' or 'min_max'), by default None (will use self.mode)
140 """
142 if mode is None:
143 mode = self.mode
144 if isinstance(stats, Statistics):
145 stats = stats.to_dict()
147 if mode == "mean_std":
148 self.mean = stats["mean"]
149 range = stats["std"]
150 self.range = sanitize_range(range)
151 elif mode == "min_max":
152 min = stats["min"]
153 max = stats["max"]
154 self.mean = min
155 range = (max - min)
156 self.range = sanitize_range(range)
157 elif mode == "custom":
158 raise AttributeError(
159 "If mode is custom the parameters should be supplied via mean and range values when creating the Normalization object or with the set_custom, not with set_from_stats."
160 )
161 else:
162 raise ValueError(
163 f'Mode {self.mode} unknonwn. Available modes: "mean_std", "min_max","custom"'
164 )
166 self.is_initialized = True
168 if mode != self.mode:
169 self.mode = mode
171 def setup_from_datamodule(self, datamodule):
172 if not self.is_initialized:
173 # obtain statistics from the dataloader
174 try:
175 stats = datamodule.train_dataloader().get_stats()["data"]
176 except KeyError:
177 raise ValueError(
178 f"Impossible to initialize {self.__class__.__name__} "
179 'because the training dataloader does not have a "data" key '
180 "(are you using multiple datasets?). A manual initialization "
181 'of "mean" and "range" is necessary.'
182 )
183 self.set_from_stats(stats, self.mode)
185 def forward(self, x: torch.Tensor) -> torch.Tensor:
186 """
187 Compute standardized inputs.
189 Parameters
190 ----------
191 x: torch.Tensor
192 input/output
194 Returns
195 -------
196 out : torch.Tensor
197 standardized inputs
198 """
200 # get mean and range
201 mean = batch_reshape(self.mean, x.size())
202 range = batch_reshape(self.range, x.size())
204 return x.sub(mean).div(range)
206 def inverse(self, x: torch.Tensor) -> torch.Tensor:
207 """
208 Remove standardization.
210 Parameters
211 ----------
212 x: torch.Tensor
213 input
215 Returns
216 -------
217 out : torch.Tensor
218 un-normalized inputs
219 """
220 # get mean and range
221 mean = batch_reshape(self.mean, x.size())
222 range = batch_reshape(self.range, x.size())
224 return x.mul(range).add(mean)
227def test_normalization():
228 from mlcolvar.core.transform.utils import Inverse
230 # create data
231 torch.manual_seed(42)
232 in_features = 2
233 X = torch.randn((100, in_features)) * 10
235 # get stats
236 from mlcolvar.core.transform.utils import Statistics
238 stats = Statistics(X).to_dict()
239 norm = Normalization(in_features, mean=stats["mean"], range=stats["std"])
241 y = norm(X)
243 # test inverse
244 z = norm.inverse(y)
245 assert (torch.allclose(X.mean(0), z.mean(0)))
246 assert (torch.allclose(X.std(0), z.std(0)))
248 # test inverse class
249 inverse = Inverse(norm)
250 q = inverse(y)
251 assert (torch.allclose(X.mean(0), q.mean(0)))
252 assert (torch.allclose(X.std(0), q.std(0)))
253 norm = Normalization(
254 in_features, mean=stats["mean"], range=stats["std"], mode="min_max"
255 )
258if __name__ == "__main__":
259 test_normalization()