diff --git a/posebusters/__init__.py b/posebusters/__init__.py index 0511b74..e20dc02 100644 --- a/posebusters/__init__.py +++ b/posebusters/__init__.py @@ -24,4 +24,4 @@ "check_volume_overlap", ] -__version__ = "0.2.11" +__version__ = "0.2.12" diff --git a/posebusters/modules/intermolecular_distance.py b/posebusters/modules/intermolecular_distance.py index 8730a39..5ac80eb 100644 --- a/posebusters/modules/intermolecular_distance.py +++ b/posebusters/modules/intermolecular_distance.py @@ -1,4 +1,5 @@ """Module to check intermolecular distances between ligand and protein.""" + from __future__ import annotations from typing import Any @@ -46,32 +47,30 @@ def check_intermolecular_distance( # noqa: PLR0913 atoms_ligand = np.array([a.GetSymbol() for a in mol_pred.GetAtoms()]) atoms_protein_all = np.array([a.GetSymbol() for a in mol_cond.GetAtoms()]) + idxs_ligand = np.array([a.GetIdx() for a in mol_pred.GetAtoms()]) + idxs_protein = np.array([a.GetIdx() for a in mol_cond.GetAtoms()]) + mask = [a.GetSymbol() != "H" for a in mol_pred.GetAtoms()] coords_ligand = coords_ligand[mask, :] atoms_ligand = atoms_ligand[mask] - + mask_ligand_idxs = idxs_ligand[mask] if ignore_types: mask = get_atom_type_mask(mol_cond, ignore_types) coords_protein = coords_protein[mask, :] atoms_protein_all = atoms_protein_all[mask] + mask_protein_idxs = idxs_protein[mask] # get radii - if radius_type == "vdw": - radius_ligand = np.array([_periodic_table.GetRvdw(a) for a in atoms_ligand]) - radius_protein_all = np.array([_periodic_table.GetRvdw(a) for a in atoms_protein_all]) - elif radius_type == "covalent": - radius_ligand = np.array([_periodic_table.GetRcovalent(a) for a in atoms_ligand]) - radius_protein_all = np.array([_periodic_table.GetRcovalent(a) for a in atoms_protein_all]) - else: - raise ValueError(f"Unknown radius type {radius_type}. Valid values are 'vdw' and 'covalent'.") - - distances_all = _pairwise_distance(coords_ligand, coords_protein) + radius_ligand = _get_radii(atoms_ligand, radius_type) + radius_protein_all = _get_radii(atoms_protein_all, radius_type) # select atoms that are close to ligand to check for clash + distances_all = _pairwise_distance(coords_ligand, coords_protein) mask_protein = distances_all.min(axis=0) <= search_distance distances = distances_all[:, mask_protein] radius_protein = radius_protein_all[mask_protein] atoms_protein = atoms_protein_all[mask_protein] + mask_protein_idxs = mask_protein_idxs[mask_protein] radius_sum = radius_ligand[:, None] + radius_protein[None, :] relative_distance = distances / radius_sum @@ -81,11 +80,13 @@ def check_intermolecular_distance( # noqa: PLR0913 violations[np.unravel_index(distances.argmin(), distances.shape)] = True # add smallest distances as info violations[np.unravel_index(relative_distance.argmin(), relative_distance.shape)] = True violation_ligand, violation_protein = np.where(violations) + reverse_ligand_idxs = mask_ligand_idxs[violation_ligand] + reverse_protein_idxs = mask_protein_idxs[violation_protein] # collect details around those violations in a dataframe details = pd.DataFrame() - details["ligand_atom_id"] = violation_ligand - details["protein_atom_id"] = violation_protein + details["ligand_atom_id"] = reverse_ligand_idxs + details["protein_atom_id"] = reverse_protein_idxs details["ligand_element"] = [atoms_ligand[i] for i in violation_ligand] details["protein_element"] = [atoms_protein[i] for i in violation_protein] details["ligand_vdw"] = [radius_ligand[i] for i in violation_ligand] @@ -103,6 +104,7 @@ def check_intermolecular_distance( # noqa: PLR0913 "no_clashes": not details["clash"].any(), } + # add most extreme values to results table i = np.argmin(details["relative_distance"]) if len(details) > 0 else None most_extreme = {"most_extreme_" + c: details.loc[i][str(c)] if i is not None else pd.NA for c in details.columns} results = {**results, **most_extreme} @@ -112,3 +114,12 @@ def check_intermolecular_distance( # noqa: PLR0913 def _pairwise_distance(x: np.ndarray, y: np.ndarray) -> np.ndarray: return np.linalg.norm(x[:, None, :] - y[None, :, :], axis=-1) + + +def _get_radii(atoms: np.ndarray, radius_type: str) -> np.ndarray: + if radius_type == "vdw": + return np.array([_periodic_table.GetRvdw(a) for a in atoms]) + elif radius_type == "covalent": + return np.array([_periodic_table.GetRcovalent(a) for a in atoms]) + else: + raise ValueError(f"Unknown radius type {radius_type}. Valid values are 'vdw' and 'covalent'.")