Coverage for biobb_ml/resampling/common.py: 67%

113 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-03 14:57 +0000

1""" Common functions for package biobb_ml.resampling """ 

2from pathlib import Path, PurePath 

3from importlib import import_module 

4from biobb_common.tools import file_utils as fu 

5import warnings 

6import csv 

7import re 

8 

9# UNDERSAMPLING METHODS 

10undersampling_methods = { 

11 'random': { 

12 'method': 'RandomUnderSampler', 

13 'module': 'imblearn.under_sampling' 

14 }, 

15 'nearmiss': { 

16 'method': 'NearMiss', 

17 'module': 'imblearn.under_sampling' 

18 }, 

19 'cnn': { 

20 'method': 'CondensedNearestNeighbour', 

21 'module': 'imblearn.under_sampling' 

22 }, 

23 'tomeklinks': { 

24 'method': 'TomekLinks', 

25 'module': 'imblearn.under_sampling' 

26 }, 

27 'enn': { 

28 'method': 'EditedNearestNeighbours', 

29 'module': 'imblearn.under_sampling' 

30 }, 

31 'ncr': { 

32 'method': 'NeighbourhoodCleaningRule', 

33 'module': 'imblearn.under_sampling' 

34 }, 

35 'cluster': { 

36 'method': 'ClusterCentroids', 

37 'module': 'imblearn.under_sampling' 

38 } 

39} 

40 

41# OVERSAMPLING METHODS 

42oversampling_methods = { 

43 'random': { 

44 'method': 'RandomOverSampler', 

45 'module': 'imblearn.over_sampling' 

46 }, 

47 'smote': { 

48 'method': 'SMOTE', 

49 'module': 'imblearn.over_sampling' 

50 }, 

51 'borderline': { 

52 'method': 'BorderlineSMOTE', 

53 'module': 'imblearn.over_sampling' 

54 }, 

55 'svmsmote': { 

56 'method': 'SVMSMOTE', 

57 'module': 'imblearn.over_sampling' 

58 }, 

59 'adasyn': { 

60 'method': 'ADASYN', 

61 'module': 'imblearn.over_sampling' 

62 } 

63} 

64 

65# RESAMPLING METHODS 

66resampling_methods = { 

67 'smotetomek': { 

68 'method': 'SMOTETomek', 

69 'module': 'imblearn.combine', 

70 'method_over': 'SMOTE', 

71 'module_over': 'imblearn.over_sampling', 

72 'method_under': 'TomekLinks', 

73 'module_under': 'imblearn.under_sampling' 

74 }, 

75 'smotenn': { 

76 'method': 'SMOTEENN', 

77 'module': 'imblearn.combine', 

78 'method_over': 'SMOTE', 

79 'module_over': 'imblearn.over_sampling', 

80 'method_under': 'EditedNearestNeighbours', 

81 'module_under': 'imblearn.under_sampling' 

82 } 

83} 

84 

85 

86# CHECK PARAMETERS 

87 

88def check_input_path(path, argument, out_log, classname): 

89 """ Checks input file """ 

90 if not Path(path).exists(): 

91 fu.log(classname + ': Unexisting %s file, exiting' % argument, out_log) 

92 raise SystemExit(classname + ': Unexisting %s file' % argument) 

93 file_extension = PurePath(path).suffix 

94 if not is_valid_file(file_extension[1:], argument): 

95 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log) 

96 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument)) 

97 return path 

98 

99 

100def check_output_path(path, argument, optional, out_log, classname): 

101 """ Checks output file """ 

102 if optional and not path: 

103 return None 

104 if PurePath(path).parent and not Path(PurePath(path).parent).exists(): 

105 fu.log(classname + ': Unexisting %s folder, exiting' % argument, out_log) 

106 raise SystemExit(classname + ': Unexisting %s folder' % argument) 

107 file_extension = PurePath(path).suffix 

108 if not is_valid_file(file_extension[1:], argument): 

109 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log) 

110 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument)) 

111 return path 

112 

113 

114def is_valid_file(ext, argument): 

115 """ Checks if file format is compatible """ 

116 formats = { 

117 'input_dataset_path': ['csv', 'txt'], 

118 'output_dataset_path': ['csv'], 

119 'output_plot_path': ['png'], 

120 'input_model_path': ['pkl'] 

121 } 

122 return ext in formats[argument] 

123 

124 

125def getTarget(target, data, out_log, classname): 

126 """ Gets targets """ 

127 if 'index' in target: 

128 return data.iloc[:, target['index']] 

129 elif 'column' in target: 

130 return data[target['column']] 

131 else: 

132 fu.log(classname + ': Incorrect target format', out_log) 

133 raise SystemExit(classname + ': Incorrect target format') 

134 

135 

136def getTargetValue(target, out_log, classname): 

137 """ Gets target value """ 

138 if 'index' in target: 

139 return target['index'] 

140 elif 'column' in target: 

141 return target['column'] 

142 else: 

143 fu.log(classname + ': Incorrect target format', out_log) 

144 raise SystemExit(classname + ': Incorrect target format') 

145 

146 

147def getHeader(file): 

148 

149 with open(file, newline='') as f: 

150 reader = csv.reader(f) 

151 header = next(reader) 

152 

153 if (len(header) == 1): 

154 return list(re.sub('\\s+|;|:|,|\t', ',', header[0]).split(",")) 

155 else: 

156 return header 

157 

158 

159def checkResamplingType(type_, out_log, classname): 

160 """ Gets resampling type """ 

161 if not type_: 

162 fu.log(classname + ': Missed mandatory type property', out_log) 

163 raise SystemExit(classname + ': Missed mandatory type property') 

164 if type_ != 'regression' and type_ != 'classification': 

165 fu.log(classname + ': Unknown %s type property' % type_, out_log) 

166 raise SystemExit(classname + ': Unknown %s type property' % type_) 

167 

168 

169def getResamplingMethod(method, type_, out_log, classname): 

170 """ Gets resampling method """ 

171 if type_ == 'undersampling': 

172 methods = undersampling_methods 

173 elif type_ == 'oversampling': 

174 methods = oversampling_methods 

175 elif type_ == 'resampling': 

176 methods = resampling_methods 

177 

178 if not method: 

179 fu.log(classname + ': Missed mandatory method property', out_log) 

180 raise SystemExit(classname + ': Missed mandatory method property') 

181 if method not in methods: 

182 fu.log(classname + ': Unknown %s method property' % method, out_log) 

183 raise SystemExit(classname + ': Unknown %s method property' % method) 

184 

185 mod = import_module(methods[method]['module']) 

186 warnings.filterwarnings("ignore") 

187 method_to_call = getattr(mod, methods[method]['method']) 

188 

189 fu.log('%s method selected' % methods[method]['method'], out_log) 

190 return method_to_call 

191 

192 

193def getCombinedMethod(method, out_log, classname): 

194 """ Gets combinded method """ 

195 methods = resampling_methods 

196 

197 if not method: 

198 fu.log(classname + ': Missed mandatory method property', out_log) 

199 raise SystemExit(classname + ': Missed mandatory method property') 

200 if method not in methods: 

201 fu.log(classname + ': Unknown %s method property' % method, out_log) 

202 raise SystemExit(classname + ': Unknown %s method property' % method) 

203 

204 mod = import_module(methods[method]['module']) 

205 warnings.filterwarnings("ignore") 

206 method_to_call = getattr(mod, methods[method]['method']) 

207 

208 fu.log('%s method selected' % methods[method]['method'], out_log) 

209 

210 mod_over = import_module(methods[method]['module_over']) 

211 method_over_to_call = getattr(mod_over, methods[method]['method_over']) 

212 mod_under = import_module(methods[method]['module_under']) 

213 method_under_to_call = getattr(mod_under, methods[method]['method_under']) 

214 

215 return method_to_call, method_over_to_call, method_under_to_call 

216 

217 

218def getSamplingStrategy(sampling_strategy, out_log, classname): 

219 """ Gets sampling strategy """ 

220 if 'target' in sampling_strategy: 

221 if isinstance(sampling_strategy['target'], str): 

222 return sampling_strategy['target'] 

223 if 'ratio' in sampling_strategy: 

224 if isinstance(sampling_strategy['ratio'], float) and sampling_strategy['ratio'] >= 0 and sampling_strategy['ratio'] <= 1: 

225 return sampling_strategy['ratio'] 

226 if 'dict' in sampling_strategy: 

227 if isinstance(sampling_strategy['dict'], dict): 

228 # trick for ensure the keys are integers 

229 samp_str = {} 

230 for key, item in sampling_strategy['dict'].items(): 

231 samp_str[int(key)] = item 

232 return samp_str 

233 if 'list' in sampling_strategy: 

234 if isinstance(sampling_strategy['list'], list): 

235 return sampling_strategy['list'] 

236 

237 fu.log(classname + ': Incorrect sampling_strategy format', out_log) 

238 raise SystemExit(classname + ': Incorrect sampling_strategy format')