Coverage for biobb_ml/regression/common.py: 80%

129 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-03 14:57 +0000

1""" Common functions for package biobb_analysis.ambertools """ 

2import matplotlib.pyplot as plt 

3import seaborn as sns 

4import csv 

5import re 

6from pathlib import Path, PurePath 

7from biobb_common.tools import file_utils as fu 

8from warnings import simplefilter 

9# ignore all future warnings 

10simplefilter(action='ignore', category=FutureWarning) 

11simplefilter(action='ignore', category=RuntimeWarning) 

12sns.set() 

13 

14 

15# CHECK PARAMETERS 

16 

17def check_input_path(path, argument, out_log, classname): 

18 """ Checks input file """ 

19 if not Path(path).exists(): 

20 fu.log(classname + ': Unexisting %s file, exiting' % argument, out_log) 

21 raise SystemExit(classname + ': Unexisting %s file' % argument) 

22 file_extension = PurePath(path).suffix 

23 if not is_valid_file(file_extension[1:], argument): 

24 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log) 

25 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument)) 

26 return path 

27 

28 

29def check_output_path(path, argument, optional, out_log, classname): 

30 """ Checks output file """ 

31 if optional and not path: 

32 return None 

33 if PurePath(path).parent and not Path(PurePath(path).parent).exists(): 

34 fu.log(classname + ': Unexisting %s folder, exiting' % argument, out_log) 

35 raise SystemExit(classname + ': Unexisting %s folder' % argument) 

36 file_extension = PurePath(path).suffix 

37 if not is_valid_file(file_extension[1:], argument): 

38 fu.log(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument), out_log) 

39 raise SystemExit(classname + ': Format %s in %s file is not compatible' % (file_extension[1:], argument)) 

40 return path 

41 

42 

43def is_valid_file(ext, argument): 

44 """ Checks if file format is compatible """ 

45 formats = { 

46 'input_dataset_path': ['csv'], 

47 'output_model_path': ['pkl'], 

48 'output_dataset_path': ['csv'], 

49 'output_results_path': ['csv'], 

50 'input_model_path': ['pkl'], 

51 'output_test_table_path': ['csv'], 

52 'output_plot_path': ['png'] 

53 } 

54 return ext in formats[argument] 

55 

56 

57def check_mandatory_property(property, name, out_log, classname): 

58 if not property: 

59 fu.log(classname + ': Unexisting %s property, exiting' % name, out_log) 

60 raise SystemExit(classname + ': Unexisting %s property' % name) 

61 return property 

62 

63 

64# UTILITIES 

65 

66def adjusted_r2(x, y, r2): 

67 n = x.shape[0] 

68 p = x.shape[1] 

69 

70 return 1-(1-r2)*(n-1)/(n-p-1) 

71 

72 

73def get_list_of_predictors(predictions): 

74 p = [] 

75 for obj in predictions: 

76 a = [] 

77 for k, v in obj.items(): 

78 a.append(v) 

79 p.append(a) 

80 return p 

81 

82 

83def get_keys_of_predictors(predictions): 

84 p = [] 

85 for obj in predictions[0]: 

86 p.append(obj) 

87 return p 

88 

89 

90def predictionPlot(tit, data1, data2, xlabel, ylabel): 

91 plt.title(tit, size=15) 

92 plt.scatter(data1, data2, alpha=0.2) 

93 plt.xlabel(xlabel, size=14) 

94 plt.ylabel(ylabel, size=14) 

95 axes = plt.gca() 

96 lims = axes.get_xlim() 

97 plt.xlim(lims) 

98 plt.ylim(lims) 

99 plt.plot(lims, lims) 

100 

101 

102def histogramPlot(tit, data1, data2, xlabel, ylabel): 

103 plt.title(tit, size=15) 

104 error = data2 - data1 

105 plt.hist(error, bins=25) 

106 plt.xlabel(xlabel, size=14) 

107 plt.ylabel(ylabel, size=14) 

108 

109 

110def plotResults(y_train, y_hat_train, y_test, y_hat_test): 

111 

112 # FIGURE 

113 plt.figure(figsize=[8, 8]) 

114 

115 plt.subplot(221) 

116 predictionPlot('Train predictions', y_train, y_hat_train, 'true values', 'predictions') 

117 

118 plt.subplot(222) 

119 histogramPlot('Train histogram', y_train, y_hat_train, 'prediction error', 'count') 

120 

121 plt.subplot(223) 

122 predictionPlot('Test predictions', y_test, y_hat_test, 'true values', 'predictions') 

123 

124 plt.subplot(224) 

125 histogramPlot('Test histogram', y_test, y_hat_test, 'prediction error', 'count') 

126 

127 plt.tight_layout() 

128 

129 return plt 

130 

131 

132def getIndependentVars(independent_vars, data, out_log, classname): 

133 if 'indexes' in independent_vars: 

134 return data.iloc[:, independent_vars['indexes']] 

135 elif 'range' in independent_vars: 

136 ranges_list = [] 

137 for rng in independent_vars['range']: 

138 for x in range(rng[0], (rng[1] + 1)): 

139 ranges_list.append(x) 

140 return data.iloc[:, ranges_list] 

141 elif 'columns' in independent_vars: 

142 return data.loc[:, independent_vars['columns']] 

143 else: 

144 fu.log(classname + ': Incorrect independent_vars format', out_log) 

145 raise SystemExit(classname + ': Incorrect independent_vars format') 

146 

147 

148def getIndependentVarsList(independent_vars): 

149 if 'indexes' in independent_vars: 

150 return ', '.join(str(x) for x in independent_vars['indexes']) 

151 elif 'range' in independent_vars: 

152 return ', '.join([str(y) for r in independent_vars['range'] for y in range(r[0], r[1] + 1)]) 

153 elif 'columns' in independent_vars: 

154 return ', '.join(independent_vars['columns']) 

155 

156 

157def getTarget(target, data, out_log, classname): 

158 if 'index' in target: 

159 return data.iloc[:, target['index']] 

160 elif 'column' in target: 

161 return data[target['column']] 

162 else: 

163 fu.log(classname + ': Incorrect target format', out_log) 

164 raise SystemExit(classname + ': Incorrect target format') 

165 

166 

167def getTargetValue(target): 

168 if 'index' in target: 

169 return str(target['index']) 

170 elif 'column' in target: 

171 return target['column'] 

172 

173 

174def getWeight(weight, data, out_log, classname): 

175 if 'index' in weight: 

176 return data.iloc[:, weight['index']] 

177 elif 'column' in weight: 

178 return data[weight['column']] 

179 else: 

180 fu.log(classname + ': Incorrect weight format', out_log) 

181 raise SystemExit(classname + ': Incorrect weight format') 

182 

183 

184def getHeader(file): 

185 with open(file, newline='') as f: 

186 reader = csv.reader(f) 

187 header = next(reader) 

188 

189 if (len(header) == 1): 

190 return list(re.sub('\\s+|;|:|,|\t', ',', header[0]).split(",")) 

191 else: 

192 return header