Coverage for biobb_pytorch / mdae / loss / utils / torch_protein_energy_utils.py: 3%

756 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1# Copyright (c) 2021 Venkata K. Ramaswamy, Samuel C. Musson, Chris G. Willcocks, Matteo T. Degiacomi 

2# 

3# Molearn is free software ; 

4# you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation ; 

5# either version 2 of the License, or (at your option) any later version. 

6# molearn is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY ; 

7# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 

8# See the GNU General Public License for more details. 

9# You should have received a copy of the GNU General Public License along with molearn ; 

10# if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. 

11 

12import numpy as np 

13import torch 

14from copy import deepcopy 

15import os 

16from importlib.resources import files 

17 

18 

19def read_lib_file(file_name, amber_atoms, atom_charge, connectivity): 

20 try: 

21 f_location = files('mdae').joinpath('loss/parameters') 

22 f_in = open(f'{str(f_location)}/{file_name}') 

23 print(f'File {f_location}/{file_name} opened') 

24 except Exception: 

25 raise Exception(f'ERROR: file {file_name} not found!') 

26 

27 lines = f_in.readlines() 

28 depth = 0 

29 indexs = {} 

30 for tline in lines: 

31 if tline.split() == ['!!index', 'array', 'str']: 

32 depth += 1 

33 for line in lines[depth:]: 

34 if line[0] != ' ': 

35 break 

36 contents = line.split() 

37 if len(contents) != 1 and len(contents[0]) != 5: 

38 break 

39 res = contents[0] 

40 if res[0] == '"' and res[-1] == '"': 

41 amber_atoms[res[1:-1]] = {} 

42 atom_charge[res[1:-1]] = {} 

43 indexs[res[1:-1]] = {} 

44 connectivity[res[1:-1]] = {} 

45 else: 

46 msg = 'I was expecting something of the form "XXX" but got %s instead' % res 

47 raise Exception(msg) 

48 depth += 1 

49 break 

50 depth += 1 

51 

52 for i, tline in enumerate(lines): 

53 entry, res, unit_atoms, unit_connectivity = tline[0:7], tline[7:10], tline[10:22], tline[10:29] 

54 if entry == '!entry.' and unit_atoms == '.unit.atoms ': 

55 depth = i + 1 

56 for line in lines[depth:]: 

57 if line[0] != ' ': 

58 break 

59 contents = line.split() 

60 if len(contents) < 3 and len(contents[0]) > 4 and len(contents[1]) > 4: 

61 break 

62 pdb_name, amber_name, _, _, _, index, element_number, charge = contents 

63 pdb_quoted = pdb_name[0] == '"' and pdb_name[-1] == '"' 

64 amber_quoted = amber_name[0] == '"' and amber_name[-1] == '"' 

65 if pdb_quoted and amber_quoted: 

66 amber_atoms[res][contents[0][1:-1]] = contents[1][1:-1] 

67 atom_charge[res][amber_name[1:-1]] = float(charge) 

68 # indexs[res][amber_name[1:-1]] = int(index) 

69 indexs[res][int(index)] = pdb_name[1:-1] 

70 connectivity[res][pdb_name[1:-1]] = [] 

71 else: 

72 msg = 'I was expecting something of the form "XXX" but got %s instead' % res 

73 raise Exception(msg) 

74 elif entry == '!entry.' and unit_connectivity == '.unit.connectivity ': 

75 depth = i + 1 

76 for line in lines[depth:]: 

77 if line[0] != ' ': 

78 break 

79 contents = line.split() 

80 if len(contents) != 3: 

81 break 

82 a1, a2, flag = contents 

83 connectivity[res][indexs[res][int(a1)]].append(indexs[res][int(a2)]) 

84 connectivity[res][indexs[res][int(a2)]].append(indexs[res][int(a1)]) 

85 

86 

87def get_amber_parameters(order=False, radians=True): 

88 

89 file_names = ('amino12.lib', 

90 'parm10.dat', 

91 'frcmod.ff14SB') 

92 

93 # amber19 is dangerous because they've replaced parameters with cmap 

94 # pdb atom names to amber atom names using amino19.lib 

95 amber_atoms = {} # knowledge[res][pdb_atom] = amber_atom 

96 atom_mass = {} 

97 atom_polarizability = {} 

98 bond_force = {} 

99 bond_equil = {} 

100 angle_force = {} 

101 angle_equil = {} 

102 torsion_factor = {} 

103 torsion_barrier = {} 

104 torsion_phase = {} 

105 torsion_period = {} 

106 improper_factor = {} 

107 improper_barrier = {} 

108 improper_phase = {} 

109 improper_period = {} 

110 

111 other_parameters = {} 

112 other_parameters['vdw_potential_well_depth'] = {} 

113 other_parameters['H_bond_10_12_parameters'] = {} 

114 other_parameters['equivalences'] = {} 

115 other_parameters['charge'] = {} 

116 other_parameters['connectivity'] = {} 

117 read_lib_file(file_names[0], amber_atoms, other_parameters['charge'], other_parameters['connectivity']) 

118 

119 try: 

120 f_location = files('mdae').joinpath('loss/parameters') 

121 f_in = open(f'{str(f_location)}/{file_names[1]}') 

122 print('File %s opened' % file_names[1]) 

123 except Exception: 

124 raise Exception('ERROR: file %s not found!' % file_names[1]) 

125 

126 # section 1 title 

127 line = f_in.readline() 

128 print(line) 

129 

130 amber_card_type_2(f_in, atom_mass, atom_polarizability) 

131 amber_card_type_3(f_in) 

132 amber_card_type_4(f_in, bond_force, bond_equil) 

133 amber_card_type_5(f_in, angle_force, angle_equil) 

134 amber_card_type_6(f_in, torsion_factor, torsion_barrier, torsion_phase, torsion_period) 

135 amber_card_type_7(f_in, improper_factor, 

136 improper_barrier, improper_phase, improper_period) 

137 amber_card_type_8(f_in, other_parameters) 

138 amber_card_type_9(f_in, other_parameters) 

139 for line in f_in: 

140 if len(line.split()) > 1: 

141 if line.split()[1] == 'RE': 

142 amber_card_type_10B(f_in, other_parameters) 

143 elif line[0:3] == 'END': 

144 print('parameters loaded') 

145 f_in.close() 

146 

147 # pen frcmod file, should be identifcal format but missing any or all cards 

148 try: 

149 f_location = files('mdae').joinpath('loss/parameters') 

150 f_in = open(f'{str(f_location)}/{file_names[2]}') 

151 print('File %s opened' % file_names[2]) 

152 except Exception: 

153 raise Exception('ERROR: file %s not found!' % file_names[2]) 

154 

155 # section 1 title 

156 line = f_in.readline() 

157 print(line) 

158 

159 for line in f_in: 

160 if line[:4] == 'MASS': 

161 amber_card_type_2(f_in, atom_mass, atom_polarizability) 

162 if line[:4] == 'BOND': 

163 amber_card_type_4(f_in, bond_force, bond_equil) 

164 if line[:4] == 'ANGL': 

165 amber_card_type_5(f_in, angle_force, angle_equil) 

166 if line[:4] == 'DIHE': 

167 amber_card_type_6(f_in, torsion_factor, torsion_barrier, torsion_phase, torsion_period) 

168 if line[:4] == 'IMPR': 

169 amber_card_type_7(f_in, improper_factor, 

170 improper_barrier, improper_phase, improper_period) 

171 if line[:4] == 'HBON': 

172 amber_card_type_8(f_in, other_parameters) 

173 if line[:4] == 'NONB': 

174 amber_card_type_10B(f_in, other_parameters) 

175 if line[:4] == 'CMAP': 

176 print('Yeah, Im not bothering to implement cmap') 

177 elif line[0:3] == 'END': 

178 print('parameters loaded') 

179 f_in.close() 

180 

181 if radians: 

182 for angle in angle_equil: 

183 angle_equil[angle] = np.deg2rad(angle_equil[angle]) 

184 for torsion in torsion_phase: 

185 torsion_phase[torsion] = list(np.deg2rad(torsion_phase[torsion])) 

186 

187 return (amber_atoms, atom_mass, atom_polarizability, bond_force, bond_equil, 

188 angle_force, angle_equil, torsion_factor, torsion_barrier, torsion_phase, 

189 torsion_period, improper_factor, improper_barrier, improper_phase, 

190 improper_period, other_parameters) 

191 

192 

193def amber_card_type_2(f_in, atom_mass, atom_polarizability): 

194 

195 # section 2 input for atom symbols and masses 

196 for line in f_in: 

197 if line == '\n' or line.strip() == '': 

198 break 

199 atom = line[0:2].strip() 

200 contents = line[2:24].split() 

201 if len(contents) == 2: 

202 mass, polarizability = float(contents[0]), float(contents[1]) 

203 atom_mass[atom] = mass 

204 atom_polarizability[atom] = polarizability 

205 elif len(contents) == 1: # sometimes a polarizability is not listed 

206 mass = float(contents[0]) 

207 atom_mass[atom] = mass 

208 atom_polarizability[atom] = polarizability 

209 else: 

210 raise Exception('Should be 2A, X, F10.2, F10.2, comments but got %s' % line) 

211 

212 

213def amber_card_type_3(f_in): 

214 # section 3 input for atom symbols that are hydrophilic 

215 f_in.readline() 

216 

217 

218def amber_card_type_4(f_in, bond_force, bond_equil, order=False): 

219 # section 4 bond length paramters 

220 for line in f_in: 

221 if line == '\n' or line.strip() == '': 

222 break 

223 atom1 = line[0:2]. strip() 

224 atom2 = line[3:5].strip() 

225 if order: 

226 bond = tuple(sorted((atom1, atom2))) # put in alphabetical order 

227 else: 

228 bond = (atom1, atom2) 

229 contents = line[5:25].split() 

230 if len(contents) != 2: 

231 raise Exception('Expected 2 floats but got %s' % line[6:26]) 

232 force_constant, equil_length = float(contents[0]), float(contents[1]) 

233 bond_force[bond] = force_constant 

234 bond_equil[bond] = equil_length 

235 # this should throw an error if there are not 

236 

237 

238def amber_card_type_5(f_in, angle_force, angle_equil, order=False): 

239 # section 5 

240 for line in f_in: 

241 if line == '\n' or line.strip() == '': 

242 break 

243 atom1 = line[0:2].strip() 

244 atom2 = line[3:5].strip() 

245 atom3 = line[6:8].strip() 

246 if order: 

247 sorted13 = sorted((atom1, atom3)) 

248 angle = (sorted13[0], atom2, sorted13[1]) 

249 # I want it sorted alphabetically by 1-3 atoms 

250 else: 

251 angle = (atom1, atom2, atom3) 

252 contents = line[8:28].split() 

253 if len(contents) != 2: 

254 raise Exception('Expected 2 floats but got %s' % line[6:26]) 

255 force_constant, equil_angle = float(contents[0]), float(contents[1]) 

256 angle_force[angle] = force_constant 

257 angle_equil[angle] = equil_angle 

258 

259 

260def amber_card_type_6(f_in, torsion_factor, torsion_barrier, torsion_phase, 

261 torsion_period, order=False): 

262 # section 6 torsion / proper dihedral 

263 for line in f_in: 

264 if line == '\n' or line.strip() == '': 

265 break 

266 atom1 = line[0:2].strip() 

267 atom2 = line[3:5].strip() 

268 atom3 = line[6:8].strip() 

269 atom4 = line[9:11].strip() 

270 if order: 

271 sort23 = sorted([(atom2, atom1), (atom3, atom4)], key=lambda x: x[0]) 

272 torsion = tuple((sort23[0][1], sort23[0][0], sort23[1][0], sort23[1][1])) 

273 else: 

274 torsion = (atom1, atom2, atom3, atom4) 

275 contents = line[11:55].split() 

276 if len(contents) != 4: 

277 raise Exception('I wanted four values here?') 

278 # the actual torsion potential is (barrier/factor)*(1+cos(period*phi-phase)) 

279 if torsion in torsion_period: 

280 if torsion_period[torsion][-1] > 0: 

281 torsion_factor[torsion] = [int(contents[0])] 

282 torsion_barrier[torsion] = [float(contents[1])] 

283 torsion_phase[torsion] = [float(contents[2])] 

284 torsion_period[torsion] = [float(contents[3])] 

285 elif torsion_period[torsion][-1] < 0: 

286 torsion_factor[torsion].append(int(contents[0])) 

287 torsion_barrier[torsion].append(float(contents[1])) 

288 torsion_phase[torsion].append(float(contents[2])) 

289 torsion_period[torsion].append(float(contents[3])) 

290 else: 

291 torsion_factor[torsion] = [int(contents[0])] 

292 torsion_barrier[torsion] = [float(contents[1])] 

293 torsion_phase[torsion] = [float(contents[2])] 

294 torsion_period[torsion] = [float(contents[3])] 

295 

296 

297def amber_card_type_7(f_in, improper_factor, improper_barrier, 

298 improper_phase, improper_period, order=False): 

299 # section 7 improper dihedrals 

300 for line in f_in: 

301 if line == '\n' or line.strip() == '': 

302 break 

303 atom1 = line[0:2].strip() 

304 atom2 = line[3:5].strip() 

305 atom3 = line[6:8].strip() 

306 atom4 = line[9:11].strip() 

307 if order: 

308 sort23 = sorted([(atom2, atom1), (atom3, atom4)], key=lambda x: x[0]) 

309 torsion = tuple((sort23[0][1], sort23[0][0], sort23[1][0], sort23[1][1])) 

310 else: 

311 torsion = (atom1, atom2, atom3, atom4) 

312 contents = line[11:55].split() 

313 if len(contents) == 3: 

314 improper_barrier[torsion] = float(contents[0]) 

315 improper_phase[torsion] = float(contents[1]) 

316 improper_period[torsion] = float(contents[2]) 

317 elif len(contents) == 4: 

318 raise Exception('This seems allowed in the doc but doesnt appear in reality') 

319 improper_factor[torsion] = int(contents[0]) 

320 improper_barrier[torsion] = float(contents[1]) 

321 improper_phase[torsion] = float(contents[2]) 

322 improper_period[torsion] = float(contents[3]) 

323 # the actual torsion potential is (barrier/factor)*(1+cos(period*phi-phase)) 

324 # it seems improper potential don't divide by the factor 

325 

326 

327def amber_card_type_8(f_in, other_parameters, order=False): 

328 # section 8 H-bond 10-12 potential parameters 

329 for line in f_in: 

330 if line == '\n' or line.strip() == '': 

331 break 

332 atom1 = line[2:4].strip() 

333 atom2 = line[6:8].strip() 

334 if order: 

335 pair = tuple(sorted((atom1, atom2))) 

336 else: 

337 pair = (atom1, atom2) 

338 contents = line[8:].split() 

339 other_parameters['H_bond_10_12_parameters'][pair] = contents 

340 

341 

342def amber_card_type_9(f_in, other_parameters): 

343 # section 9 equivalencing atom symbols for non-bonded 6-12 potential parameters 

344 for line in f_in: 

345 if line == '\n' or line.strip() == '': 

346 break 

347 contents = line.split() 

348 other_parameters['equivalences'][contents[0]] = contents 

349 

350 

351def amber_card_type_10B(f_in, other_parameters): 

352 # section 10 6-12 potential parameters 

353 for line in f_in: 

354 if line == '\n' or line.strip() == '': 

355 break 

356 contents = line.split() 

357 other_parameters['vdw_potential_well_depth'][contents[0]] = [float(i) for i in contents[1:3]] 

358 

359 

360def get_convolutions(dataset, pdb_atom_names, 

361 atom_label=('set', 'string')[0], 

362 perform_checks=True, 

363 v=5, 

364 order=False, 

365 return_type=['mask', 'idxs'][1], 

366 absolute_torsion_period=True, 

367 NB=('matrix',)[0], 

368 fix_terminal=True, 

369 fix_charmm_residues=True, 

370 fix_slice_method=False, 

371 fix_h=False, 

372 alt_vdw=[], 

373 permitivity=1.0 

374 ): 

375 ''' 

376 ##INPUTS## 

377 

378 dataset: one frame of a trajectory of shape [3, N] 

379 

380 pdb_atom_names: should be an array of shape [N,2] 

381 pdb_atom_names[:,0] is the pdb_atom_names and 

382 pdb_atom_names[:,1] is the residue names 

383 

384 atom_label: (default, 'set') deprecated and broken for anything other than 'set' 

385 

386 perform_checks: No longer works so has been removed 

387 

388 v: (default, 5) atom_selection version, bonds are determined by interatomic 

389 distance, with v=2. v=1 shouldn't be used except in specific cirumstances. v=5 is using connectivity from amber parameters. 

390 

391 order: (bool, default false) are atoms ordered, I think I've fixed this so that 

392 it shouldn't matter either way but keep as False. 

393 

394 return_type: (option now removed) 

395 

396 

397 ##OUTPUTS## 

398 convolution output shape N* will be N-(conv length -1)+padding 

399 

400 bond_masks, b_equil, b_force: shape [number of convolutions, N*] 

401 

402 bond_weights: shape[number of convolutions, conv_size] 

403 

404 angle_masks, a_equil, a_force: shape [number of convolutions, conv_size] 

405 

406 angle_weights: shape[number of convolutions, 2, conv_size] 

407 

408 torsion_masks: shape[number of convolution, 3, conv_size] 

409 

410 t_para: shape[num of convs, N*, 4, max number torsion parameters ] 

411 

412 tornsion_weigths: shape [number of convolutions, 3, conv_size] 

413 

414 ''' 

415 

416 # get amber parameters 

417 (amber_atoms, atom_mass, atom_polarizability, bond_force, bond_equil, 

418 angle_force, angle_equil, torsion_factor, torsion_barrier, torsion_phase, 

419 torsion_period, improper_factor, improper_barrier, improper_phase, 

420 improper_period, other_parameters) = get_amber_parameters() 

421 if fix_terminal: 

422 pdb_atom_names[pdb_atom_names[:, 0] == 'OXT', 0] = 'O' 

423 if fix_charmm_residues: 

424 pdb_atom_names[pdb_atom_names[:, 1] == 'HSD', 1] = 'HID' 

425 pdb_atom_names[pdb_atom_names[:, 1] == 'HSE', 1] = 'HIE' 

426 for i in np.unique(pdb_atom_names[:, 2]): 

427 res_mask = pdb_atom_names[:, 2] == i 

428 if (pdb_atom_names[res_mask, 1] == 'HIS').all(): # if a HIS residue 

429 if (pdb_atom_names[res_mask, 0] == 'HD1').any() and (pdb_atom_names[res_mask, 0] == 'HE2').any(): 

430 pdb_atom_names[res_mask, 1] = 'HIP' 

431 elif (pdb_atom_names[res_mask, 0] == 'HD1').any(): 

432 pdb_atom_names[res_mask, 1] = 'HID' 

433 elif (pdb_atom_names[res_mask, 0] == 'HE2').any(): 

434 pdb_atom_names[res_mask, 1] = 'HIE' 

435 # if any HIS are remaining it does not matter which because the H is dealt with above 

436 pdb_atom_names[pdb_atom_names[:, 1] == 'HIS', 1] = 'HIE' 

437 if fix_h: 

438 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'MET'), 0] = 'HB3' 

439 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'MET'), 0] = 'HG3' 

440 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'ASN'), 0] = 'HB3' 

441 pdb_atom_names[pdb_atom_names[:, 0] == 'HN', 0] = 'H' 

442 pdb_atom_names[pdb_atom_names[:, 0] == '1HD2', 0] = 'HD21' 

443 pdb_atom_names[pdb_atom_names[:, 0] == '2HD2', 0] = 'HD22' 

444 pdb_atom_names[pdb_atom_names[:, 0] == '1HG2', 0] = 'HG21' 

445 pdb_atom_names[pdb_atom_names[:, 0] == '2HG2', 0] = 'HG22' 

446 pdb_atom_names[pdb_atom_names[:, 0] == '3HG2', 0] = 'HG23' 

447 pdb_atom_names[pdb_atom_names[:, 0] == '3HG1', 0] = 'HG13' 

448 pdb_atom_names[pdb_atom_names[:, 0] == '1HG1', 0] = 'HG11' 

449 pdb_atom_names[pdb_atom_names[:, 0] == '2HG1', 0] = 'HG12' 

450 pdb_atom_names[pdb_atom_names[:, 0] == '1HD1', 0] = 'HD11' 

451 pdb_atom_names[pdb_atom_names[:, 0] == '2HD1', 0] = 'HD12' 

452 pdb_atom_names[pdb_atom_names[:, 0] == '3HD1', 0] = 'HD13' 

453 pdb_atom_names[pdb_atom_names[:, 0] == '3HD2', 0] = 'HD23' 

454 pdb_atom_names[pdb_atom_names[:, 0] == '1HH1', 0] = 'HH11' 

455 pdb_atom_names[pdb_atom_names[:, 0] == '2HH1', 0] = 'HH12' 

456 pdb_atom_names[pdb_atom_names[:, 0] == '1HH2', 0] = 'HH21' 

457 pdb_atom_names[pdb_atom_names[:, 0] == '2HH2', 0] = 'HH22' 

458 pdb_atom_names[pdb_atom_names[:, 0] == '1HE2', 0] = 'HE21' 

459 pdb_atom_names[pdb_atom_names[:, 0] == '2HE2', 0] = 'HE22' 

460 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG11', pdb_atom_names[:, 1] == 'ILE'), 0] = 'HG13' 

461 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'CD', pdb_atom_names[:, 1] == 'ILE'), 0] = 'CD1' 

462 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD1', pdb_atom_names[:, 1] == 'ILE'), 0] = 'HD11' 

463 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD2', pdb_atom_names[:, 1] == 'ILE'), 0] = 'HD12' 

464 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD3', pdb_atom_names[:, 1] == 'ILE'), 0] = 'HD13' 

465 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'PHE'), 0] = 'HB3' 

466 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'GLU'), 0] = 'HB3' 

467 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'GLU'), 0] = 'HG3' 

468 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'LEU'), 0] = 'HB3' 

469 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'ARG'), 0] = 'HB3' 

470 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'ARG'), 0] = 'HG3' 

471 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD1', pdb_atom_names[:, 1] == 'ARG'), 0] = 'HD3' 

472 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'ASP'), 0] = 'HB3' 

473 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HA1', pdb_atom_names[:, 1] == 'GLY'), 0] = 'HA3' 

474 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'LYS'), 0] = 'HB3' 

475 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'LYS'), 0] = 'HG3' 

476 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD1', pdb_atom_names[:, 1] == 'LYS'), 0] = 'HD3' 

477 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HE1', pdb_atom_names[:, 1] == 'LYS'), 0] = 'HE3' 

478 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'TYR'), 0] = 'HB3' 

479 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'HIP'), 0] = 'HB3' 

480 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'SER'), 0] = 'HB3' 

481 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'SER'), 0] = 'HG' 

482 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'PRO'), 0] = 'HB3' 

483 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'PRO'), 0] = 'HG3' 

484 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD1', pdb_atom_names[:, 1] == 'PRO'), 0] = 'HD3' 

485 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'LEU'), 0] = 'HB3' 

486 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'GLN'), 0] = 'HB3' 

487 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'GLN'), 0] = 'HG3' 

488 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'TRP'), 0] = 'HB3' 

489 # writes termini as H because we haven't loaded in termini parameters 

490 atom_names = [[amber_atoms[res][atom], res, resid] if atom not in ['H2', 'H3'] else [amber_atoms[res]['H'], res, resid] for atom, res, resid in pdb_atom_names] 

491 p_atom_names = [[atom, res, resid] if atom not in ['H2', 'H3'] else ['H', res, resid] for atom, res, resid in pdb_atom_names] 

492 

493 # atom_names = [[amber_atoms[res][atom],res] for atom, res, resid in pdb_atom_names ] 

494 atom_charges = [other_parameters['charge'][res][atom] for atom, res, resid in atom_names] 

495 if NB == 'matrix': 

496 equiv_t = other_parameters['equivalences'] 

497 vdw_para = other_parameters['vdw_potential_well_depth'] 

498 # switch these around so that values point to key 

499 equiv = {} 

500 for i in equiv_t.keys(): 

501 j = equiv_t[i] 

502 for k in j: 

503 equiv[k] = i 

504 atom_R = torch.tensor([vdw_para[equiv.get(atom, atom)][0] for atom, res, resid in atom_names]) # radius 

505 atom_e = torch.tensor([vdw_para[equiv.get(atom, atom)][1] for atom, res, resid in atom_names]) # welldepth 

506 

507 print('Determining bonds') 

508 version = v # method of selecting bonded atoms 

509 N = dataset.shape[1] # 145 

510 

511 cmat = (torch.nn.functional.pdist((dataset).permute(1, 0))).cpu().numpy() 

512 if version == 1: 

513 bond_idxs = np.argpartition(cmat, (N - 1, N)) 

514 # this will work for any non cyclic monomeric protein 

515 # that in mind will break if enough proline atoms to make a cycle are selected 

516 bond_idxs, u = bond_idxs[:N - 1], bond_idxs[N - 1] 

517 if cmat[u] - cmat[bond_idxs[-1]] < 0.25: 

518 dist_val = str(cmat[u] - cmat[bond_idxs[-1]]) 

519 msg = ("WARNING: May not have correctly selected the bonded distances: value %s " 

520 "should be roughly between 0.42 and 0.57 (>0.25)" % dist_val) 

521 raise Exception(msg) # should be 0.42-0.57 

522 version += 1 # try version 2 instead 

523 mid = cmat[bond_idxs[-1]] + ((cmat[u] - cmat[bond_idxs[-1]]) / 2) # mid point 

524 full_mask = (cmat < mid).astype('int8') 

525 if version == 2: 

526 full_mask = (cmat < (1.643 + 2.129) / 2).astype('int8') 

527 bond_idxs = np.where(full_mask)[0] # for some reason returns tuple with one array 

528 if version == 3: 

529 if alt_vdw: 

530 vdw = torch.tensor(alt_vdw) 

531 max_bond_dist = (0.6 * (vdw.view(1, -1) + vdw.view(-1, 1))) 

532 cdist = torch.cdist(dataset.T, dataset.T) 

533 i, j = np.where((max_bond_dist > cdist).triu(diagonal=1).numpy()) 

534 remove = np.where(np.abs(j - i) > 30) 

535 max_bond_dist[i[remove], j[remove]] = 0.0 

536 max_bond_dist = max_bond_dist.numpy() 

537 else: 

538 max_bond_dist = (0.6 * (atom_R.view(1, -1) + atom_R.view(-1, 1))).cpu().numpy() 

539 max_bond_dist = max_bond_dist[np.where(np.triu(np.ones((N, N)), k=1))] 

540 full_mask = np.greater(max_bond_dist, cmat) 

541 bond_idxs = np.where(full_mask)[0] # for some reason returns tuple with one array 

542 if version == 4: 

543 # fix_hydrogens = [[atom, res, resid] for atom, res, resid in pdb_atom_names if atom in ['H2', 'H3']] 

544 connectivity = other_parameters['connectivity'] 

545 bond_types = [] 

546 bond_idxs = [] 

547 # tracker = [[]]*N doesn't work because of mutability 

548 tracker = [[] for i in range(N)] 

549 current_resid = -9999 

550 current_atoms = [] 

551 for i1, (atom1, res, resid) in enumerate(p_atom_names): 

552 assert atom1 in connectivity[res] 

553 for atom2, i2 in current_atoms: 

554 if resid != current_resid: # and atom2 == 'C': 

555 if atom2 != 'C': 

556 continue 

557 elif not (atom2 in connectivity[res][atom1] and atom1 in connectivity[res][atom2]): 

558 continue 

559 # if not (atom2 in connectivity[res][atom1] and atom1 in connectivity[res][atom2]): 

560 # if resid != current_resid and atom2 == 'C': 

561 # current_resid = resid 

562 # current_atoms = [] 

563 # else: 

564 # continue 

565 if atom1 == 'N' and atom2 == 'CA': 

566 continue 

567 tracker[i1].append(i2) 

568 tracker[i2].append(i1) 

569 if atom_label == 'set': 

570 if order: 

571 names = tuple(sorted((atom_names[i2][0], atom_names[i1][0]))) 

572 else: 

573 names = tuple((atom_names[i2][0], atom_names[i1][0])) 

574 bond_types.append(names) 

575 bond_idxs.append([i2, i1]) 

576 if resid != current_resid: # and atom2 == 'C': 

577 current_resid = resid 

578 current_atoms = [] 

579 current_atoms.append([atom1, i1]) 

580 if version == 5: 

581 connectivity = other_parameters['connectivity'] 

582 bond_types = [] 

583 bond_idxs = [] 

584 tracker = [[] for i in range(N)] 

585 current_resid = -9999 

586 current_atoms = [] 

587 previous_atoms = [] 

588 for i1, (atom1, res, resid) in enumerate(p_atom_names): 

589 assert atom1 in connectivity[res] 

590 if resid != current_resid: 

591 previous_atoms = deepcopy(current_atoms) 

592 current_atoms = [] 

593 current_resid = resid 

594 if atom1 == 'N': 

595 for atom2, i2 in previous_atoms: 

596 if atom2 == 'C': 

597 tracker[i1].append(i2) 

598 tracker[i2].append(i1) 

599 bond_types.append(tuple((atom_names[i2][0], atom_names[i1][0]))) 

600 bond_idxs.append([i2, i1]) 

601 

602 for atom2, i2 in current_atoms: 

603 if atom2 in connectivity[res][atom1] and atom1 in connectivity[res][atom2]: 

604 tracker[i1].append(i2) 

605 tracker[i2].append(i1) 

606 names = tuple((atom_names[i2][0], atom_names[i1][0])) 

607 bond_types.append(names) 

608 bond_idxs.append([i2, i1]) 

609 current_atoms.append([atom1, i1]) 

610 if version < 4: 

611 all_bond_idxs = np.sort(bond_idxs) 

612 

613 bond_types = [] 

614 bond_idxs = [] 

615 tracker = [[]] # this will keep track of some of the bonds to help work out the angles 

616 atom1 = 0 

617 atom2 = 1 

618 counter = 0 # index of the distance N,N+1 

619 for bond in all_bond_idxs: 

620 if bond < counter + (N - atom1 - 1): 

621 atom2 = atom1 + bond - counter + 1 # 0-0+1 

622 tracker[-1].append(atom2) 

623 while bond > counter + (N - atom1 - 2): 

624 counter += (N - atom1 - 1) 

625 atom1 += 1 

626 tracker.append([]) 

627 if bond < counter + (N - atom1 - 1): 

628 atom2 = atom1 + bond - counter + 1 

629 tracker[-1].append(atom2) 

630 if atom_label == 'string': # string of atom labels, doesn't handle Proline alternate ordering 

631 bond_types.append(atom_names[atom1][0] + '_' + atom_names[atom2][0]) 

632 bond_idxs.append([atom1, atom2]) 

633 elif atom_label == 'set': # set of atom labels 

634 if order: 

635 names = tuple(sorted((atom_names[atom1][0], atom_names[atom2][0]))) 

636 else: 

637 names = (atom_names[atom1][0], atom_names[atom2][0]) 

638 bond_types.append(names) 

639 bond_idxs.append([atom1, atom2]) 

640 

641 while len(tracker) < N: 

642 tracker.append([]) # ensure so the next bit doesn't break by indexing N-1 

643 

644 # Angles/1-3 

645 print('Determining angles') 

646 angle_types = [] 

647 angle_idxs = [] 

648 

649 torsion_types = [] 

650 torsion_idxs = [] 

651 

652 bond_14_idxs = [] 

653 

654 counter = 0 

655 # add missing bonds (each bond counted twice after but atom3>atom1 prevents duplicates later ) 

656 if version < 4: 

657 for atom1, atom1_bonds in enumerate(deepcopy(tracker)): # for _, [] in enum [[]] 

658 for atom2 in atom1_bonds: # for _ in [] 

659 tracker[atom2].append(atom1) 

660 # find every angle and add it 

661 for atom1, atom1_bonds in enumerate(tracker): 

662 for atom2 in atom1_bonds: 

663 for atom3 in tracker[atom2]: 

664 if atom3 > atom1: # each angle will only be counter once 

665 if order: 

666 sort13 = sorted([(atom_names[atom1][0], atom1), (atom_names[atom3][0], atom3)], key=lambda x: x[0]) 

667 names = tuple((sort13[0][0], atom_names[atom2][0], sort13[1][0])) 

668 

669 angle_types.append(names) 

670 angle_idxs.append([sort13[0][1], atom2, sort13[1][1]]) 

671 else: 

672 angle_types.append((atom_names[atom1][0], atom_names[atom2][0], 

673 atom_names[atom3][0])) 

674 angle_idxs.append([atom1, atom2, atom3]) 

675 if atom3 != atom1: 

676 for atom4 in tracker[atom3]: 

677 if atom4 > atom1 and atom2 != atom4: # each torsion will be counter once 

678 # torsions are done based on the 2 3 atoms, so sort 23 

679 if order: 

680 sort23 = sorted([(atom_names[atom2][0], atom2, atom_names[atom1][0], atom1), 

681 (atom_names[atom3][0], atom3, atom_names[atom4][0], atom4)], key=lambda x: x[0]) 

682 names = tuple((sort23[0][2], sort23[0][0], sort23[1][0], sort23[1][2])) 

683 torsion_types.append(names) 

684 torsion_idxs.append([sort23[0][3], sort23[0][1], sort23[1][1], sort23[1][3]]) 

685 else: 

686 torsion_types.append((atom_names[atom1][0], atom_names[atom2][0], 

687 atom_names[atom3][0], atom_names[atom4][0])) 

688 torsion_idxs.append([atom1, atom2, atom3, atom4]) 

689 bond_14_idxs.append([atom1, atom4]) 

690 # currently have bond_types, angle_types, and torsion_typs + idxs 

691 bond_idxs = np.array(bond_idxs) 

692 angle_idxs = np.array(angle_idxs) 

693 torsion_idxs = np.array(torsion_idxs) 

694 bond_max_conv = (bond_idxs.max(axis=1) - bond_idxs.min(axis=1)).max() + 1 

695 if bond_max_conv < 3 and fix_slice_method: 

696 bond_max_conv = 3 

697 angle_max_conv = (angle_idxs.max(axis=1) - angle_idxs.min(axis=1)).max() + 1 

698 if angle_max_conv < 5 and fix_slice_method: 

699 angle_max_conv = 5 

700 torsion_max_conv = (torsion_idxs.max(axis=1) - torsion_idxs.min(axis=1)).max() + 1 

701 if torsion_max_conv < 7 and fix_slice_method: 

702 torsion_max_conv = 7 

703 # there is a problem where i accidentally index [padding-3] so if (len -4) < 3 we index -1 which breaks things 

704 # it shouldn't affect anything to say the max conv is greater than 6 

705 # this little bit just turns the 'types' list into equivalent parameters 

706 # key error if you don't have the parameter 

707 bond_para = np.array([[bond_equil[bond], bond_force[bond]] if bond in bond_equil 

708 else [bond_equil[(bond[1], bond[0])], bond_force[(bond[1], bond[0])]] 

709 for bond in bond_types]) 

710 angle_para = np.array([[angle_equil[angle], angle_force[angle]] if angle in angle_equil 

711 else [angle_equil[(angle[2], angle[1], angle[0])], angle_force[(angle[2], angle[1], angle[0])]] 

712 for angle in angle_types]) 

713 torsion_para = [] 

714 t_unique = list(set(torsion_types)) 

715 t_unique_para = {} 

716 max_para = 0 

717 for torsion in t_unique: 

718 torsion_b = (torsion[3], torsion[2], torsion[1], torsion[0]) 

719 torsion_xx = ('X', torsion[2], torsion[1], 'X') 

720 torsion_xb = ('X', torsion[1], torsion[2], 'X') 

721 if torsion in torsion_barrier: 

722 max_para = max(max_para, len(torsion_barrier[torsion])) 

723 t_unique_para[torsion] = [torsion_factor[torsion], torsion_barrier[torsion], 

724 torsion_phase[torsion], torsion_period[torsion]] 

725 elif torsion_b in torsion_barrier: 

726 max_para = max(max_para, len(torsion_barrier[torsion_b])) 

727 t_unique_para[torsion] = [torsion_factor[torsion_b], torsion_barrier[torsion_b], 

728 torsion_phase[torsion_b], torsion_period[torsion_b]] 

729 elif torsion_xx in torsion_barrier: 

730 max_para = max(max_para, len(torsion_barrier[torsion_xx])) 

731 t_unique_para[torsion] = [torsion_factor[torsion_xx], torsion_barrier[torsion_xx], 

732 torsion_phase[torsion_xx], torsion_period[torsion_xx]] 

733 elif torsion_xb in torsion_barrier: 

734 max_para = max(max_para, len(torsion_barrier[torsion_xb])) 

735 t_unique_para[torsion] = [torsion_factor[torsion_xb], torsion_barrier[torsion_xb], 

736 torsion_phase[torsion_xb], torsion_period[torsion_xb]] 

737 else: 

738 print('ERROR: Torsion %s cannot be found in torsion_barrier and will not be included' % torsion) 

739 torsion_para = np.zeros((len(torsion_types), 4, max_para)) 

740 # we do not want barrier/factor to return nan so set factor to 1 by default 

741 torsion_para[:, 0, :] = 1.0 

742 for i, torsion in enumerate(torsion_types): 

743 para = t_unique_para[torsion] 

744 torsion_para[i, :, :len(para[0])] = para 

745 # make phase positive 

746 if absolute_torsion_period: 

747 torsion_para[:, 3, :] = np.abs(torsion_para[:, 3, :]) 

748 

749 # bonds 

750 

751 bond_masks = np.zeros((bond_max_conv - 1, N - (bond_max_conv - 1) + 2 * (bond_max_conv - 2)), dtype=bool) 

752 bond_conv = (bond_idxs.max(axis=1) - bond_idxs.min(axis=1)) - 1 

753 

754 bond_weights = [] 

755 b_equil = np.zeros(bond_masks.shape) 

756 b_force = np.zeros(bond_masks.shape) 

757 for i in range(bond_max_conv - 1): 

758 weight = [0.0] * bond_max_conv 

759 weight[0] = 1.0 

760 weight[i + 1] = -1.0 

761 bond_weights.append(weight) 

762 mask_index = bond_idxs.min(axis=1)[bond_conv == i] + bond_max_conv - 2 

763 bond_masks[i, mask_index] = True 

764 b_equil[i, mask_index] = bond_para[bond_conv == i, 0] 

765 b_force[i, mask_index] = bond_para[bond_conv == i, 1] 

766 

767 # angles 

768 

769 angle_conv = (angle_idxs - angle_idxs.min(axis=1).reshape(-1, 1)) # relative positions of atoms 

770 angle_conv = np.where((angle_conv[:, 0] < angle_conv[:, 2]).reshape(-1, 1), angle_conv, angle_conv[:, [2, 1, 0]]) # remove mirrors 

771 angle_unique = np.unique(angle_conv, axis=0) # unique 

772 

773 angle_masks = np.zeros((len(angle_unique), N - (angle_max_conv - 1) + 2 * (angle_max_conv - 3)), dtype=bool) 

774 angle_weights = [] 

775 a_equil = np.zeros(angle_masks.shape) 

776 a_force = np.zeros(angle_masks.shape) 

777 for i, angle in enumerate(angle_unique): 

778 weight = [[0.0] * angle_max_conv, [0.0] * angle_max_conv] # 2xsize 

779 weight[0][angle[0]] = 1.0 

780 weight[0][angle[1]] = -1.0 

781 weight[1][angle[1]] = -1.0 

782 weight[1][angle[2]] = 1.0 

783 angle_weights.append(weight) 

784 mask_index = angle_idxs.min(axis=1)[(angle_conv == angle).all(axis=1)] + angle_max_conv - 3 

785 a_equil[i, mask_index] = angle_para[(angle_conv == angle).all(axis=1), 0] 

786 a_force[i, mask_index] = angle_para[(angle_conv == angle).all(axis=1), 1] 

787 angle_masks[i, mask_index] = True 

788 

789 # torsion 

790 

791 torsion_conv = (torsion_idxs - torsion_idxs.min(axis=1).reshape(-1, 1)) # relative positions of atoms 

792 torsion_conv = np.where((torsion_conv[:, 0] < torsion_conv[:, 3]).reshape(-1, 1), torsion_conv, torsion_conv[:, [3, 2, 1, 0]]) # remove mirrors 

793 torsion_unique = np.unique(torsion_conv, axis=0) # unique 

794 

795 torsion_masks = np.zeros((len(torsion_unique), N - (torsion_max_conv - 1) + 2 * (torsion_max_conv - 4)), dtype=bool) 

796 torsion_weights = [] 

797 # ts = torsion_masks.shape 

798 t_para = np.zeros((torsion_masks.shape[0], torsion_masks.shape[1], 

799 torsion_para.shape[1], torsion_para.shape[2])) 

800 # we do not want barrier/factor to return nan so set factor to 1 by default 

801 t_para[:, :, 0, :] = 1.0 

802 for i, torsion in enumerate(torsion_unique): 

803 weight = [[0.0] * torsion_max_conv, [0.0] * torsion_max_conv, [0.0] * torsion_max_conv] 

804 weight[0][torsion[0]] = 1.0 # b1 = ri-rj 

805 weight[0][torsion[1]] = -1.0 

806 weight[1][torsion[1]] = 1.0 # b2 = rj-rk 

807 weight[1][torsion[2]] = -1.0 

808 weight[2][torsion[2]] = -1.0 # b3 = rl-rk 

809 weight[2][torsion[3]] = 1.0 

810 torsion_weights.append(weight) 

811 mask_index = torsion_idxs.min(axis=1)[(torsion_conv == torsion).all(axis=1)] + torsion_max_conv - 4 

812 torsion_masks[i, mask_index] = True 

813 t_para[i, mask_index] = torsion_para[(torsion_conv == torsion).all(axis=1)] 

814 

815 if NB == 'matrix': 

816 # cdist is easier to work with than pdist, batch pdist was removed from torch and has not been readded as of writting this 

817 vdw_R = 0.5 * torch.cdist(atom_R.view(-1, 1), -atom_R.view(-1, 1)).triu(diagonal=1) 

818 vdw_e = (atom_e.view(1, -1) * atom_e.view(-1, 1)).triu(diagonal=1).sqrt() 

819 # set 1-2, and 1-3 distances to 0.0 

820 vdw_R[bond_idxs.T] = 0.0 

821 vdw_e[bond_idxs.T] = 0.0 

822 vdw_R[angle_idxs[:, (0, 2)].T] = 0.0 

823 vdw_e[angle_idxs[:, (0, 2)].T] = 0.0 

824 vdw_R[torsion_idxs[:, (0, 3)].T] = 0.0 

825 vdw_e[torsion_idxs[:, (0, 3)].T] = 0.0 

826 

827 e_ = permitivity # permitivity 

828 atom_charges = torch.tensor(atom_charges) 

829 q1q2 = (atom_charges.view(1, -1) * atom_charges.view(-1, 1) / e_).triu(diagonal=1) # Aij=bi*bj 

830 q1q2[bond_idxs.T] = 0.0 

831 q1q2[angle_idxs[:, (0, 2)].T] = 0.0 

832 q1q2[torsion_idxs[:, (0, 3)].T] = 0.0 

833 

834 # 1-4 are should be included but scaled 

835 bond_14_idxs = np.array(bond_14_idxs) 

836 vdw_14R = 0.5 * (atom_R[bond_14_idxs[:, 0]] + atom_R[bond_14_idxs[:, 1]]) 

837 vdw_14e = (atom_e[bond_14_idxs[:, 0]] + atom_e[bond_14_idxs[:, 1]]).sqrt() 

838 q1q2_14 = (atom_charges[bond_14_idxs[:, 0]] * atom_charges[bond_14_idxs[:, 1]]) / e_ 

839 

840 return (bond_masks, b_equil, b_force, bond_weights, 

841 angle_masks, a_equil, a_force, angle_weights, 

842 torsion_masks, t_para, torsion_weights, 

843 vdw_R, vdw_e, vdw_14R, vdw_14e, 

844 q1q2, q1q2_14) 

845 

846 return (bond_masks, b_equil, b_force, bond_weights, 

847 angle_masks, a_equil, a_force, angle_weights, 

848 torsion_masks, t_para, torsion_weights) 

849 

850 

851def get_conv_pad_res(dataset, pdb_atom_names, 

852 absolute_torsion_period=True, 

853 NB=('matrix',)[0], 

854 fix_terminal=True, 

855 fix_charmm_residues=True, 

856 correct_1_4=True, 

857 permitivity=1.0): 

858 ''' 

859 ##INPUTS## 

860 

861 dataset: one frame of a trajectory of shape [3, N] 

862 

863 pdb_atom_names: should be an array of shape [N,2] 

864 pdb_atom_names[:,0] is the pdb_atom_names and 

865 pdb_atom_names[:,1] is the residue names 

866 

867 ''' 

868 

869 if len(dataset.shape) != 3: 

870 raise Exception(f'967 dataset frame here should be of shape [R, M, 3] not {str(dataset.shape)}') 

871 if dataset.shape != pdb_atom_names.shape: 

872 raise Exception('969 dataset.shape != pdb_atom_names.shape') 

873 

874 # get amber parameters 

875 (amber_atoms, atom_mass, atom_polarizability, bond_force, bond_equil, 

876 angle_force, angle_equil, torsion_factor, torsion_barrier, torsion_phase, 

877 torsion_period, improper_factor, improper_barrier, improper_phase, 

878 improper_period, other_parameters) = get_amber_parameters() 

879 

880 R, M, D = dataset.shape # N residues, Max atom per res, dimension D 

881 N = R * M 

882 # [R*M, 3], 3 = atom, res, resid 

883 pdb_atom_names = pdb_atom_names.reshape(-1, 3) 

884 # [R*M, 3], 3 = x, y, z 

885 dataset = dataset.reshape(-1, 3) 

886 

887 if fix_terminal: # fix atoms and residues not in amber parameters 

888 pdb_atom_names[pdb_atom_names[:, 0] == 'OXT', 0] = 'O' 

889 if fix_charmm_residues: 

890 pdb_atom_names[pdb_atom_names[:, 1] == 'HSD', 0] = 'HID' 

891 pdb_atom_names[pdb_atom_names[:, 1] == 'HSE', 0] = 'HIE' 

892 

893 # pdb atoms -> amber atom names and residues 

894 padded_atom_names = np.array([[amber_atoms[res][atom], res, resid] if atom is not None else [atom, res, resid] for atom, res, resid in pdb_atom_names]) 

895 # unpadded_atom_names = [[amber_atoms[res][atom],res, resid] for atom, res, resid in pdb_atom_names if atom is not None] 

896 padded_atom_charges = np.array([other_parameters['charge'][res][atom] if atom is not None else np.nan for atom, res, _ in padded_atom_names]) 

897 if padded_atom_names.shape != dataset.shape: # just a little check 

898 raise Exception('996 padded_atom_names!=dataset.shape') 

899 atom_names = padded_atom_names 

900 atom_charges = padded_atom_charges 

901 print('Determining bonds') 

902 

903 connect = other_parameters['connectivity'] 

904 # connectivity = [[]]*N # careful with mutability 

905 connectivity = [[] for i in range(N)] 

906 current_resid = -9999 

907 current_atoms = [] 

908 for i1, (atom1, res, resid) in enumerate(padded_atom_names): 

909 if atom1 is None: 

910 continue 

911 if resid != current_resid: 

912 current_resid = resid 

913 current_atoms = [] 

914 assert atom1 in connect[res] 

915 for atom2, i2 in current_atoms: 

916 if atom2 in connect[res][atom1] and atom1 in connect[res][atom2]: 

917 connectivity[i1].append(i2) 

918 connectivity[i2].append(i1) 

919# cmat = torch.cdist(dataset, dataset) #[R*M,3 ]-> [R*M, R*M] 

920# #1.643 was max bond distance in MurD test, 2.129 was the smallest nonbonded distance 

921# #can't say what the best solution is but somewhere in the middle will probably be okay 

922# all_bond_mask = (cmat<(1.643+2.1269)/2).triu(diagonal=1) # [R*M,R*M] 

923# bond_idxs = all_bond_mask.nonzero() # [B x 2] 

924# #name_set = set(atom_names[:,0]) 

925# 

926# connectivity = [[] for i in range(N)] # this will keep track of some of the bonds to help work out the angles 

927# for i,j in bond_idxs: 

928# connectivity[i].append(j) 

929# connectivity[j].append(i) 

930# ##################### Angles/1-3 ##################### 

931 print('Determining angles') 

932 

933 bond_idxs_ = [] 

934 angle_idxs = [] 

935 torsion_idxs = [] 

936 bond_14_idxs = [] 

937 

938 bond_para = [] 

939 angle_para = [] 

940 torsion_para_ = [] 

941 

942 for atom1, atom2_list in enumerate(connectivity): 

943 for atom2 in atom2_list: 

944 a1, a2 = atom_names[atom1][0], atom_names[atom2][0] 

945 if atom1 < atom2: # stops any pair of atoms being selected twice 

946 bond_idxs_.append([atom1, atom2]) 

947 for b in [(a1, a2), (a2, a1)]: 

948 if b in bond_equil: 

949 bond_para.append([bond_equil[b], bond_force[b]]) 

950 break # break prevents any bond from beind added twice 

951 else: 

952 raise Exception('No associated bond parameter') 

953 

954 for atom3 in connectivity[atom2]: 

955 a3 = atom_names[atom3][0] 

956 if atom3 > atom1: # each angle will only be counter once 

957 angle_idxs.append([atom1, atom2, atom3]) 

958 for a in [(a1, a2, a3), (a3, a2, a1)]: 

959 if a in angle_equil: 

960 angle_para.append([angle_equil[a], angle_force[a]]) 

961 break 

962 else: 

963 raise Exception('No associated angle parameter') 

964 if atom3 != atom1: # don't go back to same atom 

965 for atom4 in connectivity[atom3]: 

966 if atom4 > atom1 and atom2 != atom4: 

967 torsion_idxs.append([atom1, atom2, atom3, atom4]) 

968 bond_14_idxs.append([atom1, atom4]) 

969 a4 = atom_names[atom4][0] 

970 for t in [(a1, a2, a3, a4), (a4, a3, a2, a1), ('X', a2, a3, 'X'), ('X', a3, a2, 'X')]: 

971 if t in torsion_barrier: 

972 torsion_para_.append(torch.tensor([ 

973 torsion_factor[t], 

974 torsion_barrier[t], 

975 torsion_phase[t], 

976 torsion_period[t]])) 

977 break # each torsion only counter once 

978 else: 

979 raise Exception('No associated torsion parameter') 

980 

981 bond_idxs = torch.as_tensor(bond_idxs_) 

982 angle_idxs = torch.tensor(angle_idxs) 

983 torsion_idxs = torch.tensor(torsion_idxs) 

984 bond_14_idxs = torch.tensor(bond_14_idxs) 

985 bond_para = torch.tensor(bond_para) 

986 angle_para = torch.tensor(angle_para) 

987 max_number_torsion_para = max([tf.shape[1] for tf in torsion_para_]) 

988 torsion_para = torch.zeros(torsion_idxs.shape[0], 4, max_number_torsion_para) 

989 torsion_para[:, 0, :] = 1.0 

990 for i, tf in enumerate(torsion_para_): 

991 torsion_para[i, :, 0:tf.shape[1]] = tf 

992 if absolute_torsion_period: 

993 torsion_para[:, 3, :] = np.abs(torsion_para[:, 3, :]) 

994 

995 # Gather based potential 

996 # currently for data [B, R*M, 3] or [B, 3, N] 

997 aij0 = bond_idxs.reshape(-1, 2, 1).eq(angle_idxs[:, (0, 1)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) 

998 aij1 = bond_idxs.reshape(-1, 2, 1).eq(angle_idxs[:, (1, 0)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) 

999 ajk0 = bond_idxs.reshape(-1, 2, 1).eq(angle_idxs[:, (1, 2)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) 

1000 ajk1 = bond_idxs.reshape(-1, 2, 1).eq(angle_idxs[:, (2, 1)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) 

1001 ij_jk = torch.stack([torch.where((aij0 + aij1).T)[1], torch.where((ajk0 + ajk1).T)[1]]) 

1002 aij_ = aij1.float() - aij0.float() # sign change needed for loss_function equation 

1003 ajk_ = ajk0.float() - ajk1.float() 

1004 angle_mask = torch.stack([aij_.sum(dim=0), ajk_.sum(dim=0)]) 

1005 

1006 # following are [N_bonds, N_torsions] arrays comparing if the ij or jk are the same 

1007 ij0 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (0, 1)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) 

1008 ij1 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (1, 0)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) 

1009 jk0 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (1, 2)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) 

1010 jk1 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (2, 1)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) 

1011 kl0 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (2, 3)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) 

1012 kl1 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (3, 2)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) 

1013 ij_jk_kl = torch.stack([torch.where((ij0 + ij1).T)[1], 

1014 torch.where((jk0 + jk1).T)[1], 

1015 torch.where((kl0 + kl1).T)[1]]) 

1016 ij_ = ij0.float() - ij1.float() 

1017 jk_ = jk0.float() - jk1.float() 

1018 kl_ = kl0.float() - kl1.float() 

1019 torsion_mask = torch.stack([ij_.sum(dim=0), jk_.sum(dim=0), kl_.sum(dim=0)]) 

1020 

1021 # j-i i->j 

1022 # i-j j->i reverse 

1023 # k-j j->k 

1024 # j-k k->j reverse 

1025 # l-k k->l 

1026 # k-l l->k reverse 

1027 

1028 if NB == 'matrix': 

1029 equiv_t = other_parameters['equivalences'] 

1030 vdw_para = other_parameters['vdw_potential_well_depth'] 

1031 # switch these around so that values point to key 

1032 equiv = {} 

1033 for i in equiv_t.keys(): 

1034 j = equiv_t[i] 

1035 for k in j: 

1036 equiv[k] = i 

1037 atom_R = torch.tensor([vdw_para[equiv.get(i, i)][0] if i is not None else np.nan for i, j, k in atom_names]) # radius 

1038 atom_e = torch.tensor([vdw_para[equiv.get(i, i)][1] if i is not None else np.nan for i, j, k in atom_names]) # welldepth 

1039 # cdist is easier to work with than pdist, batch pdist doesn't seem to exist too 

1040 vdw_R = 0.5 * torch.cdist(atom_R.view(-1, 1), -atom_R.view(-1, 1)).triu(diagonal=1) 

1041 vdw_e = (atom_e.view(1, -1) * atom_e.view(-1, 1)).triu(diagonal=1).sqrt() 

1042 # set 1-2, and 1-3 distances to 0.0 

1043 vdw_R[list(bond_idxs.T)] = 0.0 

1044 vdw_e[list(bond_idxs.T)] = 0.0 

1045 vdw_R[list(angle_idxs[:, (0, 2)].T)] = 0.0 

1046 vdw_e[list(angle_idxs[:, (0, 2)].T)] = 0.0 

1047 if correct_1_4: 

1048 # sum A/R**12 - B/R**6; A = e* (R**12); B = 2*e *(R**6) 

1049 # therefore scale vdw by setting e /= 2.0 

1050 # vdw_R[list(torsion_idxs[:,(0,3)].T)]/=2.0 

1051 vdw_e[list(torsion_idxs[:, (0, 3)].T)] /= 2.0 

1052 else: 

1053 vdw_R[list(torsion_idxs[:, (0, 3)].T)] = 0.0 

1054 vdw_e[list(torsion_idxs[:, (0, 3)].T)] = 0.0 

1055 vdw_R[torch.isnan(vdw_R)] = 0.0 

1056 vdw_e[torch.isnan(vdw_e)] = 0.0 

1057 

1058 # partial charges are given as fragments of electron charge. 

1059 # Can convert coulomb energy into kcal/mol by multiplying with 332.05. 

1060 # therofore multiply q by sqrt(332.05)=18.22 

1061 e_ = permitivity # permittivity 

1062 atom_charges = torch.tensor(atom_charges) 

1063 q1q2 = (atom_charges.view(1, -1) * atom_charges.view(-1, 1) / e_).triu(diagonal=1) # Aij=bi*bj 

1064 q1q2[list(bond_idxs.T)] = 0.0 

1065 q1q2[list(angle_idxs[:, (0, 2)].T)] = 0.0 

1066 if correct_1_4: 

1067 q1q2[list(torsion_idxs[:, (0, 3)].T)] /= 1.2 

1068 else: 

1069 q1q2[list(torsion_idxs[:, (0, 3)].T)] = 0.0 

1070 # 1-4 are should be included but scaled 

1071 return (bond_idxs, bond_para, 

1072 angle_idxs, angle_para, angle_mask, ij_jk, 

1073 torsion_idxs, torsion_para, torsion_mask, ij_jk_kl, 

1074 vdw_R, vdw_e, 

1075 q1q2) 

1076 return (bond_idxs, bond_para, 

1077 angle_idxs, angle_para, angle_mask, ij_jk, 

1078 torsion_idxs, torsion_para, torsion_mask, ij_jk_kl) 

1079 

1080 

1081if __name__ == '__main__': 

1082 

1083 import sys 

1084 sys.path.insert(0, os.path.abspath('../'))