Coverage for biobb_pytorch / mdae / loss / utils / openmm_thread.py: 0%
315 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-02 16:33 +0000
1import os
2import pandas as pd
3from openmm import Platform
4from openmm.app import ForceField, PDBFile, Simulation
5from openmm.app import element as elem
6import openmm
7from openmm.app.forcefield import _createResidueSignature
8from openmm.app.internal import compiled
9from torchexposedintegratorplugin import TorchExposedIntegrator
11import torch
12import numpy as np
13from copy import deepcopy
15soft_xml_script = """\
16<ForceField>
17 <Script>
18import openmm as mm
19nb = mm.CustomNonbondedForce('C/((r/0.2)^4+1)')
20nb.addGlobalParameter('C', 1.0)
21sys.addForce(nb)
22for i in range(sys.getNumParticles()):
23 nb.addParticle([])
24exclusions = set()
25for bond in data.bonds:
26 exclusions.add((min(bond.atom1, bond.atom2), max(bond.atom1, bond.atom2)))
27for angle in data.angles:
28 exclusions.add((min(angle[0], angle[2]), max(angle[0], angle[2])))
29for a1, a2 in exclusions:
30 nb.addExclusion(a1, a2)
31 </Script>
32</ForceField>
33"""
36class ModifiedForceField(ForceField):
38 def __init__(self, *args, alternative_residue_names=None, **kwargs):
39 '''
40 Takes all `*args` and `**kwargs` of `openmm.app.ForceField`, plus an optional parameter described here.
42 :param dict alternative_residue_names: aliases for resnames, e.g., `{'HIS':'HIE'}`.
43 '''
44 super().__init__(*args, **kwargs)
45 if isinstance(alternative_residue_names, dict):
46 self._alternative_residue_names = alternative_residue_names
47 else:
48 self._alternative_residue_names = {'HIS': 'HIE'}
50 def _getResidueTemplateMatches(self, res, bondedToAtom, templateSignatures=None, ignoreExternalBonds=False, ignoreExtraParticles=False):
51 """
52 Return the templates that match a residue, or None if none are found.
54 :param res: Topology.Residue, the residue for which template matches are to be retrieved.
55 :param bondedToAtom: list of set of int, bondedToAtom[i] is the set of atoms bonded to atom index i
56 :returns: list with two elements [template, matches].
57 _TemplateData is the matching forcefield residue template, or None if no matches are found.
58 matches is a list specifying which atom of the template each atom of the residue corresponds to,
59 or None if it does not match the template.
60 """
61 template = None
62 matches = None
63 for matcher in self._templateMatchers:
64 template = matcher(self, res, bondedToAtom, ignoreExternalBonds, ignoreExtraParticles)
65 if template is not None:
66 match = compiled.matchResidueToTemplate(res, template, bondedToAtom, ignoreExternalBonds, ignoreExtraParticles)
67 if match is None:
68 raise ValueError('A custom template matcher returned a template for residue %d (%s), but it does not match the residue.' % (res.index, res.name))
69 return [template, match]
70 if templateSignatures is None:
71 templateSignatures = self._templateSignatures
72 signature = _createResidueSignature([atom.element for atom in res.atoms()])
73 if signature in templateSignatures:
74 allMatches = []
75 for t in templateSignatures[signature]:
76 match = compiled.matchResidueToTemplate(res, t, bondedToAtom, ignoreExternalBonds, ignoreExtraParticles)
77 if match is not None:
78 allMatches.append((t, match))
79 if len(allMatches) == 1:
80 template = allMatches[0][0]
81 matches = allMatches[0][1]
82 elif len(allMatches) > 1:
83 for i, (t, m) in enumerate(allMatches):
84 name = self._alternative_residue_names.get(res.name, res.name)
85 if name == t.name.split('-')[0]:
86 template = t
87 matches = m
88 return [template, matches]
89 elif 'N' + name == t.name.split('-')[0]:
90 print(f'{str(res)}, {res.index}, is a being set as a N terminal residue')
91 template = t
92 matches = m
93 return [template, matches]
94 elif 'C' + name == t.name.split('-')[0]:
95 print(f'{str(res)} is a being set as a C terminal residue')
96 template = t
97 matches = m
98 return [template, matches]
99 print(f'multiple for {t.name}')
100 # We found multiple matches. This is OK if and only if they assign identical types and parameters to all atoms.
101 t1, m1 = allMatches[0]
103 for t2, m2 in allMatches[1:]:
104 if not t1.areParametersIdentical(t2, m1, m2):
105 raise Exception('Multiple non-identical matching templates found for residue %d (%s): %s.' % (res.index + 1, res.name, ', '.join(match[0].name for match in allMatches)))
106 template = allMatches[0][0]
107 matches = allMatches[0][1]
108 return [template, matches]
111class OpenmmPluginScore():
112 '''
113 This will use the new OpenMM Plugin to calculate forces and energy. The intention is that this will be fast enough to be able to calculate forces and energy during training.
114 N.B.: The current torchintegratorplugin only supports float on GPU and double on CPU.
115 '''
117 def __init__(self, mol=None, xml_file=['amber14-all.xml'], platform='CUDA', remove_NB=False,
118 alternative_residue_names=dict(HIS='HIE', HSE='HIE'), atoms=['CA', 'C', 'N', 'CB', 'O'],
119 soft=False):
120 '''
121 :param `biobox.Molecule` mol: if pldataloader is not given, then a biobox object will be taken from this parameter. If neither are given then an error will be thrown.
122 :param str xml_file: xml parameter file
123 :param str platform: 'CUDA' or 'Reference'.
124 :param bool remove_NB: if True remove NonbondedForce, CustomGBForce and CMMotionRemover, else just remove CustomGBForce
125 :param dict alternative_residue_names: aliases for resnames, e.g., `{'HIS':'HIE'}`.
126 :param atoms:
127 :param soft:
128 '''
129 self.mol = mol
130 for key, value in alternative_residue_names.items():
131 # self.mol.data.loc[:,'resname'][self.mol.data['resname']==key]=value
132 self.mol.data.loc[self.mol.data['resname'] == key, 'resname'] = value
133 # self.mol.data.loc[lambda df: df['resname']==key, key]=value
134 tmp_file = f'tmp{np.random.randint(1e10)}.pdb'
135 self.atoms = atoms
136 self.mol.write_pdb(tmp_file, split_struc=False)
137 self.pdb = PDBFile(tmp_file)
138 if soft:
139 print('attempting soft forcefield')
140 from pdbfixer import PDBFixer
141 f = PDBFixer(tmp_file)
142 self.forcefield = f._createForceField(self.pdb.topology, False)
143 self.system = self.forcefield.createSystem(self.pdb.topology)
144 else:
145 if isinstance(xml_file, str):
146 self.forcefield = ModifiedForceField(xml_file, alternative_residue_names=alternative_residue_names)
147 elif len(xml_file) > 0:
148 self.forcefield = ModifiedForceField(*xml_file, alternative_residue_names=alternative_residue_names)
149 else:
150 raise ValueError(f'xml_file: {xml_file} needs to be a str or a list of str')
152 if atoms == 'no_hydrogen':
153 self.ignore_hydrogen()
154 else:
155 self.atomselect(atoms)
156 # save pdb and reload in modeller
157 templates, unique_unmatched_residues = self.forcefield.generateTemplatesForUnmatchedResidues(self.pdb.topology)
158 self.system = self.forcefield.createSystem(self.pdb.topology)
159 if remove_NB:
160 forces = self.system.getForces()
161 for idx in reversed(range(len(forces))):
162 force = forces[idx]
163 if isinstance(force, ( # openmm.PeriodicTorsionForce,
164 openmm.CustomGBForce,
165 openmm.NonbondedForce,
166 openmm.CMMotionRemover)):
167 self.system.removeForce(idx)
168 else:
169 forces = self.system.getForces()
170 for idx in reversed(range(len(forces))):
171 force = forces[idx]
172 if isinstance(force, openmm.CustomGBForce):
173 self.system.removeForce(idx)
175 self.integrator = TorchExposedIntegrator()
176 self.platform = Platform.getPlatformByName(platform)
177 self.simulation = Simulation(self.pdb.topology, self.system, self.integrator, self.platform)
178 if platform == 'CUDA':
179 self.platform.setPropertyValue(self.simulation.context, 'Precision', 'single')
180 self.n_particles = self.simulation.context.getSystem().getNumParticles()
181 self.simulation.context.setPositions(self.pdb.positions)
182 self.get_score = self.get_energy
183 print(self.simulation.context.getState(getEnergy=True).getPotentialEnergy()._value)
184 os.remove(tmp_file)
186 def ignore_hydrogen(self):
187 # ignore = ['ASH', 'LYN', 'GLH', 'HID', 'HIP', 'CYM', ]
188 ignore = []
189 for name, template in self.forcefield._templates.items():
190 if name in ignore:
191 continue
192 patchData = ForceField._PatchData(name + '_remove_h', 1)
194 for atom in template.atoms:
195 if atom.element is elem.hydrogen:
196 if atom.name not in patchData.allAtomNames:
197 patchData.allAtomNames.add(atom.name)
198 atomDescription = ForceField._PatchAtomData(atom.name)
199 patchData.deletedAtoms.append(atomDescription)
200 else:
201 raise ValueError()
202 for bond in template.bonds:
203 atom1 = template.atoms[bond[0]]
204 atom2 = template.atoms[bond[1]]
205 if atom1.element is elem.hydrogen or atom2.element is elem.hydrogen:
206 a1 = ForceField._PatchAtomData(atom1.name)
207 a2 = ForceField._PatchAtomData(atom2.name)
208 patchData.deletedBonds.append((a1, a2))
209 self.forcefield.registerTemplatePatch(name, name + '_remove_h', 0)
210 self.forcefield.registerPatch(patchData)
212 def atomselect(self, atoms):
213 atoms = deepcopy(atoms)
214 if 'OT2' in atoms:
215 atoms.append('OXT')
216 if 'OT1' in atoms:
217 atoms.append('OXT')
219 for name, template in self.forcefield._templates.items():
220 patchData = ForceField._PatchData(name + '_leave_only_' + '_'.join(atoms), 1)
222 for atom in template.atoms:
223 if atom.name not in atoms:
224 if atom.name not in patchData.allAtomNames:
225 patchData.allAtomNames.add(atom.name)
226 atomDescription = ForceField._PatchAtomData(atom.name)
227 patchData.deletedAtoms.append(atomDescription)
228 else:
229 raise ValueError()
231 for bond in template.bonds:
232 atom1 = template.atoms[bond[0]]
233 atom2 = template.atoms[bond[1]]
234 if atom1.name not in atoms or atom2.name not in atoms:
235 a1 = ForceField._PatchAtomData(atom1.name)
236 a2 = ForceField._PatchAtomData(atom2.name)
237 patchData.deletedBonds.append((a1, a2))
238 self.forcefield.registerTemplatePatch(name, name + '_leave_only_' + '_'.join(atoms), 0)
239 self.forcefield.registerPatch(patchData)
241 def get_energy(self, pos_ptr, force_ptr, energy_ptr, n_particles, batch_size):
242 '''
243 :param pos_ptr: tensor.data_ptr()
244 :param force_ptr: tensor.data_ptr()
245 :param energy_ptr: tensor.data_ptr()
246 :param int n_particles: number of particles
247 :param int batch_size: batch size
248 '''
249 assert n_particles == self.n_particles
250 torch.cuda.synchronize()
251 self.integrator.torchMultiStructureE(pos_ptr, force_ptr, energy_ptr, n_particles, batch_size)
252 return True
254 def execute(self, x):
255 '''
256 :param `torch.Tensor` x: shape [b, N, 3]. dtype=float. device = gpu
257 '''
258 force = torch.zeros_like(x)
259 energy = torch.zeros(x.shape[0], device=torch.device('cpu'), dtype=torch.double)
260 self.get_energy(x.data_ptr(), force.data_ptr(), energy.data_ptr(), x.shape[1], x.shape[0])
261 return force, energy
264class OpenmmTorchEnergyMinimizer(OpenmmPluginScore):
266 def minimize(self, x, maxIterations=10, threshold=10000):
267 minimized_x = torch.empty_like(x)
268 for i, s in enumerate(x.unsqueeze(1)):
269 h = 0.01
270 force, energy = self.execute(s)
271 abs_max = 1 / (force.abs().max())
272 for j in range(maxIterations):
273 new_s = s - force * abs_max * h
274 new_force, new_energy = self.execute(new_s)
275 if new_energy < energy:
276 s, energy, force = new_s, new_energy, new_force
277 if energy < threshold:
278 break
279 h *= 1.2
281 else:
282 h *= 0.2
283 minimized_x[i] = s
284 return minimized_x
287class OpenMMPluginScoreSoftForceField(OpenmmPluginScore):
289 def __init__(self, mol=None, platform='CUDA', atoms=['CA', 'C', 'N', 'CB', 'O']):
290 self.mol = mol
291 tmp_file = 'tmp.pdb'
292 self.atoms = atoms
293 self.mol.write_pdb(tmp_file, split_struc=False)
294 self.pdb = PDBFile(tmp_file)
295 from pdbfixer import PDBFixer
296 f = PDBFixer(tmp_file)
297 self.forcefield = f._createForceField(self.pdb.topology)
298 self.system = self.forcefield.createSystem(self.pdb.topology)
299 self.integrator = TorchExposedIntegrator()
300 self.platform = Platform.getPlatformByName(platform)
301 self.simulation = Simulation(self.pdb.topology, self.system, self.integrator, self.platform)
302 if platform == 'CUDA':
303 self.platform.setPropertyValue(self.simulation.context, 'Precision', 'single')
304 self.n_particles = self.simulation.context.getSystem().getNumParticles()
305 self.simulation.context.setPositions(self.pdb.positions)
306 self.get_score = self.get_energy
307 print(self.simulation.context.getState(getEnergy=True).getPotentialEnergy()._value)
310class openmm_energy_function(torch.autograd.Function):
312 @staticmethod
313 def forward(ctx, plugin, x):
314 '''
315 :param plugin: OpenmmPluginScore instance
316 :param `torch.Tensor` x: dtype = float, shape = [B, N, 3], device = any
317 :returns: energy tensor, dtype = float, shape = [B], device = any
318 '''
319 if x.device == torch.device('cpu'):
320 force = np.zeros(x.shape)
321 energy = np.zeros(x.shape[0])
322 for i, t in enumerate(x):
323 plugin.simulation.context.setPositions(t.numpy())
324 state = plugin.simulation.context.getState(getForces=True, getEnergy=True)
325 force[i] = state.getForces(asNumpy=True)
326 energy[i] = state.getPotentialEnergy()._value
327 force = torch.tensor(force).float()
328 energy = torch.tensor(energy).float()
329 else:
330 # torch.cuda.synchronize(x.device)
331 force, energy = plugin.execute(x)
332 # torch.cuda.synchronize(x.device)
333 ctx.save_for_backward(force)
334 energy = energy.float().to(x.device)
335 return energy
337 @staticmethod
338 def backward(ctx, grad_output):
339 force = ctx.saved_tensors[0] # force shape [B, N, 3]
340 # embed(header='23 openmm_loss_function')
341 return None, -force * grad_output.view(-1, 1, 1)
344class openmm_clamped_energy_function(torch.autograd.Function):
346 @staticmethod
347 def forward(ctx, plugin, x, clamp):
348 '''
349 :param plugin: OpenmmPluginScore instance
350 :param `torch.Tensor` x: dtype = float, shape = [B, N, 3], device = cuda
351 :returns: energy tensor, dtype = double, shape = [B], device CPU
352 '''
353 if x.device == torch.device('cpu'):
354 force = np.zeros(x.shape)
355 energy = np.zeros(x.shape[0])
356 for i, t in enumerate(x):
357 plugin.simulation.context.setPositions(t.numpy())
358 state = plugin.simulation.context.getState(getForces=True, getEnergy=True)
359 force[i] = state.getForces(asNumpy=True)
360 energy[i] = state.getPotentialEnergy()._value
361 force = torch.tensor(force).float()
362 energy = torch.tensor(energy).float()
363 else:
364 force, energy = plugin.execute(x)
366 force = torch.clamp(force, **clamp)
367 ctx.save_for_backward(force)
368 energy = energy.float().to(x.device)
369 return energy
371 @staticmethod
372 def backward(ctx, grad_output):
373 force = ctx.saved_tensors[0]
374 return None, -force * grad_output.view(-1, 1, 1), None
377class openmm_energy(torch.nn.Module):
379 def __init__(self, mol, std, clamp=None, **kwargs):
380 super().__init__()
381 self.openmmplugin = OpenmmPluginScore(mol, **kwargs)
382 self.std = std / 10
383 self.clamp = clamp
384 if self.clamp is not None:
385 self.forward = self._clamp_forward
386 else:
387 self.forward = self._forward
389 def _forward(self, x):
390 '''
391 :param `torch.Tensor` x: dtype=torch.float, device=CUDA, shape B, 3, N
392 :returns: torch energy tensor dtype should be float and on same device as x
393 '''
394 _x = (x * self.std).permute(0, 2, 1).contiguous()
395 energy = openmm_energy_function.apply(self.openmmplugin, _x)
396 return energy
398 def _clamp_forward(self, x):
399 '''
400 :param `torch.Tensor` x: dtype=torch.float, device=CUDA, shape B, 3, N
401 :returns: torch energy tensor dtype should be float and on same device as x
402 '''
403 _x = (x * self.std).permute(0, 2, 1).contiguous()
404 energy = openmm_clamped_energy_function.apply(self.openmmplugin, _x, self.clamp)
405 return energy
408class openmm_energy_setup():
410 def __init__(self, topology):
412 self.top = topology
414 def get_atominfo(self):
415 """
416 Get the atom information from the topology.
417 """
418 atom_info = []
419 for i in self.top.topology.atoms:
420 atom_info.append([i.name, i.residue.name, i.residue.index + 1])
421 return np.array(atom_info, dtype=object)
423 def mol_dataframe(self):
424 """
425 Get the atom information from the topology and convert it to a pandas DataFrame.
426 """
427 return MolDataFrame.from_stats(self.top)
430class MolDataFrame(pd.DataFrame):
431 """
432 Subclass of pandas.DataFrame specialized for molecular data.
433 Stores atomic information plus coordinates, and can write itself to PDB.
434 """
435 @property
436 def data(self):
437 return self
439 @property
440 def _constructor(self):
441 return MolDataFrame
443 @classmethod
444 def from_trajectory(cls, traj, frame=0):
445 """
446 Build MolDataFrame from an mdtraj.Trajectory `traj`.
447 - `frame`: index of the frame whose coords to use (default 0).
448 Adds x, y, z columns.
449 """
450 # Atomic radii map (Å)
451 radius_dict = {"N": 1.55, "C": 1.70, "O": 1.52,
452 "S": 1.80, "H": 1.20}
454 def idx_to_letters(idx):
455 letters = ''
456 while idx >= 0:
457 letters = chr(idx % 26 + ord('A')) + letters
458 idx = idx // 26 - 1
459 return letters
461 rows = []
462 coords = traj.xyz[frame] # shape (n_atoms,3)
463 for atom in traj.topology.atoms:
464 i = atom.index
465 # occupancy and B-factor
466 occ = getattr(traj, 'occupancies', [[None]])[0][i]
467 beta = getattr(traj, 'bfactors', [[None]])[0][i]
468 occ = 1.0 if occ is None else occ
469 beta = 0.0 if beta is None else beta
471 name = atom.name
472 element = atom.element.symbol if atom.element else None
473 radius = radius_dict.get(name,
474 radius_dict.get(element, None))
475 charge = 0.0
477 # chain as letters
478 cid = idx_to_letters(atom.residue.chain.index)
480 x, y, z = coords[i]
481 rows.append({
482 'atom': 'ATOM',
483 'index': i,
484 'name': name,
485 'resname': atom.residue.name,
486 'chain': cid,
487 'resid': atom.residue.index + 1,
488 'occupancy': occ,
489 'beta': beta,
490 'atomtype': element,
491 'radius': radius,
492 'charge': charge,
493 'x': x,
494 'y': y,
495 'z': z,
496 })
498 cols = ['atom', 'index', 'name', 'resname', 'chain', 'resid',
499 'occupancy', 'beta', 'atomtype', 'radius', 'charge',
500 'x', 'y', 'z']
501 return cls(rows, columns=cols)
503 def write_pdb(self, filename):
504 """
505 Write this MolDataFrame (with x,y,z) to a PDB file.
506 """
507 fmt = (
508 "ATOM {index:5d} {name:^4s}{resname:>3s} {chain:1s}"
509 "{resid:4d} {x:8.3f}{y:8.3f}{z:8.3f}"
510 "{occupancy:6.2f}{beta:6.2f} {atomtype:>2s}\n"
511 )
512 with open(filename, 'w') as f:
513 for row in self.to_dict('records'):
514 f.write(fmt.format(**row))
515 f.write("TER\nEND\n")