Skip to content

Commit

Permalink
Fix numbering bug (#33)
Browse files Browse the repository at this point in the history
* Fix return atom indexes of clashes (#32)

---------

Co-authored-by: Ondřej Bouček <60691909+Endyff@users.noreply.github.com>
  • Loading branch information
maabuu and Endyff authored Mar 19, 2024
1 parent 58e7581 commit 3c467ab
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
2 changes: 1 addition & 1 deletion posebusters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
"check_volume_overlap",
]

__version__ = "0.2.11"
__version__ = "0.2.12"
37 changes: 24 additions & 13 deletions posebusters/modules/intermolecular_distance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module to check intermolecular distances between ligand and protein."""

from __future__ import annotations

from typing import Any
Expand Down Expand Up @@ -46,32 +47,30 @@ def check_intermolecular_distance( # noqa: PLR0913
atoms_ligand = np.array([a.GetSymbol() for a in mol_pred.GetAtoms()])
atoms_protein_all = np.array([a.GetSymbol() for a in mol_cond.GetAtoms()])

idxs_ligand = np.array([a.GetIdx() for a in mol_pred.GetAtoms()])
idxs_protein = np.array([a.GetIdx() for a in mol_cond.GetAtoms()])

mask = [a.GetSymbol() != "H" for a in mol_pred.GetAtoms()]
coords_ligand = coords_ligand[mask, :]
atoms_ligand = atoms_ligand[mask]

mask_ligand_idxs = idxs_ligand[mask]
if ignore_types:
mask = get_atom_type_mask(mol_cond, ignore_types)
coords_protein = coords_protein[mask, :]
atoms_protein_all = atoms_protein_all[mask]
mask_protein_idxs = idxs_protein[mask]

# get radii
if radius_type == "vdw":
radius_ligand = np.array([_periodic_table.GetRvdw(a) for a in atoms_ligand])
radius_protein_all = np.array([_periodic_table.GetRvdw(a) for a in atoms_protein_all])
elif radius_type == "covalent":
radius_ligand = np.array([_periodic_table.GetRcovalent(a) for a in atoms_ligand])
radius_protein_all = np.array([_periodic_table.GetRcovalent(a) for a in atoms_protein_all])
else:
raise ValueError(f"Unknown radius type {radius_type}. Valid values are 'vdw' and 'covalent'.")

distances_all = _pairwise_distance(coords_ligand, coords_protein)
radius_ligand = _get_radii(atoms_ligand, radius_type)
radius_protein_all = _get_radii(atoms_protein_all, radius_type)

# select atoms that are close to ligand to check for clash
distances_all = _pairwise_distance(coords_ligand, coords_protein)
mask_protein = distances_all.min(axis=0) <= search_distance
distances = distances_all[:, mask_protein]
radius_protein = radius_protein_all[mask_protein]
atoms_protein = atoms_protein_all[mask_protein]
mask_protein_idxs = mask_protein_idxs[mask_protein]

radius_sum = radius_ligand[:, None] + radius_protein[None, :]
relative_distance = distances / radius_sum
Expand All @@ -81,11 +80,13 @@ def check_intermolecular_distance( # noqa: PLR0913
violations[np.unravel_index(distances.argmin(), distances.shape)] = True # add smallest distances as info
violations[np.unravel_index(relative_distance.argmin(), relative_distance.shape)] = True
violation_ligand, violation_protein = np.where(violations)
reverse_ligand_idxs = mask_ligand_idxs[violation_ligand]
reverse_protein_idxs = mask_protein_idxs[violation_protein]

# collect details around those violations in a dataframe
details = pd.DataFrame()
details["ligand_atom_id"] = violation_ligand
details["protein_atom_id"] = violation_protein
details["ligand_atom_id"] = reverse_ligand_idxs
details["protein_atom_id"] = reverse_protein_idxs
details["ligand_element"] = [atoms_ligand[i] for i in violation_ligand]
details["protein_element"] = [atoms_protein[i] for i in violation_protein]
details["ligand_vdw"] = [radius_ligand[i] for i in violation_ligand]
Expand All @@ -103,6 +104,7 @@ def check_intermolecular_distance( # noqa: PLR0913
"no_clashes": not details["clash"].any(),
}

# add most extreme values to results table
i = np.argmin(details["relative_distance"]) if len(details) > 0 else None
most_extreme = {"most_extreme_" + c: details.loc[i][str(c)] if i is not None else pd.NA for c in details.columns}
results = {**results, **most_extreme}
Expand All @@ -112,3 +114,12 @@ def check_intermolecular_distance( # noqa: PLR0913

def _pairwise_distance(x: np.ndarray, y: np.ndarray) -> np.ndarray:
return np.linalg.norm(x[:, None, :] - y[None, :, :], axis=-1)


def _get_radii(atoms: np.ndarray, radius_type: str) -> np.ndarray:
if radius_type == "vdw":
return np.array([_periodic_table.GetRvdw(a) for a in atoms])
elif radius_type == "covalent":
return np.array([_periodic_table.GetRcovalent(a) for a in atoms])
else:
raise ValueError(f"Unknown radius type {radius_type}. Valid values are 'vdw' and 'covalent'.")

0 comments on commit 3c467ab

Please sign in to comment.