Coverage for biobb_common/biobb_common/tools/test_fixtures.py: 34%

271 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 07:21 +0000

1"""Boiler plate functions for testsys 

2""" 

3import os 

4import pickle 

5from typing import Optional, Union, Any 

6from pathlib import Path 

7import sys 

8import shutil 

9import hashlib 

10from Bio.PDB import Superimposer, PDBParser # type: ignore 

11import codecs 

12from biobb_common.configuration import settings 

13from biobb_common.tools import file_utils as fu 

14import numpy as np 

15import json 

16import jsonschema 

17 

18 

19def test_setup(test_object, dict_key: Optional[str] = None, config: Optional[str] = None): 

20 """Add the unitest_dir, test_dir, conf_file_path, properties and path as 

21 attributes to the **test_object** and create a directory to launch the unitest. 

22 

23 Args: 

24 test_object (:obj:`test`): The test object. 

25 dict_key (str): Key of the test parameters in the yaml config file. 

26 config (str): Path to the configuration file. 

27 """ 

28 test_object.testfile_dir = str(Path(Path(str(sys.modules[test_object.__module__].__file__)).resolve()).parent) 

29 test_object.unitest_dir = str(Path(test_object.testfile_dir).parent) 

30 test_object.test_dir = str(Path(test_object.unitest_dir).parent) 

31 test_object.data_dir = str(Path(test_object.test_dir).joinpath('data')) 

32 test_object.reference_dir = str(Path(test_object.test_dir).joinpath('reference')) 

33 if config: 

34 test_object.conf_file_path = config 

35 else: 

36 test_object.conf_file_path = str(Path(test_object.test_dir).joinpath('conf.yml')) 

37 

38 conf = settings.ConfReader(test_object.conf_file_path) 

39 

40 if dict_key: 

41 test_object.properties = conf.get_prop_dic()[dict_key] 

42 test_object.paths = {k: v.replace('test_data_dir', test_object.data_dir, 1).replace('test_reference_dir', test_object.reference_dir, 1) for k, v in conf.get_paths_dic()[dict_key].items()} 

43 else: 

44 test_object.properties = conf.get_prop_dic() 

45 test_object.paths = {k: v.replace('test_data_dir', test_object.data_dir, 1).replace('test_reference_dir', test_object.reference_dir, 1) for k, v in conf.get_paths_dic().items()} 

46 

47 fu.create_dir(test_object.properties['path']) 

48 os.chdir(test_object.properties['path']) 

49 

50 

51def test_teardown(test_object): 

52 """Remove the **test_object.properties['working_dir_path']** 

53 

54 Args: 

55 test_object (:obj:`test`): The test object. 

56 """ 

57 unitests_path = Path(test_object.properties['path']).resolve().parent 

58 print(f"\nRemoving: {unitests_path}") 

59 shutil.rmtree(unitests_path) 

60 

61 

62def exe_success(return_code: int) -> bool: 

63 """Check if **return_code** is 0 

64 

65 Args: 

66 return_code (int): Return code of a process. 

67 

68 Returns: 

69 bool: True if return code is equal to 0 

70 """ 

71 return return_code == 0 

72 

73 

74def not_empty(file_path: str) -> bool: 

75 """Check if file exists and is not empty. 

76 

77 Args: 

78 file_path (str): Path to the file. 

79 

80 Returns: 

81 bool: True if **file_path** exists and is not empty. 

82 """ 

83 if file_path.endswith('.zip'): 

84 print("Checking if empty zip: "+file_path) 

85 # Create a temporary directory to extract zip 

86 temp_dir = fu.create_unique_dir() 

87 # Extract zip and get list of files 

88 unzipped_files = fu.unzip_list(file_path, dest_dir=temp_dir) 

89 # Check if there are any files in the zip 

90 return len(unzipped_files) > 0 

91 

92 print("Checking if empty file: "+file_path) 

93 return Path(file_path).is_file() and Path(file_path).stat().st_size > 0 

94 

95 

96def compare_hash(file_a: str, file_b: str) -> bool: 

97 """Compute and compare the hashes of two files""" 

98 print("Comparing: ") 

99 print(" File_A: "+file_a) 

100 print(" File_B: "+file_b) 

101 file_a_hash = hashlib.sha256(open(file_a, 'rb').read()).digest() 

102 file_b_hash = hashlib.sha256(open(file_b, 'rb').read()).digest() 

103 print(" File_A hash: "+str(file_a_hash)) 

104 print(" File_B hash: "+str(file_b_hash)) 

105 return file_a_hash == file_b_hash 

106 

107 

108def equal(file_a: str, file_b: str, ignore_list: Optional[list[Union[str, int]]] = None, **kwargs) -> bool: 

109 """Check if two files are equal""" 

110 if ignore_list: 

111 # Line by line comparison 

112 return compare_line_by_line(file_a, file_b, ignore_list) 

113 

114 if file_a.endswith(".zip") and file_b.endswith(".zip"): 

115 return compare_zip(file_a, file_b) 

116 

117 if file_a.endswith(".pdb") and file_b.endswith(".pdb"): 

118 return compare_pdb(file_a, file_b, **kwargs) 

119 

120 if file_a.endswith(".top") and file_b.endswith(".top"): 

121 return compare_top_itp(file_a, file_b) 

122 

123 if file_a.endswith(".itp") and file_b.endswith(".itp"): 

124 return compare_top_itp(file_a, file_b) 

125 

126 if file_a.endswith(".gro") and file_b.endswith(".gro"): 

127 return compare_ignore_first(file_a, file_b) 

128 

129 if file_a.endswith(".prmtop") and file_b.endswith(".prmtop"): 

130 return compare_ignore_first(file_a, file_b) 

131 

132 if file_a.endswith(".inp") and file_b.endswith(".inp"): 

133 return compare_ignore_first(file_a, file_b) 

134 

135 if file_a.endswith(".par") and file_b.endswith(".par"): 

136 return compare_ignore_first(file_a, file_b) 

137 

138 if file_a.endswith((".nc", ".netcdf", ".xtc")) and file_b.endswith((".nc", ".netcdf", ".xtc")): 

139 return compare_size(file_a, file_b, kwargs.get('percent_tolerance', 1.0)) 

140 

141 if file_a.endswith(".xvg") and file_b.endswith(".xvg"): 

142 return compare_xvg(file_a, file_b, kwargs.get('percent_tolerance', 1.0)) 

143 

144 image_extensions = ('.png', '.jfif', '.ppm', '.tiff', '.jpg', '.dib', '.pgm', '.bmp', '.jpeg', '.pbm', '.jpe', '.apng', '.pnm', '.gif', '.tif') 

145 if file_a.endswith(image_extensions) and file_b.endswith(image_extensions): 

146 return compare_images(file_a, file_b, kwargs.get('percent_tolerance', 1.0)) 

147 

148 return compare_hash(file_a, file_b) 

149 

150 

151def compare_line_by_line(file_a: str, file_b: str, ignore_list: list[Union[str, int]]) -> bool: 

152 print(f"Comparing ignoring lines containing this words: {ignore_list}") 

153 print(" FILE_A: "+file_a) 

154 print(" FILE_B: "+file_b) 

155 with open(file_a) as fa, open(file_b) as fb: 

156 for index, (line_a, line_b) in enumerate(zip(fa, fb)): 

157 if index in ignore_list or any(word in line_a for word in ignore_list if isinstance(word, str)): 

158 continue 

159 elif line_a != line_b: 

160 return False 

161 return True 

162 

163 

164def equal_txt(file_a: str, file_b: str) -> bool: 

165 """Check if two text files are equal""" 

166 return compare_hash(file_a, file_b) 

167 

168 

169def compare_zip(zip_a: str, zip_b: str) -> bool: 

170 """ Compare zip files """ 

171 print("This is a ZIP comparison!") 

172 print("Unzipping:") 

173 print("Creating a unique_dir for: %s" % zip_a) 

174 zip_a_dir = fu.create_unique_dir() 

175 zip_a_list = fu.unzip_list(zip_a, dest_dir=zip_a_dir) 

176 print("Creating a unique_dir for: %s" % zip_b) 

177 zip_b_dir = fu.create_unique_dir() 

178 zip_b_list = fu.unzip_list(zip_b, dest_dir=zip_b_dir) 

179 

180 if not len(zip_a_list) == len(zip_b_list): 

181 return False 

182 

183 for uncompressed_zip_a in zip_a_list: 

184 uncompressed_zip_b = str(Path(zip_b_dir).joinpath(Path(uncompressed_zip_a).name)) 

185 if not equal(uncompressed_zip_a, uncompressed_zip_b): 

186 return False 

187 

188 return True 

189 

190 

191def compare_pdb(pdb_a: str, pdb_b: str, rmsd_cutoff: int = 1, remove_hetatm: bool = True, remove_hydrogen: bool = True, **kwargs): 

192 """ Compare pdb files """ 

193 print("Checking RMSD between:") 

194 print(" PDB_A: "+pdb_a) 

195 print(" PDB_B: "+pdb_b) 

196 pdb_parser = PDBParser(PERMISSIVE=True, QUIET=True) 

197 st_a = pdb_parser.get_structure("st_a", pdb_a) 

198 st_b = pdb_parser.get_structure("st_b", pdb_b) 

199 if st_a is None or st_b is None: 

200 print(" One of the PDB structures could not be parsed.") 

201 return False 

202 st_a = st_a[0] 

203 st_b = st_b[0] 

204 

205 if remove_hetatm: 

206 print(" Ignoring HETAMT in RMSD") 

207 residues_a = [list(res.get_atoms()) for res in st_a.get_residues() if not res.id[0].startswith('H_')] 

208 residues_b = [list(res.get_atoms()) for res in st_b.get_residues() if not res.id[0].startswith('H_')] 

209 atoms_a = [atom for residue in residues_a for atom in residue] 

210 atoms_b = [atom for residue in residues_b for atom in residue] 

211 else: 

212 atoms_a = st_a.get_atoms() 

213 atoms_b = st_b.get_atoms() 

214 

215 if remove_hydrogen: 

216 print(" Ignoring Hydrogen atoms in RMSD") 

217 atoms_a = [atom for atom in atoms_a if not atom.get_name().startswith('H')] 

218 atoms_b = [atom for atom in atoms_b if not atom.get_name().startswith('H')] 

219 

220 atoms_a_list = list(atoms_a) 

221 atoms_b_list = list(atoms_b) 

222 print(" Atoms ALIGNED in PDB_A: "+str(len(atoms_a_list))) 

223 print(" Atoms ALIGNED in PDB_B: "+str(len(atoms_b_list))) 

224 super_imposer = Superimposer() 

225 super_imposer.set_atoms(atoms_a, atoms_b) 

226 super_imposer.apply(atoms_b) 

227 super_imposer_rms = super_imposer.rms if super_imposer.rms is not None else float('inf') 

228 print(' RMS: '+str(super_imposer_rms)) 

229 print(' RMS_CUTOFF: '+str(rmsd_cutoff)) 

230 return super_imposer_rms < rmsd_cutoff 

231 

232 

233def compare_top_itp(file_a: str, file_b: str) -> bool: 

234 """ Compare top/itp files """ 

235 print("Comparing TOP/ITP:") 

236 print(" FILE_A: "+file_a) 

237 print(" FILE_B: "+file_b) 

238 with codecs.open(file_a, 'r', encoding='utf-8', errors='ignore') as f_a: 

239 next(f_a) 

240 with codecs.open(file_b, 'r', encoding='utf-8', errors='ignore') as f_b: 

241 next(f_b) 

242 return [line.strip() for line in f_a if not line.strip().startswith(';')] == [line.strip() for line in f_b if not line.strip().startswith(';')] 

243 

244 

245def compare_ignore_first(file_a: str, file_b: str) -> bool: 

246 """ Compare two files ignoring the first line """ 

247 print("Comparing ignoring first line of both files:") 

248 print(" FILE_A: "+file_a) 

249 print(" FILE_B: "+file_b) 

250 with open(file_a) as f_a: 

251 next(f_a) 

252 with open(file_b) as f_b: 

253 next(f_b) 

254 return [line.strip() for line in f_a] == [line.strip() for line in f_b] 

255 

256 

257def compare_size(file_a: str, file_b: str, percent_tolerance: float = 1.0) -> bool: 

258 """ Compare two files using size """ 

259 print("Comparing size of both files:") 

260 print(f" FILE_A: {file_a}") 

261 print(f" FILE_B: {file_b}") 

262 size_a = Path(file_a).stat().st_size 

263 size_b = Path(file_b).stat().st_size 

264 average_size = (size_a + size_b) / 2 

265 tolerance = average_size * percent_tolerance / 100 

266 tolerance_low = average_size - tolerance 

267 tolerance_high = average_size + tolerance 

268 print(f" SIZE_A: {size_a} bytes") 

269 print(f" SIZE_B: {size_b} bytes") 

270 print(f" TOLERANCE: {percent_tolerance}%, Low: {tolerance_low} bytes, High: {tolerance_high} bytes") 

271 return (tolerance_low <= size_a <= tolerance_high) and (tolerance_low <= size_b <= tolerance_high) 

272 

273 

274def compare_xvg(file_a: str, file_b: str, percent_tolerance: float = 1.0) -> bool: 

275 """ Compare two files using size """ 

276 print("Comparing size of both files:") 

277 print(f" FILE_A: {file_a}") 

278 print(f" FILE_B: {file_b}") 

279 arrays_tuple_a = np.loadtxt(file_a, comments=["@", '#'], unpack=True) 

280 arrays_tuple_b = np.loadtxt(file_b, comments=["@", '#'], unpack=True) 

281 for array_a, array_b in zip(arrays_tuple_a, arrays_tuple_b): 

282 if not np.allclose(array_a, array_b, rtol=percent_tolerance / 100): 

283 return False 

284 return True 

285 

286 

287def compare_images(file_a: str, file_b: str, percent_tolerance: float = 1.0) -> bool: 

288 try: 

289 from PIL import Image # type: ignore 

290 import imagehash 

291 except ImportError: 

292 print("To compare images, please install the following packages: Pillow, imagehash") 

293 return False 

294 

295 """ Compare two files using size """ 

296 print("Comparing images of both files:") 

297 print(f" IMAGE_A: {file_a}") 

298 print(f" IMAGE_B: {file_b}") 

299 hash_a = imagehash.average_hash(Image.open(file_a)) 

300 hash_b = imagehash.average_hash(Image.open(file_b)) 

301 tolerance = (len(hash_a) + len(hash_b)) / 2 * percent_tolerance / 100 

302 if tolerance < 1: 

303 tolerance = 1 

304 difference = hash_a - hash_b 

305 print(f" IMAGE_A HASH: {hash_a} SIZE: {len(hash_a)} bits") 

306 print(f" IMAGE_B HASH: {hash_b} SIZE: {len(hash_b)} bits") 

307 print(f" TOLERANCE: {percent_tolerance}%, ABS TOLERANCE: {tolerance} bits, DIFFERENCE: {difference} bits") 

308 if difference > tolerance: 

309 return False 

310 return True 

311 

312 

313def compare_object_pickle(python_object: Any, pickle_file_path: Union[str, Path], **kwargs) -> bool: 

314 """ Compare a python object with a pickle file """ 

315 print(f"Loading pickle file: {pickle_file_path}") 

316 with open(pickle_file_path, 'rb') as f: 

317 pickle_object = pickle.load(f) 

318 

319 # Special case for dictionaries 

320 if isinstance(python_object, dict) and isinstance(pickle_object, dict): 

321 differences = compare_dictionaries(python_object, pickle_object, ignore_keys=kwargs.get('ignore_keys', []), compare_values=kwargs.get('compare_values', True), ignore_substring=kwargs.get('ignore_substring', "")) 

322 if differences: 

323 print(50*'*') 

324 print("OBJECT:") 

325 print(python_object) 

326 print(50*'*') 

327 print() 

328 print(50*'*') 

329 print("EXPECTED OBJECT:") 

330 print(pickle_object) 

331 print(50*'*') 

332 

333 print("Differences found:") 

334 for difference in differences: 

335 print(f" {difference}") 

336 return False 

337 return True 

338 

339 return python_object == pickle_object 

340 

341 

342def compare_dictionaries(dict1: dict, dict2: dict, path: str = "", ignore_keys: Optional[list[str]] = None, compare_values: bool = True, ignore_substring: str = "") -> list[str]: 

343 """Compare two dictionaries and print only the differences, ignoring specified keys.""" 

344 if ignore_keys is None: 

345 ignore_keys = [] 

346 

347 differences = [] 

348 

349 # Get all keys from both dictionaries 

350 all_keys = set(dict1.keys()).union(set(dict2.keys())) 

351 

352 for key in all_keys: 

353 if key in ignore_keys: 

354 continue 

355 if key not in dict1: 

356 differences.append(f"Key '{path + key}' found in dict2 but not in dict1") 

357 elif key not in dict2: 

358 differences.append(f"Key '{path + key}' found in dict1 but not in dict2") 

359 else: 

360 value1 = dict1[key] 

361 value2 = dict2[key] 

362 if isinstance(value1, dict) and isinstance(value2, dict): 

363 # Recursively compare nested dictionaries 

364 nested_differences = compare_dictionaries(value1, value2, path + key + ".", ignore_keys, compare_values, ignore_substring) 

365 differences.extend(nested_differences) 

366 elif (value1 != value2) and compare_values: 

367 if ignore_substring: 

368 if (not str(value1).endswith(str(value2).replace(ignore_substring, ""))) and (not str(value2).endswith(str(value1).replace(ignore_substring, ""))): 

369 differences.append(f"Difference at '{path + key}': dict1 has {value1}, dict2 has {value2}") 

370 

371 else: 

372 differences.append(f"Difference at '{path + key}': dict1 has {value1}, dict2 has {value2}") 

373 

374 return differences 

375 

376 

377def validate_json(json_file_path: Union[str, Path], json_schema_path: Union[str, Path]) -> bool: 

378 """ 

379 Validates a JSON file against a provided JSON schema. 

380 

381 Args: 

382 json_file_path (str): Path to the JSON file to validate. 

383 json_schema_path (str): Path to the JSON schema file. 

384 

385 Returns: 

386 bool: True if the JSON is valid, False if invalid. 

387 """ 

388 print("Validating JSON file:") 

389 print(f" JSON file: {json_file_path}") 

390 print(f" JSON schema: {json_schema_path}") 

391 try: 

392 # Load the JSON file 

393 with open(json_file_path, 'r') as json_file: 

394 json_data = json.load(json_file) 

395 

396 # Load the JSON schema 

397 with open(json_schema_path, 'r') as schema_file: 

398 schema = json.load(schema_file) 

399 

400 # Validate the JSON data against the schema 

401 jsonschema.validate(instance=json_data, schema=schema) 

402 

403 return True 

404 except jsonschema.ValidationError as ve: 

405 print(f"Validation error: {ve.message}") 

406 return False 

407 except json.JSONDecodeError as je: 

408 print(f"Invalid JSON format: {je.msg}") 

409 return False