Coverage for biobb_ml/resampling/common.py: 67%
113 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
1""" Common functions for package biobb_ml.resampling """
2from pathlib import Path, PurePath
3from importlib import import_module
4from biobb_common.tools import file_utils as fu
5import warnings
6import csv
7import re
9# UNDERSAMPLING METHODS
10undersampling_methods = {
11 'random': {
12 'method': 'RandomUnderSampler',
13 'module': 'imblearn.under_sampling'
14 },
15 'nearmiss': {
16 'method': 'NearMiss',
17 'module': 'imblearn.under_sampling'
18 },
19 'cnn': {
20 'method': 'CondensedNearestNeighbour',
21 'module': 'imblearn.under_sampling'
22 },
23 'tomeklinks': {
24 'method': 'TomekLinks',
25 'module': 'imblearn.under_sampling'
26 },
27 'enn': {
28 'method': 'EditedNearestNeighbours',
29 'module': 'imblearn.under_sampling'
30 },
31 'ncr': {
32 'method': 'NeighbourhoodCleaningRule',
33 'module': 'imblearn.under_sampling'
34 },
35 'cluster': {
36 'method': 'ClusterCentroids',
37 'module': 'imblearn.under_sampling'
38 }
39}
41# OVERSAMPLING METHODS
42oversampling_methods = {
43 'random': {
44 'method': 'RandomOverSampler',
45 'module': 'imblearn.over_sampling'
46 },
47 'smote': {
48 'method': 'SMOTE',
49 'module': 'imblearn.over_sampling'
50 },
51 'borderline': {
52 'method': 'BorderlineSMOTE',
53 'module': 'imblearn.over_sampling'
54 },
55 'svmsmote': {
56 'method': 'SVMSMOTE',
57 'module': 'imblearn.over_sampling'
58 },
59 'adasyn': {
60 'method': 'ADASYN',
61 'module': 'imblearn.over_sampling'
62 }
63}
65# RESAMPLING METHODS
66resampling_methods = {
67 'smotetomek': {
68 'method': 'SMOTETomek',
69 'module': 'imblearn.combine',
70 'method_over': 'SMOTE',
71 'module_over': 'imblearn.over_sampling',
72 'method_under': 'TomekLinks',
73 'module_under': 'imblearn.under_sampling'
74 },
75 'smotenn': {
76 'method': 'SMOTEENN',
77 'module': 'imblearn.combine',
78 'method_over': 'SMOTE',
79 'module_over': 'imblearn.over_sampling',
80 'method_under': 'EditedNearestNeighbours',
81 'module_under': 'imblearn.under_sampling'
82 }
83}
86# CHECK PARAMETERS
88def check_input_path(path, argument, out_log, classname):
89 """ Checks input file """
90 if not Path(path).exists():
91 fu.log(classname + ': Unexisting %s file, exiting' % argument, out_log)
92 raise SystemExit(classname + ': Unexisting %s file' % argument)
93 file_extension = PurePath(path).suffix
94 if not is_valid_file(file_extension[1:], argument):
95 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log)
96 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument))
97 return path
100def check_output_path(path, argument, optional, out_log, classname):
101 """ Checks output file """
102 if optional and not path:
103 return None
104 if PurePath(path).parent and not Path(PurePath(path).parent).exists():
105 fu.log(classname + ': Unexisting %s folder, exiting' % argument, out_log)
106 raise SystemExit(classname + ': Unexisting %s folder' % argument)
107 file_extension = PurePath(path).suffix
108 if not is_valid_file(file_extension[1:], argument):
109 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log)
110 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument))
111 return path
114def is_valid_file(ext, argument):
115 """ Checks if file format is compatible """
116 formats = {
117 'input_dataset_path': ['csv', 'txt'],
118 'output_dataset_path': ['csv'],
119 'output_plot_path': ['png'],
120 'input_model_path': ['pkl']
121 }
122 return ext in formats[argument]
125def getTarget(target, data, out_log, classname):
126 """ Gets targets """
127 if 'index' in target:
128 return data.iloc[:, target['index']]
129 elif 'column' in target:
130 return data[target['column']]
131 else:
132 fu.log(classname + ': Incorrect target format', out_log)
133 raise SystemExit(classname + ': Incorrect target format')
136def getTargetValue(target, out_log, classname):
137 """ Gets target value """
138 if 'index' in target:
139 return target['index']
140 elif 'column' in target:
141 return target['column']
142 else:
143 fu.log(classname + ': Incorrect target format', out_log)
144 raise SystemExit(classname + ': Incorrect target format')
147def getHeader(file):
149 with open(file, newline='') as f:
150 reader = csv.reader(f)
151 header = next(reader)
153 if (len(header) == 1):
154 return list(re.sub('\\s+|;|:|,|\t', ',', header[0]).split(","))
155 else:
156 return header
159def checkResamplingType(type_, out_log, classname):
160 """ Gets resampling type """
161 if not type_:
162 fu.log(classname + ': Missed mandatory type property', out_log)
163 raise SystemExit(classname + ': Missed mandatory type property')
164 if type_ != 'regression' and type_ != 'classification':
165 fu.log(classname + ': Unknown %s type property' % type_, out_log)
166 raise SystemExit(classname + ': Unknown %s type property' % type_)
169def getResamplingMethod(method, type_, out_log, classname):
170 """ Gets resampling method """
171 if type_ == 'undersampling':
172 methods = undersampling_methods
173 elif type_ == 'oversampling':
174 methods = oversampling_methods
175 elif type_ == 'resampling':
176 methods = resampling_methods
178 if not method:
179 fu.log(classname + ': Missed mandatory method property', out_log)
180 raise SystemExit(classname + ': Missed mandatory method property')
181 if method not in methods:
182 fu.log(classname + ': Unknown %s method property' % method, out_log)
183 raise SystemExit(classname + ': Unknown %s method property' % method)
185 mod = import_module(methods[method]['module'])
186 warnings.filterwarnings("ignore")
187 method_to_call = getattr(mod, methods[method]['method'])
189 fu.log('%s method selected' % methods[method]['method'], out_log)
190 return method_to_call
193def getCombinedMethod(method, out_log, classname):
194 """ Gets combinded method """
195 methods = resampling_methods
197 if not method:
198 fu.log(classname + ': Missed mandatory method property', out_log)
199 raise SystemExit(classname + ': Missed mandatory method property')
200 if method not in methods:
201 fu.log(classname + ': Unknown %s method property' % method, out_log)
202 raise SystemExit(classname + ': Unknown %s method property' % method)
204 mod = import_module(methods[method]['module'])
205 warnings.filterwarnings("ignore")
206 method_to_call = getattr(mod, methods[method]['method'])
208 fu.log('%s method selected' % methods[method]['method'], out_log)
210 mod_over = import_module(methods[method]['module_over'])
211 method_over_to_call = getattr(mod_over, methods[method]['method_over'])
212 mod_under = import_module(methods[method]['module_under'])
213 method_under_to_call = getattr(mod_under, methods[method]['method_under'])
215 return method_to_call, method_over_to_call, method_under_to_call
218def getSamplingStrategy(sampling_strategy, out_log, classname):
219 """ Gets sampling strategy """
220 if 'target' in sampling_strategy:
221 if isinstance(sampling_strategy['target'], str):
222 return sampling_strategy['target']
223 if 'ratio' in sampling_strategy:
224 if isinstance(sampling_strategy['ratio'], float) and sampling_strategy['ratio'] >= 0 and sampling_strategy['ratio'] <= 1:
225 return sampling_strategy['ratio']
226 if 'dict' in sampling_strategy:
227 if isinstance(sampling_strategy['dict'], dict):
228 # trick for ensure the keys are integers
229 samp_str = {}
230 for key, item in sampling_strategy['dict'].items():
231 samp_str[int(key)] = item
232 return samp_str
233 if 'list' in sampling_strategy:
234 if isinstance(sampling_strategy['list'], list):
235 return sampling_strategy['list']
237 fu.log(classname + ': Incorrect sampling_strategy format', out_log)
238 raise SystemExit(classname + ': Incorrect sampling_strategy format')