Coverage for biobb_ml/resampling/oversampling.py: 78%

156 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-03 14:57 +0000

1#!/usr/bin/env python3 

2 

3"""Module containing the Oversampling class and the command line interface.""" 

4import argparse 

5import numpy as np 

6import pandas as pd 

7from collections import Counter 

8from biobb_common.generic.biobb_object import BiobbObject 

9from sklearn import preprocessing 

10from sklearn.model_selection import cross_val_score 

11from sklearn.model_selection import RepeatedStratifiedKFold 

12from sklearn.ensemble import RandomForestClassifier 

13from biobb_ml.resampling.reg_resampler import resampler 

14from biobb_common.configuration import settings 

15from biobb_common.tools import file_utils as fu 

16from biobb_common.tools.file_utils import launchlogger 

17from biobb_ml.resampling.common import check_input_path, check_output_path, getResamplingMethod, checkResamplingType, getSamplingStrategy, getTargetValue, getHeader, getTarget, oversampling_methods 

18 

19 

20class Oversampling(BiobbObject): 

21 """ 

22 | biobb_ml Oversampling 

23 | Wrapper of most of the imblearn.over_sampling methods. 

24 | Involves supplementing the training data with multiple copies of some of the minority classes of a given dataset. If regression is specified as type, the data will be resampled to classes in order to apply the oversampling model. Visit the imbalanced-learn official website for the different methods accepted in this wrapper: `RandomOverSampler <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.over_sampling.RandomOverSampler.html>`_, `SMOTE <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.over_sampling.SMOTE.html>`_, `BorderlineSMOTE <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.over_sampling.BorderlineSMOTE.html>`_, `SVMSMOTE <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.over_sampling.SVMSMOTE.html>`_, `ADASYN <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.over_sampling.ADASYN.html>`_ 

25 

26 Args: 

27 input_dataset_path (str): Path to the input dataset. File type: input. `Sample file <https://github.com/bioexcel/biobb_ml/raw/master/biobb_ml/test/data/resampling/dataset_resampling.csv>`_. Accepted formats: csv (edam:format_3752). 

28 output_dataset_path (str): Path to the output dataset. File type: output. `Sample file <https://github.com/bioexcel/biobb_ml/raw/master/biobb_ml/test/reference/resampling/ref_output_oversampling.csv>`_. Accepted formats: csv (edam:format_3752). 

29 properties (dic - Python dictionary object containing the tool parameters, not input/output files): 

30 * **method** (*str*) - (None) Oversampling method. It's a mandatory property. Values: random (`RandomOverSampler <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.over_sampling.RandomOverSampler.html>`_: Object to over-sample the minority classes by picking samples at random with replacement), smote (`SMOTE <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.over_sampling.SMOTE.html>`_: This object is an implementation of SMOTE - Synthetic Minority Over-sampling Technique), borderline (`BorderlineSMOTE <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.over_sampling.BorderlineSMOTE.html>`_: This algorithm is a variant of the original SMOTE algorithm. Borderline samples will be detected and used to generate new synthetic samples), svmsmote (`SVMSMOTE <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.over_sampling.SVMSMOTE.html>`_: Variant of SMOTE algorithm which use an SVM algorithm to detect sample to use for generating new synthetic samples), adasyn (`ADASYN <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.over_sampling.ADASYN.html>`_: Perform over-sampling using Adaptive Synthetic -ADASYN- sampling approach for imbalanced datasets). 

31 * **type** (*str*) - (None) Type of oversampling. It's a mandatory property. Values: regression (the oversampling will be applied on a continuous dataset), classification (the oversampling will be applied on a classified dataset). 

32 * **target** (*dict*) - ({}) Dependent variable you want to predict from your dataset. You can specify either a column name or a column index. Formats: { "column": "column3" } or { "index": 21 }. In case of mulitple formats, the first one will be picked. 

33 * **evaluate** (*bool*) - (False) Whether or not to evaluate the dataset before and after applying the resampling. 

34 * **evaluate_splits** (*int*) - (3) [2~100|1] Number of folds to be applied by the Repeated Stratified K-Fold evaluation method. Must be at least 2. 

35 * **evaluate_repeats** (*int*) - (3) [2~100|1] Number of times Repeated Stratified K-Fold cross validator needs to be repeated. 

36 * **n_bins** (*int*) - (5) [1~100|1] Only for regression oversampling. The number of classes that the user wants to generate with the target data. 

37 * **balanced_binning** (*bool*) - (False) Only for regression oversampling. Decides whether samples are to be distributed roughly equally across all classes. 

38 * **sampling_strategy** (*dict*) - ({ "target": "auto" }) Sampling information to sample the data set. Formats: { "target": "auto" }, { "ratio": 0.3 }, { "dict": { 0: 300, 1: 200, 2: 100 } } or { "list": [0, 2, 3] }. When "target", specify the class targeted by the resampling; the number of samples in the different classes will be equalized; possible choices are: minority (resample only the minority class), not minority (resample all classes but the minority class), not majority (resample all classes but the majority class), all (resample all classes), auto (equivalent to 'not majority'). When "ratio", it corresponds to the desired ratio of the number of samples in the minority class over the number of samples in the majority class after resampling (ONLY IN CASE OF BINARY CLASSIFICATION). When "dict", the keys correspond to the targeted classes, the values correspond to the desired number of samples for each targeted class. When "list", the list contains the classes targeted by the resampling. 

39 * **k_neighbors** (*int*) - (5) [1~100|1] Only for SMOTE, BorderlineSMOTE, SVMSMOTE, ADASYN. The number of nearest neighbours used to construct synthetic samples. 

40 * **random_state_method** (*int*) - (5) [1~1000|1] Controls the randomization of the algorithm. 

41 * **random_state_evaluate** (*int*) - (5) [1~1000|1] Controls the shuffling applied to the Repeated Stratified K-Fold evaluation method. 

42 * **remove_tmp** (*bool*) - (True) [WF property] Remove temporal files. 

43 * **restart** (*bool*) - (False) [WF property] Do not execute if output files exist. 

44 * **sandbox_path** (*str*) - ("./") [WF property] Parent path to the sandbox directory. 

45 

46 Examples: 

47 This is a use example of how to use the building block from Python:: 

48 

49 from biobb_ml.resampling.oversampling import oversampling 

50 prop = { 

51 'method': 'random, 

52 'type': 'regression, 

53 'target': { 

54 'column': 'target' 

55 }, 

56 'evaluate': true, 

57 'n_bins': 10, 

58 'sampling_strategy': { 

59 'target': 'minority' 

60 } 

61 } 

62 oversampling(input_dataset_path='/path/to/myDataset.csv', 

63 output_dataset_path='/path/to/newDataset.csv', 

64 properties=prop) 

65 

66 Info: 

67 * wrapped_software: 

68 * name: imbalanced-learn over_sampling 

69 * version: >0.7.0 

70 * license: MIT 

71 * ontology: 

72 * name: EDAM 

73 * schema: http://edamontology.org/EDAM.owl 

74 

75 """ 

76 

77 def __init__(self, input_dataset_path, output_dataset_path, 

78 properties=None, **kwargs) -> None: 

79 properties = properties or {} 

80 

81 # Call parent class constructor 

82 super().__init__(properties) 

83 self.locals_var_dict = locals().copy() 

84 

85 # Input/Output files 

86 self.io_dict = { 

87 "in": {"input_dataset_path": input_dataset_path}, 

88 "out": {"output_dataset_path": output_dataset_path} 

89 } 

90 

91 # Properties specific for BB 

92 self.method = properties.get('method', None) 

93 self.type = properties.get('type', None) 

94 self.target = properties.get('target', {}) 

95 self.evaluate = properties.get('evaluate', False) 

96 self.evaluate_splits = properties.get('evaluate_splits', 3) 

97 self.evaluate_repeats = properties.get('evaluate_repeats', 3) 

98 self.n_bins = properties.get('n_bins', 5) 

99 self.balanced_binning = properties.get('balanced_binning', False) 

100 self.sampling_strategy = properties.get('sampling_strategy', {'target': 'auto'}) 

101 self.k_neighbors = properties.get('k_neighbors', 5) 

102 self.random_state_method = properties.get('random_state_method', 5) 

103 self.random_state_evaluate = properties.get('random_state_evaluate', 5) 

104 self.properties = properties 

105 

106 # Check the properties 

107 self.check_properties(properties) 

108 self.check_arguments() 

109 

110 def check_data_params(self, out_log, err_log): 

111 """ Checks all the input/output paths and parameters """ 

112 self.io_dict["in"]["input_dataset_path"] = check_input_path(self.io_dict["in"]["input_dataset_path"], "input_dataset_path", out_log, self.__class__.__name__) 

113 self.io_dict["out"]["output_dataset_path"] = check_output_path(self.io_dict["out"]["output_dataset_path"], "output_dataset_path", False, out_log, self.__class__.__name__) 

114 

115 @launchlogger 

116 def launch(self) -> int: 

117 """Execute the :class:`Oversampling <resampling.oversampling.Oversampling>` resampling.oversampling.Oversampling object.""" 

118 

119 # check input/output paths and parameters 

120 self.check_data_params(self.out_log, self.err_log) 

121 

122 # Setup Biobb 

123 if self.check_restart(): 

124 return 0 

125 self.stage_files() 

126 

127 # check mandatory properties 

128 method = getResamplingMethod(self.method, 'oversampling', self.out_log, self.__class__.__name__) 

129 checkResamplingType(self.type, self.out_log, self.__class__.__name__) 

130 sampling_strategy = getSamplingStrategy(self.sampling_strategy, self.out_log, self.__class__.__name__) 

131 

132 # load dataset 

133 fu.log('Getting dataset from %s' % self.io_dict["in"]["input_dataset_path"], self.out_log, self.global_log) 

134 if 'column' in self.target: 

135 labels = getHeader(self.io_dict["in"]["input_dataset_path"]) 

136 skiprows = 1 

137 header = 0 

138 else: 

139 labels = None 

140 skiprows = None 

141 header = None 

142 data = pd.read_csv(self.io_dict["in"]["input_dataset_path"], header=None, sep="\\s+|;|:|,|\t", engine="python", skiprows=skiprows, names=labels) 

143 

144 train_df = data 

145 ranges = None 

146 

147 le = preprocessing.LabelEncoder() 

148 

149 cols_encoded = [] 

150 for column in train_df: 

151 # if type object, LabelEncoder.fit_transform 

152 if train_df[column].dtypes == 'object': 

153 cols_encoded.append(column) 

154 train_df[column] = le.fit_transform(train_df[column]) 

155 

156 # defining X 

157 X = train_df.loc[:, train_df.columns != getTargetValue(self.target, self.out_log, self.__class__.__name__)] 

158 # calling oversample method 

159 if self.method == 'random': 

160 method = method(sampling_strategy=sampling_strategy, random_state=self.random_state_method) 

161 elif self.method == 'smote': 

162 method = method(sampling_strategy=sampling_strategy, k_neighbors=self.k_neighbors, random_state=self.random_state_method) 

163 elif self.method == 'borderline': 

164 method = method(sampling_strategy=sampling_strategy, k_neighbors=self.k_neighbors, random_state=self.random_state_method) 

165 elif self.method == 'svmsmote': 

166 method = method(sampling_strategy=sampling_strategy, k_neighbors=self.k_neighbors, random_state=self.random_state_method) 

167 elif self.method == 'adasyn': 

168 method = method(sampling_strategy=sampling_strategy, n_neighbors=self.k_neighbors, random_state=self.random_state_method) 

169 

170 fu.log('Target: %s' % (getTargetValue(self.target, self.out_log, self.__class__.__name__)), self.out_log, self.global_log) 

171 

172 # oversampling 

173 if self.type == 'regression': 

174 fu.log('Oversampling regression dataset, continuous data will be classified', self.out_log, self.global_log) 

175 # call resampler class for Regression ReSampling 

176 rs = resampler() 

177 # Create n_bins classes for the dataset 

178 ranges, y, target_pos = rs.fit(train_df, target=getTargetValue(self.target, self.out_log, self.__class__.__name__), bins=self.n_bins, balanced_binning=self.balanced_binning, verbose=0) 

179 # Get the over-sampled data 

180 final_X, final_y = rs.resample(method, train_df, y) 

181 elif self.type == 'classification': 

182 # get X and y 

183 y = getTarget(self.target, train_df, self.out_log, self.__class__.__name__) 

184 # fit and resample 

185 final_X, final_y = method.fit_resample(X, y) 

186 target_pos = None 

187 

188 # evaluate oversampling 

189 if self.evaluate: 

190 fu.log('Evaluating data before oversampling with RandomForestClassifier', self.out_log, self.global_log) 

191 cv = RepeatedStratifiedKFold(n_splits=self.evaluate_splits, n_repeats=self.evaluate_repeats, random_state=self.random_state_evaluate) 

192 # evaluate model 

193 scores = cross_val_score(RandomForestClassifier(), X, y, scoring='accuracy', cv=cv, n_jobs=-1) 

194 if not np.isnan(np.mean(scores)): 

195 fu.log('Mean Accuracy before oversampling: %.3f' % (np.mean(scores)), self.out_log, self.global_log) 

196 else: 

197 fu.log('Unable to calculate cross validation score, NaN was returned.', self.out_log, self.global_log) 

198 

199 # log distribution before oversampling 

200 dist = '' 

201 for k, v in Counter(y).items(): 

202 per = v / len(y) * 100 

203 rng = '' 

204 if ranges: 

205 rng = str(ranges[k]) 

206 dist = dist + 'Class=%d, n=%d (%.3f%%) %s\n' % (k, v, per, rng) 

207 fu.log('Classes distribution before oversampling:\n\n%s' % dist, self.out_log, self.global_log) 

208 

209 # join final_X and final_y in the output dataframe 

210 if header is None: 

211 # numpy 

212 out_df = np.column_stack((final_X, final_y)) 

213 else: 

214 # pandas 

215 out_df = final_X.join(final_y) 

216 

217 # if no header, convert np to pd 

218 if header is None: 

219 out_df = pd.DataFrame(data=out_df) 

220 

221 # if cols encoded, decode them 

222 if cols_encoded: 

223 for column in cols_encoded: 

224 if header is None: 

225 out_df = out_df.astype({column: int}) 

226 out_df[column] = le.inverse_transform(out_df[column].values.ravel()) 

227 

228 # if no header, target is in a different column 

229 if target_pos: 

230 t = target_pos 

231 else: 

232 t = getTargetValue(self.target, self.out_log, self.__class__.__name__) 

233 # log distribution after oversampling 

234 if self.type == 'regression': 

235 ranges, y_out, _ = rs.fit(out_df, target=t, bins=self.n_bins, balanced_binning=self.balanced_binning, verbose=0) 

236 elif self.type == 'classification': 

237 y_out = getTarget(self.target, out_df, self.out_log, self.__class__.__name__) 

238 

239 dist = '' 

240 for k, v in Counter(y_out).items(): 

241 per = v / len(y_out) * 100 

242 rng = '' 

243 if ranges: 

244 rng = str(ranges[k]) 

245 dist = dist + 'Class=%d, n=%d (%.3f%%) %s\n' % (k, v, per, rng) 

246 fu.log('Classes distribution after oversampling:\n\n%s' % dist, self.out_log, self.global_log) 

247 

248 # evaluate oversampling 

249 if self.evaluate: 

250 fu.log('Evaluating data after oversampling with RandomForestClassifier', self.out_log, self.global_log) 

251 cv = RepeatedStratifiedKFold(n_splits=self.evaluate_splits, n_repeats=self.evaluate_repeats, random_state=self.random_state_evaluate) 

252 # evaluate model 

253 scores = cross_val_score(RandomForestClassifier(), final_X, y_out, scoring='accuracy', cv=cv, n_jobs=-1) 

254 if not np.isnan(np.mean(scores)): 

255 fu.log('Mean Accuracy after oversampling a %s dataset with %s method: %.3f' % (self.type, oversampling_methods[self.method]['method'], np.mean(scores)), self.out_log, self.global_log) 

256 else: 

257 fu.log('Unable to calculate cross validation score, NaN was returned.', self.out_log, self.global_log) 

258 

259 # save output 

260 hdr = False 

261 if header == 0: 

262 hdr = True 

263 fu.log('Saving oversampled dataset to %s' % self.io_dict["out"]["output_dataset_path"], self.out_log, self.global_log) 

264 out_df.to_csv(self.io_dict["out"]["output_dataset_path"], index=False, header=hdr) 

265 

266 # Copy files to host 

267 self.copy_to_host() 

268 

269 self.tmp_files.extend([ 

270 self.stage_io_dict.get("unique_dir") 

271 ]) 

272 self.remove_tmp_files() 

273 

274 self.check_arguments(output_files_created=True, raise_exception=False) 

275 

276 return 0 

277 

278 

279def oversampling(input_dataset_path: str, output_dataset_path: str, properties: dict = None, **kwargs) -> int: 

280 """Execute the :class:`Oversampling <resampling.oversampling.Oversampling>` class and 

281 execute the :meth:`launch() <resampling.oversampling.Oversampling.launch>` method.""" 

282 

283 return Oversampling(input_dataset_path=input_dataset_path, 

284 output_dataset_path=output_dataset_path, 

285 properties=properties, **kwargs).launch() 

286 

287 

288def main(): 

289 """Command line execution of this building block. Please check the command line documentation.""" 

290 parser = argparse.ArgumentParser(description="Wrapper of most of the imblearn.over_sampling methods.", formatter_class=lambda prog: argparse.RawTextHelpFormatter(prog, width=99999)) 

291 parser.add_argument('--config', required=False, help='Configuration file') 

292 

293 # Specific args of each building block 

294 required_args = parser.add_argument_group('required arguments') 

295 required_args.add_argument('--input_dataset_path', required=True, help='Path to the input dataset. Accepted formats: csv.') 

296 required_args.add_argument('--output_dataset_path', required=True, help='Path to the output dataset. Accepted formats: csv.') 

297 

298 args = parser.parse_args() 

299 args.config = args.config or "{}" 

300 properties = settings.ConfReader(config=args.config).get_prop_dic() 

301 

302 # Specific call of each building block 

303 oversampling(input_dataset_path=args.input_dataset_path, 

304 output_dataset_path=args.output_dataset_path, 

305 properties=properties) 

306 

307 

308if __name__ == '__main__': 

309 main()