Skip to content

Commit

Permalink
(wip) harmonize from_json and from (data) json
Browse files Browse the repository at this point in the history
  • Loading branch information
rwxayheee committed Dec 19, 2024
1 parent 5ec1bf7 commit cba5994
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 110 deletions.
3 changes: 2 additions & 1 deletion meeko/cli/mk_prepare_receptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,8 @@ def main():
fn = args.write_json[0]
else: # args.write_json is empty list (was used without arg)
fn = str(outpath) + ".json"
polymer.to_json_file(fn)
with open(fn, "w") as f:
f.write(polymer.to_json())
written_files_log["filename"].append(fn)
written_files_log["description"].append("parameterized receptor")

Expand Down
10 changes: 5 additions & 5 deletions meeko/data/residue_chem_templates.json
Original file line number Diff line number Diff line change
Expand Up @@ -150,27 +150,27 @@
"padders": {
"5-prime": {
"rxn_smarts": "[PX4h1:1]>>[P:1][O:11][C:12]",
"adjacent_res_smarts": "[O+0X2h1:11][CX4:12]1CC(n)[OX2]C1",
"adjacent_smarts": "[O+0X2h1:11][CX4:12]1CC(n)[OX2]C1",
"auto_blunt": true
},
"3-prime": {
"rxn_smarts": "[O+0X2h1:1][CX4:2]>>[O:1]([C:2])[P:11]([O-:12])(=[O:13])[O:14][C:15]",
"adjacent_res_smarts": "[PX4h1:11]([O-:12])(=[O:13])[O:14][C:15]",
"adjacent_smarts": "[PX4h1:11]([O-:12])(=[O:13])[O:14][C:15]",
"auto_blunt": true
},
"N-term": {
"rxn_smarts": "[C:1](=[O:2])[C:3][N:4]>>[C:1](=[O:2])[C:3][N:4][C:11](=[O:12])[C:13]",
"adjacent_res_smarts": "[C:11](=[O:12])[C:13][N]",
"adjacent_smarts": "[C:11](=[O:12])[C:13][N]",
"auto_blunt": true
},
"C-term": {
"rxn_smarts": "[C:1](=[O:2])[C:3][N:4]>>[C:11][N:12][C:1](=[O:2])[C:3][N:4]",
"adjacent_res_smarts": "[C](=O)[C:11][N:12]",
"adjacent_smarts": "[C](=O)[C:11][N:12]",
"auto_blunt": true
},
"dissulfide": {
"rxn_smarts": "[C:1][S:2]>>[C:1][S:2][S:11][C:12]",
"adjacent_res_smarts": "[S:11][C:12]",
"adjacent_smarts": "[S:11][C:12]",
"auto_blunt": false
}
},
Expand Down
12 changes: 6 additions & 6 deletions meeko/molsetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _decode_object(cls, obj: dict[str, Any]):
interaction_vectors = [np.asarray(i) for i in obj["interaction_vectors"]]
is_dummy = obj["is_dummy"]
is_pseudo_atom = obj["is_pseudo_atom"]
output_atom = Atom(
output_atom = cls(
index,
pdbinfo,
charge,
Expand Down Expand Up @@ -304,7 +304,7 @@ def _decode_object(cls, obj: dict[str, Any]):
index1 = obj["index1"]
index2 = obj["index2"]
rotatable = obj["rotatable"]
output_bond = Bond(index1, index2, rotatable)
output_bond = cls(index1, index2, rotatable)
return output_bond
# endregion

Expand Down Expand Up @@ -351,7 +351,7 @@ def _decode_object(cls, obj: dict[str, Any]):

# Constructs a Ring object from the provided keys.
ring_id = string_to_tuple(obj["ring_id"], int)
output_ring = Ring(ring_id)
output_ring = cls(ring_id)
return output_ring
# endregion

Expand Down Expand Up @@ -397,7 +397,7 @@ def _decode_object(cls, obj: dict[str, Any]):
target_coords = tuple(obj["target_coords"])
kcal_per_angstrom_square = obj["kcal_per_angstrom_square"]
delay_angstroms = obj["delay_angstroms"]
output_restraint = Restraint(
output_restraint = cls(
atom_index, target_coords, kcal_per_angstrom_square, delay_angstroms
)
return output_restraint
Expand Down Expand Up @@ -527,7 +527,7 @@ def _decode_object(cls, obj: dict[str, Any]):
# Constructs a MoleculeSetup object and restores the expected attributes
name = obj["name"]
is_sidechain = obj["is_sidechain"]
molsetup = MoleculeSetup(name, is_sidechain)
molsetup = cls(name, is_sidechain)

molsetup.pseudoatom_count = obj["pseudoatom_count"]
molsetup.atoms = [Atom.json_decoder(x) for x in obj["atoms"]]
Expand Down Expand Up @@ -1551,7 +1551,7 @@ def json_encoder(cls, obj: "RDKitMoleculeSetup") -> Optional[dict[str, Any]]:
def _decode_object(cls, obj: dict[str, Any]):

base_molsetup = MoleculeSetup.json_decoder(obj)
rdkit_molsetup = RDKitMoleculeSetup(source = base_molsetup)
rdkit_molsetup = cls(source = base_molsetup)

# Restores RDKitMoleculeSetup-specific attributes from the json dict
rdkit_molsetup.mol = rdkit_mol_from_json(obj["mol"])
Expand Down
128 changes: 45 additions & 83 deletions meeko/polymer.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,47 +619,11 @@ def _decode_object(cls, obj: dict[str, Any]):
}
padders = {k: ResiduePadder.json_decoder(v) for k, v in obj["padders"].items()}

residue_chem_templates = ResidueChemTemplates(templates, padders, obj["ambiguous"])
residue_chem_templates = cls(templates, padders, obj["ambiguous"])

return residue_chem_templates
# endregion

@classmethod
def from_dict(cls, alldata):
"""
constructs ResidueTemplates and ResiduePadders from a dictionary
with raw data such as that in data/residue_chem_templates.json
This is pretty much a JSON deserializer that takes a dictionary
as input to allow users to modify the input dict in Python
"""

ambiguous = {k: v.copy() for k, v in alldata["ambiguous"].items()}
residue_templates = {}
padders = {}
for key, data in alldata["residue_templates"].items():
res_template = cls.residue_template_from_dict(data)
residue_templates[key] = res_template
for link_label, data in alldata["padders"].items():
padders[link_label] = cls.padder_from_dict(data)
return cls(residue_templates, padders, ambiguous)

@staticmethod
def residue_template_from_dict(data):
if "link_labels" in data:
link_labels = convert_to_int_keyed_dict(data["link_labels"])
else:
link_labels = None
atom_names = data.get("atom_name", None)
return ResidueTemplate(data["smiles"], link_labels, atom_names)

@staticmethod
def padder_from_dict(data):
rxn_smarts = data["rxn_smarts"]
adjacent_res_smarts = data.get("adjacent_res_smarts", None)
auto_blunt = data.get("auto_blunt", False)
padder = ResiduePadder(rxn_smarts, adjacent_res_smarts, auto_blunt)
return padder

def add_dict(self, data, overwrite=False):
bad_keys = set(data) - {"ambiguous", "residue_templates", "padders"}
if bad_keys:
Expand All @@ -673,11 +637,11 @@ def add_dict(self, data, overwrite=False):
self.ambiguous = new_ambiguous
for key, value in data.get("residue_templates", {}).items():
if overwrite or key not in self.residue_templates:
res_template = self.residue_template_from_dict(value)
res_template = ResidueTemplate.from_dict(value)
self.residue_templates[key] = res_template
for link_label, value in data.get("padders", {}).items():
if overwrite or key not in self.padders:
padder = self.padder_from_dict(data)
padder = ResiduePadder.from_dict(data)
self.padders[link_label] = padder
return

Expand All @@ -698,9 +662,18 @@ def from_json_file(cls, filename):
filename = cls.lookup_filename(filename, data_path)
with open(filename) as f:
jsonstr = f.read()
data = json.loads(jsonstr)
return cls.from_dict(data)
alldata = json.loads(jsonstr)

ambiguous = {k: v.copy() for k,v in alldata.get("ambiguous", {}).items()}
residue_templates = {}
padders = {}
for key, data in alldata.get("residue_templates", {}).items():
res_template = ResidueTemplate.from_dict(data)
residue_templates[key] = res_template
for link_label, data in alldata.get("padders", {}).items():
padders[link_label] = ResiduePadder.from_dict(data)
return cls(residue_templates, padders, ambiguous)

@classmethod
def create_from_defaults(cls):
return cls.from_json_file("residue_chem_templates")
Expand Down Expand Up @@ -992,7 +965,7 @@ def _decode_object(cls, obj: dict[str, Any]):
obj["residue_chem_templates"]
)

polymer = Polymer({}, {}, residue_chem_templates)
polymer = cls({}, {}, residue_chem_templates)

polymer.monomers = {
k: Monomer.json_decoder(v) for k, v in obj["monomers"].items()
Expand Down Expand Up @@ -1876,10 +1849,15 @@ def to_pdb(self, new_positions: Optional[dict]=None):
icode = ""
resnum = int(resnum)

for i, atom in enumerate(rdkit_mol.GetAtoms()):
atoms_in_rdkitmol = [atom for atom in rdkit_mol.GetAtoms()]
atom_names = self.monomers[res_id].atom_names
if not atom_names:
self.monomers[res_id]._set_pdbinfo(res_id)

for i, atom in enumerate(atoms_in_rdkitmol):
atom_count += 1
props = atom.GetPropsAsDict()
atom_name = self.monomers[res_id].atom_names[i]
atom_name = atom_names[i]
x, y, z = positions[i]
element = mini_periodic_table[atom.GetAtomicNum()]
pdbout += pdb_line.format(
Expand Down Expand Up @@ -2182,29 +2160,6 @@ def _decode_object(cls, obj: dict[str, Any]):
return monomer
# endregion

def set_atom_names(self, atom_names_list):
"""
Parameters
----------
atom_names_list
Returns
-------
"""
if self.rdkit_mol is None:
raise RuntimeError("can't set atom_names if rdkit_mol is not set yet")
if len(atom_names_list) != self.rdkit_mol.GetNumAtoms():
raise ValueError(
f"{len(atom_names_list)=} differs from {self.rdkit_mol.GetNumAtoms()=}"
)
name_types = set([type(name) for name in atom_names_list])
if name_types != {str}:
raise ValueError(f"atom names must be str but {name_types=}")
self.atom_names = atom_names_list
return

def parameterize(self, mk_prep, residue_id):

molsetups = mk_prep(self.padded_mol)
Expand Down Expand Up @@ -2296,7 +2251,7 @@ class ResiduePadder(BaseJSONParsable):
# reaction should not delete atoms, not even Hs
# reaction should create bonds at non-real Hs (implicit or explicit rdktt H)

def __init__(self, rxn_smarts: str, adjacent_res_smarts: str = None, auto_blunt:bool=False):
def __init__(self, rxn_smarts: str, adjacent_smarts: str = None, auto_blunt:bool=False):
"""
Initialize the ResiduePadder with reaction SMARTS and optional adjacent residue SMARTS.
Expand All @@ -2306,9 +2261,9 @@ def __init__(self, rxn_smarts: str, adjacent_res_smarts: str = None, auto_blunt:
Reaction SMARTS to pad a link atom of a Monomer molecule.
Product atoms that are not mapped in the reactants will have
their coordinates set from an adjacent residue molecule, given
that adjacent_res_smarts is provided and the atom labels match
that adjacent_smarts is provided and the atom labels match
the unmapped product atoms of rxn_smarts.
adjacent_res_smarts: str
adjacent_smarts: str
SMARTS pattern to identify atoms in molecule of adjacent residue
and copy their positions to padding atoms. The SMARTS atom labels
must match those of the product atoms of rxn_smarts that are
Expand All @@ -2323,13 +2278,13 @@ def __init__(self, rxn_smarts: str, adjacent_res_smarts: str = None, auto_blunt:
self.auto_blunt = auto_blunt

# Fill in adjacent_smartsmol_mapidx
if adjacent_res_smarts is None:
if adjacent_smarts is None:
self.adjacent_smartsmol = None
self.adjacent_smartsmol_mapidx = None
return

# Ensure adjacent_res_smarts is None or a valid SMARTS
self.adjacent_smartsmol = self._initialize_adj_smartsmol(adjacent_res_smarts)
# Ensure adjacent_smarts is None or a valid SMARTS
self.adjacent_smartsmol = self._initialize_adj_smartsmol(adjacent_smarts)

# Ensure the mapping numbers are the same in adjacent_smartsmol and rxn_smarts's product
self._check_adj_smarts(self.rxn, self.adjacent_smartsmol)
Expand All @@ -2352,11 +2307,11 @@ def _validate_rxn_smarts(rxn_smarts: str) -> rdChemReactions.ChemicalReaction:
return rxn

@staticmethod
def _initialize_adj_smartsmol(adjacent_res_smarts: str) -> Chem.Mol:
"""Validate adjacent_res_smarts and return adjacent_smartsmol"""
adjacent_smartsmol = Chem.MolFromSmarts(adjacent_res_smarts)
def _initialize_adj_smartsmol(adjacent_smarts: str) -> Chem.Mol:
"""Validate adjacent_smarts and return adjacent_smartsmol"""
adjacent_smartsmol = Chem.MolFromSmarts(adjacent_smarts)
if adjacent_smartsmol is None:
raise RuntimeError("Invalid SMARTS pattern in adjacent_res_smarts")
raise RuntimeError("Invalid SMARTS pattern in adjacent_smarts")
return adjacent_smartsmol

@staticmethod
Expand Down Expand Up @@ -2504,7 +2459,7 @@ def _check_adjacent_mol(expected_adjacent_smartsmol: Chem.Mol, adjacent_mol: Che
there's exactly one match that includes atom with adjacent_required_atom_index
"""
if expected_adjacent_smartsmol is None:
raise RuntimeError("adjacent_res_smarts must be initialized to support adjacent_mol.")
raise RuntimeError("adjacent_smarts must be initialized to support adjacent_mol.")

hits = adjacent_mol.GetSubstructMatches(expected_adjacent_smartsmol)
if adjacent_required_atom_index is not None:
Expand Down Expand Up @@ -2544,8 +2499,10 @@ def json_encoder(cls, obj: "ResiduePadder") -> Optional[dict[str, Any]]:

@classmethod
def _decode_object(cls, obj: dict[str, Any]):

residue_padder = cls(obj["rxn_smarts"], obj["adjacent_smarts"], obj.get("auto_blunt", False))

return ResiduePadder(obj["rxn_smarts"], obj["adjacent_smarts"], obj["auto_blunt"])
return residue_padder
# endregion

# Utility Functions
Expand Down Expand Up @@ -2654,13 +2611,18 @@ def json_encoder(cls, obj: "ResidueTemplate") -> Optional[dict[str, Any]]:
def _decode_object(cls, obj: dict[str, Any]):

# Converting ResidueTemplate init values that need conversion
deserialized_mol = rdkit_mol_from_json(obj["mol"])
deserialized_mol = rdkit_mol_from_json(obj.get("mol"))
# do not write canonical smiles to preserve original atom order
mol_smiles = rdkit.Chem.MolToSmiles(deserialized_mol, canonical=False)
link_labels = convert_to_int_keyed_dict(obj["link_labels"])
if deserialized_mol:
mol_smiles = rdkit.Chem.MolToSmiles(deserialized_mol, canonical=False)
# if dry json (data) is supplied
else:
mol_smiles = obj.get("smiles")

link_labels = convert_to_int_keyed_dict(obj.get("link_labels"))

# Construct a ResidueTemplate object
residue_template = ResidueTemplate(mol_smiles, None, obj["atom_names"])
residue_template = cls(mol_smiles, None, obj.get("atom_names"))
# Separately ensure that link_labels is restored to the value we expect it to be so there are not errors in
# the constructor
residue_template.link_labels = link_labels
Expand Down
18 changes: 3 additions & 15 deletions meeko/utils/jsonutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def _decode_object(cls, obj):

# Inheritable JSON Interchange Functions
@classmethod
def json_decoder(cls, obj: dict[str, Any]):
def json_decoder(cls, obj: dict[str, Any], check_keys: bool = True):
# Avoid using json_decoder as object_hook for nested objects
if not isinstance(obj, dict):
return obj
if not cls.expected_json_keys.issubset(obj.keys()):
if check_keys and not cls.expected_json_keys.issubset(obj.keys()):
return obj

# Delegate specific decoding logic to a subclass-defined method
Expand Down Expand Up @@ -160,19 +160,7 @@ def from_json(cls, json_string: str):

@classmethod
def from_dict(cls, obj: dict) -> "BaseJSONParsable":
return cls.json_decoder(obj)

@classmethod
def from_json_file(cls, json_file) -> "BaseJSONParsable":
with open(json_file, "r") as f:
json_string = f.read()
return cls.from_json(json_string)
return cls.json_decoder(obj, check_keys=False)

def to_json(self):
return json.dumps(self, default=self.__class__.json_encoder)

def to_json_file(self, json_file):
json_string = self.to_json()
with open(json_file, "w") as f:
f.write(json_string)

0 comments on commit cba5994

Please sign in to comment.