Coverage for biobb_ml/resampling/resampling.py: 82%
151 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
1#!/usr/bin/env python3
3"""Module containing the Resampling class and the command line interface."""
4import argparse
5import pandas as pd
6import numpy as np
7from collections import Counter
8from biobb_common.generic.biobb_object import BiobbObject
9from sklearn import preprocessing
10from sklearn.model_selection import cross_val_score
11from sklearn.model_selection import RepeatedStratifiedKFold
12from sklearn.ensemble import RandomForestClassifier
13from biobb_ml.resampling.reg_resampler import resampler
14from biobb_common.configuration import settings
15from biobb_common.tools import file_utils as fu
16from biobb_common.tools.file_utils import launchlogger
17from biobb_ml.resampling.common import check_input_path, check_output_path, getCombinedMethod, checkResamplingType, getSamplingStrategy, getHeader, getTargetValue, getTarget, resampling_methods
20class Resampling(BiobbObject):
21 """
22 | biobb_ml Resampling
23 | Wrapper of the imblearn.combine methods.
24 | Combine over- and under-sampling methods to remove samples and supplement the dataset. If regression is specified as type, the data will be resampled to classes in order to apply the resampling model. Visit the imbalanced-learn official website for the different methods accepted in this wrapper: `SMOTETomek <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.combine.SMOTETomek.html>`_, `SMOTEENN <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.combine.SMOTEENN.html>`_.
26 Args:
27 input_dataset_path (str): Path to the input dataset. File type: input. `Sample file <https://github.com/bioexcel/biobb_ml/raw/master/biobb_ml/test/data/resampling/dataset_resampling.csv>`_. Accepted formats: csv (edam:format_3752).
28 output_dataset_path (str): Path to the output dataset. File type: output. `Sample file <https://github.com/bioexcel/biobb_ml/raw/master/biobb_ml/test/reference/resampling/ref_output_resampling.csv>`_. Accepted formats: csv (edam:format_3752).
29 properties (dic - Python dictionary object containing the tool parameters, not input/output files):
30 * **method** (*str*) - (None) Resampling method. It's a mandatory property. Values: smotetomek (`SMOTETomek <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.combine.SMOTETomek.html>`_: Class to perform over-sampling using SMOTE and cleaning using Tomek links), smotenn (`SMOTEENN <https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.combine.SMOTEENN.html>`_: Class to perform over-sampling using SMOTE and cleaning using ENN).
31 * **type** (*str*) - (None) Type of oversampling. It's a mandatory property. Values: regression (the oversampling will be applied on a continuous dataset), classification (the oversampling will be applied on a classified dataset).
32 * **target** (*dict*) - ({}) Dependent variable you want to predict from your dataset. You can specify either a column name or a column index. Formats: { "column": "column3" } or { "index": 21 }. In case of mulitple formats, the first one will be picked.
33 * **evaluate** (*bool*) - (False) Whether or not to evaluate the dataset before and after applying the resampling.
34 * **evaluate_splits** (*int*) - (3) [2~100|1] Number of folds to be applied by the Repeated Stratified K-Fold evaluation method. Must be at least 2.
35 * **evaluate_repeats** (*int*) - (3) [2~100|1] Number of times Repeated Stratified K-Fold cross validator needs to be repeated.
36 * **n_bins** (*int*) - (5) [1~100|1] Only for regression resampling. The number of classes that the user wants to generate with the target data.
37 * **balanced_binning** (*bool*) - (False) Only for regression resampling. Decides whether samples are to be distributed roughly equally across all classes.
38 * **sampling_strategy_over** (*dict*) - ({ "target": "auto" }) Sampling information applied in the dataset oversampling process. Formats: { "target": "auto" }, { "ratio": 0.3 } or { "dict": { 0: 300, 1: 200, 2: 100 } }. When "target", specify the class targeted by the resampling; the number of samples in the different classes will be equalized; possible choices are: minority (resample only the minority class), not minority (resample all classes but the minority class), not majority (resample all classes but the majority class), all (resample all classes), auto (equivalent to 'not majority'). When "ratio", it corresponds to the desired ratio of the number of samples in the minority class over the number of samples in the majority class after resampling (ONLY IN CASE OF BINARY CLASSIFICATION). When "dict", the keys correspond to the targeted classes and the values correspond to the desired number of samples for each targeted class.
39 * **sampling_strategy_under** (*dict*) - ({ "target": "auto" }) Sampling information applied in the dataset cleaning process. Formats: { "target": "auto" } or { "list": [0, 2, 3] }. When "target", specify the class targeted by the resampling; the number of samples in the different classes will be equalized; possible choices are: majority (resample only the majority class), not minority (resample all classes but the minority class), not majority (resample all classes but the majority class), all (resample all classes), auto (equivalent to 'not minority'). When "list", the list contains the classes targeted by the resampling.
40 * **random_state_method** (*int*) - (5) [1~1000|1] Controls the randomization of the algorithm.
41 * **random_state_evaluate** (*int*) - (5) [1~1000|1] Controls the shuffling applied to the Repeated Stratified K-Fold evaluation method.
42 * **remove_tmp** (*bool*) - (True) [WF property] Remove temporal files.
43 * **restart** (*bool*) - (False) [WF property] Do not execute if output files exist.
44 * **sandbox_path** (*str*) - ("./") [WF property] Parent path to the sandbox directory.
46 Examples:
47 This is a use example of how to use the building block from Python::
49 from biobb_ml.resampling.resampling import resampling
50 prop = {
51 'method': 'smotenn',
52 'type': 'regression',
53 'target': {
54 'column': 'target'
55 },
56 'evaluate': true,
57 'n_bins': 10,
58 'sampling_strategy_over': {
59 'dict': { '4': 1000, '5': 1000, '6': 1000, '7': 1000 }
60 },
61 'sampling_strategy_under': {
62 'list': [0,1]
63 }
64 }
65 resampling(input_dataset_path='/path/to/myDataset.csv',
66 output_dataset_path='/path/to/newDataset.csv',
67 properties=prop)
69 Info:
70 * wrapped_software:
71 * name: imbalanced-learn combine
72 * version: >0.7.0
73 * license: MIT
74 * ontology:
75 * name: EDAM
76 * schema: http://edamontology.org/EDAM.owl
78 """
80 def __init__(self, input_dataset_path, output_dataset_path,
81 properties=None, **kwargs) -> None:
82 properties = properties or {}
84 # Call parent class constructor
85 super().__init__(properties)
86 self.locals_var_dict = locals().copy()
88 # Input/Output files
89 self.io_dict = {
90 "in": {"input_dataset_path": input_dataset_path},
91 "out": {"output_dataset_path": output_dataset_path}
92 }
94 # Properties specific for BB
95 self.method = properties.get('method', None)
96 self.type = properties.get('type', None)
97 self.target = properties.get('target', {})
98 self.evaluate = properties.get('evaluate', False)
99 self.evaluate_splits = properties.get('evaluate_splits', 3)
100 self.evaluate_repeats = properties.get('evaluate_repeats', 3)
101 self.n_bins = properties.get('n_bins', 5)
102 self.balanced_binning = properties.get('balanced_binning', False)
103 self.sampling_strategy_over = properties.get('sampling_strategy_over', {'target': 'auto'})
104 self.sampling_strategy_under = properties.get('sampling_strategy_under', {'target': 'auto'})
105 self.random_state_method = properties.get('random_state_method', 5)
106 self.random_state_evaluate = properties.get('random_state_evaluate', 5)
107 self.properties = properties
109 # Check the properties
110 self.check_properties(properties)
111 self.check_arguments()
113 def check_data_params(self, out_log, err_log):
114 """ Checks all the input/output paths and parameters """
115 self.io_dict["in"]["input_dataset_path"] = check_input_path(self.io_dict["in"]["input_dataset_path"], "input_dataset_path", out_log, self.__class__.__name__)
116 self.io_dict["out"]["output_dataset_path"] = check_output_path(self.io_dict["out"]["output_dataset_path"], "output_dataset_path", False, out_log, self.__class__.__name__)
118 @launchlogger
119 def launch(self) -> int:
120 """Execute the :class:`Resampling <resampling.resampling.Resampling>` resampling.resampling.Resampling object."""
122 # check input/output paths and parameters
123 self.check_data_params(self.out_log, self.err_log)
125 # Setup Biobb
126 if self.check_restart():
127 return 0
128 self.stage_files()
130 # check mandatory properties
131 method, over, under = getCombinedMethod(self.method, self.out_log, self.__class__.__name__)
132 checkResamplingType(self.type, self.out_log, self.__class__.__name__)
133 sampling_strategy_over = getSamplingStrategy(self.sampling_strategy_over, self.out_log, self.__class__.__name__)
134 sampling_strategy_under = getSamplingStrategy(self.sampling_strategy_under, self.out_log, self.__class__.__name__)
136 # load dataset
137 fu.log('Getting dataset from %s' % self.io_dict["in"]["input_dataset_path"], self.out_log, self.global_log)
138 if 'column' in self.target:
139 labels = getHeader(self.io_dict["in"]["input_dataset_path"])
140 skiprows = 1
141 header = 0
142 else:
143 labels = None
144 skiprows = None
145 header = None
146 data = pd.read_csv(self.io_dict["in"]["input_dataset_path"], header=None, sep="\\s+|;|:|,|\t", engine="python", skiprows=skiprows, names=labels)
148 train_df = data
149 ranges = None
151 le = preprocessing.LabelEncoder()
153 cols_encoded = []
154 for column in train_df:
155 # if type object, LabelEncoder.fit_transform
156 if train_df[column].dtypes == 'object':
157 cols_encoded.append(column)
158 train_df[column] = le.fit_transform(train_df[column])
160 # defining X
161 X = train_df.loc[:, train_df.columns != getTargetValue(self.target, self.out_log, self.__class__.__name__)]
162 # calling resample method
163 if self.method == 'smotetomek':
164 method = method(smote=over(sampling_strategy=sampling_strategy_over), tomek=under(sampling_strategy=sampling_strategy_under), random_state=self.random_state_method)
165 elif self.method == 'smotenn':
166 method = method(smote=over(sampling_strategy=sampling_strategy_over), enn=under(sampling_strategy=sampling_strategy_under), random_state=self.random_state_method)
168 fu.log('Target: %s' % (getTargetValue(self.target, self.out_log, self.__class__.__name__)), self.out_log, self.global_log)
170 # resampling
171 if self.type == 'regression':
172 fu.log('Resampling regression dataset, continuous data will be classified', self.out_log, self.global_log)
173 # call resampler class for Regression ReSampling
174 rs = resampler()
175 # Create n_bins classes for the dataset
176 ranges, y, target_pos = rs.fit(train_df, target=getTargetValue(self.target, self.out_log, self.__class__.__name__), bins=self.n_bins, balanced_binning=self.balanced_binning, verbose=0)
177 # Get the re-sampled data
178 final_X, final_y = rs.resample(method, train_df, y)
179 elif self.type == 'classification':
180 # get X and y
181 y = getTarget(self.target, train_df, self.out_log, self.__class__.__name__)
182 # fit and resample
183 final_X, final_y = method.fit_resample(X, y)
184 target_pos = None
186 # evaluate resampling
187 if self.evaluate:
188 fu.log('Evaluating data before resampling with RandomForestClassifier', self.out_log, self.global_log)
189 cv = RepeatedStratifiedKFold(n_splits=self.evaluate_splits, n_repeats=self.evaluate_repeats, random_state=self.random_state_evaluate)
190 # evaluate model
191 scores = cross_val_score(RandomForestClassifier(class_weight='balanced'), X, y, scoring='accuracy', cv=cv, n_jobs=-1)
192 if not np.isnan(np.mean(scores)):
193 fu.log('Mean Accuracy before resampling: %.3f' % (np.mean(scores)), self.out_log, self.global_log)
194 else:
195 fu.log('Unable to calculate cross validation score, NaN was returned.', self.out_log, self.global_log)
197 # log distribution before resampling
198 dist = ''
199 for k, v in Counter(y).items():
200 per = v / len(y) * 100
201 rng = ''
202 if ranges:
203 rng = str(ranges[k])
204 dist = dist + 'Class=%d, n=%d (%.3f%%) %s\n' % (k, v, per, rng)
205 fu.log('Classes distribution before resampling:\n\n%s' % dist, self.out_log, self.global_log)
207 # join final_X and final_y in the output dataframe
208 if header is None:
209 # numpy
210 out_df = np.column_stack((final_X, final_y))
211 else:
212 # pandas
213 out_df = final_X.join(final_y)
215 # if no header, convert np to pd
216 if header is None:
217 out_df = pd.DataFrame(data=out_df)
219 # if cols encoded, decode them
220 if cols_encoded:
221 for column in cols_encoded:
222 if header is None:
223 out_df = out_df.astype({column: int})
224 out_df[column] = le.inverse_transform(out_df[column].values.ravel())
226 # if no header, target is in a different column
227 if target_pos:
228 t = target_pos
229 else:
230 t = getTargetValue(self.target, self.out_log, self.__class__.__name__)
231 # log distribution after resampling
232 if self.type == 'regression':
233 ranges, y_out, _ = rs.fit(out_df, target=t, bins=self.n_bins, balanced_binning=self.balanced_binning, verbose=0)
234 elif self.type == 'classification':
235 y_out = getTarget(self.target, out_df, self.out_log, self.__class__.__name__)
237 dist = ''
238 for k, v in Counter(y_out).items():
239 per = v / len(y_out) * 100
240 rng = ''
241 if ranges:
242 rng = str(ranges[k])
243 dist = dist + 'Class=%d, n=%d (%.3f%%) %s\n' % (k, v, per, rng)
244 fu.log('Classes distribution after resampling:\n\n%s' % dist, self.out_log, self.global_log)
246 # evaluate resampling
247 if self.evaluate:
248 fu.log('Evaluating data after resampling with RandomForestClassifier', self.out_log, self.global_log)
249 cv = RepeatedStratifiedKFold(n_splits=3, n_repeats=3, random_state=42)
250 # evaluate model
251 scores = cross_val_score(RandomForestClassifier(class_weight='balanced'), final_X, y_out, scoring='accuracy', cv=cv, n_jobs=-1)
252 if not np.isnan(np.mean(scores)):
253 fu.log('Mean Accuracy after resampling a %s dataset with %s method: %.3f' % (self.type, resampling_methods[self.method]['method'], np.mean(scores)), self.out_log, self.global_log)
254 else:
255 fu.log('Unable to calculate cross validation score, NaN was returned.', self.out_log, self.global_log)
257 # save output
258 hdr = False
259 if header == 0:
260 hdr = True
261 fu.log('Saving resampled dataset to %s' % self.io_dict["out"]["output_dataset_path"], self.out_log, self.global_log)
262 out_df.to_csv(self.io_dict["out"]["output_dataset_path"], index=False, header=hdr)
264 # Copy files to host
265 self.copy_to_host()
267 self.tmp_files.extend([
268 self.stage_io_dict.get("unique_dir")
269 ])
270 self.remove_tmp_files()
272 self.check_arguments(output_files_created=True, raise_exception=False)
274 return 0
277def resampling(input_dataset_path: str, output_dataset_path: str, properties: dict = None, **kwargs) -> int:
278 """Execute the :class:`Resampling <resampling.resampling.Resampling>` class and
279 execute the :meth:`launch() <resampling.resampling.Resampling.launch>` method."""
281 return Resampling(input_dataset_path=input_dataset_path,
282 output_dataset_path=output_dataset_path,
283 properties=properties, **kwargs).launch()
286def main():
287 """Command line execution of this building block. Please check the command line documentation."""
288 parser = argparse.ArgumentParser(description="Wrapper of the imblearn.combine methods.", formatter_class=lambda prog: argparse.RawTextHelpFormatter(prog, width=99999))
289 parser.add_argument('--config', required=False, help='Configuration file')
291 # Specific args of each building block
292 required_args = parser.add_argument_group('required arguments')
293 required_args.add_argument('--input_dataset_path', required=True, help='Path to the input dataset. Accepted formats: csv.')
294 required_args.add_argument('--output_dataset_path', required=True, help='Path to the output dataset. Accepted formats: csv.')
296 args = parser.parse_args()
297 args.config = args.config or "{}"
298 properties = settings.ConfReader(config=args.config).get_prop_dic()
300 # Specific call of each building block
301 resampling(input_dataset_path=args.input_dataset_path,
302 output_dataset_path=args.output_dataset_path,
303 properties=properties)
306if __name__ == '__main__':
307 main()