Coverage for biobb_pytorch / mdae / loss / utils / openmm_thread.py: 0%

315 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 16:33 +0000

1import os 

2import pandas as pd 

3from openmm import Platform 

4from openmm.app import ForceField, PDBFile, Simulation 

5from openmm.app import element as elem 

6import openmm 

7from openmm.app.forcefield import _createResidueSignature 

8from openmm.app.internal import compiled 

9from torchexposedintegratorplugin import TorchExposedIntegrator 

10 

11import torch 

12import numpy as np 

13from copy import deepcopy 

14 

15soft_xml_script = """\ 

16<ForceField> 

17 <Script> 

18import openmm as mm 

19nb = mm.CustomNonbondedForce('C/((r/0.2)^4+1)') 

20nb.addGlobalParameter('C', 1.0) 

21sys.addForce(nb) 

22for i in range(sys.getNumParticles()): 

23 nb.addParticle([]) 

24exclusions = set() 

25for bond in data.bonds: 

26 exclusions.add((min(bond.atom1, bond.atom2), max(bond.atom1, bond.atom2))) 

27for angle in data.angles: 

28 exclusions.add((min(angle[0], angle[2]), max(angle[0], angle[2]))) 

29for a1, a2 in exclusions: 

30 nb.addExclusion(a1, a2) 

31 </Script> 

32</ForceField> 

33""" 

34 

35 

36class ModifiedForceField(ForceField): 

37 

38 def __init__(self, *args, alternative_residue_names=None, **kwargs): 

39 ''' 

40 Takes all `*args` and `**kwargs` of `openmm.app.ForceField`, plus an optional parameter described here. 

41 

42 :param dict alternative_residue_names: aliases for resnames, e.g., `{'HIS':'HIE'}`. 

43 ''' 

44 super().__init__(*args, **kwargs) 

45 if isinstance(alternative_residue_names, dict): 

46 self._alternative_residue_names = alternative_residue_names 

47 else: 

48 self._alternative_residue_names = {'HIS': 'HIE'} 

49 

50 def _getResidueTemplateMatches(self, res, bondedToAtom, templateSignatures=None, ignoreExternalBonds=False, ignoreExtraParticles=False): 

51 """ 

52 Return the templates that match a residue, or None if none are found. 

53 

54 :param res: Topology.Residue, the residue for which template matches are to be retrieved. 

55 :param bondedToAtom: list of set of int, bondedToAtom[i] is the set of atoms bonded to atom index i 

56 :returns: list with two elements [template, matches]. 

57 _TemplateData is the matching forcefield residue template, or None if no matches are found. 

58 matches is a list specifying which atom of the template each atom of the residue corresponds to, 

59 or None if it does not match the template. 

60 """ 

61 template = None 

62 matches = None 

63 for matcher in self._templateMatchers: 

64 template = matcher(self, res, bondedToAtom, ignoreExternalBonds, ignoreExtraParticles) 

65 if template is not None: 

66 match = compiled.matchResidueToTemplate(res, template, bondedToAtom, ignoreExternalBonds, ignoreExtraParticles) 

67 if match is None: 

68 raise ValueError('A custom template matcher returned a template for residue %d (%s), but it does not match the residue.' % (res.index, res.name)) 

69 return [template, match] 

70 if templateSignatures is None: 

71 templateSignatures = self._templateSignatures 

72 signature = _createResidueSignature([atom.element for atom in res.atoms()]) 

73 if signature in templateSignatures: 

74 allMatches = [] 

75 for t in templateSignatures[signature]: 

76 match = compiled.matchResidueToTemplate(res, t, bondedToAtom, ignoreExternalBonds, ignoreExtraParticles) 

77 if match is not None: 

78 allMatches.append((t, match)) 

79 if len(allMatches) == 1: 

80 template = allMatches[0][0] 

81 matches = allMatches[0][1] 

82 elif len(allMatches) > 1: 

83 for i, (t, m) in enumerate(allMatches): 

84 name = self._alternative_residue_names.get(res.name, res.name) 

85 if name == t.name.split('-')[0]: 

86 template = t 

87 matches = m 

88 return [template, matches] 

89 elif 'N' + name == t.name.split('-')[0]: 

90 print(f'{str(res)}, {res.index}, is a being set as a N terminal residue') 

91 template = t 

92 matches = m 

93 return [template, matches] 

94 elif 'C' + name == t.name.split('-')[0]: 

95 print(f'{str(res)} is a being set as a C terminal residue') 

96 template = t 

97 matches = m 

98 return [template, matches] 

99 print(f'multiple for {t.name}') 

100 # We found multiple matches. This is OK if and only if they assign identical types and parameters to all atoms. 

101 t1, m1 = allMatches[0] 

102 

103 for t2, m2 in allMatches[1:]: 

104 if not t1.areParametersIdentical(t2, m1, m2): 

105 raise Exception('Multiple non-identical matching templates found for residue %d (%s): %s.' % (res.index + 1, res.name, ', '.join(match[0].name for match in allMatches))) 

106 template = allMatches[0][0] 

107 matches = allMatches[0][1] 

108 return [template, matches] 

109 

110 

111class OpenmmPluginScore(): 

112 ''' 

113 This will use the new OpenMM Plugin to calculate forces and energy. The intention is that this will be fast enough to be able to calculate forces and energy during training. 

114 N.B.: The current torchintegratorplugin only supports float on GPU and double on CPU. 

115 ''' 

116 

117 def __init__(self, mol=None, xml_file=['amber14-all.xml'], platform='CUDA', remove_NB=False, 

118 alternative_residue_names=dict(HIS='HIE', HSE='HIE'), atoms=['CA', 'C', 'N', 'CB', 'O'], 

119 soft=False): 

120 ''' 

121 :param `biobox.Molecule` mol: if pldataloader is not given, then a biobox object will be taken from this parameter. If neither are given then an error will be thrown. 

122 :param str xml_file: xml parameter file 

123 :param str platform: 'CUDA' or 'Reference'. 

124 :param bool remove_NB: if True remove NonbondedForce, CustomGBForce and CMMotionRemover, else just remove CustomGBForce 

125 :param dict alternative_residue_names: aliases for resnames, e.g., `{'HIS':'HIE'}`. 

126 :param atoms: 

127 :param soft: 

128 ''' 

129 self.mol = mol 

130 for key, value in alternative_residue_names.items(): 

131 # self.mol.data.loc[:,'resname'][self.mol.data['resname']==key]=value 

132 self.mol.data.loc[self.mol.data['resname'] == key, 'resname'] = value 

133 # self.mol.data.loc[lambda df: df['resname']==key, key]=value 

134 tmp_file = f'tmp{np.random.randint(1e10)}.pdb' 

135 self.atoms = atoms 

136 self.mol.write_pdb(tmp_file, split_struc=False) 

137 self.pdb = PDBFile(tmp_file) 

138 if soft: 

139 print('attempting soft forcefield') 

140 from pdbfixer import PDBFixer 

141 f = PDBFixer(tmp_file) 

142 self.forcefield = f._createForceField(self.pdb.topology, False) 

143 self.system = self.forcefield.createSystem(self.pdb.topology) 

144 else: 

145 if isinstance(xml_file, str): 

146 self.forcefield = ModifiedForceField(xml_file, alternative_residue_names=alternative_residue_names) 

147 elif len(xml_file) > 0: 

148 self.forcefield = ModifiedForceField(*xml_file, alternative_residue_names=alternative_residue_names) 

149 else: 

150 raise ValueError(f'xml_file: {xml_file} needs to be a str or a list of str') 

151 

152 if atoms == 'no_hydrogen': 

153 self.ignore_hydrogen() 

154 else: 

155 self.atomselect(atoms) 

156 # save pdb and reload in modeller 

157 templates, unique_unmatched_residues = self.forcefield.generateTemplatesForUnmatchedResidues(self.pdb.topology) 

158 self.system = self.forcefield.createSystem(self.pdb.topology) 

159 if remove_NB: 

160 forces = self.system.getForces() 

161 for idx in reversed(range(len(forces))): 

162 force = forces[idx] 

163 if isinstance(force, ( # openmm.PeriodicTorsionForce, 

164 openmm.CustomGBForce, 

165 openmm.NonbondedForce, 

166 openmm.CMMotionRemover)): 

167 self.system.removeForce(idx) 

168 else: 

169 forces = self.system.getForces() 

170 for idx in reversed(range(len(forces))): 

171 force = forces[idx] 

172 if isinstance(force, openmm.CustomGBForce): 

173 self.system.removeForce(idx) 

174 

175 self.integrator = TorchExposedIntegrator() 

176 self.platform = Platform.getPlatformByName(platform) 

177 self.simulation = Simulation(self.pdb.topology, self.system, self.integrator, self.platform) 

178 if platform == 'CUDA': 

179 self.platform.setPropertyValue(self.simulation.context, 'Precision', 'single') 

180 self.n_particles = self.simulation.context.getSystem().getNumParticles() 

181 self.simulation.context.setPositions(self.pdb.positions) 

182 self.get_score = self.get_energy 

183 print(self.simulation.context.getState(getEnergy=True).getPotentialEnergy()._value) 

184 os.remove(tmp_file) 

185 

186 def ignore_hydrogen(self): 

187 # ignore = ['ASH', 'LYN', 'GLH', 'HID', 'HIP', 'CYM', ] 

188 ignore = [] 

189 for name, template in self.forcefield._templates.items(): 

190 if name in ignore: 

191 continue 

192 patchData = ForceField._PatchData(name + '_remove_h', 1) 

193 

194 for atom in template.atoms: 

195 if atom.element is elem.hydrogen: 

196 if atom.name not in patchData.allAtomNames: 

197 patchData.allAtomNames.add(atom.name) 

198 atomDescription = ForceField._PatchAtomData(atom.name) 

199 patchData.deletedAtoms.append(atomDescription) 

200 else: 

201 raise ValueError() 

202 for bond in template.bonds: 

203 atom1 = template.atoms[bond[0]] 

204 atom2 = template.atoms[bond[1]] 

205 if atom1.element is elem.hydrogen or atom2.element is elem.hydrogen: 

206 a1 = ForceField._PatchAtomData(atom1.name) 

207 a2 = ForceField._PatchAtomData(atom2.name) 

208 patchData.deletedBonds.append((a1, a2)) 

209 self.forcefield.registerTemplatePatch(name, name + '_remove_h', 0) 

210 self.forcefield.registerPatch(patchData) 

211 

212 def atomselect(self, atoms): 

213 atoms = deepcopy(atoms) 

214 if 'OT2' in atoms: 

215 atoms.append('OXT') 

216 if 'OT1' in atoms: 

217 atoms.append('OXT') 

218 

219 for name, template in self.forcefield._templates.items(): 

220 patchData = ForceField._PatchData(name + '_leave_only_' + '_'.join(atoms), 1) 

221 

222 for atom in template.atoms: 

223 if atom.name not in atoms: 

224 if atom.name not in patchData.allAtomNames: 

225 patchData.allAtomNames.add(atom.name) 

226 atomDescription = ForceField._PatchAtomData(atom.name) 

227 patchData.deletedAtoms.append(atomDescription) 

228 else: 

229 raise ValueError() 

230 

231 for bond in template.bonds: 

232 atom1 = template.atoms[bond[0]] 

233 atom2 = template.atoms[bond[1]] 

234 if atom1.name not in atoms or atom2.name not in atoms: 

235 a1 = ForceField._PatchAtomData(atom1.name) 

236 a2 = ForceField._PatchAtomData(atom2.name) 

237 patchData.deletedBonds.append((a1, a2)) 

238 self.forcefield.registerTemplatePatch(name, name + '_leave_only_' + '_'.join(atoms), 0) 

239 self.forcefield.registerPatch(patchData) 

240 

241 def get_energy(self, pos_ptr, force_ptr, energy_ptr, n_particles, batch_size): 

242 ''' 

243 :param pos_ptr: tensor.data_ptr() 

244 :param force_ptr: tensor.data_ptr() 

245 :param energy_ptr: tensor.data_ptr() 

246 :param int n_particles: number of particles 

247 :param int batch_size: batch size 

248 ''' 

249 assert n_particles == self.n_particles 

250 torch.cuda.synchronize() 

251 self.integrator.torchMultiStructureE(pos_ptr, force_ptr, energy_ptr, n_particles, batch_size) 

252 return True 

253 

254 def execute(self, x): 

255 ''' 

256 :param `torch.Tensor` x: shape [b, N, 3]. dtype=float. device = gpu 

257 ''' 

258 force = torch.zeros_like(x) 

259 energy = torch.zeros(x.shape[0], device=torch.device('cpu'), dtype=torch.double) 

260 self.get_energy(x.data_ptr(), force.data_ptr(), energy.data_ptr(), x.shape[1], x.shape[0]) 

261 return force, energy 

262 

263 

264class OpenmmTorchEnergyMinimizer(OpenmmPluginScore): 

265 

266 def minimize(self, x, maxIterations=10, threshold=10000): 

267 minimized_x = torch.empty_like(x) 

268 for i, s in enumerate(x.unsqueeze(1)): 

269 h = 0.01 

270 force, energy = self.execute(s) 

271 abs_max = 1 / (force.abs().max()) 

272 for j in range(maxIterations): 

273 new_s = s - force * abs_max * h 

274 new_force, new_energy = self.execute(new_s) 

275 if new_energy < energy: 

276 s, energy, force = new_s, new_energy, new_force 

277 if energy < threshold: 

278 break 

279 h *= 1.2 

280 

281 else: 

282 h *= 0.2 

283 minimized_x[i] = s 

284 return minimized_x 

285 

286 

287class OpenMMPluginScoreSoftForceField(OpenmmPluginScore): 

288 

289 def __init__(self, mol=None, platform='CUDA', atoms=['CA', 'C', 'N', 'CB', 'O']): 

290 self.mol = mol 

291 tmp_file = 'tmp.pdb' 

292 self.atoms = atoms 

293 self.mol.write_pdb(tmp_file, split_struc=False) 

294 self.pdb = PDBFile(tmp_file) 

295 from pdbfixer import PDBFixer 

296 f = PDBFixer(tmp_file) 

297 self.forcefield = f._createForceField(self.pdb.topology) 

298 self.system = self.forcefield.createSystem(self.pdb.topology) 

299 self.integrator = TorchExposedIntegrator() 

300 self.platform = Platform.getPlatformByName(platform) 

301 self.simulation = Simulation(self.pdb.topology, self.system, self.integrator, self.platform) 

302 if platform == 'CUDA': 

303 self.platform.setPropertyValue(self.simulation.context, 'Precision', 'single') 

304 self.n_particles = self.simulation.context.getSystem().getNumParticles() 

305 self.simulation.context.setPositions(self.pdb.positions) 

306 self.get_score = self.get_energy 

307 print(self.simulation.context.getState(getEnergy=True).getPotentialEnergy()._value) 

308 

309 

310class openmm_energy_function(torch.autograd.Function): 

311 

312 @staticmethod 

313 def forward(ctx, plugin, x): 

314 ''' 

315 :param plugin: OpenmmPluginScore instance 

316 :param `torch.Tensor` x: dtype = float, shape = [B, N, 3], device = any 

317 :returns: energy tensor, dtype = float, shape = [B], device = any 

318 ''' 

319 if x.device == torch.device('cpu'): 

320 force = np.zeros(x.shape) 

321 energy = np.zeros(x.shape[0]) 

322 for i, t in enumerate(x): 

323 plugin.simulation.context.setPositions(t.numpy()) 

324 state = plugin.simulation.context.getState(getForces=True, getEnergy=True) 

325 force[i] = state.getForces(asNumpy=True) 

326 energy[i] = state.getPotentialEnergy()._value 

327 force = torch.tensor(force).float() 

328 energy = torch.tensor(energy).float() 

329 else: 

330 # torch.cuda.synchronize(x.device) 

331 force, energy = plugin.execute(x) 

332 # torch.cuda.synchronize(x.device) 

333 ctx.save_for_backward(force) 

334 energy = energy.float().to(x.device) 

335 return energy 

336 

337 @staticmethod 

338 def backward(ctx, grad_output): 

339 force = ctx.saved_tensors[0] # force shape [B, N, 3] 

340 # embed(header='23 openmm_loss_function') 

341 return None, -force * grad_output.view(-1, 1, 1) 

342 

343 

344class openmm_clamped_energy_function(torch.autograd.Function): 

345 

346 @staticmethod 

347 def forward(ctx, plugin, x, clamp): 

348 ''' 

349 :param plugin: OpenmmPluginScore instance 

350 :param `torch.Tensor` x: dtype = float, shape = [B, N, 3], device = cuda 

351 :returns: energy tensor, dtype = double, shape = [B], device CPU 

352 ''' 

353 if x.device == torch.device('cpu'): 

354 force = np.zeros(x.shape) 

355 energy = np.zeros(x.shape[0]) 

356 for i, t in enumerate(x): 

357 plugin.simulation.context.setPositions(t.numpy()) 

358 state = plugin.simulation.context.getState(getForces=True, getEnergy=True) 

359 force[i] = state.getForces(asNumpy=True) 

360 energy[i] = state.getPotentialEnergy()._value 

361 force = torch.tensor(force).float() 

362 energy = torch.tensor(energy).float() 

363 else: 

364 force, energy = plugin.execute(x) 

365 

366 force = torch.clamp(force, **clamp) 

367 ctx.save_for_backward(force) 

368 energy = energy.float().to(x.device) 

369 return energy 

370 

371 @staticmethod 

372 def backward(ctx, grad_output): 

373 force = ctx.saved_tensors[0] 

374 return None, -force * grad_output.view(-1, 1, 1), None 

375 

376 

377class openmm_energy(torch.nn.Module): 

378 

379 def __init__(self, mol, std, clamp=None, **kwargs): 

380 super().__init__() 

381 self.openmmplugin = OpenmmPluginScore(mol, **kwargs) 

382 self.std = std / 10 

383 self.clamp = clamp 

384 if self.clamp is not None: 

385 self.forward = self._clamp_forward 

386 else: 

387 self.forward = self._forward 

388 

389 def _forward(self, x): 

390 ''' 

391 :param `torch.Tensor` x: dtype=torch.float, device=CUDA, shape B, 3, N 

392 :returns: torch energy tensor dtype should be float and on same device as x 

393 ''' 

394 _x = (x * self.std).permute(0, 2, 1).contiguous() 

395 energy = openmm_energy_function.apply(self.openmmplugin, _x) 

396 return energy 

397 

398 def _clamp_forward(self, x): 

399 ''' 

400 :param `torch.Tensor` x: dtype=torch.float, device=CUDA, shape B, 3, N 

401 :returns: torch energy tensor dtype should be float and on same device as x 

402 ''' 

403 _x = (x * self.std).permute(0, 2, 1).contiguous() 

404 energy = openmm_clamped_energy_function.apply(self.openmmplugin, _x, self.clamp) 

405 return energy 

406 

407 

408class openmm_energy_setup(): 

409 

410 def __init__(self, topology): 

411 

412 self.top = topology 

413 

414 def get_atominfo(self): 

415 """ 

416 Get the atom information from the topology. 

417 """ 

418 atom_info = [] 

419 for i in self.top.topology.atoms: 

420 atom_info.append([i.name, i.residue.name, i.residue.index + 1]) 

421 return np.array(atom_info, dtype=object) 

422 

423 def mol_dataframe(self): 

424 """ 

425 Get the atom information from the topology and convert it to a pandas DataFrame. 

426 """ 

427 return MolDataFrame.from_stats(self.top) 

428 

429 

430class MolDataFrame(pd.DataFrame): 

431 """ 

432 Subclass of pandas.DataFrame specialized for molecular data. 

433 Stores atomic information plus coordinates, and can write itself to PDB. 

434 """ 

435 @property 

436 def data(self): 

437 return self 

438 

439 @property 

440 def _constructor(self): 

441 return MolDataFrame 

442 

443 @classmethod 

444 def from_trajectory(cls, traj, frame=0): 

445 """ 

446 Build MolDataFrame from an mdtraj.Trajectory `traj`. 

447 - `frame`: index of the frame whose coords to use (default 0). 

448 Adds x, y, z columns. 

449 """ 

450 # Atomic radii map (Å) 

451 radius_dict = {"N": 1.55, "C": 1.70, "O": 1.52, 

452 "S": 1.80, "H": 1.20} 

453 

454 def idx_to_letters(idx): 

455 letters = '' 

456 while idx >= 0: 

457 letters = chr(idx % 26 + ord('A')) + letters 

458 idx = idx // 26 - 1 

459 return letters 

460 

461 rows = [] 

462 coords = traj.xyz[frame] # shape (n_atoms,3) 

463 for atom in traj.topology.atoms: 

464 i = atom.index 

465 # occupancy and B-factor 

466 occ = getattr(traj, 'occupancies', [[None]])[0][i] 

467 beta = getattr(traj, 'bfactors', [[None]])[0][i] 

468 occ = 1.0 if occ is None else occ 

469 beta = 0.0 if beta is None else beta 

470 

471 name = atom.name 

472 element = atom.element.symbol if atom.element else None 

473 radius = radius_dict.get(name, 

474 radius_dict.get(element, None)) 

475 charge = 0.0 

476 

477 # chain as letters 

478 cid = idx_to_letters(atom.residue.chain.index) 

479 

480 x, y, z = coords[i] 

481 rows.append({ 

482 'atom': 'ATOM', 

483 'index': i, 

484 'name': name, 

485 'resname': atom.residue.name, 

486 'chain': cid, 

487 'resid': atom.residue.index + 1, 

488 'occupancy': occ, 

489 'beta': beta, 

490 'atomtype': element, 

491 'radius': radius, 

492 'charge': charge, 

493 'x': x, 

494 'y': y, 

495 'z': z, 

496 }) 

497 

498 cols = ['atom', 'index', 'name', 'resname', 'chain', 'resid', 

499 'occupancy', 'beta', 'atomtype', 'radius', 'charge', 

500 'x', 'y', 'z'] 

501 return cls(rows, columns=cols) 

502 

503 def write_pdb(self, filename): 

504 """ 

505 Write this MolDataFrame (with x,y,z) to a PDB file. 

506 """ 

507 fmt = ( 

508 "ATOM {index:5d} {name:^4s}{resname:>3s} {chain:1s}" 

509 "{resid:4d} {x:8.3f}{y:8.3f}{z:8.3f}" 

510 "{occupancy:6.2f}{beta:6.2f} {atomtype:>2s}\n" 

511 ) 

512 with open(filename, 'w') as f: 

513 for row in self.to_dict('records'): 

514 f.write(fmt.format(**row)) 

515 f.write("TER\nEND\n")