Last active
October 19, 2022 09:33
-
-
Save moritzschaefer/2b2c97c266b93d0df2559ab1b9ef512f to your computer and use it in GitHub Desktop.
Revisions
-
moritzschaefer revised this gist
Oct 19, 2022 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -14,7 +14,7 @@ # import torch with open('5t5f.pdb', 'r') as f: pdb_str = f.read() data = protein.from_pdb_string(pdb_str, 'H') -
moritzschaefer revised this gist
Oct 19, 2022 . 2 changed files with 81 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,81 @@ import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np from alphafold.common import protein from alphafold.model import geometry, r3 from alphafold.model.all_atom import ( atom37_to_atom14, atom37_to_torsion_angles, frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames) from alphafold.model.folding_multimer import make_backbone_affine from alphafold.model.quat_affine import QuatAffine from alphafold.model.tf.data_transforms import make_atom14_masks # import torch with open('/home/moritz/Projects/guided-protein-diffusion/tests/5t5f.pdb', 'r') as f: pdb_str = f.read() data = protein.from_pdb_string(pdb_str, 'H') data = { 'all_atom_positions': jnp.array(data.atom_positions), 'aatype': jnp.array(data.aatype), 'all_atom_mask': jnp.array(data.atom_mask), 'residue_index': jnp.array(data.residue_index), 'b_factors': jnp.array(data.b_factors), } torsions = atom37_to_torsion_angles(data['aatype'][np.newaxis, ...], data['all_atom_positions'][np.newaxis, ...], data['all_atom_mask'][np.newaxis, ...]) data = make_atom14_masks(data) data = {k: jnp.array(v) for k, v in data.items()} # convert tensorflow to jax atom14 = atom37_to_atom14(data['all_atom_positions'], data) # gt_affine = quat_affine.QuatAffine.from_tensor( # batch['backbone_affine_tensor']) # gt_rigid = r3.rigids_from_quataffine(gt_affine) # backbone_mask = batch['backbone_affine_mask'] pos = geometry.Vec3Array(x=data['all_atom_positions'][..., 0], y=data['all_atom_positions'][..., 1], z=data['all_atom_positions'][..., 2], ) gt_rigid, gt_affine_mask = make_backbone_affine(pos, data['all_atom_mask'], data['aatype']) # Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" rr=r3.Rigids(gt_rigid.rotation, gt_rigid.translation) # r3.Rigids with shape (N, 8). all_frames_to_global = torsion_angles_to_frames( data['aatype'], rr, torsions['torsion_angles_sin_cos'][0]) all_frames_to_global_alt = torsion_angles_to_frames( data['aatype'], rr, torsions['torsion_angles_sin_cos'][0]) # Use frames and literature positions to create the final atom coordinates. # r3.Vecs with shape (N, 14). pred_positions = frames_and_literature_positions_to_atom14_pos( data['aatype'], all_frames_to_global) pred_positions_alt = frames_and_literature_positions_to_atom14_pos( data['aatype'], all_frames_to_global_alt) import ipdb; ipdb.set_trace() rmsd = ((r3.vecs_to_tensor(pred_positions) * data['atom14_atom_exists'][..., None] - atom14 * data['atom14_atom_exists'][..., None]) ** 2).sum(axis=-1) ** 0.5 rmsd_alt = ((r3.vecs_to_tensor(pred_positions_alt) * data['atom14_atom_exists'][..., None] - atom14 * data['atom14_atom_exists'][..., None]) ** 2).sum(axis=-1) ** 0.5 plt.hist(rmsd[data['atom14_atom_exists'].astype(bool)], bins=100) plt.ylabel('atom count') plt.xlabel('RMSD (angstrom) for all atoms of a 200AA long protein') plt.title('Histogram of RMSD between original and frame/angle-derived PDB positions') plt.show() print(np.min(np.stack((rmsd, rmsd_alt)), axis=0)[0].max()) File renamed without changes. -
moritzschaefer created this gist
Oct 15, 2022 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,156 @@ import torch from torch import nn from openfold.data import data_transforms from openfold.data.input_pipeline import compose from openfold.np import protein from openfold.np import residue_constants as rc from openfold.utils.feats import ( frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames) from openfold.utils.rigid_utils import Rigid, Rotation # Loading an example protein with open('5t5f.pdb', 'r') as f: pdb_str = f.read() data = protein.from_pdb_string(pdb_str, 'H') # Convert to tensor-dict data = { 'all_atom_positions': torch.tensor(data.atom_positions), 'aatype': torch.tensor(data.aatype), 'all_atom_mask': torch.tensor(data.atom_mask), 'residue_index': torch.tensor(data.residue_index), 'b_factors': torch.tensor(data.b_factors), } # Prepare data: Generate atom14, frames and torsion angles _transforms = [ data_transforms.make_atom14_masks, data_transforms.make_atom14_positions, data_transforms.atom37_to_frames, data_transforms.atom37_to_torsion_angles(""), data_transforms.make_pseudo_beta(""), data_transforms.get_backbone_frames, data_transforms.get_chi_angles ] data = compose(_transforms)(data) # This class is adopted from StructureModule (OpenFold/AF2). The key function is 'frames_to_atom14' class _RigidConversion(nn.Module): def __init__(self): super(_RigidConversion, self).__init__() def _init_residue_constants(self, float_dtype, device): if not hasattr(self, "default_frames"): self.register_buffer( "default_frames", torch.tensor( rc.restype_rigid_group_default_frame, dtype=float_dtype, device=device, requires_grad=False, ), persistent=False, ) if not hasattr(self, "group_idx"): self.register_buffer( "group_idx", torch.tensor( rc.restype_atom14_to_rigid_group, device=device, requires_grad=False, ), persistent=False, ) if not hasattr(self, "atom_mask"): self.register_buffer( "atom_mask", torch.tensor( rc.restype_atom14_mask, dtype=float_dtype, device=device, requires_grad=False, ), persistent=False, ) if not hasattr(self, "lit_positions"): self.register_buffer( "lit_positions", torch.tensor( rc.restype_atom14_rigid_group_positions, dtype=float_dtype, device=device, requires_grad=False, ), persistent=False, ) def torsion_angles_to_frames(self, r, alpha, f): # Lazily initialize the residue constants on the correct device self._init_residue_constants(alpha.dtype, alpha.device) # Separated purely to make testing less annoying return torsion_angles_to_frames(r, alpha, f, self.default_frames) def frames_and_literature_positions_to_atom14_pos( self, r, f # [*, N, 8] # [*, N] ): # Lazily initialize the residue constants on the correct device self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) return frames_and_literature_positions_to_atom14_pos( r, f, self.default_frames, self.group_idx, self.atom_mask, self.lit_positions, ) def frames_to_atom14(self, rigids: Rigid, angles: torch.Tensor, aatypes: torch.Tensor, scale_translation): ''' Atopted from forward pass of StructureModule ''' rigids = rigids.scale_translation(scale_translation) # [*, N, 7, 2] # unnormalized_angles, angles = self.angle_resnet(s, s_initial) all_frames_to_global = self.torsion_angles_to_frames( rigids, angles, aatypes, ) xyz = self.frames_and_literature_positions_to_atom14_pos( all_frames_to_global, aatypes, ) return xyz rigid_conversion = _RigidConversion() # Calculate RMSD between loaded atom positions and calculated atom positions (atom positions -> frames/angles -> atom positions) # To rule out that alternative conformations (180 degree turns that lead to the same molecule) lead to errors, I (desperately) try all provided angles and positions and take the minimum deviations across them (there is no difference across them though...) out = rigid_conversion.frames_to_atom14( Rigid.from_tensor_4x4(data['backbone_rigid_tensor']), data['torsion_angles_sin_cos'].to(torch.float64), # normalize oder so? data['aatype'], 1.0 ) out_alt = rigid_conversion.frames_to_atom14( Rigid.from_tensor_4x4(data['backbone_rigid_tensor']), data['alt_torsion_angles_sin_cos'].to(torch.float64), # normalize oder so? data['aatype'], 1.0 ) rmsd = ((out * data['atom14_gt_exists'][..., None] - data['atom14_gt_positions'] * data['atom14_gt_exists'][..., None]) ** 2).sum(dim=-1).sqrt() rmsd2 = ((out_alt * data['atom14_gt_exists'][..., None] - data['atom14_gt_positions'] * data['atom14_gt_exists'][..., None]) ** 2).sum(dim=-1).sqrt() rmsd3 = ((out * data['atom14_alt_gt_exists'][..., None] - data['atom14_alt_gt_positions'] * data['atom14_alt_gt_exists'][..., None]) ** 2).sum(dim=-1).sqrt() rmsd4 = ((out_alt * data['atom14_alt_gt_exists'][..., None] - data['atom14_alt_gt_positions'] * data['atom14_alt_gt_exists'][..., None]) ** 2).sum(dim=-1).sqrt() # Should be eps (e.g. 0.000001), but is about 1 angstrom. print(torch.min(torch.stack((rmsd, rmsd2, rmsd3, rmsd4)), dim=0)[0].max())