Coverage for biobb_ml/neural_networks/common.py: 71%

248 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-03 14:57 +0000

1""" Common functions for package biobb_analysis.ambertools """ 

2import matplotlib.pyplot as plt 

3import numpy as np 

4import pandas as pd 

5import seaborn as sns 

6import itertools 

7import csv 

8import re 

9from pathlib import Path, PurePath 

10from sklearn.metrics import roc_curve, auc 

11from biobb_common.tools import file_utils as fu 

12from warnings import simplefilter 

13# ignore all future warnings 

14simplefilter(action='ignore', category=FutureWarning) 

15sns.set() 

16 

17 

18# CHECK PARAMETERS 

19 

20def check_input_path(path, argument, optional, out_log, classname): 

21 """ Checks input file """ 

22 if optional and not path: 

23 return None 

24 if not Path(path).exists(): 

25 fu.log(classname + ': Unexisting %s file, exiting' % argument, out_log) 

26 raise SystemExit(classname + ': Unexisting %s file' % argument) 

27 file_extension = PurePath(path).suffix 

28 if not is_valid_file(file_extension[1:], argument): 

29 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log) 

30 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument)) 

31 return path 

32 

33 

34def check_output_path(path, argument, optional, out_log, classname): 

35 """ Checks output file """ 

36 if optional and not path: 

37 return None 

38 if PurePath(path).parent and not Path(PurePath(path).parent).exists(): 

39 fu.log(classname + ': Unexisting %s folder, exiting' % argument, out_log) 

40 raise SystemExit(classname + ': Unexisting %s folder' % argument) 

41 file_extension = PurePath(path).suffix 

42 if not is_valid_file(file_extension[1:], argument): 

43 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log) 

44 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument)) 

45 return path 

46 

47 

48def is_valid_file(ext, argument): 

49 """ Checks if file format is compatible """ 

50 formats = { 

51 'input_dataset_path': ['csv'], 

52 'input_decode_path': ['csv'], 

53 'input_predict_path': ['csv'], 

54 'input_model_path': ['h5'], 

55 'output_model_path': ['h5'], 

56 'output_results_path': ['csv'], 

57 'output_test_table_path': ['csv'], 

58 'output_test_decode_path': ['csv'], 

59 'output_test_predict_path': ['csv'], 

60 'output_decode_path': ['csv'], 

61 'output_predict_path': ['csv'], 

62 'output_plot_path': ['png'] 

63 } 

64 return ext in formats[argument] 

65 

66 

67def check_mandatory_property(property, name, out_log, classname): 

68 if not property: 

69 fu.log(classname + ': Unexisting %s property, exiting' % name, out_log) 

70 raise SystemExit(classname + ': Unexisting %s property' % name) 

71 return property 

72 

73 

74# UTILITIES 

75 

76def get_list_of_predictors(predictions): 

77 p = [] 

78 for obj in predictions: 

79 a = [] 

80 for k, v in obj.items(): 

81 a.append(v) 

82 p.append(a) 

83 return p 

84 

85 

86def get_keys_of_predictors(predictions): 

87 p = [] 

88 for obj in predictions[0]: 

89 p.append(obj) 

90 return p 

91 

92 

93def get_num_cols(num): 

94 p = [] 

95 for i in range(1, num + 1): 

96 p.append('item ' + str(i)) 

97 return p 

98 

99 

100def split_sequence(sequence, n_steps): 

101 X, y = list(), list() 

102 for i in range(len(sequence)): 

103 # find the end of this pattern 

104 end_ix = i + n_steps 

105 # check if we are beyond the sequence 

106 if end_ix > len(sequence)-1: 

107 break 

108 # gather input and output parts of the pattern 

109 seq_x, seq_y = sequence[i:end_ix], sequence[end_ix] 

110 X.append(seq_x) 

111 y.append(seq_y) 

112 return np.asarray(X), np.asarray(y) 

113 

114 

115def doublePlot(tit, data1, data2, xlabel, ylabel, legend): 

116 plt.title(tit, size=15) 

117 plt.plot(data1) 

118 plt.plot(data2) 

119 plt.xlabel(xlabel, size=14) 

120 plt.ylabel(ylabel, size=14) 

121 plt.legend(legend, loc='best') 

122 

123 

124def predictionPlot(tit, data1, data2, xlabel, ylabel): 

125 plt.title(tit, size=15) 

126 plt.scatter(data1, data2, alpha=0.2) 

127 plt.xlabel(xlabel, size=14) 

128 plt.ylabel(ylabel, size=14) 

129 axes = plt.gca() 

130 lims = axes.get_xlim() 

131 plt.xlim(lims) 

132 plt.ylim(lims) 

133 plt.plot(lims, lims) 

134 

135 

136def histogramPlot(tit, data1, data2, xlabel, ylabel): 

137 plt.title(tit, size=15) 

138 error = data2 - data1 

139 plt.hist(error, bins=25) 

140 plt.xlabel(xlabel, size=14) 

141 plt.ylabel(ylabel, size=14) 

142 

143 

144def CMPlotBinary(position, cm, group_names, title, normalize): 

145 plt.subplot(position) 

146 plt.title(title, size=15) 

147 if normalize: 

148 group_counts = ["{0:0.2f}".format(value) for value in cm.flatten()] 

149 else: 

150 group_counts = ["{0:0.0f}".format(value) for value in cm.flatten()] 

151 labels_cfm = [f"{v1}\n{v2}" for v1, v2 in zip(group_counts, group_names)] 

152 labels_cfm = np.asarray(labels_cfm).reshape(2, 2) 

153 sns.heatmap(cm, annot=labels_cfm, fmt='', cmap='Blues', square=True) 

154 plt.ylabel('True Values', size=13) 

155 plt.xlabel('Predicted Values', size=13) 

156 plt.yticks(rotation=0) 

157 

158 

159def CMplotNonBinary(position, cm, title, normalize, values): 

160 

161 if cm.shape[1] < 5: 

162 fs = 10 

163 elif cm.shape[1] >= 5 and cm.shape[1] < 10: 

164 fs = 8 

165 elif cm.shape[1] >= 10: 

166 fs = 6 

167 

168 plt.subplot(position) 

169 plt.title(title, size=15) 

170 if normalize: 

171 group_counts = ["{0:0.2f}".format(value) for value in cm.flatten()] 

172 else: 

173 group_counts = ["{0:0.0f}".format(value) for value in cm.flatten()] 

174 group_names = [] 

175 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 

176 if i == j: 

177 group_names.append("True " + str(values[i])) 

178 else: 

179 group_names.append("False " + str(values[i])) 

180 labels_cfm = [f"{v1}\n{v2}" for v1, v2 in zip(group_counts, group_names)] 

181 labels_cfm = np.asarray(labels_cfm).reshape(cm.shape[0], cm.shape[1]) 

182 sns.heatmap(cm, annot=labels_cfm, fmt='', cmap='Blues', xticklabels=values, yticklabels=values, square=True, annot_kws={"fontsize": fs}) 

183 plt.ylabel('True Values', size=13) 

184 plt.xlabel('Predicted Values', size=13) 

185 plt.yticks(rotation=0) 

186 

187 

188def plotResultsClassMultCM(data, cm_train, cm_test, normalize, values): 

189 

190 # FIGURE 

191 plt.figure(figsize=[12, 8]) 

192 

193 plt.subplot(231) 

194 doublePlot('Model loss', data['loss'], data['val_loss'], 'epoch', 'loss', ['training', 'validation']) 

195 

196 plt.subplot(232) 

197 doublePlot('Model accuracy', data['accuracy'], data['val_accuracy'], 'epoch', 'accuracy', ['training', 'validation']) 

198 

199 plt.subplot(233) 

200 doublePlot('Model MSE', data['mse'], data['val_mse'], 'epoch', 'mse', ['training', 'validation']) 

201 

202 CMplotNonBinary(234, cm_train, 'Confusion Matrix Train', normalize, values) 

203 

204 CMplotNonBinary(235, cm_test, 'Confusion Matrix Test', normalize, values) 

205 

206 plt.tight_layout() 

207 

208 return plt 

209 

210 

211def distPredPlot(position, y, pos_p, labels, title): 

212 df = pd.DataFrame({'probPos': pos_p, 'target': y}) 

213 plt.subplot(position) 

214 plt.hist(df[df.target == 1].probPos, density=True, bins=25, 

215 alpha=.5, color='green', label=labels[0]) 

216 plt.hist(df[df.target == 0].probPos, density=True, bins=25, 

217 alpha=.5, color='red', label=labels[1]) 

218 plt.axvline(.5, color='blue', linestyle='--', label='Boundary') 

219 plt.xlim([0, 1]) 

220 plt.title(title, size=15) 

221 plt.xlabel('Positive Probability (predicted)', size=13) 

222 plt.ylabel('Samples (normalized scale)', size=13) 

223 plt.legend(loc="upper right") 

224 

225 

226def ROCPlot(position, y, p, cm, title): 

227 fp_rates, tp_rates, _ = roc_curve(y, p) 

228 roc_auc = auc(fp_rates, tp_rates) 

229 plt.subplot(position) 

230 plt.plot(fp_rates, tp_rates, color='green', 

231 lw=1, label='ROC curve (area = %0.2f)' % roc_auc) 

232 plt.plot([0, 1], [0, 1], lw=1, linestyle='--', color='grey') 

233 # plot current decision point: 

234 tn, fp, fn, tp = [i for i in cm.ravel()] 

235 plt.plot(fp/(fp+tn), tp/(tp+fn), 'bo', markersize=8, label='Decision Point') 

236 plt.xlim([0.0, 1.0]) 

237 plt.ylim([0.0, 1.05]) 

238 plt.xlabel('False Positive Rate', size=13) 

239 plt.ylabel('True Positive Rate', size=13) 

240 plt.title(title, size=15) 

241 plt.legend(loc="lower right") 

242 

243 

244def plotResultsClassBinCM(data, proba_train, proba_test, y_train, y_test, cm_train, cm_test, normalize, values): 

245 

246 # FIGURE 

247 plt.figure(figsize=[15, 15]) 

248 

249 plt.subplot(331) 

250 doublePlot('Model loss', data['loss'], data['val_loss'], 'epoch', 'loss', ['training', 'validation']) 

251 

252 plt.subplot(332) 

253 doublePlot('Model accuracy', data['accuracy'], data['val_accuracy'], 'epoch', 'accuracy', ['training', 'validation']) 

254 

255 plt.subplot(333) 

256 doublePlot('Model MSE', data['mse'], data['val_mse'], 'epoch', 'mse', ['training', 'validation']) 

257 

258 pos_p = proba_train[:, 1] 

259 

260 # CMplotNonBinary(334, cm_train, 'Confusion Matrix Train', normalize, values) 

261 CMPlotBinary(334, cm_train, ['True Negatives', 'False Positives', 'False Negatives', 'True Positives'], 'Confusion Matrix Train', normalize) 

262 

263 distPredPlot(335, y_train, pos_p, ['Positives', 'Negatives'], 'Distributions of Predictions Train') 

264 

265 ROCPlot(336, y_train, pos_p, cm_train, 'ROC Curve Train') 

266 

267 pos_p = proba_test[:, 1] 

268 

269 # CMplotNonBinary(337, cm_test, 'Confusion Matrix Test', normalize, values) 

270 CMPlotBinary(337, cm_test, ['True Negatives', 'False Positives', 'False Negatives', 'True Positives'], 'Confusion Matrix Test', normalize) 

271 

272 distPredPlot(338, y_test, pos_p, ['Positives', 'Negatives'], 'Distributions of Predictions Test') 

273 

274 ROCPlot(339, y_test, pos_p, cm_test, 'ROC Curve Test') 

275 

276 plt.tight_layout() 

277 

278 return plt 

279 

280 

281def plotResultsReg(data, test_labels, test_predictions, train_labels, train_predictions): 

282 

283 # FIGURE 

284 plt.figure(figsize=[12, 12]) 

285 

286 plt.subplot(331) 

287 doublePlot('Model loss', data['loss'], data['val_loss'], 'epoch', 'loss', ['training', 'validation']) 

288 

289 plt.subplot(332) 

290 doublePlot('Model MAE', data['mae'], data['val_mae'], 'epoch', 'mae', ['training', 'validation']) 

291 

292 plt.subplot(333) 

293 doublePlot('Model MSE', data['mse'], data['val_mse'], 'epoch', 'mse', ['training', 'validation']) 

294 

295 plt.subplot(334) 

296 predictionPlot('Train predictions', train_labels, train_predictions, 'true values', 'predictions') 

297 

298 plt.subplot(335) 

299 histogramPlot('Train histogram', train_labels, train_predictions, 'prediction error', 'count') 

300 

301 plt.subplot(337) 

302 predictionPlot('Test predictions', test_labels, test_predictions, 'true values', 'predictions') 

303 

304 plt.subplot(338) 

305 histogramPlot('Test histogram', test_labels, test_predictions, 'prediction error', 'count') 

306 

307 plt.tight_layout() 

308 

309 return plt 

310 

311 

312def getFeatures(independent_vars, data, out_log, classname): 

313 if 'indexes' in independent_vars: 

314 return data.iloc[:, independent_vars['indexes']] 

315 elif 'range' in independent_vars: 

316 ranges_list = [] 

317 for rng in independent_vars['range']: 

318 for x in range(rng[0], (rng[1] + 1)): 

319 ranges_list.append(x) 

320 return data.iloc[:, ranges_list] 

321 elif 'columns' in independent_vars: 

322 return data.loc[:, independent_vars['columns']] 

323 else: 

324 fu.log(classname + ': Incorrect independent_vars format', out_log) 

325 raise SystemExit(classname + ': Incorrect independent_vars format') 

326 

327 

328def getIndependentVarsList(independent_vars): 

329 if 'indexes' in independent_vars: 

330 return ', '.join(str(x) for x in independent_vars['indexes']) 

331 elif 'range' in independent_vars: 

332 return ', '.join([str(y) for r in independent_vars['range'] for y in range(r[0], r[1] + 1)]) 

333 elif 'columns' in independent_vars: 

334 return ', '.join(independent_vars['columns']) 

335 

336 

337def getTarget(target, data, out_log, classname): 

338 if 'index' in target: 

339 return data.iloc[:, target['index']] 

340 elif 'column' in target: 

341 return data[target['column']] 

342 else: 

343 fu.log(classname + ': Incorrect target format', out_log) 

344 raise SystemExit(classname + ': Incorrect target format') 

345 

346 

347def getWeight(weight, data, out_log, classname): 

348 if 'index' in weight: 

349 return data.iloc[:, weight['index']] 

350 elif 'column' in weight: 

351 return data[weight['column']] 

352 else: 

353 fu.log(classname + ': Incorrect weight format', out_log) 

354 raise SystemExit(classname + ': Incorrect weight format') 

355 

356 

357def getHeader(file): 

358 with open(file, newline='') as f: 

359 reader = csv.reader(f) 

360 header = next(reader) 

361 

362 if (len(header) == 1): 

363 return list(re.sub('\\s+|;|:|,|\t', ',', header[0]).split(",")) 

364 else: 

365 return header 

366 

367 

368def getTargetValue(target): 

369 if 'index' in target: 

370 return target['index'] 

371 elif 'column' in target: 

372 return target['column']