Coverage for biobb_ml/classification/common.py: 76%

190 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-03 14:57 +0000

1""" Common functions for package biobb_analysis.ambertools """ 

2from pathlib import Path, PurePath 

3import matplotlib.pyplot as plt 

4import itertools 

5import csv 

6import re 

7import numpy as np 

8import pandas as pd 

9import seaborn as sns 

10from sklearn.metrics import roc_curve, auc 

11from biobb_common.tools import file_utils as fu 

12from warnings import simplefilter 

13# ignore all future warnings 

14simplefilter(action='ignore', category=FutureWarning) 

15sns.set() 

16 

17 

18# CHECK PARAMETERS 

19 

20def check_input_path(path, argument, out_log, classname): 

21 """ Checks input file """ 

22 if not Path(path).exists(): 

23 fu.log(classname + ': Unexisting %s file, exiting' % argument, out_log) 

24 raise SystemExit(classname + ': Unexisting %s file' % argument) 

25 file_extension = PurePath(path).suffix 

26 if not is_valid_file(file_extension[1:], argument): 

27 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log) 

28 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument)) 

29 return path 

30 

31 

32def check_output_path(path, argument, optional, out_log, classname): 

33 """ Checks output file """ 

34 if optional and not path: 

35 return None 

36 if PurePath(path).parent and not Path(PurePath(path).parent).exists(): 

37 fu.log(classname + ': Unexisting %s folder, exiting' % argument, out_log) 

38 raise SystemExit(classname + ': Unexisting %s folder' % argument) 

39 file_extension = PurePath(path).suffix 

40 if not is_valid_file(file_extension[1:], argument): 

41 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log) 

42 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument)) 

43 return path 

44 

45 

46def is_valid_file(ext, argument): 

47 """ Checks if file format is compatible """ 

48 formats = { 

49 'input_dataset_path': ['csv'], 

50 'output_model_path': ['pkl'], 

51 'output_dataset_path': ['csv'], 

52 'output_results_path': ['csv'], 

53 'input_model_path': ['pkl'], 

54 'output_test_table_path': ['csv'], 

55 'output_plot_path': ['png'] 

56 } 

57 return ext in formats[argument] 

58 

59 

60def check_mandatory_property(property, name, out_log, classname): 

61 if not property: 

62 fu.log(classname + ': Unexisting %s property, exiting' % name, out_log) 

63 raise SystemExit(classname + ': Unexisting %s property' % name) 

64 return property 

65 

66 

67# UTILITIES 

68 

69def get_list_of_predictors(predictions): 

70 p = [] 

71 for obj in predictions: 

72 a = [] 

73 for k, v in obj.items(): 

74 a.append(v) 

75 p.append(a) 

76 return p 

77 

78 

79def get_keys_of_predictors(predictions): 

80 p = [] 

81 for obj in predictions[0]: 

82 p.append(obj) 

83 return p 

84 

85 

86def CMPlotBinary(position, cm, group_names, title, normalize): 

87 plt.subplot(position) 

88 plt.title(title, size=15) 

89 if normalize: 

90 group_counts = ["{0:0.2f}".format(value) for value in cm.flatten()] 

91 else: 

92 group_counts = ["{0:0.0f}".format(value) for value in cm.flatten()] 

93 labels_cfm = [f"{v1}\n{v2}" for v1, v2 in zip(group_counts, group_names)] 

94 labels_cfm = np.asarray(labels_cfm).reshape(2, 2) 

95 sns.heatmap(cm, annot=labels_cfm, fmt='', cmap='Blues', square=True) 

96 plt.ylabel('True Values', size=13) 

97 plt.xlabel('Predicted Values', size=13) 

98 plt.yticks(rotation=0) 

99 

100 

101def distPredPlot(position, y, pos_p, labels, title): 

102 df = pd.DataFrame({'probPos': pos_p, 'target': y}) 

103 plt.subplot(position) 

104 plt.hist(df[df.target == 1].probPos, density=True, bins=25, 

105 alpha=.5, color='green', label=labels[0]) 

106 plt.hist(df[df.target == 0].probPos, density=True, bins=25, 

107 alpha=.5, color='red', label=labels[1]) 

108 plt.axvline(.5, color='blue', linestyle='--', label='Boundary') 

109 plt.xlim([0, 1]) 

110 plt.title(title, size=15) 

111 plt.xlabel('Positive Probability (predicted)', size=13) 

112 plt.ylabel('Samples (normalized scale)', size=13) 

113 plt.legend(loc="upper right") 

114 

115 

116def ROCPlot(position, y, p, cm, title): 

117 fp_rates, tp_rates, _ = roc_curve(y, p[:, 1]) 

118 roc_auc = auc(fp_rates, tp_rates) 

119 plt.subplot(position) 

120 plt.plot(fp_rates, tp_rates, color='green', 

121 lw=1, label='ROC curve (area = %0.2f)' % roc_auc) 

122 plt.plot([0, 1], [0, 1], lw=1, linestyle='--', color='grey') 

123 # plot current decision point: 

124 tn, fp, fn, tp = [i for i in cm.ravel()] 

125 plt.plot(fp/(fp+tn), tp/(tp+fn), 'bo', markersize=8, label='Decision Point') 

126 plt.xlim([0.0, 1.0]) 

127 plt.ylim([0.0, 1.05]) 

128 plt.xlabel('False Positive Rate', size=13) 

129 plt.ylabel('True Positive Rate', size=13) 

130 plt.title(title, size=15) 

131 plt.legend(loc="lower right") 

132 

133 

134# Visualize the performance of a Logistic Regression Binary Classifier. 

135# https://towardsdatascience.com/how-to-interpret-a-binary-logistic-regressor-with-scikit-learn-6d56c5783b49 

136def plotBinaryClassifier(model, proba_train, proba_test, cm_train, cm_test, y_train, y_test, normalize=False, labels=['Positives', 'Negatives'], cmticks=[0, 1], get_plot=True): 

137 

138 # TRAINING 

139 

140 # model predicts probabilities of positive class 

141 p = proba_train 

142 if len(model.classes_) != 2: 

143 raise ValueError('A binary class problem is required') 

144 if model.classes_[1] == 1: 

145 pos_p = p[:, 1] 

146 elif model.classes_[0] == 1: 

147 pos_p = p[:, 0] 

148 

149 # FIGURE 

150 plt.figure(figsize=[15, 8]) 

151 

152 # 1 -- Confusion matrix train 

153 CMPlotBinary(231, cm_train, ['True Negatives', 'False Positives', 'False Negatives', 'True Positives'], 'Confusion Matrix Train', normalize) 

154 

155 # 2 -- Distributions of Predicted Probabilities of both classes train 

156 distPredPlot(232, y_train, pos_p, labels, 'Distributions of Predictions Train') 

157 

158 # 3 -- ROC curve with annotated decision point train 

159 ROCPlot(233, y_train, p, cm_train, 'ROC Curve Train') 

160 

161 # TESTING 

162 

163 # model predicts probabilities of positive class 

164 p = proba_test 

165 if len(model.classes_) != 2: 

166 raise ValueError('A binary class problem is required') 

167 if model.classes_[1] == 1: 

168 pos_p = p[:, 1] 

169 elif model.classes_[0] == 1: 

170 pos_p = p[:, 0] 

171 

172 # 1 -- Confusion matrix test 

173 CMPlotBinary(234, cm_test, ['True Negatives', 'False Positives', 'False Negatives', 'True Positives'], 'Confusion Matrix Test', normalize) 

174 

175 # 2 -- Distributions of Predicted Probabilities of both classes test 

176 distPredPlot(235, y_test, pos_p, labels, 'Distributions of Predictions Test') 

177 

178 # 3 -- ROC curve with annotated decision point test 

179 ROCPlot(236, y_test, p, cm_test, 'ROC Curve Test') 

180 

181 plt.tight_layout() 

182 

183 return plt 

184 

185 

186def CMplotNonBinary(position, cm, title, normalize, values): 

187 

188 if cm.shape[1] < 5: 

189 fs = 10 

190 elif cm.shape[1] >= 5 and cm.shape[1] < 10: 

191 fs = 8 

192 elif cm.shape[1] >= 10: 

193 fs = 6 

194 

195 plt.subplot(position) 

196 plt.title(title, size=15) 

197 if normalize: 

198 group_counts = ["{0:0.2f}".format(value) for value in cm.flatten()] 

199 else: 

200 group_counts = ["{0:0.0f}".format(value) for value in cm.flatten()] 

201 group_names = [] 

202 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 

203 if i == j: 

204 group_names.append("True " + str(values[i])) 

205 else: 

206 group_names.append("False " + str(values[i])) 

207 labels_cfm = [f"{v1}\n{v2}" for v1, v2 in zip(group_counts, group_names)] 

208 labels_cfm = np.asarray(labels_cfm).reshape(cm.shape[0], cm.shape[1]) 

209 sns.heatmap(cm, annot=labels_cfm, fmt='', cmap='Blues', xticklabels=values, yticklabels=values, square=True, annot_kws={"fontsize": fs}) 

210 plt.ylabel('True Values', size=13) 

211 plt.xlabel('Predicted Values', size=13) 

212 plt.yticks(rotation=0) 

213 

214 

215def plotMultipleCM(cm_train, cm_test, normalize, values): 

216 

217 # FIGURE 

218 plt.figure(figsize=[8, 4]) 

219 

220 # 1 -- Confusion matrix train 

221 CMplotNonBinary(121, cm_train, 'Confusion Matrix Train', normalize, values) 

222 

223 # 2 -- Confusion matrix test 

224 CMplotNonBinary(122, cm_test, 'Confusion Matrix Test', normalize, values) 

225 

226 plt.tight_layout() 

227 

228 return plt 

229 

230 

231def getIndependentVars(independent_vars, data, out_log, classname): 

232 if 'indexes' in independent_vars: 

233 return data.iloc[:, independent_vars['indexes']] 

234 elif 'range' in independent_vars: 

235 ranges_list = [] 

236 for rng in independent_vars['range']: 

237 for x in range(rng[0], (rng[1] + 1)): 

238 ranges_list.append(x) 

239 return data.iloc[:, ranges_list] 

240 elif 'columns' in independent_vars: 

241 return data.loc[:, independent_vars['columns']] 

242 else: 

243 fu.log(classname + ': Incorrect independent_vars format', out_log) 

244 raise SystemExit(classname + ': Incorrect independent_vars format') 

245 

246 

247def getIndependentVarsList(independent_vars): 

248 if 'indexes' in independent_vars: 

249 return ', '.join(str(x) for x in independent_vars['indexes']) 

250 elif 'range' in independent_vars: 

251 return ', '.join([str(y) for r in independent_vars['range'] for y in range(r[0], r[1] + 1)]) 

252 elif 'columns' in independent_vars: 

253 return ', '.join(independent_vars['columns']) 

254 

255 

256def getTarget(target, data, out_log, classname): 

257 if 'index' in target: 

258 return data.iloc[:, target['index']] 

259 elif 'column' in target: 

260 return data[target['column']] 

261 else: 

262 fu.log(classname + ': Incorrect target format', out_log) 

263 raise SystemExit(classname + ': Incorrect target format') 

264 

265 

266def getTargetValue(target): 

267 if 'index' in target: 

268 return str(target['index']) 

269 elif 'column' in target: 

270 return target['column'] 

271 

272 

273def getWeight(weight, data, out_log, classname): 

274 if 'index' in weight: 

275 return data.iloc[:, weight['index']] 

276 elif 'column' in weight: 

277 return data[weight['column']] 

278 else: 

279 fu.log(classname + ': Incorrect weight format', out_log) 

280 raise SystemExit(classname + ': Incorrect weight format') 

281 

282 

283def getHeader(file): 

284 with open(file, newline='') as f: 

285 reader = csv.reader(f) 

286 header = next(reader) 

287 

288 if (len(header) == 1): 

289 return list(re.sub('\\s+|;|:|,|\t', ',', header[0]).split(",")) 

290 else: 

291 return header