Coverage for biobb_ml/dimensionality_reduction/common.py: 65%
167 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
1""" Common functions for package biobb_analysis.ambertools """
2from pathlib import Path, PurePath
3import matplotlib.pyplot as plt
4from matplotlib.lines import Line2D
5import numpy as np
6import seaborn as sns
7import csv
8import re
9from biobb_common.tools import file_utils as fu
10from warnings import simplefilter
11# ignore all future warnings
12simplefilter(action='ignore', category=FutureWarning)
13sns.set()
16# CHECK PARAMETERS
18def check_input_path(path, argument, out_log, classname):
19 """ Checks input file """
20 if not Path(path).exists():
21 fu.log(classname + ': Unexisting %s file, exiting' % argument, out_log)
22 raise SystemExit(classname + ': Unexisting %s file' % argument)
23 file_extension = PurePath(path).suffix
24 if not is_valid_file(file_extension[1:], argument):
25 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log)
26 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument))
27 return path
30def check_output_path(path, argument, optional, out_log, classname):
31 """ Checks output file """
32 if optional and not path:
33 return None
34 if PurePath(path).parent and not Path(PurePath(path).parent).exists():
35 fu.log(classname + ': Unexisting %s folder, exiting' % argument, out_log)
36 raise SystemExit(classname + ': Unexisting %s folder' % argument)
37 file_extension = PurePath(path).suffix
38 if not is_valid_file(file_extension[1:], argument):
39 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log)
40 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument))
41 return path
44def is_valid_file(ext, argument):
45 """ Checks if file format is compatible """
46 formats = {
47 'input_dataset_path': ['csv'],
48 'output_results_path': ['csv'],
49 'output_plot_path': ['png']
50 }
51 return ext in formats[argument]
54# UTILITIES
56def getWindowLength(default, feat):
57 window_length = default
58 # if features size is less than WL, then get last odd
59 if feat < window_length:
60 if (feat % 2) == 0:
61 window_length = feat - 1
62 else:
63 window_length = feat
64 return window_length
67def generate_columns_labels(label, length):
68 return [label + ' ' + str(x + 1) for x in range(0, length)]
71def plot2D(ax, pca_table, targets, target, x, y):
72 ax.set_xlabel('PC ' + str(x), fontsize=12)
73 ax.set_ylabel('PC ' + str(y), fontsize=12)
74 ax.set_title('2 Component PCA (PC ' + str(x) + ' vs PC ' + str(y) + ')', fontsize=15)
76 colors = plt.get_cmap('rainbow_r')(np.linspace(0.0, 1.0, len(targets)))
77 for tgt, color in zip(targets, colors):
78 indicesToKeep = pca_table[target] == tgt
79 ax.scatter(pca_table.loc[indicesToKeep, 'PC ' + str(x)],
80 pca_table.loc[indicesToKeep, 'PC ' + str(y)],
81 color=color,
82 s=50,
83 alpha=0.6)
84 if len(targets) < 15:
85 ax.legend(targets)
88def PCA2CPlot(pca_table, targets, target):
89 fig = plt.figure(figsize=(8, 8))
90 ax = fig.add_subplot(1, 1, 1)
91 plot2D(ax, pca_table, targets, target, 1, 2)
92 plt.tight_layout()
95def scatter3DLegend(targets):
96 colors = plt.get_cmap('rainbow_r')(np.linspace(0.0, 1.0, len(targets)))
97 proxies = []
98 for i, v in enumerate(targets):
99 proxies.append(Line2D([0], [0], linestyle="none", c=colors[i], marker='o'))
100 return proxies
103def plot3D(ax, pca_table, targets, dt):
104 xs = pca_table['PC 1']
105 ys = pca_table['PC 2']
106 zs = pca_table['PC 3']
107 ax.scatter(xs, ys, zs, s=50, alpha=0.6, c=dt, cmap='rainbow_r')
109 ax.set_xlabel('PC 1')
110 ax.set_ylabel('PC 2')
111 ax.set_zlabel('PC 3')
113 if len(targets) < 15:
114 scatter_proxies = scatter3DLegend(targets)
115 ax.legend(scatter_proxies, targets, numpoints=1)
117 plt.title('3 Component PCA', size=15, pad=35)
120def PCA3CPlot(pca_table, targets, target):
121 lst = pca_table[target].unique().tolist()
122 dct = {lst[i]: i for i in range(0, len(lst))}
123 dt = pca_table[target].map(dct)
125 fig = plt.figure(figsize=(12, 12))
126 ax = fig.add_subplot(2, 2, 1, projection='3d')
128 plot3D(ax, pca_table, targets, dt)
130 ax = fig.add_subplot(2, 2, 2)
132 plot2D(ax, pca_table, targets, target, 1, 2)
134 ax = fig.add_subplot(2, 2, 3)
136 plot2D(ax, pca_table, targets, target, 1, 3)
138 ax = fig.add_subplot(2, 2, 4)
140 plot2D(ax, pca_table, targets, target, 2, 3)
142 plt.tight_layout()
145def predictionPlot(tit, data1, data2, xlabel, ylabel):
146 z = np.polyfit(data1, data2, 1)
147 plt.scatter(data2, data1, alpha=0.2)
148 plt.title(tit, size=15)
149 plt.xlabel(xlabel, size=14)
150 plt.ylabel(ylabel, size=14)
151 # Plot the best fit line
152 plt.plot(np.polyval(z, data1), data1, c='red', linewidth=1)
153 # Plot the ideal 1:1 line
154 axes = plt.gca()
155 lims = axes.get_xlim()
156 plt.xlim(lims)
157 plt.ylim(lims)
158 plt.plot(lims, lims)
159 plt.legend(('Best fit', 'Ideal 1:1'))
162def histogramPlot(tit, data1, data2, xlabel, ylabel):
163 plt.title(tit, size=15)
164 error = data2 - data1
165 plt.hist(error, bins=25)
166 plt.xlabel(xlabel, size=14)
167 plt.ylabel(ylabel, size=14)
170def PLSRegPlot(y, y_c, y_cv):
172 # FIGURE
173 plt.figure(figsize=[8, 8])
175 plt.subplot(221)
176 predictionPlot('Calibration predictions', y, y_c, 'true values', 'predictions')
178 plt.subplot(222)
179 histogramPlot('Calibration histogram', y, y_c[0], 'prediction error', 'count')
181 plt.subplot(223)
182 predictionPlot('Cross Validation predictions', y, y_cv, 'true values', 'predictions')
184 plt.subplot(224)
185 histogramPlot('Cross Validation histogram', y, y_cv[0], 'prediction error', 'count')
187 plt.tight_layout()
189 return plt
192def getIndependentVars(independent_vars, data, out_log, classname):
193 if 'indexes' in independent_vars:
194 return data.iloc[:, independent_vars['indexes']]
195 elif 'range' in independent_vars:
196 ranges_list = []
197 for rng in independent_vars['range']:
198 for x in range(rng[0], (rng[1] + 1)):
199 ranges_list.append(x)
200 return data.iloc[:, ranges_list]
201 elif 'columns' in independent_vars:
202 return data.loc[:, independent_vars['columns']]
203 else:
204 fu.log(classname + ': Incorrect independent_vars format', out_log)
205 raise SystemExit(classname + ': Incorrect independent_vars format')
208def getIndependentVarsList(independent_vars):
209 if 'indexes' in independent_vars:
210 return ', '.join(str(x) for x in independent_vars['indexes'])
211 elif 'range' in independent_vars:
212 return ', '.join([str(y) for r in independent_vars['range'] for y in range(r[0], r[1] + 1)])
213 elif 'columns' in independent_vars:
214 return ', '.join(independent_vars['columns'])
217def getTarget(target, data, out_log, classname):
218 if 'index' in target:
219 return data.iloc[:, target['index']]
220 elif 'column' in target:
221 return data[target['column']]
222 else:
223 fu.log(classname + ': Incorrect target format', out_log)
224 raise SystemExit(classname + ': Incorrect target format')
227def getTargetValue(target):
228 if 'index' in target:
229 return str(target['index'])
230 elif 'column' in target:
231 return target['column']
234def getWeight(weight, data, out_log, classname):
235 if 'index' in weight:
236 return data.iloc[:, weight['index']]
237 elif 'column' in weight:
238 return data[weight['column']]
239 else:
240 fu.log(classname + ': Incorrect weight format', out_log)
241 raise SystemExit(classname + ': Incorrect weight format')
244def getHeader(file):
245 with open(file, newline='') as f:
246 reader = csv.reader(f)
247 header = next(reader)
249 if (len(header) == 1):
250 return list(re.sub('\\s+|;|:|,|\t', ',', header[0]).split(","))
251 else:
252 return header