Coverage for biobb_ml/regression/common.py: 80%
129 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-03 14:57 +0000
1""" Common functions for package biobb_analysis.ambertools """
2import matplotlib.pyplot as plt
3import seaborn as sns
4import csv
5import re
6from pathlib import Path, PurePath
7from biobb_common.tools import file_utils as fu
8from warnings import simplefilter
9# ignore all future warnings
10simplefilter(action='ignore', category=FutureWarning)
11simplefilter(action='ignore', category=RuntimeWarning)
12sns.set()
15# CHECK PARAMETERS
17def check_input_path(path, argument, out_log, classname):
18 """ Checks input file """
19 if not Path(path).exists():
20 fu.log(classname + ': Unexisting %s file, exiting' % argument, out_log)
21 raise SystemExit(classname + ': Unexisting %s file' % argument)
22 file_extension = PurePath(path).suffix
23 if not is_valid_file(file_extension[1:], argument):
24 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log)
25 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument))
26 return path
29def check_output_path(path, argument, optional, out_log, classname):
30 """ Checks output file """
31 if optional and not path:
32 return None
33 if PurePath(path).parent and not Path(PurePath(path).parent).exists():
34 fu.log(classname + ': Unexisting %s folder, exiting' % argument, out_log)
35 raise SystemExit(classname + ': Unexisting %s folder' % argument)
36 file_extension = PurePath(path).suffix
37 if not is_valid_file(file_extension[1:], argument):
38 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log)
39 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument))
40 return path
43def is_valid_file(ext, argument):
44 """ Checks if file format is compatible """
45 formats = {
46 'input_dataset_path': ['csv'],
47 'output_model_path': ['pkl'],
48 'output_dataset_path': ['csv'],
49 'output_results_path': ['csv'],
50 'input_model_path': ['pkl'],
51 'output_test_table_path': ['csv'],
52 'output_plot_path': ['png']
53 }
54 return ext in formats[argument]
57def check_mandatory_property(property, name, out_log, classname):
58 if not property:
59 fu.log(classname + ': Unexisting %s property, exiting' % name, out_log)
60 raise SystemExit(classname + ': Unexisting %s property' % name)
61 return property
64# UTILITIES
66def adjusted_r2(x, y, r2):
67 n = x.shape[0]
68 p = x.shape[1]
70 return 1-(1-r2)*(n-1)/(n-p-1)
73def get_list_of_predictors(predictions):
74 p = []
75 for obj in predictions:
76 a = []
77 for k, v in obj.items():
78 a.append(v)
79 p.append(a)
80 return p
83def get_keys_of_predictors(predictions):
84 p = []
85 for obj in predictions[0]:
86 p.append(obj)
87 return p
90def predictionPlot(tit, data1, data2, xlabel, ylabel):
91 plt.title(tit, size=15)
92 plt.scatter(data1, data2, alpha=0.2)
93 plt.xlabel(xlabel, size=14)
94 plt.ylabel(ylabel, size=14)
95 axes = plt.gca()
96 lims = axes.get_xlim()
97 plt.xlim(lims)
98 plt.ylim(lims)
99 plt.plot(lims, lims)
102def histogramPlot(tit, data1, data2, xlabel, ylabel):
103 plt.title(tit, size=15)
104 error = data2 - data1
105 plt.hist(error, bins=25)
106 plt.xlabel(xlabel, size=14)
107 plt.ylabel(ylabel, size=14)
110def plotResults(y_train, y_hat_train, y_test, y_hat_test):
112 # FIGURE
113 plt.figure(figsize=[8, 8])
115 plt.subplot(221)
116 predictionPlot('Train predictions', y_train, y_hat_train, 'true values', 'predictions')
118 plt.subplot(222)
119 histogramPlot('Train histogram', y_train, y_hat_train, 'prediction error', 'count')
121 plt.subplot(223)
122 predictionPlot('Test predictions', y_test, y_hat_test, 'true values', 'predictions')
124 plt.subplot(224)
125 histogramPlot('Test histogram', y_test, y_hat_test, 'prediction error', 'count')
127 plt.tight_layout()
129 return plt
132def getIndependentVars(independent_vars, data, out_log, classname):
133 if 'indexes' in independent_vars:
134 return data.iloc[:, independent_vars['indexes']]
135 elif 'range' in independent_vars:
136 ranges_list = []
137 for rng in independent_vars['range']:
138 for x in range(rng[0], (rng[1] + 1)):
139 ranges_list.append(x)
140 return data.iloc[:, ranges_list]
141 elif 'columns' in independent_vars:
142 return data.loc[:, independent_vars['columns']]
143 else:
144 fu.log(classname + ': Incorrect independent_vars format', out_log)
145 raise SystemExit(classname + ': Incorrect independent_vars format')
148def getIndependentVarsList(independent_vars):
149 if 'indexes' in independent_vars:
150 return ', '.join(str(x) for x in independent_vars['indexes'])
151 elif 'range' in independent_vars:
152 return ', '.join([str(y) for r in independent_vars['range'] for y in range(r[0], r[1] + 1)])
153 elif 'columns' in independent_vars:
154 return ', '.join(independent_vars['columns'])
157def getTarget(target, data, out_log, classname):
158 if 'index' in target:
159 return data.iloc[:, target['index']]
160 elif 'column' in target:
161 return data[target['column']]
162 else:
163 fu.log(classname + ': Incorrect target format', out_log)
164 raise SystemExit(classname + ': Incorrect target format')
167def getTargetValue(target):
168 if 'index' in target:
169 return str(target['index'])
170 elif 'column' in target:
171 return target['column']
174def getWeight(weight, data, out_log, classname):
175 if 'index' in weight:
176 return data.iloc[:, weight['index']]
177 elif 'column' in weight:
178 return data[weight['column']]
179 else:
180 fu.log(classname + ': Incorrect weight format', out_log)
181 raise SystemExit(classname + ': Incorrect weight format')
184def getHeader(file):
185 with open(file, newline='') as f:
186 reader = csv.reader(f)
187 header = next(reader)
189 if (len(header) == 1):
190 return list(re.sub('\\s+|;|:|,|\t', ',', header[0]).split(","))
191 else:
192 return header