Coverage for biobb_ml/clustering/common.py: 81%
264 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
1""" Common functions for package biobb_analysis.ambertools """
2from pathlib import Path, PurePath
3import matplotlib.pyplot as plt
4from matplotlib.lines import Line2D
5import csv
6import re
7import numpy as np
8import pandas as pd
9import seaborn as sns
10from sklearn.neighbors import NearestNeighbors
11from sklearn.cluster import KMeans
12from sklearn.cluster import SpectralClustering
13from sklearn.cluster import AgglomerativeClustering
14from sklearn.metrics import silhouette_score
15from random import sample
16from math import isnan
17from biobb_common.tools import file_utils as fu
18from warnings import simplefilter
19# ignore all future warnings
20simplefilter(action='ignore', category=FutureWarning)
21sns.set()
24# CHECK PARAMETERS
26def check_input_path(path, argument, out_log, classname):
27 """ Checks input file """
28 if not Path(path).exists():
29 fu.log(classname + ': Unexisting %s file, exiting' % argument, out_log)
30 raise SystemExit(classname + ': Unexisting %s file' % argument)
31 file_extension = PurePath(path).suffix
32 if not is_valid_file(file_extension[1:], argument):
33 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log)
34 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument))
35 return path
38def check_output_path(path, argument, optional, out_log, classname):
39 """ Checks output file """
40 if optional and not path:
41 return None
42 if PurePath(path).parent and not Path(PurePath(path).parent).exists():
43 fu.log(classname + ': Unexisting %s folder, exiting' % argument, out_log)
44 raise SystemExit(classname + ': Unexisting %s folder' % argument)
45 file_extension = PurePath(path).suffix
46 if not is_valid_file(file_extension[1:], argument):
47 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log)
48 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument))
49 return path
52def is_valid_file(ext, argument):
53 """ Checks if file format is compatible """
54 formats = {
55 'input_dataset_path': ['csv'],
56 'output_model_path': ['pkl'],
57 'input_model_path': ['pkl'],
58 'output_results_path': ['csv'],
59 'output_plot_path': ['png']
60 }
61 return ext in formats[argument]
64def check_mandatory_property(property, name, out_log, classname):
65 if not property:
66 fu.log(classname + ': Unexisting %s property, exiting' % name, out_log)
67 raise SystemExit(classname + ': Unexisting %s property' % name)
68 return property
71# UTILITIES
73def get_list_of_predictors(predictions):
74 p = []
75 for obj in predictions:
76 a = []
77 for k, v in obj.items():
78 a.append(v)
79 p.append(a)
80 return p
83def get_keys_of_predictors(predictions):
84 p = []
85 for obj in predictions[0]:
86 p.append(obj)
87 return p
90# get best K in WCSS plot (getting elbow point)
91def get_best_K(wcss):
92 curve = wcss
93 nPoints = len(curve)
94 allCoord = np.vstack((range(nPoints), curve)).T
95 np.array([range(nPoints), curve])
96 firstPoint = allCoord[0]
97 lineVec = allCoord[-1] - allCoord[0]
98 lineVecNorm = lineVec / np.sqrt(np.sum(lineVec**2))
99 vecFromFirst = allCoord - firstPoint
100 scalarProduct = np.sum(vecFromFirst * np.tile(lineVecNorm, (nPoints, 1)), axis=1)
101 vecFromFirstParallel = np.outer(scalarProduct, lineVecNorm)
102 vecToLine = vecFromFirst - vecFromFirstParallel
103 distToLine = np.sqrt(np.sum(vecToLine ** 2, axis=1))
104 idxOfBestPoint = np.argmax(distToLine)
106 return idxOfBestPoint + 1, np.argmax(distToLine)
109# hopkins test
110# https://matevzkunaver.wordpress.com/2017/06/20/hopkins-test-for-cluster-tendency/
111def hopkins(X):
112 d = X.shape[1]
113 # d = len(vars) # columns
114 n = len(X) # rows
115 m = int(0.1 * n) # heuristic from article [1]
116 nbrs = NearestNeighbors(n_neighbors=1).fit(X.values)
118 rand_X = sample(range(0, n, 1), m)
120 ujd = []
121 wjd = []
122 for j in range(0, m):
123 u_dist, _ = nbrs.kneighbors(np.random.uniform(np.amin(X, axis=0), np.amax(X, axis=0), d).reshape(1, -1), 2, return_distance=True)
124 ujd.append(u_dist[0][1])
125 w_dist, _ = nbrs.kneighbors(X.iloc[rand_X[j]].values.reshape(1, -1), 2, return_distance=True)
126 wjd.append(w_dist[0][1])
128 H = sum(ujd) / (sum(ujd) + sum(wjd))
129 if isnan(H):
130 print(ujd, wjd)
131 H = 0
133 return H
136# compute elbow
137def getWCSS(method, max_clusters, t_predictors):
138 wcss = []
139 for i in range(1, max_clusters + 1):
140 if method == 'kmeans':
141 clusterer = KMeans(i)
142 elif method == 'agglomerative':
143 clusterer = AgglomerativeClustering(n_clusters=i, linkage="average")
144 clusterer.fit(t_predictors)
145 wcss_iter = clusterer.inertia_
146 wcss.append(wcss_iter)
148 return wcss
151# compute gap
152# https://anaconda.org/milesgranger/gap-statistic/notebook
153def getGap(method, data, nrefs=3, maxClusters=15):
154 """
155 Calculates KMeans optimal K using Gap Statistic from Tibshirani, Walther, Hastie
156 Params:
157 data: ndarry of shape (n_samples, n_features)
158 nrefs: number of sample reference datasets to create
159 maxClusters: Maximum number of clusters to test for
160 Returns: (gaps, optimalK)
161 """
162 gaps = np.zeros((len(range(1, maxClusters)),))
163 resultsdf = pd.DataFrame({'cluster': [], 'gap': []})
164 for gap_index, k in enumerate(range(1, maxClusters)):
166 # Holder for reference dispersion results
167 refDisps = np.zeros(nrefs)
169 # For n references, generate random sample and perform kmeans getting resulting dispersion of each loop
170 for i in range(nrefs):
172 # Create new random reference set
173 randomReference = np.random.random_sample(size=data.shape)
175 # Fit to it
176 clusterer = KMeans(k)
177 clusterer.fit(randomReference)
179 refDisp = clusterer.inertia_
180 refDisps[i] = refDisp
182 # Fit cluster to original data and create dispersion
183 clusterer = KMeans(k)
184 clusterer.fit(data)
186 origDisp = clusterer.inertia_
188 # Calculate gap statistic
189 gap = np.log(np.mean(refDisps)) - np.log(origDisp)
191 # Assign this loop's gap statistic to gaps
192 gaps[gap_index] = gap
194 resultsdf = resultsdf.append({'cluster': k, 'gap': gap}, ignore_index=True)
196 return (gaps.argmax() + 1, resultsdf) # Plus 1 because index of 0 means 1 cluster is optimal, index 2 = 3 clusters are optimal
199def getSilhouetthe(method, X, max_clusters, affinity=None, linkage=None, random_state=None):
200 # Run clustering with different k and check the metrics
201 silhouette_list = []
203 k_list = list(range(2, max_clusters + 1))
204 for p in k_list:
206 if method == 'kmeans':
207 clusterer = KMeans(n_clusters=p, random_state=random_state)
208 elif method == 'agglomerative':
209 clusterer = AgglomerativeClustering(n_clusters=p, affinity=affinity, linkage=linkage)
210 elif method == 'spectral':
211 clusterer = SpectralClustering(n_clusters=p, affinity="nearest_neighbors", random_state=random_state)
213 clusterer.fit(X)
214 # The higher (up to 1) the better
215 s = round(silhouette_score(X, clusterer.labels_), 4)
217 silhouette_list.append(s)
219 k_list.insert(0, 1)
220 silhouette_list.insert(0, 0)
222 return silhouette_list, k_list
225# plot elbow, gap & silhouette
226def plotKmeansTrain(max_clusters, wcss, gap, sil, best_k, best_g, best_s):
227 number_clusters = range(1, max_clusters + 1)
228 plt.figure(figsize=[15, 4])
229 # 1 -- WCSS
230 plt.subplot(131)
231 plt.title('The Elbow Method', size=15)
232 plt.plot(number_clusters, wcss, '-o')
233 plt.axvline(x=best_k, c='red')
234 plt.legend(('WCSS', 'Best K'))
235 plt.xlabel('Cluster')
236 plt.ylabel('Within-cluster Sum of Squares')
238 # 2 -- GAP
239 plt.subplot(132)
240 plt.title('Gap Statistics', size=15)
241 plt.plot(number_clusters, gap, '-o')
242 plt.ylabel('Gap')
243 plt.xlabel('Cluster')
244 plt.axvline(x=best_g, c='red')
245 plt.legend(('GAP', 'Best K'))
247 # 3 -- SILHOUETTE
248 plt.subplot(133)
249 plt.title('Silhouette', size=15)
250 plt.plot(number_clusters, sil, '-o')
251 plt.ylabel('Silhouette score')
252 plt.xlabel('Cluster')
253 plt.axvline(x=best_s, c='red')
254 plt.legend(('Silhouette', 'Best K'))
256 plt.tight_layout()
258 return plt
261def plotCluster(new_plots, clusters):
262 if len(new_plots) == 1:
263 fs = (6, 6)
264 ps = 110
265 elif len(new_plots) == 2:
266 fs = (10, 6)
267 ps = 120
268 elif len(new_plots) == 3:
269 fs = (15, 4)
270 ps = 130
271 else:
272 fs = (15, 8)
273 ps = 230
275 plt.figure(figsize=fs)
277 for i, plot in enumerate(new_plots):
279 position = ps + i + 1
281 if len(plot['features']) == 2:
282 plt.subplot(position)
283 colors = plt.get_cmap('rainbow')(np.linspace(0.0, 1.0, len(set(clusters['cluster']))))
284 outliers = False
285 for clust_number in set(clusters['cluster']):
286 # outliers in grey
287 if clust_number == -1:
288 outliers = True
289 c = ([0.4, 0.4, 0.4])
290 else:
291 c = colors[clust_number]
292 clust_set = clusters[clusters.cluster == clust_number]
293 plt.scatter(clust_set[plot['features'][0]], clust_set[plot['features'][1]], color=c, s=20, alpha=0.85)
294 plt.title(plot['title'], size=15)
295 plt.xlabel(plot['features'][0], size=13)
296 plt.ylabel(plot['features'][1], size=13)
298 if outliers:
299 custom_lines = [Line2D([0], [0], marker='o', color=([0, 0, 0, 0]), label='Outliers', markerfacecolor=([0.4, 0.4, 0.4]), markersize=10)]
300 plt.legend(custom_lines, ['Outliers'])
302 if len(plot['features']) == 3:
303 ax = plt.subplot(position, projection='3d')
305 xs = clusters[plot['features'][0]]
306 ys = clusters[plot['features'][1]]
307 zs = clusters[plot['features'][2]]
308 ax.scatter(xs, ys, zs, s=50, alpha=0.6, c=clusters['cluster'], cmap='rainbow')
310 ax.set_xlabel(plot['features'][0])
311 ax.set_ylabel(plot['features'][1])
312 ax.set_zlabel(plot['features'][2])
314 plt.title(plot['title'], size=15, pad=35)
316 plt.tight_layout()
318 return plt
321# plot elbow, gap & silhouette
322def plotAgglomerativeTrain(max_clusters, sil, best_s):
323 number_clusters = range(1, max_clusters + 1)
324 plt.figure()
325 # 1 -- SILHOUETTE
326 plt.title('Silhouette', size=15)
327 plt.plot(number_clusters, sil, '-o')
328 plt.ylabel('Silhouette score')
329 plt.xlabel('Cluster')
330 plt.axvline(x=best_s, c='red')
332 plt.tight_layout()
334 return plt
337def getIndependentVars(independent_vars, data, out_log, classname):
338 if 'indexes' in independent_vars:
339 return data.iloc[:, independent_vars['indexes']]
340 elif 'range' in independent_vars:
341 ranges_list = []
342 for rng in independent_vars['range']:
343 for x in range(rng[0], (rng[1] + 1)):
344 ranges_list.append(x)
345 return data.iloc[:, ranges_list]
346 elif 'columns' in independent_vars:
347 return data.loc[:, independent_vars['columns']]
348 else:
349 fu.log(classname + ': Incorrect independent_vars format', out_log)
350 raise SystemExit(classname + ': Incorrect independent_vars format')
353def getIndependentVarsList(independent_vars):
354 if 'indexes' in independent_vars:
355 return ', '.join(str(x) for x in independent_vars['indexes'])
356 elif 'range' in independent_vars:
357 return ', '.join([str(y) for r in independent_vars['range'] for y in range(r[0], r[1] + 1)])
358 elif 'columns' in independent_vars:
359 return ', '.join(independent_vars['columns'])
362def getTarget(target, data, out_log, classname):
363 if 'index' in target:
364 return data.iloc[:, target['index']]
365 elif 'column' in target:
366 return data[target['column']]
367 else:
368 fu.log(classname + ': Incorrect target format', out_log)
369 raise SystemExit(classname + ': Incorrect target format')
372def getTargetValue(target):
373 if 'index' in target:
374 return str(target['index'])
375 elif 'column' in target:
376 return target['column']
379def getWeight(weight, data, out_log, classname):
380 if 'index' in weight:
381 return data.iloc[:, weight['index']]
382 elif 'column' in weight:
383 return data[weight['column']]
384 else:
385 fu.log(classname + ': Incorrect weight format', out_log)
386 raise SystemExit(classname + ': Incorrect weight format')
389def getHeader(file):
390 with open(file, newline='') as f:
391 reader = csv.reader(f)
392 header = next(reader)
394 if (len(header) == 1):
395 return list(re.sub('\\s+|;|:|,|\t', ',', header[0]).split(","))
396 else:
397 return header