Skip to content

Instantly share code, notes, and snippets.

@moritzschaefer
Last active October 19, 2022 09:33
Show Gist options
  • Select an option

  • Save moritzschaefer/2b2c97c266b93d0df2559ab1b9ef512f to your computer and use it in GitHub Desktop.

Select an option

Save moritzschaefer/2b2c97c266b93d0df2559ab1b9ef512f to your computer and use it in GitHub Desktop.

Revisions

  1. moritzschaefer revised this gist Oct 19, 2022. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion test_alphafold.py
    Original file line number Diff line number Diff line change
    @@ -14,7 +14,7 @@
    # import torch


    with open('/home/moritz/Projects/guided-protein-diffusion/tests/5t5f.pdb', 'r') as f:
    with open('5t5f.pdb', 'r') as f:
    pdb_str = f.read()
    data = protein.from_pdb_string(pdb_str, 'H')

  2. moritzschaefer revised this gist Oct 19, 2022. 2 changed files with 81 additions and 0 deletions.
    81 changes: 81 additions & 0 deletions test_alphafold.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,81 @@
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    import numpy as np

    from alphafold.common import protein
    from alphafold.model import geometry, r3
    from alphafold.model.all_atom import (
    atom37_to_atom14, atom37_to_torsion_angles,
    frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames)
    from alphafold.model.folding_multimer import make_backbone_affine
    from alphafold.model.quat_affine import QuatAffine
    from alphafold.model.tf.data_transforms import make_atom14_masks

    # import torch


    with open('/home/moritz/Projects/guided-protein-diffusion/tests/5t5f.pdb', 'r') as f:
    pdb_str = f.read()
    data = protein.from_pdb_string(pdb_str, 'H')

    data = {
    'all_atom_positions': jnp.array(data.atom_positions),
    'aatype': jnp.array(data.aatype),
    'all_atom_mask': jnp.array(data.atom_mask),
    'residue_index': jnp.array(data.residue_index),
    'b_factors': jnp.array(data.b_factors),
    }

    torsions = atom37_to_torsion_angles(data['aatype'][np.newaxis, ...], data['all_atom_positions'][np.newaxis, ...], data['all_atom_mask'][np.newaxis, ...])


    data = make_atom14_masks(data)
    data = {k: jnp.array(v) for k, v in data.items()} # convert tensorflow to jax

    atom14 = atom37_to_atom14(data['all_atom_positions'], data)


    # gt_affine = quat_affine.QuatAffine.from_tensor(
    # batch['backbone_affine_tensor'])
    # gt_rigid = r3.rigids_from_quataffine(gt_affine)
    # backbone_mask = batch['backbone_affine_mask']
    pos = geometry.Vec3Array(x=data['all_atom_positions'][..., 0],
    y=data['all_atom_positions'][..., 1],
    z=data['all_atom_positions'][..., 2],
    )
    gt_rigid, gt_affine_mask = make_backbone_affine(pos,
    data['all_atom_mask'],
    data['aatype'])
    # Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates"

    rr=r3.Rigids(gt_rigid.rotation, gt_rigid.translation)
    # r3.Rigids with shape (N, 8).
    all_frames_to_global = torsion_angles_to_frames(
    data['aatype'],
    rr,
    torsions['torsion_angles_sin_cos'][0])

    all_frames_to_global_alt = torsion_angles_to_frames(
    data['aatype'],
    rr,
    torsions['torsion_angles_sin_cos'][0])

    # Use frames and literature positions to create the final atom coordinates.
    # r3.Vecs with shape (N, 14).
    pred_positions = frames_and_literature_positions_to_atom14_pos(
    data['aatype'], all_frames_to_global)

    pred_positions_alt = frames_and_literature_positions_to_atom14_pos(
    data['aatype'], all_frames_to_global_alt)
    import ipdb; ipdb.set_trace()

    rmsd = ((r3.vecs_to_tensor(pred_positions) * data['atom14_atom_exists'][..., None] - atom14 * data['atom14_atom_exists'][..., None]) ** 2).sum(axis=-1) ** 0.5
    rmsd_alt = ((r3.vecs_to_tensor(pred_positions_alt) * data['atom14_atom_exists'][..., None] - atom14 * data['atom14_atom_exists'][..., None]) ** 2).sum(axis=-1) ** 0.5


    plt.hist(rmsd[data['atom14_atom_exists'].astype(bool)], bins=100)
    plt.ylabel('atom count')
    plt.xlabel('RMSD (angstrom) for all atoms of a 200AA long protein')
    plt.title('Histogram of RMSD between original and frame/angle-derived PDB positions')
    plt.show()
    print(np.min(np.stack((rmsd, rmsd_alt)), axis=0)[0].max())
    File renamed without changes.
  3. moritzschaefer created this gist Oct 15, 2022.
    5,023 changes: 5,023 additions & 0 deletions 5t5f.pdb
    5,023 additions, 0 deletions not shown because the diff is too large. Please use a local Git client to view these changes.
    156 changes: 156 additions & 0 deletions test.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,156 @@
    import torch
    from torch import nn

    from openfold.data import data_transforms
    from openfold.data.input_pipeline import compose
    from openfold.np import protein
    from openfold.np import residue_constants as rc
    from openfold.utils.feats import (
    frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames)
    from openfold.utils.rigid_utils import Rigid, Rotation

    # Loading an example protein
    with open('5t5f.pdb', 'r') as f:
    pdb_str = f.read()
    data = protein.from_pdb_string(pdb_str, 'H')
    # Convert to tensor-dict
    data = {
    'all_atom_positions': torch.tensor(data.atom_positions),
    'aatype': torch.tensor(data.aatype),
    'all_atom_mask': torch.tensor(data.atom_mask),
    'residue_index': torch.tensor(data.residue_index),
    'b_factors': torch.tensor(data.b_factors),
    }

    # Prepare data: Generate atom14, frames and torsion angles
    _transforms = [
    data_transforms.make_atom14_masks,
    data_transforms.make_atom14_positions,
    data_transforms.atom37_to_frames,
    data_transforms.atom37_to_torsion_angles(""),
    data_transforms.make_pseudo_beta(""),
    data_transforms.get_backbone_frames,
    data_transforms.get_chi_angles
    ]
    data = compose(_transforms)(data)

    # This class is adopted from StructureModule (OpenFold/AF2). The key function is 'frames_to_atom14'
    class _RigidConversion(nn.Module):
    def __init__(self):
    super(_RigidConversion, self).__init__()

    def _init_residue_constants(self, float_dtype, device):
    if not hasattr(self, "default_frames"):
    self.register_buffer(
    "default_frames",
    torch.tensor(
    rc.restype_rigid_group_default_frame,
    dtype=float_dtype,
    device=device,
    requires_grad=False,
    ),
    persistent=False,
    )
    if not hasattr(self, "group_idx"):
    self.register_buffer(
    "group_idx",
    torch.tensor(
    rc.restype_atom14_to_rigid_group,
    device=device,
    requires_grad=False,
    ),
    persistent=False,
    )
    if not hasattr(self, "atom_mask"):
    self.register_buffer(
    "atom_mask",
    torch.tensor(
    rc.restype_atom14_mask,
    dtype=float_dtype,
    device=device,
    requires_grad=False,
    ),
    persistent=False,
    )
    if not hasattr(self, "lit_positions"):
    self.register_buffer(
    "lit_positions",
    torch.tensor(
    rc.restype_atom14_rigid_group_positions,
    dtype=float_dtype,
    device=device,
    requires_grad=False,
    ),
    persistent=False,
    )

    def torsion_angles_to_frames(self, r, alpha, f):
    # Lazily initialize the residue constants on the correct device
    self._init_residue_constants(alpha.dtype, alpha.device)
    # Separated purely to make testing less annoying
    return torsion_angles_to_frames(r, alpha, f, self.default_frames)

    def frames_and_literature_positions_to_atom14_pos(
    self, r, f # [*, N, 8] # [*, N]
    ):
    # Lazily initialize the residue constants on the correct device
    self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
    return frames_and_literature_positions_to_atom14_pos(
    r,
    f,
    self.default_frames,
    self.group_idx,
    self.atom_mask,
    self.lit_positions,
    )

    def frames_to_atom14(self, rigids: Rigid, angles: torch.Tensor, aatypes: torch.Tensor, scale_translation):
    '''
    Atopted from forward pass of StructureModule
    '''

    rigids = rigids.scale_translation(scale_translation)

    # [*, N, 7, 2]
    # unnormalized_angles, angles = self.angle_resnet(s, s_initial)

    all_frames_to_global = self.torsion_angles_to_frames(
    rigids,
    angles,
    aatypes,
    )

    xyz = self.frames_and_literature_positions_to_atom14_pos(
    all_frames_to_global,
    aatypes,
    )
    return xyz

    rigid_conversion = _RigidConversion()


    # Calculate RMSD between loaded atom positions and calculated atom positions (atom positions -> frames/angles -> atom positions)

    # To rule out that alternative conformations (180 degree turns that lead to the same molecule) lead to errors, I (desperately) try all provided angles and positions and take the minimum deviations across them (there is no difference across them though...)
    out = rigid_conversion.frames_to_atom14(
    Rigid.from_tensor_4x4(data['backbone_rigid_tensor']),
    data['torsion_angles_sin_cos'].to(torch.float64), # normalize oder so?
    data['aatype'],
    1.0
    )

    out_alt = rigid_conversion.frames_to_atom14(
    Rigid.from_tensor_4x4(data['backbone_rigid_tensor']),
    data['alt_torsion_angles_sin_cos'].to(torch.float64), # normalize oder so?
    data['aatype'],
    1.0
    )

    rmsd = ((out * data['atom14_gt_exists'][..., None] - data['atom14_gt_positions'] * data['atom14_gt_exists'][..., None]) ** 2).sum(dim=-1).sqrt()
    rmsd2 = ((out_alt * data['atom14_gt_exists'][..., None] - data['atom14_gt_positions'] * data['atom14_gt_exists'][..., None]) ** 2).sum(dim=-1).sqrt()

    rmsd3 = ((out * data['atom14_alt_gt_exists'][..., None] - data['atom14_alt_gt_positions'] * data['atom14_alt_gt_exists'][..., None]) ** 2).sum(dim=-1).sqrt()
    rmsd4 = ((out_alt * data['atom14_alt_gt_exists'][..., None] - data['atom14_alt_gt_positions'] * data['atom14_alt_gt_exists'][..., None]) ** 2).sum(dim=-1).sqrt()

    # Should be eps (e.g. 0.000001), but is about 1 angstrom.
    print(torch.min(torch.stack((rmsd, rmsd2, rmsd3, rmsd4)), dim=0)[0].max())