Coverage for biobb_pytorch / mdae / featurization / normalization.py: 67%

95 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1 

2# -------------------------------------------------------------------------------------- 

3# normalization.py 

4# 

5# from the mlcolvar repository 

6# https://github.com/mlcolvar/mlcolvar 

7# Copyright (c) 2023 Luigi Bonati, Enrico Trizio, Andrea Rizzi & Michele Parrinello 

8# Licensed under the MIT License (see project LICENSE file for full text) 

9# -------------------------------------------------------------------------------------- 

10 

11import torch 

12from mlcolvar.core.transform.utils import Statistics 

13from mlcolvar.core.transform import Transform 

14 

15__all__ = ["Normalization"] 

16 

17 

18def batch_reshape(t: torch.Tensor, size: torch.Size) -> torch.Tensor: 

19 """Return value reshaped according to size. 

20 In case of batch unsqueeze and expand along the first dimension. 

21 For single inputs just pass. 

22 

23 Parameters 

24 ---------- 

25 mean and range 

26 

27 """ 

28 if len(size) == 1: 

29 return t 

30 if len(size) == 2: 

31 batch_size = size[0] 

32 x_size = size[1] 

33 t = t.unsqueeze(0).expand(batch_size, x_size) 

34 else: 

35 raise ValueError( 

36 f"Input tensor must of shape (n_features) or (n_batch,n_features), not {size} (len={len(size)})." 

37 ) 

38 return t 

39 

40 

41def sanitize_range(range: torch.Tensor): 

42 """Sanitize 

43 

44 Parameters 

45 ---------- 

46 range : torch.Tensor 

47 range to be used for standardization 

48 

49 """ 

50 

51 if (range < 1e-6).nonzero().sum() > 0: 

52 print( 

53 "[Warning] Normalization: the following features have a range of values < 1e-6:", 

54 (range < 1e-6).nonzero(), 

55 ) 

56 range[range < 1e-6] = 1.0 

57 

58 return range 

59 

60 

61class Normalization(Transform): 

62 """ 

63 Normalizing block, used for computing standardized inputs/outputs. 

64 """ 

65 

66 def __init__( 

67 self, 

68 in_features: int, 

69 mean: torch.Tensor = None, 

70 range: torch.Tensor = None, 

71 stats: dict = None, 

72 mode: str = "mean_std", 

73 ): 

74 """Initialize a normalization object. Values will be subtracted by self.mean and then divided by self.range. 

75 The parameters for the standardization can be either given from the user (via mean/range keywords), or they can be calculated from a datamodule. 

76 In the former, the mode will be overriden as 'custom'. 'In the latter, the standardization mode can be either 'mean_std' (remove by the mean and divide by the standard deviation) or 'min_max' (scale and shift the range of values such that all inputs are between -1 and 1). 

77 

78 Parameters 

79 ---------- 

80 in_features : int 

81 number of inputs 

82 mean: torch.Tensor, optional 

83 values to be subtracted 

84 range: torch.Tensor, optional 

85 values to be scaled by 

86 mode : str, optional 

87 normalization mode (mean_std, min_max), by default 'mean_std' 

88 """ 

89 

90 super().__init__(in_features=in_features, out_features=in_features) 

91 

92 # buffers containing mean and range for standardization 

93 self.register_buffer("mean", torch.zeros(in_features)) 

94 self.register_buffer("range", torch.ones(in_features)) 

95 

96 self.mode = mode 

97 self.is_initialized = False 

98 

99 # set values based on args if provided 

100 self.set_custom(mean, range) 

101 if stats is not None: 

102 self.set_from_stats(stats, mode=mode) 

103 

104 # save params 

105 self.in_features = in_features 

106 self.out_features = in_features 

107 

108 def extra_repr(self) -> str: 

109 return f"in_features={self.in_features}, out_features={self.out_features}, mode={self.mode}" 

110 

111 def set_custom(self, mean: torch.Tensor = None, range: torch.Tensor = None): 

112 """Set parameter of the normalization layer. 

113 

114 Parameters 

115 ---------- 

116 mean : torch.Tensor 

117 Value that will be removed. 

118 range : torch.Tensor, optional 

119 Value that will be divided for. 

120 """ 

121 

122 if mean is not None: 

123 self.mean = mean 

124 if range is not None: 

125 self.range = sanitize_range(range) 

126 

127 if mean is not None or range is not None: 

128 self.is_initialized = True 

129 self.mode = "custom" 

130 

131 def set_from_stats(self, stats: dict, mode: str = None): 

132 """Set parameters of the normalization layer based on a dictionary with statistics 

133 

134 Parameters 

135 ---------- 

136 stats : dict or Statistics 

137 dictionary with statistics 

138 mode : str, optional 

139 standardization mode ('mean_std' or 'min_max'), by default None (will use self.mode) 

140 """ 

141 

142 if mode is None: 

143 mode = self.mode 

144 if isinstance(stats, Statistics): 

145 stats = stats.to_dict() 

146 

147 if mode == "mean_std": 

148 self.mean = stats["mean"] 

149 range = stats["std"] 

150 self.range = sanitize_range(range) 

151 elif mode == "min_max": 

152 min = stats["min"] 

153 max = stats["max"] 

154 self.mean = min 

155 range = (max - min) 

156 self.range = sanitize_range(range) 

157 elif mode == "custom": 

158 raise AttributeError( 

159 "If mode is custom the parameters should be supplied via mean and range values when creating the Normalization object or with the set_custom, not with set_from_stats." 

160 ) 

161 else: 

162 raise ValueError( 

163 f'Mode {self.mode} unknonwn. Available modes: "mean_std", "min_max","custom"' 

164 ) 

165 

166 self.is_initialized = True 

167 

168 if mode != self.mode: 

169 self.mode = mode 

170 

171 def setup_from_datamodule(self, datamodule): 

172 if not self.is_initialized: 

173 # obtain statistics from the dataloader 

174 try: 

175 stats = datamodule.train_dataloader().get_stats()["data"] 

176 except KeyError: 

177 raise ValueError( 

178 f"Impossible to initialize {self.__class__.__name__} " 

179 'because the training dataloader does not have a "data" key ' 

180 "(are you using multiple datasets?). A manual initialization " 

181 'of "mean" and "range" is necessary.' 

182 ) 

183 self.set_from_stats(stats, self.mode) 

184 

185 def forward(self, x: torch.Tensor) -> torch.Tensor: 

186 """ 

187 Compute standardized inputs. 

188 

189 Parameters 

190 ---------- 

191 x: torch.Tensor 

192 input/output 

193 

194 Returns 

195 ------- 

196 out : torch.Tensor 

197 standardized inputs 

198 """ 

199 

200 # get mean and range 

201 mean = batch_reshape(self.mean, x.size()) 

202 range = batch_reshape(self.range, x.size()) 

203 

204 return x.sub(mean).div(range) 

205 

206 def inverse(self, x: torch.Tensor) -> torch.Tensor: 

207 """ 

208 Remove standardization. 

209 

210 Parameters 

211 ---------- 

212 x: torch.Tensor 

213 input 

214 

215 Returns 

216 ------- 

217 out : torch.Tensor 

218 un-normalized inputs 

219 """ 

220 # get mean and range 

221 mean = batch_reshape(self.mean, x.size()) 

222 range = batch_reshape(self.range, x.size()) 

223 

224 return x.mul(range).add(mean) 

225 

226 

227def test_normalization(): 

228 from mlcolvar.core.transform.utils import Inverse 

229 

230 # create data 

231 torch.manual_seed(42) 

232 in_features = 2 

233 X = torch.randn((100, in_features)) * 10 

234 

235 # get stats 

236 from mlcolvar.core.transform.utils import Statistics 

237 

238 stats = Statistics(X).to_dict() 

239 norm = Normalization(in_features, mean=stats["mean"], range=stats["std"]) 

240 

241 y = norm(X) 

242 

243 # test inverse 

244 z = norm.inverse(y) 

245 assert (torch.allclose(X.mean(0), z.mean(0))) 

246 assert (torch.allclose(X.std(0), z.std(0))) 

247 

248 # test inverse class 

249 inverse = Inverse(norm) 

250 q = inverse(y) 

251 assert (torch.allclose(X.mean(0), q.mean(0))) 

252 assert (torch.allclose(X.std(0), q.std(0))) 

253 norm = Normalization( 

254 in_features, mean=stats["mean"], range=stats["std"], mode="min_max" 

255 ) 

256 

257 

258if __name__ == "__main__": 

259 test_normalization()