Coverage for biobb_ml/classification/common.py: 76%
190 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
1""" Common functions for package biobb_analysis.ambertools """
2from pathlib import Path, PurePath
3import matplotlib.pyplot as plt
4import itertools
5import csv
6import re
7import numpy as np
8import pandas as pd
9import seaborn as sns
10from sklearn.metrics import roc_curve, auc
11from biobb_common.tools import file_utils as fu
12from warnings import simplefilter
13# ignore all future warnings
14simplefilter(action='ignore', category=FutureWarning)
15sns.set()
18# CHECK PARAMETERS
20def check_input_path(path, argument, out_log, classname):
21 """ Checks input file """
22 if not Path(path).exists():
23 fu.log(classname + ': Unexisting %s file, exiting' % argument, out_log)
24 raise SystemExit(classname + ': Unexisting %s file' % argument)
25 file_extension = PurePath(path).suffix
26 if not is_valid_file(file_extension[1:], argument):
27 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log)
28 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument))
29 return path
32def check_output_path(path, argument, optional, out_log, classname):
33 """ Checks output file """
34 if optional and not path:
35 return None
36 if PurePath(path).parent and not Path(PurePath(path).parent).exists():
37 fu.log(classname + ': Unexisting %s folder, exiting' % argument, out_log)
38 raise SystemExit(classname + ': Unexisting %s folder' % argument)
39 file_extension = PurePath(path).suffix
40 if not is_valid_file(file_extension[1:], argument):
41 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log)
42 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument))
43 return path
46def is_valid_file(ext, argument):
47 """ Checks if file format is compatible """
48 formats = {
49 'input_dataset_path': ['csv'],
50 'output_model_path': ['pkl'],
51 'output_dataset_path': ['csv'],
52 'output_results_path': ['csv'],
53 'input_model_path': ['pkl'],
54 'output_test_table_path': ['csv'],
55 'output_plot_path': ['png']
56 }
57 return ext in formats[argument]
60def check_mandatory_property(property, name, out_log, classname):
61 if not property:
62 fu.log(classname + ': Unexisting %s property, exiting' % name, out_log)
63 raise SystemExit(classname + ': Unexisting %s property' % name)
64 return property
67# UTILITIES
69def get_list_of_predictors(predictions):
70 p = []
71 for obj in predictions:
72 a = []
73 for k, v in obj.items():
74 a.append(v)
75 p.append(a)
76 return p
79def get_keys_of_predictors(predictions):
80 p = []
81 for obj in predictions[0]:
82 p.append(obj)
83 return p
86def CMPlotBinary(position, cm, group_names, title, normalize):
87 plt.subplot(position)
88 plt.title(title, size=15)
89 if normalize:
90 group_counts = ["{0:0.2f}".format(value) for value in cm.flatten()]
91 else:
92 group_counts = ["{0:0.0f}".format(value) for value in cm.flatten()]
93 labels_cfm = [f"{v1}\n{v2}" for v1, v2 in zip(group_counts, group_names)]
94 labels_cfm = np.asarray(labels_cfm).reshape(2, 2)
95 sns.heatmap(cm, annot=labels_cfm, fmt='', cmap='Blues', square=True)
96 plt.ylabel('True Values', size=13)
97 plt.xlabel('Predicted Values', size=13)
98 plt.yticks(rotation=0)
101def distPredPlot(position, y, pos_p, labels, title):
102 df = pd.DataFrame({'probPos': pos_p, 'target': y})
103 plt.subplot(position)
104 plt.hist(df[df.target == 1].probPos, density=True, bins=25,
105 alpha=.5, color='green', label=labels[0])
106 plt.hist(df[df.target == 0].probPos, density=True, bins=25,
107 alpha=.5, color='red', label=labels[1])
108 plt.axvline(.5, color='blue', linestyle='--', label='Boundary')
109 plt.xlim([0, 1])
110 plt.title(title, size=15)
111 plt.xlabel('Positive Probability (predicted)', size=13)
112 plt.ylabel('Samples (normalized scale)', size=13)
113 plt.legend(loc="upper right")
116def ROCPlot(position, y, p, cm, title):
117 fp_rates, tp_rates, _ = roc_curve(y, p[:, 1])
118 roc_auc = auc(fp_rates, tp_rates)
119 plt.subplot(position)
120 plt.plot(fp_rates, tp_rates, color='green',
121 lw=1, label='ROC curve (area = %0.2f)' % roc_auc)
122 plt.plot([0, 1], [0, 1], lw=1, linestyle='--', color='grey')
123 # plot current decision point:
124 tn, fp, fn, tp = [i for i in cm.ravel()]
125 plt.plot(fp/(fp+tn), tp/(tp+fn), 'bo', markersize=8, label='Decision Point')
126 plt.xlim([0.0, 1.0])
127 plt.ylim([0.0, 1.05])
128 plt.xlabel('False Positive Rate', size=13)
129 plt.ylabel('True Positive Rate', size=13)
130 plt.title(title, size=15)
131 plt.legend(loc="lower right")
134# Visualize the performance of a Logistic Regression Binary Classifier.
135# https://towardsdatascience.com/how-to-interpret-a-binary-logistic-regressor-with-scikit-learn-6d56c5783b49
136def plotBinaryClassifier(model, proba_train, proba_test, cm_train, cm_test, y_train, y_test, normalize=False, labels=['Positives', 'Negatives'], cmticks=[0, 1], get_plot=True):
138 # TRAINING
140 # model predicts probabilities of positive class
141 p = proba_train
142 if len(model.classes_) != 2:
143 raise ValueError('A binary class problem is required')
144 if model.classes_[1] == 1:
145 pos_p = p[:, 1]
146 elif model.classes_[0] == 1:
147 pos_p = p[:, 0]
149 # FIGURE
150 plt.figure(figsize=[15, 8])
152 # 1 -- Confusion matrix train
153 CMPlotBinary(231, cm_train, ['True Negatives', 'False Positives', 'False Negatives', 'True Positives'], 'Confusion Matrix Train', normalize)
155 # 2 -- Distributions of Predicted Probabilities of both classes train
156 distPredPlot(232, y_train, pos_p, labels, 'Distributions of Predictions Train')
158 # 3 -- ROC curve with annotated decision point train
159 ROCPlot(233, y_train, p, cm_train, 'ROC Curve Train')
161 # TESTING
163 # model predicts probabilities of positive class
164 p = proba_test
165 if len(model.classes_) != 2:
166 raise ValueError('A binary class problem is required')
167 if model.classes_[1] == 1:
168 pos_p = p[:, 1]
169 elif model.classes_[0] == 1:
170 pos_p = p[:, 0]
172 # 1 -- Confusion matrix test
173 CMPlotBinary(234, cm_test, ['True Negatives', 'False Positives', 'False Negatives', 'True Positives'], 'Confusion Matrix Test', normalize)
175 # 2 -- Distributions of Predicted Probabilities of both classes test
176 distPredPlot(235, y_test, pos_p, labels, 'Distributions of Predictions Test')
178 # 3 -- ROC curve with annotated decision point test
179 ROCPlot(236, y_test, p, cm_test, 'ROC Curve Test')
181 plt.tight_layout()
183 return plt
186def CMplotNonBinary(position, cm, title, normalize, values):
188 if cm.shape[1] < 5:
189 fs = 10
190 elif cm.shape[1] >= 5 and cm.shape[1] < 10:
191 fs = 8
192 elif cm.shape[1] >= 10:
193 fs = 6
195 plt.subplot(position)
196 plt.title(title, size=15)
197 if normalize:
198 group_counts = ["{0:0.2f}".format(value) for value in cm.flatten()]
199 else:
200 group_counts = ["{0:0.0f}".format(value) for value in cm.flatten()]
201 group_names = []
202 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
203 if i == j:
204 group_names.append("True " + str(values[i]))
205 else:
206 group_names.append("False " + str(values[i]))
207 labels_cfm = [f"{v1}\n{v2}" for v1, v2 in zip(group_counts, group_names)]
208 labels_cfm = np.asarray(labels_cfm).reshape(cm.shape[0], cm.shape[1])
209 sns.heatmap(cm, annot=labels_cfm, fmt='', cmap='Blues', xticklabels=values, yticklabels=values, square=True, annot_kws={"fontsize": fs})
210 plt.ylabel('True Values', size=13)
211 plt.xlabel('Predicted Values', size=13)
212 plt.yticks(rotation=0)
215def plotMultipleCM(cm_train, cm_test, normalize, values):
217 # FIGURE
218 plt.figure(figsize=[8, 4])
220 # 1 -- Confusion matrix train
221 CMplotNonBinary(121, cm_train, 'Confusion Matrix Train', normalize, values)
223 # 2 -- Confusion matrix test
224 CMplotNonBinary(122, cm_test, 'Confusion Matrix Test', normalize, values)
226 plt.tight_layout()
228 return plt
231def getIndependentVars(independent_vars, data, out_log, classname):
232 if 'indexes' in independent_vars:
233 return data.iloc[:, independent_vars['indexes']]
234 elif 'range' in independent_vars:
235 ranges_list = []
236 for rng in independent_vars['range']:
237 for x in range(rng[0], (rng[1] + 1)):
238 ranges_list.append(x)
239 return data.iloc[:, ranges_list]
240 elif 'columns' in independent_vars:
241 return data.loc[:, independent_vars['columns']]
242 else:
243 fu.log(classname + ': Incorrect independent_vars format', out_log)
244 raise SystemExit(classname + ': Incorrect independent_vars format')
247def getIndependentVarsList(independent_vars):
248 if 'indexes' in independent_vars:
249 return ', '.join(str(x) for x in independent_vars['indexes'])
250 elif 'range' in independent_vars:
251 return ', '.join([str(y) for r in independent_vars['range'] for y in range(r[0], r[1] + 1)])
252 elif 'columns' in independent_vars:
253 return ', '.join(independent_vars['columns'])
256def getTarget(target, data, out_log, classname):
257 if 'index' in target:
258 return data.iloc[:, target['index']]
259 elif 'column' in target:
260 return data[target['column']]
261 else:
262 fu.log(classname + ': Incorrect target format', out_log)
263 raise SystemExit(classname + ': Incorrect target format')
266def getTargetValue(target):
267 if 'index' in target:
268 return str(target['index'])
269 elif 'column' in target:
270 return target['column']
273def getWeight(weight, data, out_log, classname):
274 if 'index' in weight:
275 return data.iloc[:, weight['index']]
276 elif 'column' in weight:
277 return data[weight['column']]
278 else:
279 fu.log(classname + ': Incorrect weight format', out_log)
280 raise SystemExit(classname + ': Incorrect weight format')
283def getHeader(file):
284 with open(file, newline='') as f:
285 reader = csv.reader(f)
286 header = next(reader)
288 if (len(header) == 1):
289 return list(re.sub('\\s+|;|:|,|\t', ',', header[0]).split(","))
290 else:
291 return header