Skip to content

Commit

Permalink
[fmt] Format topology module and tests (#4849)
Browse files Browse the repository at this point in the history
  • Loading branch information
RMeli authored Dec 29, 2024
1 parent b710e57 commit 453be6c
Show file tree
Hide file tree
Showing 56 changed files with 2,504 additions and 1,559 deletions.
78 changes: 51 additions & 27 deletions package/MDAnalysis/topology/CRDParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class CRDParser(TopologyReaderBase):
Type and mass are not longer guessed here. Until 3.0 these will still be
set by default through through universe.guess_TopologyAttrs() API.
"""
format = 'CRD'

format = "CRD"

def parse(self, **kwargs):
"""Create the Topology object
Expand All @@ -102,8 +103,10 @@ def parse(self, **kwargs):
----
Could use the resnum and temp factor better
"""
extformat = FORTRANReader('2I10,2X,A8,2X,A8,3F20.10,2X,A8,2X,A8,F20.10')
stdformat = FORTRANReader('2I5,1X,A4,1X,A4,3F10.5,1X,A4,1X,A4,F10.5')
extformat = FORTRANReader(
"2I10,2X,A8,2X,A8,3F20.10,2X,A8,2X,A8,F20.10"
)
stdformat = FORTRANReader("2I5,1X,A4,1X,A4,3F10.5,1X,A4,1X,A4,F10.5")

atomids = []
atomnames = []
Expand All @@ -116,21 +119,36 @@ def parse(self, **kwargs):
with openany(self.filename) as crd:
for linenum, line in enumerate(crd):
# reading header
if line.split()[0] == '*':
if line.split()[0] == "*":
continue
elif line.split()[-1] == 'EXT' and int(line.split()[0]):
elif line.split()[-1] == "EXT" and int(line.split()[0]):
r = extformat
continue
elif line.split()[0] == line.split()[-1] and line.split()[0] != '*':
elif (
line.split()[0] == line.split()[-1]
and line.split()[0] != "*"
):
r = stdformat
continue
# anything else should be an atom
try:
(serial, resnum, resName, name,
x, y, z, segid, resid, tempFactor) = r.read(line)
(
serial,
resnum,
resName,
name,
x,
y,
z,
segid,
resid,
tempFactor,
) = r.read(line)
except Exception:
errmsg = (f"Check CRD format at line {linenum + 1}: "
f"{line.rstrip()}")
errmsg = (
f"Check CRD format at line {linenum + 1}: "
f"{line.rstrip()}"
)
raise ValueError(errmsg) from None

atomids.append(serial)
Expand All @@ -150,22 +168,28 @@ def parse(self, **kwargs):
resnums = np.array(resnums, dtype=np.int32)
segids = np.array(segids, dtype=object)

atom_residx, (res_resids, res_resnames, res_resnums, res_segids) = change_squash(
(resids, resnames), (resids, resnames, resnums, segids))
res_segidx, (seg_segids,) = change_squash(
(res_segids,), (res_segids,))

top = Topology(len(atomids), len(res_resids), len(seg_segids),
attrs=[
Atomids(atomids),
Atomnames(atomnames),
Tempfactors(tempfactors),
Resids(res_resids),
Resnames(res_resnames),
Resnums(res_resnums),
Segids(seg_segids),
],
atom_resindex=atom_residx,
residue_segindex=res_segidx)
atom_residx, (res_resids, res_resnames, res_resnums, res_segids) = (
change_squash(
(resids, resnames), (resids, resnames, resnums, segids)
)
)
res_segidx, (seg_segids,) = change_squash((res_segids,), (res_segids,))

top = Topology(
len(atomids),
len(res_resids),
len(seg_segids),
attrs=[
Atomids(atomids),
Atomnames(atomnames),
Tempfactors(tempfactors),
Resids(res_resids),
Resnames(res_resnames),
Resnums(res_resnums),
Segids(seg_segids),
],
atom_resindex=atom_residx,
residue_segindex=res_segidx,
)

return top
20 changes: 10 additions & 10 deletions package/MDAnalysis/topology/DLPolyParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class ConfigParser(TopologyReaderBase):
Removed type and mass guessing (attributes guessing takes place now
through universe.guess_TopologyAttrs() API).
"""
format = 'CONFIG'

format = "CONFIG"

def parse(self, **kwargs):
with openany(self.filename) as inf:
Expand Down Expand Up @@ -117,10 +118,9 @@ def parse(self, **kwargs):
Atomids(ids),
Resids(np.array([1])),
Resnums(np.array([1])),
Segids(np.array(['SYSTEM'], dtype=object)),
Segids(np.array(["SYSTEM"], dtype=object)),
]
top = Topology(n_atoms, 1, 1,
attrs=attrs)
top = Topology(n_atoms, 1, 1, attrs=attrs)

return top

Expand All @@ -130,7 +130,8 @@ class HistoryParser(TopologyReaderBase):
.. versionadded:: 0.10.1
"""
format = 'HISTORY'

format = "HISTORY"

def parse(self, **kwargs):
with openany(self.filename) as inf:
Expand All @@ -143,10 +144,10 @@ def parse(self, **kwargs):
line = inf.readline()
while not (len(line.split()) == 4 or len(line.split()) == 5):
line = inf.readline()
if line == '':
if line == "":
raise EOFError("End of file reached when reading HISTORY.")

while line and not line.startswith('timestep'):
while line and not line.startswith("timestep"):
name = line[:8].strip()
names.append(name)
try:
Expand Down Expand Up @@ -179,9 +180,8 @@ def parse(self, **kwargs):
Atomids(ids),
Resids(np.array([1])),
Resnums(np.array([1])),
Segids(np.array(['SYSTEM'], dtype=object)),
Segids(np.array(["SYSTEM"], dtype=object)),
]
top = Topology(n_atoms, 1, 1,
attrs=attrs)
top = Topology(n_atoms, 1, 1, attrs=attrs)

return top
99 changes: 54 additions & 45 deletions package/MDAnalysis/topology/DMSParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@

class Atomnums(AtomAttr):
"""The number for each Atom"""
attrname = 'atomnums'
singular = 'atomnum'

attrname = "atomnums"
singular = "atomnum"


class DMSParser(TopologyReaderBase):
Expand Down Expand Up @@ -100,7 +101,8 @@ class DMSParser(TopologyReaderBase):
through universe.guess_TopologyAttrs() API).
"""
format = 'DMS'

format = "DMS"

def parse(self, **kwargs):
"""Parse DMS file *filename* and return the Topology object"""
Expand All @@ -121,28 +123,29 @@ def dict_factory(cursor, row):
attrs = {}

# Row factories for different data types
facs = {np.int32: lambda c, r: r[0],
np.float32: lambda c, r: r[0],
object: lambda c, r: str(r[0].strip())}
facs = {
np.int32: lambda c, r: r[0],
np.float32: lambda c, r: r[0],
object: lambda c, r: str(r[0].strip()),
}

with sqlite3.connect(self.filename) as con:
# Selecting single column, so just strip tuple
for attrname, dt in [
('id', np.int32),
('anum', np.int32),
('mass', np.float32),
('charge', np.float32),
('name', object),
('resname', object),
('resid', np.int32),
('chain', object),
('segid', object),
("id", np.int32),
("anum", np.int32),
("mass", np.float32),
("charge", np.float32),
("name", object),
("resname", object),
("resid", np.int32),
("chain", object),
("segid", object),
]:
try:
cur = con.cursor()
cur.row_factory = facs[dt]
cur.execute('SELECT {} FROM particle'
''.format(attrname))
cur.execute("SELECT {} FROM particle" "".format(attrname))
vals = cur.fetchall()
except sqlite3.DatabaseError:
errmsg = "Failed reading the atoms from DMS Database"
Expand All @@ -152,7 +155,7 @@ def dict_factory(cursor, row):

try:
cur.row_factory = dict_factory
cur.execute('SELECT * FROM bond')
cur.execute("SELECT * FROM bond")
bonds = cur.fetchall()
except sqlite3.DatabaseError:
errmsg = "Failed reading the bonds from DMS Database"
Expand All @@ -161,44 +164,46 @@ def dict_factory(cursor, row):
bondlist = []
bondorder = {}
for b in bonds:
desc = tuple(sorted([b['p0'], b['p1']]))
desc = tuple(sorted([b["p0"], b["p1"]]))
bondlist.append(desc)
bondorder[desc] = b['order']
attrs['bond'] = bondlist
attrs['bondorder'] = bondorder
bondorder[desc] = b["order"]
attrs["bond"] = bondlist
attrs["bondorder"] = bondorder

topattrs = []
# Bundle in Atom level objects
for attr, cls in [
('id', Atomids),
('anum', Atomnums),
('mass', Masses),
('charge', Charges),
('name', Atomnames),
('chain', ChainIDs),
("id", Atomids),
("anum", Atomnums),
("mass", Masses),
("charge", Charges),
("name", Atomnames),
("chain", ChainIDs),
]:
topattrs.append(cls(attrs[attr]))

# Residues
atom_residx, (res_resids,
res_resnums,
res_resnames,
res_segids) = change_squash(
(attrs['resid'], attrs['resname'], attrs['segid']),
(attrs['resid'],
attrs['resid'].copy(),
attrs['resname'],
attrs['segid']),
atom_residx, (res_resids, res_resnums, res_resnames, res_segids) = (
change_squash(
(attrs["resid"], attrs["resname"], attrs["segid"]),
(
attrs["resid"],
attrs["resid"].copy(),
attrs["resname"],
attrs["segid"],
),
)
)

n_residues = len(res_resids)
topattrs.append(Resids(res_resids))
topattrs.append(Resnums(res_resnums))
topattrs.append(Resnames(res_resnames))

if any(res_segids) and not any(val is None for val in res_segids):
res_segidx, (res_segids,) = change_squash((res_segids,),
(res_segids,))
res_segidx, (res_segids,) = change_squash(
(res_segids,), (res_segids,)
)

uniq_seg = np.unique(res_segids)
idx2seg = {idx: res_segids[idx] for idx in res_segidx}
Expand All @@ -211,14 +216,18 @@ def dict_factory(cursor, row):
topattrs.append(Segids(res_segids))
else:
n_segments = 1
topattrs.append(Segids(np.array(['SYSTEM'], dtype=object)))
topattrs.append(Segids(np.array(["SYSTEM"], dtype=object)))
res_segidx = None

topattrs.append(Bonds(attrs['bond']))
topattrs.append(Bonds(attrs["bond"]))

top = Topology(len(attrs['id']), n_residues, n_segments,
attrs=topattrs,
atom_resindex=atom_residx,
residue_segindex=res_segidx)
top = Topology(
len(attrs["id"]),
n_residues,
n_segments,
attrs=topattrs,
atom_resindex=atom_residx,
residue_segindex=res_segidx,
)

return top
Loading

0 comments on commit 453be6c

Please sign in to comment.