Coverage for biobb_ml/dimensionality_reduction/common.py: 65%

167 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-07 09:39 +0000

1""" Common functions for package biobb_analysis.ambertools """ 

2from pathlib import Path, PurePath 

3import matplotlib.pyplot as plt 

4from matplotlib.lines import Line2D 

5import numpy as np 

6import seaborn as sns 

7import csv 

8import re 

9from biobb_common.tools import file_utils as fu 

10from warnings import simplefilter 

11# ignore all future warnings 

12simplefilter(action='ignore', category=FutureWarning) 

13sns.set() 

14 

15 

16# CHECK PARAMETERS 

17 

18def check_input_path(path, argument, out_log, classname): 

19 """ Checks input file """ 

20 if not Path(path).exists(): 

21 fu.log(classname + ': Unexisting %s file, exiting' % argument, out_log) 

22 raise SystemExit(classname + ': Unexisting %s file' % argument) 

23 file_extension = PurePath(path).suffix 

24 if not is_valid_file(file_extension[1:], argument): 

25 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log) 

26 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument)) 

27 return path 

28 

29 

30def check_output_path(path, argument, optional, out_log, classname): 

31 """ Checks output file """ 

32 if optional and not path: 

33 return None 

34 if PurePath(path).parent and not Path(PurePath(path).parent).exists(): 

35 fu.log(classname + ': Unexisting %s folder, exiting' % argument, out_log) 

36 raise SystemExit(classname + ': Unexisting %s folder' % argument) 

37 file_extension = PurePath(path).suffix 

38 if not is_valid_file(file_extension[1:], argument): 

39 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log) 

40 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument)) 

41 return path 

42 

43 

44def is_valid_file(ext, argument): 

45 """ Checks if file format is compatible """ 

46 formats = { 

47 'input_dataset_path': ['csv'], 

48 'output_results_path': ['csv'], 

49 'output_plot_path': ['png'] 

50 } 

51 return ext in formats[argument] 

52 

53 

54# UTILITIES 

55 

56def getWindowLength(default, feat): 

57 window_length = default 

58 # if features size is less than WL, then get last odd 

59 if feat < window_length: 

60 if (feat % 2) == 0: 

61 window_length = feat - 1 

62 else: 

63 window_length = feat 

64 return window_length 

65 

66 

67def generate_columns_labels(label, length): 

68 return [label + ' ' + str(x + 1) for x in range(0, length)] 

69 

70 

71def plot2D(ax, pca_table, targets, target, x, y): 

72 ax.set_xlabel('PC ' + str(x), fontsize=12) 

73 ax.set_ylabel('PC ' + str(y), fontsize=12) 

74 ax.set_title('2 Component PCA (PC ' + str(x) + ' vs PC ' + str(y) + ')', fontsize=15) 

75 

76 colors = plt.get_cmap('rainbow_r')(np.linspace(0.0, 1.0, len(targets))) 

77 for tgt, color in zip(targets, colors): 

78 indicesToKeep = pca_table[target] == tgt 

79 ax.scatter(pca_table.loc[indicesToKeep, 'PC ' + str(x)], 

80 pca_table.loc[indicesToKeep, 'PC ' + str(y)], 

81 color=color, 

82 s=50, 

83 alpha=0.6) 

84 if len(targets) < 15: 

85 ax.legend(targets) 

86 

87 

88def PCA2CPlot(pca_table, targets, target): 

89 fig = plt.figure(figsize=(8, 8)) 

90 ax = fig.add_subplot(1, 1, 1) 

91 plot2D(ax, pca_table, targets, target, 1, 2) 

92 plt.tight_layout() 

93 

94 

95def scatter3DLegend(targets): 

96 colors = plt.get_cmap('rainbow_r')(np.linspace(0.0, 1.0, len(targets))) 

97 proxies = [] 

98 for i, v in enumerate(targets): 

99 proxies.append(Line2D([0], [0], linestyle="none", c=colors[i], marker='o')) 

100 return proxies 

101 

102 

103def plot3D(ax, pca_table, targets, dt): 

104 xs = pca_table['PC 1'] 

105 ys = pca_table['PC 2'] 

106 zs = pca_table['PC 3'] 

107 ax.scatter(xs, ys, zs, s=50, alpha=0.6, c=dt, cmap='rainbow_r') 

108 

109 ax.set_xlabel('PC 1') 

110 ax.set_ylabel('PC 2') 

111 ax.set_zlabel('PC 3') 

112 

113 if len(targets) < 15: 

114 scatter_proxies = scatter3DLegend(targets) 

115 ax.legend(scatter_proxies, targets, numpoints=1) 

116 

117 plt.title('3 Component PCA', size=15, pad=35) 

118 

119 

120def PCA3CPlot(pca_table, targets, target): 

121 lst = pca_table[target].unique().tolist() 

122 dct = {lst[i]: i for i in range(0, len(lst))} 

123 dt = pca_table[target].map(dct) 

124 

125 fig = plt.figure(figsize=(12, 12)) 

126 ax = fig.add_subplot(2, 2, 1, projection='3d') 

127 

128 plot3D(ax, pca_table, targets, dt) 

129 

130 ax = fig.add_subplot(2, 2, 2) 

131 

132 plot2D(ax, pca_table, targets, target, 1, 2) 

133 

134 ax = fig.add_subplot(2, 2, 3) 

135 

136 plot2D(ax, pca_table, targets, target, 1, 3) 

137 

138 ax = fig.add_subplot(2, 2, 4) 

139 

140 plot2D(ax, pca_table, targets, target, 2, 3) 

141 

142 plt.tight_layout() 

143 

144 

145def predictionPlot(tit, data1, data2, xlabel, ylabel): 

146 z = np.polyfit(data1, data2, 1) 

147 plt.scatter(data2, data1, alpha=0.2) 

148 plt.title(tit, size=15) 

149 plt.xlabel(xlabel, size=14) 

150 plt.ylabel(ylabel, size=14) 

151 # Plot the best fit line 

152 plt.plot(np.polyval(z, data1), data1, c='red', linewidth=1) 

153 # Plot the ideal 1:1 line 

154 axes = plt.gca() 

155 lims = axes.get_xlim() 

156 plt.xlim(lims) 

157 plt.ylim(lims) 

158 plt.plot(lims, lims) 

159 plt.legend(('Best fit', 'Ideal 1:1')) 

160 

161 

162def histogramPlot(tit, data1, data2, xlabel, ylabel): 

163 plt.title(tit, size=15) 

164 error = data2 - data1 

165 plt.hist(error, bins=25) 

166 plt.xlabel(xlabel, size=14) 

167 plt.ylabel(ylabel, size=14) 

168 

169 

170def PLSRegPlot(y, y_c, y_cv): 

171 

172 # FIGURE 

173 plt.figure(figsize=[8, 8]) 

174 

175 plt.subplot(221) 

176 predictionPlot('Calibration predictions', y, y_c, 'true values', 'predictions') 

177 

178 plt.subplot(222) 

179 histogramPlot('Calibration histogram', y, y_c[0], 'prediction error', 'count') 

180 

181 plt.subplot(223) 

182 predictionPlot('Cross Validation predictions', y, y_cv, 'true values', 'predictions') 

183 

184 plt.subplot(224) 

185 histogramPlot('Cross Validation histogram', y, y_cv[0], 'prediction error', 'count') 

186 

187 plt.tight_layout() 

188 

189 return plt 

190 

191 

192def getIndependentVars(independent_vars, data, out_log, classname): 

193 if 'indexes' in independent_vars: 

194 return data.iloc[:, independent_vars['indexes']] 

195 elif 'range' in independent_vars: 

196 ranges_list = [] 

197 for rng in independent_vars['range']: 

198 for x in range(rng[0], (rng[1] + 1)): 

199 ranges_list.append(x) 

200 return data.iloc[:, ranges_list] 

201 elif 'columns' in independent_vars: 

202 return data.loc[:, independent_vars['columns']] 

203 else: 

204 fu.log(classname + ': Incorrect independent_vars format', out_log) 

205 raise SystemExit(classname + ': Incorrect independent_vars format') 

206 

207 

208def getIndependentVarsList(independent_vars): 

209 if 'indexes' in independent_vars: 

210 return ', '.join(str(x) for x in independent_vars['indexes']) 

211 elif 'range' in independent_vars: 

212 return ', '.join([str(y) for r in independent_vars['range'] for y in range(r[0], r[1] + 1)]) 

213 elif 'columns' in independent_vars: 

214 return ', '.join(independent_vars['columns']) 

215 

216 

217def getTarget(target, data, out_log, classname): 

218 if 'index' in target: 

219 return data.iloc[:, target['index']] 

220 elif 'column' in target: 

221 return data[target['column']] 

222 else: 

223 fu.log(classname + ': Incorrect target format', out_log) 

224 raise SystemExit(classname + ': Incorrect target format') 

225 

226 

227def getTargetValue(target): 

228 if 'index' in target: 

229 return str(target['index']) 

230 elif 'column' in target: 

231 return target['column'] 

232 

233 

234def getWeight(weight, data, out_log, classname): 

235 if 'index' in weight: 

236 return data.iloc[:, weight['index']] 

237 elif 'column' in weight: 

238 return data[weight['column']] 

239 else: 

240 fu.log(classname + ': Incorrect weight format', out_log) 

241 raise SystemExit(classname + ': Incorrect weight format') 

242 

243 

244def getHeader(file): 

245 with open(file, newline='') as f: 

246 reader = csv.reader(f) 

247 header = next(reader) 

248 

249 if (len(header) == 1): 

250 return list(re.sub('\\s+|;|:|,|\t', ',', header[0]).split(",")) 

251 else: 

252 return header