Coverage for biobb_pytorch / mdae / loss / utils / torch_protein_energy_utils.py: 3%
756 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1# Copyright (c) 2021 Venkata K. Ramaswamy, Samuel C. Musson, Chris G. Willcocks, Matteo T. Degiacomi
2#
3# Molearn is free software ;
4# you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation ;
5# either version 2 of the License, or (at your option) any later version.
6# molearn is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY ;
7# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
8# See the GNU General Public License for more details.
9# You should have received a copy of the GNU General Public License along with molearn ;
10# if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
12import numpy as np
13import torch
14from copy import deepcopy
15import os
16from importlib.resources import files
19def read_lib_file(file_name, amber_atoms, atom_charge, connectivity):
20 try:
21 f_location = files('mdae').joinpath('loss/parameters')
22 f_in = open(f'{str(f_location)}/{file_name}')
23 print(f'File {f_location}/{file_name} opened')
24 except Exception:
25 raise Exception(f'ERROR: file {file_name} not found!')
27 lines = f_in.readlines()
28 depth = 0
29 indexs = {}
30 for tline in lines:
31 if tline.split() == ['!!index', 'array', 'str']:
32 depth += 1
33 for line in lines[depth:]:
34 if line[0] != ' ':
35 break
36 contents = line.split()
37 if len(contents) != 1 and len(contents[0]) != 5:
38 break
39 res = contents[0]
40 if res[0] == '"' and res[-1] == '"':
41 amber_atoms[res[1:-1]] = {}
42 atom_charge[res[1:-1]] = {}
43 indexs[res[1:-1]] = {}
44 connectivity[res[1:-1]] = {}
45 else:
46 msg = 'I was expecting something of the form "XXX" but got %s instead' % res
47 raise Exception(msg)
48 depth += 1
49 break
50 depth += 1
52 for i, tline in enumerate(lines):
53 entry, res, unit_atoms, unit_connectivity = tline[0:7], tline[7:10], tline[10:22], tline[10:29]
54 if entry == '!entry.' and unit_atoms == '.unit.atoms ':
55 depth = i + 1
56 for line in lines[depth:]:
57 if line[0] != ' ':
58 break
59 contents = line.split()
60 if len(contents) < 3 and len(contents[0]) > 4 and len(contents[1]) > 4:
61 break
62 pdb_name, amber_name, _, _, _, index, element_number, charge = contents
63 pdb_quoted = pdb_name[0] == '"' and pdb_name[-1] == '"'
64 amber_quoted = amber_name[0] == '"' and amber_name[-1] == '"'
65 if pdb_quoted and amber_quoted:
66 amber_atoms[res][contents[0][1:-1]] = contents[1][1:-1]
67 atom_charge[res][amber_name[1:-1]] = float(charge)
68 # indexs[res][amber_name[1:-1]] = int(index)
69 indexs[res][int(index)] = pdb_name[1:-1]
70 connectivity[res][pdb_name[1:-1]] = []
71 else:
72 msg = 'I was expecting something of the form "XXX" but got %s instead' % res
73 raise Exception(msg)
74 elif entry == '!entry.' and unit_connectivity == '.unit.connectivity ':
75 depth = i + 1
76 for line in lines[depth:]:
77 if line[0] != ' ':
78 break
79 contents = line.split()
80 if len(contents) != 3:
81 break
82 a1, a2, flag = contents
83 connectivity[res][indexs[res][int(a1)]].append(indexs[res][int(a2)])
84 connectivity[res][indexs[res][int(a2)]].append(indexs[res][int(a1)])
87def get_amber_parameters(order=False, radians=True):
89 file_names = ('amino12.lib',
90 'parm10.dat',
91 'frcmod.ff14SB')
93 # amber19 is dangerous because they've replaced parameters with cmap
94 # pdb atom names to amber atom names using amino19.lib
95 amber_atoms = {} # knowledge[res][pdb_atom] = amber_atom
96 atom_mass = {}
97 atom_polarizability = {}
98 bond_force = {}
99 bond_equil = {}
100 angle_force = {}
101 angle_equil = {}
102 torsion_factor = {}
103 torsion_barrier = {}
104 torsion_phase = {}
105 torsion_period = {}
106 improper_factor = {}
107 improper_barrier = {}
108 improper_phase = {}
109 improper_period = {}
111 other_parameters = {}
112 other_parameters['vdw_potential_well_depth'] = {}
113 other_parameters['H_bond_10_12_parameters'] = {}
114 other_parameters['equivalences'] = {}
115 other_parameters['charge'] = {}
116 other_parameters['connectivity'] = {}
117 read_lib_file(file_names[0], amber_atoms, other_parameters['charge'], other_parameters['connectivity'])
119 try:
120 f_location = files('mdae').joinpath('loss/parameters')
121 f_in = open(f'{str(f_location)}/{file_names[1]}')
122 print('File %s opened' % file_names[1])
123 except Exception:
124 raise Exception('ERROR: file %s not found!' % file_names[1])
126 # section 1 title
127 line = f_in.readline()
128 print(line)
130 amber_card_type_2(f_in, atom_mass, atom_polarizability)
131 amber_card_type_3(f_in)
132 amber_card_type_4(f_in, bond_force, bond_equil)
133 amber_card_type_5(f_in, angle_force, angle_equil)
134 amber_card_type_6(f_in, torsion_factor, torsion_barrier, torsion_phase, torsion_period)
135 amber_card_type_7(f_in, improper_factor,
136 improper_barrier, improper_phase, improper_period)
137 amber_card_type_8(f_in, other_parameters)
138 amber_card_type_9(f_in, other_parameters)
139 for line in f_in:
140 if len(line.split()) > 1:
141 if line.split()[1] == 'RE':
142 amber_card_type_10B(f_in, other_parameters)
143 elif line[0:3] == 'END':
144 print('parameters loaded')
145 f_in.close()
147 # pen frcmod file, should be identifcal format but missing any or all cards
148 try:
149 f_location = files('mdae').joinpath('loss/parameters')
150 f_in = open(f'{str(f_location)}/{file_names[2]}')
151 print('File %s opened' % file_names[2])
152 except Exception:
153 raise Exception('ERROR: file %s not found!' % file_names[2])
155 # section 1 title
156 line = f_in.readline()
157 print(line)
159 for line in f_in:
160 if line[:4] == 'MASS':
161 amber_card_type_2(f_in, atom_mass, atom_polarizability)
162 if line[:4] == 'BOND':
163 amber_card_type_4(f_in, bond_force, bond_equil)
164 if line[:4] == 'ANGL':
165 amber_card_type_5(f_in, angle_force, angle_equil)
166 if line[:4] == 'DIHE':
167 amber_card_type_6(f_in, torsion_factor, torsion_barrier, torsion_phase, torsion_period)
168 if line[:4] == 'IMPR':
169 amber_card_type_7(f_in, improper_factor,
170 improper_barrier, improper_phase, improper_period)
171 if line[:4] == 'HBON':
172 amber_card_type_8(f_in, other_parameters)
173 if line[:4] == 'NONB':
174 amber_card_type_10B(f_in, other_parameters)
175 if line[:4] == 'CMAP':
176 print('Yeah, Im not bothering to implement cmap')
177 elif line[0:3] == 'END':
178 print('parameters loaded')
179 f_in.close()
181 if radians:
182 for angle in angle_equil:
183 angle_equil[angle] = np.deg2rad(angle_equil[angle])
184 for torsion in torsion_phase:
185 torsion_phase[torsion] = list(np.deg2rad(torsion_phase[torsion]))
187 return (amber_atoms, atom_mass, atom_polarizability, bond_force, bond_equil,
188 angle_force, angle_equil, torsion_factor, torsion_barrier, torsion_phase,
189 torsion_period, improper_factor, improper_barrier, improper_phase,
190 improper_period, other_parameters)
193def amber_card_type_2(f_in, atom_mass, atom_polarizability):
195 # section 2 input for atom symbols and masses
196 for line in f_in:
197 if line == '\n' or line.strip() == '':
198 break
199 atom = line[0:2].strip()
200 contents = line[2:24].split()
201 if len(contents) == 2:
202 mass, polarizability = float(contents[0]), float(contents[1])
203 atom_mass[atom] = mass
204 atom_polarizability[atom] = polarizability
205 elif len(contents) == 1: # sometimes a polarizability is not listed
206 mass = float(contents[0])
207 atom_mass[atom] = mass
208 atom_polarizability[atom] = polarizability
209 else:
210 raise Exception('Should be 2A, X, F10.2, F10.2, comments but got %s' % line)
213def amber_card_type_3(f_in):
214 # section 3 input for atom symbols that are hydrophilic
215 f_in.readline()
218def amber_card_type_4(f_in, bond_force, bond_equil, order=False):
219 # section 4 bond length paramters
220 for line in f_in:
221 if line == '\n' or line.strip() == '':
222 break
223 atom1 = line[0:2]. strip()
224 atom2 = line[3:5].strip()
225 if order:
226 bond = tuple(sorted((atom1, atom2))) # put in alphabetical order
227 else:
228 bond = (atom1, atom2)
229 contents = line[5:25].split()
230 if len(contents) != 2:
231 raise Exception('Expected 2 floats but got %s' % line[6:26])
232 force_constant, equil_length = float(contents[0]), float(contents[1])
233 bond_force[bond] = force_constant
234 bond_equil[bond] = equil_length
235 # this should throw an error if there are not
238def amber_card_type_5(f_in, angle_force, angle_equil, order=False):
239 # section 5
240 for line in f_in:
241 if line == '\n' or line.strip() == '':
242 break
243 atom1 = line[0:2].strip()
244 atom2 = line[3:5].strip()
245 atom3 = line[6:8].strip()
246 if order:
247 sorted13 = sorted((atom1, atom3))
248 angle = (sorted13[0], atom2, sorted13[1])
249 # I want it sorted alphabetically by 1-3 atoms
250 else:
251 angle = (atom1, atom2, atom3)
252 contents = line[8:28].split()
253 if len(contents) != 2:
254 raise Exception('Expected 2 floats but got %s' % line[6:26])
255 force_constant, equil_angle = float(contents[0]), float(contents[1])
256 angle_force[angle] = force_constant
257 angle_equil[angle] = equil_angle
260def amber_card_type_6(f_in, torsion_factor, torsion_barrier, torsion_phase,
261 torsion_period, order=False):
262 # section 6 torsion / proper dihedral
263 for line in f_in:
264 if line == '\n' or line.strip() == '':
265 break
266 atom1 = line[0:2].strip()
267 atom2 = line[3:5].strip()
268 atom3 = line[6:8].strip()
269 atom4 = line[9:11].strip()
270 if order:
271 sort23 = sorted([(atom2, atom1), (atom3, atom4)], key=lambda x: x[0])
272 torsion = tuple((sort23[0][1], sort23[0][0], sort23[1][0], sort23[1][1]))
273 else:
274 torsion = (atom1, atom2, atom3, atom4)
275 contents = line[11:55].split()
276 if len(contents) != 4:
277 raise Exception('I wanted four values here?')
278 # the actual torsion potential is (barrier/factor)*(1+cos(period*phi-phase))
279 if torsion in torsion_period:
280 if torsion_period[torsion][-1] > 0:
281 torsion_factor[torsion] = [int(contents[0])]
282 torsion_barrier[torsion] = [float(contents[1])]
283 torsion_phase[torsion] = [float(contents[2])]
284 torsion_period[torsion] = [float(contents[3])]
285 elif torsion_period[torsion][-1] < 0:
286 torsion_factor[torsion].append(int(contents[0]))
287 torsion_barrier[torsion].append(float(contents[1]))
288 torsion_phase[torsion].append(float(contents[2]))
289 torsion_period[torsion].append(float(contents[3]))
290 else:
291 torsion_factor[torsion] = [int(contents[0])]
292 torsion_barrier[torsion] = [float(contents[1])]
293 torsion_phase[torsion] = [float(contents[2])]
294 torsion_period[torsion] = [float(contents[3])]
297def amber_card_type_7(f_in, improper_factor, improper_barrier,
298 improper_phase, improper_period, order=False):
299 # section 7 improper dihedrals
300 for line in f_in:
301 if line == '\n' or line.strip() == '':
302 break
303 atom1 = line[0:2].strip()
304 atom2 = line[3:5].strip()
305 atom3 = line[6:8].strip()
306 atom4 = line[9:11].strip()
307 if order:
308 sort23 = sorted([(atom2, atom1), (atom3, atom4)], key=lambda x: x[0])
309 torsion = tuple((sort23[0][1], sort23[0][0], sort23[1][0], sort23[1][1]))
310 else:
311 torsion = (atom1, atom2, atom3, atom4)
312 contents = line[11:55].split()
313 if len(contents) == 3:
314 improper_barrier[torsion] = float(contents[0])
315 improper_phase[torsion] = float(contents[1])
316 improper_period[torsion] = float(contents[2])
317 elif len(contents) == 4:
318 raise Exception('This seems allowed in the doc but doesnt appear in reality')
319 improper_factor[torsion] = int(contents[0])
320 improper_barrier[torsion] = float(contents[1])
321 improper_phase[torsion] = float(contents[2])
322 improper_period[torsion] = float(contents[3])
323 # the actual torsion potential is (barrier/factor)*(1+cos(period*phi-phase))
324 # it seems improper potential don't divide by the factor
327def amber_card_type_8(f_in, other_parameters, order=False):
328 # section 8 H-bond 10-12 potential parameters
329 for line in f_in:
330 if line == '\n' or line.strip() == '':
331 break
332 atom1 = line[2:4].strip()
333 atom2 = line[6:8].strip()
334 if order:
335 pair = tuple(sorted((atom1, atom2)))
336 else:
337 pair = (atom1, atom2)
338 contents = line[8:].split()
339 other_parameters['H_bond_10_12_parameters'][pair] = contents
342def amber_card_type_9(f_in, other_parameters):
343 # section 9 equivalencing atom symbols for non-bonded 6-12 potential parameters
344 for line in f_in:
345 if line == '\n' or line.strip() == '':
346 break
347 contents = line.split()
348 other_parameters['equivalences'][contents[0]] = contents
351def amber_card_type_10B(f_in, other_parameters):
352 # section 10 6-12 potential parameters
353 for line in f_in:
354 if line == '\n' or line.strip() == '':
355 break
356 contents = line.split()
357 other_parameters['vdw_potential_well_depth'][contents[0]] = [float(i) for i in contents[1:3]]
360def get_convolutions(dataset, pdb_atom_names,
361 atom_label=('set', 'string')[0],
362 perform_checks=True,
363 v=5,
364 order=False,
365 return_type=['mask', 'idxs'][1],
366 absolute_torsion_period=True,
367 NB=('matrix',)[0],
368 fix_terminal=True,
369 fix_charmm_residues=True,
370 fix_slice_method=False,
371 fix_h=False,
372 alt_vdw=[],
373 permitivity=1.0
374 ):
375 '''
376 ##INPUTS##
378 dataset: one frame of a trajectory of shape [3, N]
380 pdb_atom_names: should be an array of shape [N,2]
381 pdb_atom_names[:,0] is the pdb_atom_names and
382 pdb_atom_names[:,1] is the residue names
384 atom_label: (default, 'set') deprecated and broken for anything other than 'set'
386 perform_checks: No longer works so has been removed
388 v: (default, 5) atom_selection version, bonds are determined by interatomic
389 distance, with v=2. v=1 shouldn't be used except in specific cirumstances. v=5 is using connectivity from amber parameters.
391 order: (bool, default false) are atoms ordered, I think I've fixed this so that
392 it shouldn't matter either way but keep as False.
394 return_type: (option now removed)
397 ##OUTPUTS##
398 convolution output shape N* will be N-(conv length -1)+padding
400 bond_masks, b_equil, b_force: shape [number of convolutions, N*]
402 bond_weights: shape[number of convolutions, conv_size]
404 angle_masks, a_equil, a_force: shape [number of convolutions, conv_size]
406 angle_weights: shape[number of convolutions, 2, conv_size]
408 torsion_masks: shape[number of convolution, 3, conv_size]
410 t_para: shape[num of convs, N*, 4, max number torsion parameters ]
412 tornsion_weigths: shape [number of convolutions, 3, conv_size]
414 '''
416 # get amber parameters
417 (amber_atoms, atom_mass, atom_polarizability, bond_force, bond_equil,
418 angle_force, angle_equil, torsion_factor, torsion_barrier, torsion_phase,
419 torsion_period, improper_factor, improper_barrier, improper_phase,
420 improper_period, other_parameters) = get_amber_parameters()
421 if fix_terminal:
422 pdb_atom_names[pdb_atom_names[:, 0] == 'OXT', 0] = 'O'
423 if fix_charmm_residues:
424 pdb_atom_names[pdb_atom_names[:, 1] == 'HSD', 1] = 'HID'
425 pdb_atom_names[pdb_atom_names[:, 1] == 'HSE', 1] = 'HIE'
426 for i in np.unique(pdb_atom_names[:, 2]):
427 res_mask = pdb_atom_names[:, 2] == i
428 if (pdb_atom_names[res_mask, 1] == 'HIS').all(): # if a HIS residue
429 if (pdb_atom_names[res_mask, 0] == 'HD1').any() and (pdb_atom_names[res_mask, 0] == 'HE2').any():
430 pdb_atom_names[res_mask, 1] = 'HIP'
431 elif (pdb_atom_names[res_mask, 0] == 'HD1').any():
432 pdb_atom_names[res_mask, 1] = 'HID'
433 elif (pdb_atom_names[res_mask, 0] == 'HE2').any():
434 pdb_atom_names[res_mask, 1] = 'HIE'
435 # if any HIS are remaining it does not matter which because the H is dealt with above
436 pdb_atom_names[pdb_atom_names[:, 1] == 'HIS', 1] = 'HIE'
437 if fix_h:
438 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'MET'), 0] = 'HB3'
439 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'MET'), 0] = 'HG3'
440 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'ASN'), 0] = 'HB3'
441 pdb_atom_names[pdb_atom_names[:, 0] == 'HN', 0] = 'H'
442 pdb_atom_names[pdb_atom_names[:, 0] == '1HD2', 0] = 'HD21'
443 pdb_atom_names[pdb_atom_names[:, 0] == '2HD2', 0] = 'HD22'
444 pdb_atom_names[pdb_atom_names[:, 0] == '1HG2', 0] = 'HG21'
445 pdb_atom_names[pdb_atom_names[:, 0] == '2HG2', 0] = 'HG22'
446 pdb_atom_names[pdb_atom_names[:, 0] == '3HG2', 0] = 'HG23'
447 pdb_atom_names[pdb_atom_names[:, 0] == '3HG1', 0] = 'HG13'
448 pdb_atom_names[pdb_atom_names[:, 0] == '1HG1', 0] = 'HG11'
449 pdb_atom_names[pdb_atom_names[:, 0] == '2HG1', 0] = 'HG12'
450 pdb_atom_names[pdb_atom_names[:, 0] == '1HD1', 0] = 'HD11'
451 pdb_atom_names[pdb_atom_names[:, 0] == '2HD1', 0] = 'HD12'
452 pdb_atom_names[pdb_atom_names[:, 0] == '3HD1', 0] = 'HD13'
453 pdb_atom_names[pdb_atom_names[:, 0] == '3HD2', 0] = 'HD23'
454 pdb_atom_names[pdb_atom_names[:, 0] == '1HH1', 0] = 'HH11'
455 pdb_atom_names[pdb_atom_names[:, 0] == '2HH1', 0] = 'HH12'
456 pdb_atom_names[pdb_atom_names[:, 0] == '1HH2', 0] = 'HH21'
457 pdb_atom_names[pdb_atom_names[:, 0] == '2HH2', 0] = 'HH22'
458 pdb_atom_names[pdb_atom_names[:, 0] == '1HE2', 0] = 'HE21'
459 pdb_atom_names[pdb_atom_names[:, 0] == '2HE2', 0] = 'HE22'
460 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG11', pdb_atom_names[:, 1] == 'ILE'), 0] = 'HG13'
461 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'CD', pdb_atom_names[:, 1] == 'ILE'), 0] = 'CD1'
462 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD1', pdb_atom_names[:, 1] == 'ILE'), 0] = 'HD11'
463 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD2', pdb_atom_names[:, 1] == 'ILE'), 0] = 'HD12'
464 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD3', pdb_atom_names[:, 1] == 'ILE'), 0] = 'HD13'
465 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'PHE'), 0] = 'HB3'
466 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'GLU'), 0] = 'HB3'
467 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'GLU'), 0] = 'HG3'
468 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'LEU'), 0] = 'HB3'
469 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'ARG'), 0] = 'HB3'
470 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'ARG'), 0] = 'HG3'
471 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD1', pdb_atom_names[:, 1] == 'ARG'), 0] = 'HD3'
472 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'ASP'), 0] = 'HB3'
473 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HA1', pdb_atom_names[:, 1] == 'GLY'), 0] = 'HA3'
474 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'LYS'), 0] = 'HB3'
475 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'LYS'), 0] = 'HG3'
476 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD1', pdb_atom_names[:, 1] == 'LYS'), 0] = 'HD3'
477 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HE1', pdb_atom_names[:, 1] == 'LYS'), 0] = 'HE3'
478 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'TYR'), 0] = 'HB3'
479 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'HIP'), 0] = 'HB3'
480 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'SER'), 0] = 'HB3'
481 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'SER'), 0] = 'HG'
482 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'PRO'), 0] = 'HB3'
483 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'PRO'), 0] = 'HG3'
484 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD1', pdb_atom_names[:, 1] == 'PRO'), 0] = 'HD3'
485 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'LEU'), 0] = 'HB3'
486 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'GLN'), 0] = 'HB3'
487 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1] == 'GLN'), 0] = 'HG3'
488 pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1] == 'TRP'), 0] = 'HB3'
489 # writes termini as H because we haven't loaded in termini parameters
490 atom_names = [[amber_atoms[res][atom], res, resid] if atom not in ['H2', 'H3'] else [amber_atoms[res]['H'], res, resid] for atom, res, resid in pdb_atom_names]
491 p_atom_names = [[atom, res, resid] if atom not in ['H2', 'H3'] else ['H', res, resid] for atom, res, resid in pdb_atom_names]
493 # atom_names = [[amber_atoms[res][atom],res] for atom, res, resid in pdb_atom_names ]
494 atom_charges = [other_parameters['charge'][res][atom] for atom, res, resid in atom_names]
495 if NB == 'matrix':
496 equiv_t = other_parameters['equivalences']
497 vdw_para = other_parameters['vdw_potential_well_depth']
498 # switch these around so that values point to key
499 equiv = {}
500 for i in equiv_t.keys():
501 j = equiv_t[i]
502 for k in j:
503 equiv[k] = i
504 atom_R = torch.tensor([vdw_para[equiv.get(atom, atom)][0] for atom, res, resid in atom_names]) # radius
505 atom_e = torch.tensor([vdw_para[equiv.get(atom, atom)][1] for atom, res, resid in atom_names]) # welldepth
507 print('Determining bonds')
508 version = v # method of selecting bonded atoms
509 N = dataset.shape[1] # 145
511 cmat = (torch.nn.functional.pdist((dataset).permute(1, 0))).cpu().numpy()
512 if version == 1:
513 bond_idxs = np.argpartition(cmat, (N - 1, N))
514 # this will work for any non cyclic monomeric protein
515 # that in mind will break if enough proline atoms to make a cycle are selected
516 bond_idxs, u = bond_idxs[:N - 1], bond_idxs[N - 1]
517 if cmat[u] - cmat[bond_idxs[-1]] < 0.25:
518 dist_val = str(cmat[u] - cmat[bond_idxs[-1]])
519 msg = ("WARNING: May not have correctly selected the bonded distances: value %s "
520 "should be roughly between 0.42 and 0.57 (>0.25)" % dist_val)
521 raise Exception(msg) # should be 0.42-0.57
522 version += 1 # try version 2 instead
523 mid = cmat[bond_idxs[-1]] + ((cmat[u] - cmat[bond_idxs[-1]]) / 2) # mid point
524 full_mask = (cmat < mid).astype('int8')
525 if version == 2:
526 full_mask = (cmat < (1.643 + 2.129) / 2).astype('int8')
527 bond_idxs = np.where(full_mask)[0] # for some reason returns tuple with one array
528 if version == 3:
529 if alt_vdw:
530 vdw = torch.tensor(alt_vdw)
531 max_bond_dist = (0.6 * (vdw.view(1, -1) + vdw.view(-1, 1)))
532 cdist = torch.cdist(dataset.T, dataset.T)
533 i, j = np.where((max_bond_dist > cdist).triu(diagonal=1).numpy())
534 remove = np.where(np.abs(j - i) > 30)
535 max_bond_dist[i[remove], j[remove]] = 0.0
536 max_bond_dist = max_bond_dist.numpy()
537 else:
538 max_bond_dist = (0.6 * (atom_R.view(1, -1) + atom_R.view(-1, 1))).cpu().numpy()
539 max_bond_dist = max_bond_dist[np.where(np.triu(np.ones((N, N)), k=1))]
540 full_mask = np.greater(max_bond_dist, cmat)
541 bond_idxs = np.where(full_mask)[0] # for some reason returns tuple with one array
542 if version == 4:
543 # fix_hydrogens = [[atom, res, resid] for atom, res, resid in pdb_atom_names if atom in ['H2', 'H3']]
544 connectivity = other_parameters['connectivity']
545 bond_types = []
546 bond_idxs = []
547 # tracker = [[]]*N doesn't work because of mutability
548 tracker = [[] for i in range(N)]
549 current_resid = -9999
550 current_atoms = []
551 for i1, (atom1, res, resid) in enumerate(p_atom_names):
552 assert atom1 in connectivity[res]
553 for atom2, i2 in current_atoms:
554 if resid != current_resid: # and atom2 == 'C':
555 if atom2 != 'C':
556 continue
557 elif not (atom2 in connectivity[res][atom1] and atom1 in connectivity[res][atom2]):
558 continue
559 # if not (atom2 in connectivity[res][atom1] and atom1 in connectivity[res][atom2]):
560 # if resid != current_resid and atom2 == 'C':
561 # current_resid = resid
562 # current_atoms = []
563 # else:
564 # continue
565 if atom1 == 'N' and atom2 == 'CA':
566 continue
567 tracker[i1].append(i2)
568 tracker[i2].append(i1)
569 if atom_label == 'set':
570 if order:
571 names = tuple(sorted((atom_names[i2][0], atom_names[i1][0])))
572 else:
573 names = tuple((atom_names[i2][0], atom_names[i1][0]))
574 bond_types.append(names)
575 bond_idxs.append([i2, i1])
576 if resid != current_resid: # and atom2 == 'C':
577 current_resid = resid
578 current_atoms = []
579 current_atoms.append([atom1, i1])
580 if version == 5:
581 connectivity = other_parameters['connectivity']
582 bond_types = []
583 bond_idxs = []
584 tracker = [[] for i in range(N)]
585 current_resid = -9999
586 current_atoms = []
587 previous_atoms = []
588 for i1, (atom1, res, resid) in enumerate(p_atom_names):
589 assert atom1 in connectivity[res]
590 if resid != current_resid:
591 previous_atoms = deepcopy(current_atoms)
592 current_atoms = []
593 current_resid = resid
594 if atom1 == 'N':
595 for atom2, i2 in previous_atoms:
596 if atom2 == 'C':
597 tracker[i1].append(i2)
598 tracker[i2].append(i1)
599 bond_types.append(tuple((atom_names[i2][0], atom_names[i1][0])))
600 bond_idxs.append([i2, i1])
602 for atom2, i2 in current_atoms:
603 if atom2 in connectivity[res][atom1] and atom1 in connectivity[res][atom2]:
604 tracker[i1].append(i2)
605 tracker[i2].append(i1)
606 names = tuple((atom_names[i2][0], atom_names[i1][0]))
607 bond_types.append(names)
608 bond_idxs.append([i2, i1])
609 current_atoms.append([atom1, i1])
610 if version < 4:
611 all_bond_idxs = np.sort(bond_idxs)
613 bond_types = []
614 bond_idxs = []
615 tracker = [[]] # this will keep track of some of the bonds to help work out the angles
616 atom1 = 0
617 atom2 = 1
618 counter = 0 # index of the distance N,N+1
619 for bond in all_bond_idxs:
620 if bond < counter + (N - atom1 - 1):
621 atom2 = atom1 + bond - counter + 1 # 0-0+1
622 tracker[-1].append(atom2)
623 while bond > counter + (N - atom1 - 2):
624 counter += (N - atom1 - 1)
625 atom1 += 1
626 tracker.append([])
627 if bond < counter + (N - atom1 - 1):
628 atom2 = atom1 + bond - counter + 1
629 tracker[-1].append(atom2)
630 if atom_label == 'string': # string of atom labels, doesn't handle Proline alternate ordering
631 bond_types.append(atom_names[atom1][0] + '_' + atom_names[atom2][0])
632 bond_idxs.append([atom1, atom2])
633 elif atom_label == 'set': # set of atom labels
634 if order:
635 names = tuple(sorted((atom_names[atom1][0], atom_names[atom2][0])))
636 else:
637 names = (atom_names[atom1][0], atom_names[atom2][0])
638 bond_types.append(names)
639 bond_idxs.append([atom1, atom2])
641 while len(tracker) < N:
642 tracker.append([]) # ensure so the next bit doesn't break by indexing N-1
644 # Angles/1-3
645 print('Determining angles')
646 angle_types = []
647 angle_idxs = []
649 torsion_types = []
650 torsion_idxs = []
652 bond_14_idxs = []
654 counter = 0
655 # add missing bonds (each bond counted twice after but atom3>atom1 prevents duplicates later )
656 if version < 4:
657 for atom1, atom1_bonds in enumerate(deepcopy(tracker)): # for _, [] in enum [[]]
658 for atom2 in atom1_bonds: # for _ in []
659 tracker[atom2].append(atom1)
660 # find every angle and add it
661 for atom1, atom1_bonds in enumerate(tracker):
662 for atom2 in atom1_bonds:
663 for atom3 in tracker[atom2]:
664 if atom3 > atom1: # each angle will only be counter once
665 if order:
666 sort13 = sorted([(atom_names[atom1][0], atom1), (atom_names[atom3][0], atom3)], key=lambda x: x[0])
667 names = tuple((sort13[0][0], atom_names[atom2][0], sort13[1][0]))
669 angle_types.append(names)
670 angle_idxs.append([sort13[0][1], atom2, sort13[1][1]])
671 else:
672 angle_types.append((atom_names[atom1][0], atom_names[atom2][0],
673 atom_names[atom3][0]))
674 angle_idxs.append([atom1, atom2, atom3])
675 if atom3 != atom1:
676 for atom4 in tracker[atom3]:
677 if atom4 > atom1 and atom2 != atom4: # each torsion will be counter once
678 # torsions are done based on the 2 3 atoms, so sort 23
679 if order:
680 sort23 = sorted([(atom_names[atom2][0], atom2, atom_names[atom1][0], atom1),
681 (atom_names[atom3][0], atom3, atom_names[atom4][0], atom4)], key=lambda x: x[0])
682 names = tuple((sort23[0][2], sort23[0][0], sort23[1][0], sort23[1][2]))
683 torsion_types.append(names)
684 torsion_idxs.append([sort23[0][3], sort23[0][1], sort23[1][1], sort23[1][3]])
685 else:
686 torsion_types.append((atom_names[atom1][0], atom_names[atom2][0],
687 atom_names[atom3][0], atom_names[atom4][0]))
688 torsion_idxs.append([atom1, atom2, atom3, atom4])
689 bond_14_idxs.append([atom1, atom4])
690 # currently have bond_types, angle_types, and torsion_typs + idxs
691 bond_idxs = np.array(bond_idxs)
692 angle_idxs = np.array(angle_idxs)
693 torsion_idxs = np.array(torsion_idxs)
694 bond_max_conv = (bond_idxs.max(axis=1) - bond_idxs.min(axis=1)).max() + 1
695 if bond_max_conv < 3 and fix_slice_method:
696 bond_max_conv = 3
697 angle_max_conv = (angle_idxs.max(axis=1) - angle_idxs.min(axis=1)).max() + 1
698 if angle_max_conv < 5 and fix_slice_method:
699 angle_max_conv = 5
700 torsion_max_conv = (torsion_idxs.max(axis=1) - torsion_idxs.min(axis=1)).max() + 1
701 if torsion_max_conv < 7 and fix_slice_method:
702 torsion_max_conv = 7
703 # there is a problem where i accidentally index [padding-3] so if (len -4) < 3 we index -1 which breaks things
704 # it shouldn't affect anything to say the max conv is greater than 6
705 # this little bit just turns the 'types' list into equivalent parameters
706 # key error if you don't have the parameter
707 bond_para = np.array([[bond_equil[bond], bond_force[bond]] if bond in bond_equil
708 else [bond_equil[(bond[1], bond[0])], bond_force[(bond[1], bond[0])]]
709 for bond in bond_types])
710 angle_para = np.array([[angle_equil[angle], angle_force[angle]] if angle in angle_equil
711 else [angle_equil[(angle[2], angle[1], angle[0])], angle_force[(angle[2], angle[1], angle[0])]]
712 for angle in angle_types])
713 torsion_para = []
714 t_unique = list(set(torsion_types))
715 t_unique_para = {}
716 max_para = 0
717 for torsion in t_unique:
718 torsion_b = (torsion[3], torsion[2], torsion[1], torsion[0])
719 torsion_xx = ('X', torsion[2], torsion[1], 'X')
720 torsion_xb = ('X', torsion[1], torsion[2], 'X')
721 if torsion in torsion_barrier:
722 max_para = max(max_para, len(torsion_barrier[torsion]))
723 t_unique_para[torsion] = [torsion_factor[torsion], torsion_barrier[torsion],
724 torsion_phase[torsion], torsion_period[torsion]]
725 elif torsion_b in torsion_barrier:
726 max_para = max(max_para, len(torsion_barrier[torsion_b]))
727 t_unique_para[torsion] = [torsion_factor[torsion_b], torsion_barrier[torsion_b],
728 torsion_phase[torsion_b], torsion_period[torsion_b]]
729 elif torsion_xx in torsion_barrier:
730 max_para = max(max_para, len(torsion_barrier[torsion_xx]))
731 t_unique_para[torsion] = [torsion_factor[torsion_xx], torsion_barrier[torsion_xx],
732 torsion_phase[torsion_xx], torsion_period[torsion_xx]]
733 elif torsion_xb in torsion_barrier:
734 max_para = max(max_para, len(torsion_barrier[torsion_xb]))
735 t_unique_para[torsion] = [torsion_factor[torsion_xb], torsion_barrier[torsion_xb],
736 torsion_phase[torsion_xb], torsion_period[torsion_xb]]
737 else:
738 print('ERROR: Torsion %s cannot be found in torsion_barrier and will not be included' % torsion)
739 torsion_para = np.zeros((len(torsion_types), 4, max_para))
740 # we do not want barrier/factor to return nan so set factor to 1 by default
741 torsion_para[:, 0, :] = 1.0
742 for i, torsion in enumerate(torsion_types):
743 para = t_unique_para[torsion]
744 torsion_para[i, :, :len(para[0])] = para
745 # make phase positive
746 if absolute_torsion_period:
747 torsion_para[:, 3, :] = np.abs(torsion_para[:, 3, :])
749 # bonds
751 bond_masks = np.zeros((bond_max_conv - 1, N - (bond_max_conv - 1) + 2 * (bond_max_conv - 2)), dtype=bool)
752 bond_conv = (bond_idxs.max(axis=1) - bond_idxs.min(axis=1)) - 1
754 bond_weights = []
755 b_equil = np.zeros(bond_masks.shape)
756 b_force = np.zeros(bond_masks.shape)
757 for i in range(bond_max_conv - 1):
758 weight = [0.0] * bond_max_conv
759 weight[0] = 1.0
760 weight[i + 1] = -1.0
761 bond_weights.append(weight)
762 mask_index = bond_idxs.min(axis=1)[bond_conv == i] + bond_max_conv - 2
763 bond_masks[i, mask_index] = True
764 b_equil[i, mask_index] = bond_para[bond_conv == i, 0]
765 b_force[i, mask_index] = bond_para[bond_conv == i, 1]
767 # angles
769 angle_conv = (angle_idxs - angle_idxs.min(axis=1).reshape(-1, 1)) # relative positions of atoms
770 angle_conv = np.where((angle_conv[:, 0] < angle_conv[:, 2]).reshape(-1, 1), angle_conv, angle_conv[:, [2, 1, 0]]) # remove mirrors
771 angle_unique = np.unique(angle_conv, axis=0) # unique
773 angle_masks = np.zeros((len(angle_unique), N - (angle_max_conv - 1) + 2 * (angle_max_conv - 3)), dtype=bool)
774 angle_weights = []
775 a_equil = np.zeros(angle_masks.shape)
776 a_force = np.zeros(angle_masks.shape)
777 for i, angle in enumerate(angle_unique):
778 weight = [[0.0] * angle_max_conv, [0.0] * angle_max_conv] # 2xsize
779 weight[0][angle[0]] = 1.0
780 weight[0][angle[1]] = -1.0
781 weight[1][angle[1]] = -1.0
782 weight[1][angle[2]] = 1.0
783 angle_weights.append(weight)
784 mask_index = angle_idxs.min(axis=1)[(angle_conv == angle).all(axis=1)] + angle_max_conv - 3
785 a_equil[i, mask_index] = angle_para[(angle_conv == angle).all(axis=1), 0]
786 a_force[i, mask_index] = angle_para[(angle_conv == angle).all(axis=1), 1]
787 angle_masks[i, mask_index] = True
789 # torsion
791 torsion_conv = (torsion_idxs - torsion_idxs.min(axis=1).reshape(-1, 1)) # relative positions of atoms
792 torsion_conv = np.where((torsion_conv[:, 0] < torsion_conv[:, 3]).reshape(-1, 1), torsion_conv, torsion_conv[:, [3, 2, 1, 0]]) # remove mirrors
793 torsion_unique = np.unique(torsion_conv, axis=0) # unique
795 torsion_masks = np.zeros((len(torsion_unique), N - (torsion_max_conv - 1) + 2 * (torsion_max_conv - 4)), dtype=bool)
796 torsion_weights = []
797 # ts = torsion_masks.shape
798 t_para = np.zeros((torsion_masks.shape[0], torsion_masks.shape[1],
799 torsion_para.shape[1], torsion_para.shape[2]))
800 # we do not want barrier/factor to return nan so set factor to 1 by default
801 t_para[:, :, 0, :] = 1.0
802 for i, torsion in enumerate(torsion_unique):
803 weight = [[0.0] * torsion_max_conv, [0.0] * torsion_max_conv, [0.0] * torsion_max_conv]
804 weight[0][torsion[0]] = 1.0 # b1 = ri-rj
805 weight[0][torsion[1]] = -1.0
806 weight[1][torsion[1]] = 1.0 # b2 = rj-rk
807 weight[1][torsion[2]] = -1.0
808 weight[2][torsion[2]] = -1.0 # b3 = rl-rk
809 weight[2][torsion[3]] = 1.0
810 torsion_weights.append(weight)
811 mask_index = torsion_idxs.min(axis=1)[(torsion_conv == torsion).all(axis=1)] + torsion_max_conv - 4
812 torsion_masks[i, mask_index] = True
813 t_para[i, mask_index] = torsion_para[(torsion_conv == torsion).all(axis=1)]
815 if NB == 'matrix':
816 # cdist is easier to work with than pdist, batch pdist was removed from torch and has not been readded as of writting this
817 vdw_R = 0.5 * torch.cdist(atom_R.view(-1, 1), -atom_R.view(-1, 1)).triu(diagonal=1)
818 vdw_e = (atom_e.view(1, -1) * atom_e.view(-1, 1)).triu(diagonal=1).sqrt()
819 # set 1-2, and 1-3 distances to 0.0
820 vdw_R[bond_idxs.T] = 0.0
821 vdw_e[bond_idxs.T] = 0.0
822 vdw_R[angle_idxs[:, (0, 2)].T] = 0.0
823 vdw_e[angle_idxs[:, (0, 2)].T] = 0.0
824 vdw_R[torsion_idxs[:, (0, 3)].T] = 0.0
825 vdw_e[torsion_idxs[:, (0, 3)].T] = 0.0
827 e_ = permitivity # permitivity
828 atom_charges = torch.tensor(atom_charges)
829 q1q2 = (atom_charges.view(1, -1) * atom_charges.view(-1, 1) / e_).triu(diagonal=1) # Aij=bi*bj
830 q1q2[bond_idxs.T] = 0.0
831 q1q2[angle_idxs[:, (0, 2)].T] = 0.0
832 q1q2[torsion_idxs[:, (0, 3)].T] = 0.0
834 # 1-4 are should be included but scaled
835 bond_14_idxs = np.array(bond_14_idxs)
836 vdw_14R = 0.5 * (atom_R[bond_14_idxs[:, 0]] + atom_R[bond_14_idxs[:, 1]])
837 vdw_14e = (atom_e[bond_14_idxs[:, 0]] + atom_e[bond_14_idxs[:, 1]]).sqrt()
838 q1q2_14 = (atom_charges[bond_14_idxs[:, 0]] * atom_charges[bond_14_idxs[:, 1]]) / e_
840 return (bond_masks, b_equil, b_force, bond_weights,
841 angle_masks, a_equil, a_force, angle_weights,
842 torsion_masks, t_para, torsion_weights,
843 vdw_R, vdw_e, vdw_14R, vdw_14e,
844 q1q2, q1q2_14)
846 return (bond_masks, b_equil, b_force, bond_weights,
847 angle_masks, a_equil, a_force, angle_weights,
848 torsion_masks, t_para, torsion_weights)
851def get_conv_pad_res(dataset, pdb_atom_names,
852 absolute_torsion_period=True,
853 NB=('matrix',)[0],
854 fix_terminal=True,
855 fix_charmm_residues=True,
856 correct_1_4=True,
857 permitivity=1.0):
858 '''
859 ##INPUTS##
861 dataset: one frame of a trajectory of shape [3, N]
863 pdb_atom_names: should be an array of shape [N,2]
864 pdb_atom_names[:,0] is the pdb_atom_names and
865 pdb_atom_names[:,1] is the residue names
867 '''
869 if len(dataset.shape) != 3:
870 raise Exception(f'967 dataset frame here should be of shape [R, M, 3] not {str(dataset.shape)}')
871 if dataset.shape != pdb_atom_names.shape:
872 raise Exception('969 dataset.shape != pdb_atom_names.shape')
874 # get amber parameters
875 (amber_atoms, atom_mass, atom_polarizability, bond_force, bond_equil,
876 angle_force, angle_equil, torsion_factor, torsion_barrier, torsion_phase,
877 torsion_period, improper_factor, improper_barrier, improper_phase,
878 improper_period, other_parameters) = get_amber_parameters()
880 R, M, D = dataset.shape # N residues, Max atom per res, dimension D
881 N = R * M
882 # [R*M, 3], 3 = atom, res, resid
883 pdb_atom_names = pdb_atom_names.reshape(-1, 3)
884 # [R*M, 3], 3 = x, y, z
885 dataset = dataset.reshape(-1, 3)
887 if fix_terminal: # fix atoms and residues not in amber parameters
888 pdb_atom_names[pdb_atom_names[:, 0] == 'OXT', 0] = 'O'
889 if fix_charmm_residues:
890 pdb_atom_names[pdb_atom_names[:, 1] == 'HSD', 0] = 'HID'
891 pdb_atom_names[pdb_atom_names[:, 1] == 'HSE', 0] = 'HIE'
893 # pdb atoms -> amber atom names and residues
894 padded_atom_names = np.array([[amber_atoms[res][atom], res, resid] if atom is not None else [atom, res, resid] for atom, res, resid in pdb_atom_names])
895 # unpadded_atom_names = [[amber_atoms[res][atom],res, resid] for atom, res, resid in pdb_atom_names if atom is not None]
896 padded_atom_charges = np.array([other_parameters['charge'][res][atom] if atom is not None else np.nan for atom, res, _ in padded_atom_names])
897 if padded_atom_names.shape != dataset.shape: # just a little check
898 raise Exception('996 padded_atom_names!=dataset.shape')
899 atom_names = padded_atom_names
900 atom_charges = padded_atom_charges
901 print('Determining bonds')
903 connect = other_parameters['connectivity']
904 # connectivity = [[]]*N # careful with mutability
905 connectivity = [[] for i in range(N)]
906 current_resid = -9999
907 current_atoms = []
908 for i1, (atom1, res, resid) in enumerate(padded_atom_names):
909 if atom1 is None:
910 continue
911 if resid != current_resid:
912 current_resid = resid
913 current_atoms = []
914 assert atom1 in connect[res]
915 for atom2, i2 in current_atoms:
916 if atom2 in connect[res][atom1] and atom1 in connect[res][atom2]:
917 connectivity[i1].append(i2)
918 connectivity[i2].append(i1)
919# cmat = torch.cdist(dataset, dataset) #[R*M,3 ]-> [R*M, R*M]
920# #1.643 was max bond distance in MurD test, 2.129 was the smallest nonbonded distance
921# #can't say what the best solution is but somewhere in the middle will probably be okay
922# all_bond_mask = (cmat<(1.643+2.1269)/2).triu(diagonal=1) # [R*M,R*M]
923# bond_idxs = all_bond_mask.nonzero() # [B x 2]
924# #name_set = set(atom_names[:,0])
925#
926# connectivity = [[] for i in range(N)] # this will keep track of some of the bonds to help work out the angles
927# for i,j in bond_idxs:
928# connectivity[i].append(j)
929# connectivity[j].append(i)
930# ##################### Angles/1-3 #####################
931 print('Determining angles')
933 bond_idxs_ = []
934 angle_idxs = []
935 torsion_idxs = []
936 bond_14_idxs = []
938 bond_para = []
939 angle_para = []
940 torsion_para_ = []
942 for atom1, atom2_list in enumerate(connectivity):
943 for atom2 in atom2_list:
944 a1, a2 = atom_names[atom1][0], atom_names[atom2][0]
945 if atom1 < atom2: # stops any pair of atoms being selected twice
946 bond_idxs_.append([atom1, atom2])
947 for b in [(a1, a2), (a2, a1)]:
948 if b in bond_equil:
949 bond_para.append([bond_equil[b], bond_force[b]])
950 break # break prevents any bond from beind added twice
951 else:
952 raise Exception('No associated bond parameter')
954 for atom3 in connectivity[atom2]:
955 a3 = atom_names[atom3][0]
956 if atom3 > atom1: # each angle will only be counter once
957 angle_idxs.append([atom1, atom2, atom3])
958 for a in [(a1, a2, a3), (a3, a2, a1)]:
959 if a in angle_equil:
960 angle_para.append([angle_equil[a], angle_force[a]])
961 break
962 else:
963 raise Exception('No associated angle parameter')
964 if atom3 != atom1: # don't go back to same atom
965 for atom4 in connectivity[atom3]:
966 if atom4 > atom1 and atom2 != atom4:
967 torsion_idxs.append([atom1, atom2, atom3, atom4])
968 bond_14_idxs.append([atom1, atom4])
969 a4 = atom_names[atom4][0]
970 for t in [(a1, a2, a3, a4), (a4, a3, a2, a1), ('X', a2, a3, 'X'), ('X', a3, a2, 'X')]:
971 if t in torsion_barrier:
972 torsion_para_.append(torch.tensor([
973 torsion_factor[t],
974 torsion_barrier[t],
975 torsion_phase[t],
976 torsion_period[t]]))
977 break # each torsion only counter once
978 else:
979 raise Exception('No associated torsion parameter')
981 bond_idxs = torch.as_tensor(bond_idxs_)
982 angle_idxs = torch.tensor(angle_idxs)
983 torsion_idxs = torch.tensor(torsion_idxs)
984 bond_14_idxs = torch.tensor(bond_14_idxs)
985 bond_para = torch.tensor(bond_para)
986 angle_para = torch.tensor(angle_para)
987 max_number_torsion_para = max([tf.shape[1] for tf in torsion_para_])
988 torsion_para = torch.zeros(torsion_idxs.shape[0], 4, max_number_torsion_para)
989 torsion_para[:, 0, :] = 1.0
990 for i, tf in enumerate(torsion_para_):
991 torsion_para[i, :, 0:tf.shape[1]] = tf
992 if absolute_torsion_period:
993 torsion_para[:, 3, :] = np.abs(torsion_para[:, 3, :])
995 # Gather based potential
996 # currently for data [B, R*M, 3] or [B, 3, N]
997 aij0 = bond_idxs.reshape(-1, 2, 1).eq(angle_idxs[:, (0, 1)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1)
998 aij1 = bond_idxs.reshape(-1, 2, 1).eq(angle_idxs[:, (1, 0)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1)
999 ajk0 = bond_idxs.reshape(-1, 2, 1).eq(angle_idxs[:, (1, 2)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1)
1000 ajk1 = bond_idxs.reshape(-1, 2, 1).eq(angle_idxs[:, (2, 1)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1)
1001 ij_jk = torch.stack([torch.where((aij0 + aij1).T)[1], torch.where((ajk0 + ajk1).T)[1]])
1002 aij_ = aij1.float() - aij0.float() # sign change needed for loss_function equation
1003 ajk_ = ajk0.float() - ajk1.float()
1004 angle_mask = torch.stack([aij_.sum(dim=0), ajk_.sum(dim=0)])
1006 # following are [N_bonds, N_torsions] arrays comparing if the ij or jk are the same
1007 ij0 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (0, 1)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1)
1008 ij1 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (1, 0)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1)
1009 jk0 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (1, 2)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1)
1010 jk1 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (2, 1)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1)
1011 kl0 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (2, 3)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1)
1012 kl1 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (3, 2)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1)
1013 ij_jk_kl = torch.stack([torch.where((ij0 + ij1).T)[1],
1014 torch.where((jk0 + jk1).T)[1],
1015 torch.where((kl0 + kl1).T)[1]])
1016 ij_ = ij0.float() - ij1.float()
1017 jk_ = jk0.float() - jk1.float()
1018 kl_ = kl0.float() - kl1.float()
1019 torsion_mask = torch.stack([ij_.sum(dim=0), jk_.sum(dim=0), kl_.sum(dim=0)])
1021 # j-i i->j
1022 # i-j j->i reverse
1023 # k-j j->k
1024 # j-k k->j reverse
1025 # l-k k->l
1026 # k-l l->k reverse
1028 if NB == 'matrix':
1029 equiv_t = other_parameters['equivalences']
1030 vdw_para = other_parameters['vdw_potential_well_depth']
1031 # switch these around so that values point to key
1032 equiv = {}
1033 for i in equiv_t.keys():
1034 j = equiv_t[i]
1035 for k in j:
1036 equiv[k] = i
1037 atom_R = torch.tensor([vdw_para[equiv.get(i, i)][0] if i is not None else np.nan for i, j, k in atom_names]) # radius
1038 atom_e = torch.tensor([vdw_para[equiv.get(i, i)][1] if i is not None else np.nan for i, j, k in atom_names]) # welldepth
1039 # cdist is easier to work with than pdist, batch pdist doesn't seem to exist too
1040 vdw_R = 0.5 * torch.cdist(atom_R.view(-1, 1), -atom_R.view(-1, 1)).triu(diagonal=1)
1041 vdw_e = (atom_e.view(1, -1) * atom_e.view(-1, 1)).triu(diagonal=1).sqrt()
1042 # set 1-2, and 1-3 distances to 0.0
1043 vdw_R[list(bond_idxs.T)] = 0.0
1044 vdw_e[list(bond_idxs.T)] = 0.0
1045 vdw_R[list(angle_idxs[:, (0, 2)].T)] = 0.0
1046 vdw_e[list(angle_idxs[:, (0, 2)].T)] = 0.0
1047 if correct_1_4:
1048 # sum A/R**12 - B/R**6; A = e* (R**12); B = 2*e *(R**6)
1049 # therefore scale vdw by setting e /= 2.0
1050 # vdw_R[list(torsion_idxs[:,(0,3)].T)]/=2.0
1051 vdw_e[list(torsion_idxs[:, (0, 3)].T)] /= 2.0
1052 else:
1053 vdw_R[list(torsion_idxs[:, (0, 3)].T)] = 0.0
1054 vdw_e[list(torsion_idxs[:, (0, 3)].T)] = 0.0
1055 vdw_R[torch.isnan(vdw_R)] = 0.0
1056 vdw_e[torch.isnan(vdw_e)] = 0.0
1058 # partial charges are given as fragments of electron charge.
1059 # Can convert coulomb energy into kcal/mol by multiplying with 332.05.
1060 # therofore multiply q by sqrt(332.05)=18.22
1061 e_ = permitivity # permittivity
1062 atom_charges = torch.tensor(atom_charges)
1063 q1q2 = (atom_charges.view(1, -1) * atom_charges.view(-1, 1) / e_).triu(diagonal=1) # Aij=bi*bj
1064 q1q2[list(bond_idxs.T)] = 0.0
1065 q1q2[list(angle_idxs[:, (0, 2)].T)] = 0.0
1066 if correct_1_4:
1067 q1q2[list(torsion_idxs[:, (0, 3)].T)] /= 1.2
1068 else:
1069 q1q2[list(torsion_idxs[:, (0, 3)].T)] = 0.0
1070 # 1-4 are should be included but scaled
1071 return (bond_idxs, bond_para,
1072 angle_idxs, angle_para, angle_mask, ij_jk,
1073 torsion_idxs, torsion_para, torsion_mask, ij_jk_kl,
1074 vdw_R, vdw_e,
1075 q1q2)
1076 return (bond_idxs, bond_para,
1077 angle_idxs, angle_para, angle_mask, ij_jk,
1078 torsion_idxs, torsion_para, torsion_mask, ij_jk_kl)
1081if __name__ == '__main__':
1083 import sys
1084 sys.path.insert(0, os.path.abspath('../'))