Coverage for biobb_ml/neural_networks/common.py: 71%
248 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
1""" Common functions for package biobb_analysis.ambertools """
2import matplotlib.pyplot as plt
3import numpy as np
4import pandas as pd
5import seaborn as sns
6import itertools
7import csv
8import re
9from pathlib import Path, PurePath
10from sklearn.metrics import roc_curve, auc
11from biobb_common.tools import file_utils as fu
12from warnings import simplefilter
13# ignore all future warnings
14simplefilter(action='ignore', category=FutureWarning)
15sns.set()
18# CHECK PARAMETERS
20def check_input_path(path, argument, optional, out_log, classname):
21 """ Checks input file """
22 if optional and not path:
23 return None
24 if not Path(path).exists():
25 fu.log(classname + ': Unexisting %s file, exiting' % argument, out_log)
26 raise SystemExit(classname + ': Unexisting %s file' % argument)
27 file_extension = PurePath(path).suffix
28 if not is_valid_file(file_extension[1:], argument):
29 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log)
30 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument))
31 return path
34def check_output_path(path, argument, optional, out_log, classname):
35 """ Checks output file """
36 if optional and not path:
37 return None
38 if PurePath(path).parent and not Path(PurePath(path).parent).exists():
39 fu.log(classname + ': Unexisting %s folder, exiting' % argument, out_log)
40 raise SystemExit(classname + ': Unexisting %s folder' % argument)
41 file_extension = PurePath(path).suffix
42 if not is_valid_file(file_extension[1:], argument):
43 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log)
44 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument))
45 return path
48def is_valid_file(ext, argument):
49 """ Checks if file format is compatible """
50 formats = {
51 'input_dataset_path': ['csv'],
52 'input_decode_path': ['csv'],
53 'input_predict_path': ['csv'],
54 'input_model_path': ['h5'],
55 'output_model_path': ['h5'],
56 'output_results_path': ['csv'],
57 'output_test_table_path': ['csv'],
58 'output_test_decode_path': ['csv'],
59 'output_test_predict_path': ['csv'],
60 'output_decode_path': ['csv'],
61 'output_predict_path': ['csv'],
62 'output_plot_path': ['png']
63 }
64 return ext in formats[argument]
67def check_mandatory_property(property, name, out_log, classname):
68 if not property:
69 fu.log(classname + ': Unexisting %s property, exiting' % name, out_log)
70 raise SystemExit(classname + ': Unexisting %s property' % name)
71 return property
74# UTILITIES
76def get_list_of_predictors(predictions):
77 p = []
78 for obj in predictions:
79 a = []
80 for k, v in obj.items():
81 a.append(v)
82 p.append(a)
83 return p
86def get_keys_of_predictors(predictions):
87 p = []
88 for obj in predictions[0]:
89 p.append(obj)
90 return p
93def get_num_cols(num):
94 p = []
95 for i in range(1, num + 1):
96 p.append('item ' + str(i))
97 return p
100def split_sequence(sequence, n_steps):
101 X, y = list(), list()
102 for i in range(len(sequence)):
103 # find the end of this pattern
104 end_ix = i + n_steps
105 # check if we are beyond the sequence
106 if end_ix > len(sequence)-1:
107 break
108 # gather input and output parts of the pattern
109 seq_x, seq_y = sequence[i:end_ix], sequence[end_ix]
110 X.append(seq_x)
111 y.append(seq_y)
112 return np.asarray(X), np.asarray(y)
115def doublePlot(tit, data1, data2, xlabel, ylabel, legend):
116 plt.title(tit, size=15)
117 plt.plot(data1)
118 plt.plot(data2)
119 plt.xlabel(xlabel, size=14)
120 plt.ylabel(ylabel, size=14)
121 plt.legend(legend, loc='best')
124def predictionPlot(tit, data1, data2, xlabel, ylabel):
125 plt.title(tit, size=15)
126 plt.scatter(data1, data2, alpha=0.2)
127 plt.xlabel(xlabel, size=14)
128 plt.ylabel(ylabel, size=14)
129 axes = plt.gca()
130 lims = axes.get_xlim()
131 plt.xlim(lims)
132 plt.ylim(lims)
133 plt.plot(lims, lims)
136def histogramPlot(tit, data1, data2, xlabel, ylabel):
137 plt.title(tit, size=15)
138 error = data2 - data1
139 plt.hist(error, bins=25)
140 plt.xlabel(xlabel, size=14)
141 plt.ylabel(ylabel, size=14)
144def CMPlotBinary(position, cm, group_names, title, normalize):
145 plt.subplot(position)
146 plt.title(title, size=15)
147 if normalize:
148 group_counts = ["{0:0.2f}".format(value) for value in cm.flatten()]
149 else:
150 group_counts = ["{0:0.0f}".format(value) for value in cm.flatten()]
151 labels_cfm = [f"{v1}\n{v2}" for v1, v2 in zip(group_counts, group_names)]
152 labels_cfm = np.asarray(labels_cfm).reshape(2, 2)
153 sns.heatmap(cm, annot=labels_cfm, fmt='', cmap='Blues', square=True)
154 plt.ylabel('True Values', size=13)
155 plt.xlabel('Predicted Values', size=13)
156 plt.yticks(rotation=0)
159def CMplotNonBinary(position, cm, title, normalize, values):
161 if cm.shape[1] < 5:
162 fs = 10
163 elif cm.shape[1] >= 5 and cm.shape[1] < 10:
164 fs = 8
165 elif cm.shape[1] >= 10:
166 fs = 6
168 plt.subplot(position)
169 plt.title(title, size=15)
170 if normalize:
171 group_counts = ["{0:0.2f}".format(value) for value in cm.flatten()]
172 else:
173 group_counts = ["{0:0.0f}".format(value) for value in cm.flatten()]
174 group_names = []
175 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
176 if i == j:
177 group_names.append("True " + str(values[i]))
178 else:
179 group_names.append("False " + str(values[i]))
180 labels_cfm = [f"{v1}\n{v2}" for v1, v2 in zip(group_counts, group_names)]
181 labels_cfm = np.asarray(labels_cfm).reshape(cm.shape[0], cm.shape[1])
182 sns.heatmap(cm, annot=labels_cfm, fmt='', cmap='Blues', xticklabels=values, yticklabels=values, square=True, annot_kws={"fontsize": fs})
183 plt.ylabel('True Values', size=13)
184 plt.xlabel('Predicted Values', size=13)
185 plt.yticks(rotation=0)
188def plotResultsClassMultCM(data, cm_train, cm_test, normalize, values):
190 # FIGURE
191 plt.figure(figsize=[12, 8])
193 plt.subplot(231)
194 doublePlot('Model loss', data['loss'], data['val_loss'], 'epoch', 'loss', ['training', 'validation'])
196 plt.subplot(232)
197 doublePlot('Model accuracy', data['accuracy'], data['val_accuracy'], 'epoch', 'accuracy', ['training', 'validation'])
199 plt.subplot(233)
200 doublePlot('Model MSE', data['mse'], data['val_mse'], 'epoch', 'mse', ['training', 'validation'])
202 CMplotNonBinary(234, cm_train, 'Confusion Matrix Train', normalize, values)
204 CMplotNonBinary(235, cm_test, 'Confusion Matrix Test', normalize, values)
206 plt.tight_layout()
208 return plt
211def distPredPlot(position, y, pos_p, labels, title):
212 df = pd.DataFrame({'probPos': pos_p, 'target': y})
213 plt.subplot(position)
214 plt.hist(df[df.target == 1].probPos, density=True, bins=25,
215 alpha=.5, color='green', label=labels[0])
216 plt.hist(df[df.target == 0].probPos, density=True, bins=25,
217 alpha=.5, color='red', label=labels[1])
218 plt.axvline(.5, color='blue', linestyle='--', label='Boundary')
219 plt.xlim([0, 1])
220 plt.title(title, size=15)
221 plt.xlabel('Positive Probability (predicted)', size=13)
222 plt.ylabel('Samples (normalized scale)', size=13)
223 plt.legend(loc="upper right")
226def ROCPlot(position, y, p, cm, title):
227 fp_rates, tp_rates, _ = roc_curve(y, p)
228 roc_auc = auc(fp_rates, tp_rates)
229 plt.subplot(position)
230 plt.plot(fp_rates, tp_rates, color='green',
231 lw=1, label='ROC curve (area = %0.2f)' % roc_auc)
232 plt.plot([0, 1], [0, 1], lw=1, linestyle='--', color='grey')
233 # plot current decision point:
234 tn, fp, fn, tp = [i for i in cm.ravel()]
235 plt.plot(fp/(fp+tn), tp/(tp+fn), 'bo', markersize=8, label='Decision Point')
236 plt.xlim([0.0, 1.0])
237 plt.ylim([0.0, 1.05])
238 plt.xlabel('False Positive Rate', size=13)
239 plt.ylabel('True Positive Rate', size=13)
240 plt.title(title, size=15)
241 plt.legend(loc="lower right")
244def plotResultsClassBinCM(data, proba_train, proba_test, y_train, y_test, cm_train, cm_test, normalize, values):
246 # FIGURE
247 plt.figure(figsize=[15, 15])
249 plt.subplot(331)
250 doublePlot('Model loss', data['loss'], data['val_loss'], 'epoch', 'loss', ['training', 'validation'])
252 plt.subplot(332)
253 doublePlot('Model accuracy', data['accuracy'], data['val_accuracy'], 'epoch', 'accuracy', ['training', 'validation'])
255 plt.subplot(333)
256 doublePlot('Model MSE', data['mse'], data['val_mse'], 'epoch', 'mse', ['training', 'validation'])
258 pos_p = proba_train[:, 1]
260 # CMplotNonBinary(334, cm_train, 'Confusion Matrix Train', normalize, values)
261 CMPlotBinary(334, cm_train, ['True Negatives', 'False Positives', 'False Negatives', 'True Positives'], 'Confusion Matrix Train', normalize)
263 distPredPlot(335, y_train, pos_p, ['Positives', 'Negatives'], 'Distributions of Predictions Train')
265 ROCPlot(336, y_train, pos_p, cm_train, 'ROC Curve Train')
267 pos_p = proba_test[:, 1]
269 # CMplotNonBinary(337, cm_test, 'Confusion Matrix Test', normalize, values)
270 CMPlotBinary(337, cm_test, ['True Negatives', 'False Positives', 'False Negatives', 'True Positives'], 'Confusion Matrix Test', normalize)
272 distPredPlot(338, y_test, pos_p, ['Positives', 'Negatives'], 'Distributions of Predictions Test')
274 ROCPlot(339, y_test, pos_p, cm_test, 'ROC Curve Test')
276 plt.tight_layout()
278 return plt
281def plotResultsReg(data, test_labels, test_predictions, train_labels, train_predictions):
283 # FIGURE
284 plt.figure(figsize=[12, 12])
286 plt.subplot(331)
287 doublePlot('Model loss', data['loss'], data['val_loss'], 'epoch', 'loss', ['training', 'validation'])
289 plt.subplot(332)
290 doublePlot('Model MAE', data['mae'], data['val_mae'], 'epoch', 'mae', ['training', 'validation'])
292 plt.subplot(333)
293 doublePlot('Model MSE', data['mse'], data['val_mse'], 'epoch', 'mse', ['training', 'validation'])
295 plt.subplot(334)
296 predictionPlot('Train predictions', train_labels, train_predictions, 'true values', 'predictions')
298 plt.subplot(335)
299 histogramPlot('Train histogram', train_labels, train_predictions, 'prediction error', 'count')
301 plt.subplot(337)
302 predictionPlot('Test predictions', test_labels, test_predictions, 'true values', 'predictions')
304 plt.subplot(338)
305 histogramPlot('Test histogram', test_labels, test_predictions, 'prediction error', 'count')
307 plt.tight_layout()
309 return plt
312def getFeatures(independent_vars, data, out_log, classname):
313 if 'indexes' in independent_vars:
314 return data.iloc[:, independent_vars['indexes']]
315 elif 'range' in independent_vars:
316 ranges_list = []
317 for rng in independent_vars['range']:
318 for x in range(rng[0], (rng[1] + 1)):
319 ranges_list.append(x)
320 return data.iloc[:, ranges_list]
321 elif 'columns' in independent_vars:
322 return data.loc[:, independent_vars['columns']]
323 else:
324 fu.log(classname + ': Incorrect independent_vars format', out_log)
325 raise SystemExit(classname + ': Incorrect independent_vars format')
328def getIndependentVarsList(independent_vars):
329 if 'indexes' in independent_vars:
330 return ', '.join(str(x) for x in independent_vars['indexes'])
331 elif 'range' in independent_vars:
332 return ', '.join([str(y) for r in independent_vars['range'] for y in range(r[0], r[1] + 1)])
333 elif 'columns' in independent_vars:
334 return ', '.join(independent_vars['columns'])
337def getTarget(target, data, out_log, classname):
338 if 'index' in target:
339 return data.iloc[:, target['index']]
340 elif 'column' in target:
341 return data[target['column']]
342 else:
343 fu.log(classname + ': Incorrect target format', out_log)
344 raise SystemExit(classname + ': Incorrect target format')
347def getWeight(weight, data, out_log, classname):
348 if 'index' in weight:
349 return data.iloc[:, weight['index']]
350 elif 'column' in weight:
351 return data[weight['column']]
352 else:
353 fu.log(classname + ': Incorrect weight format', out_log)
354 raise SystemExit(classname + ': Incorrect weight format')
357def getHeader(file):
358 with open(file, newline='') as f:
359 reader = csv.reader(f)
360 header = next(reader)
362 if (len(header) == 1):
363 return list(re.sub('\\s+|;|:|,|\t', ',', header[0]).split(","))
364 else:
365 return header
368def getTargetValue(target):
369 if 'index' in target:
370 return target['index']
371 elif 'column' in target:
372 return target['column']