Skip to content

Commit

Permalink
fix spec updates (#617)
Browse files Browse the repository at this point in the history
* fix spec updates

* fix flaky test
  • Loading branch information
jdkent authored Nov 11, 2023
1 parent 7795865 commit ece81ce
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 25 deletions.
7 changes: 6 additions & 1 deletion compose/neurosynth_compose/models/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ class SpecificationCondition(BaseMixin, db.Model):
db.Text, db.ForeignKey("conditions.id"), index=True, primary_key=True
)
condition = relationship("Condition", backref=backref("specification_conditions"))
specification = relationship("Specification", backref=backref("specification_conditions"))
specification = relationship(
"Specification", backref=backref("specification_conditions")
)
user_id = db.Column(db.Text, db.ForeignKey("users.external_id"))
user = relationship("User", backref=backref("specification_conditions"))


class Specification(BaseMixin, db.Model):
Expand All @@ -59,6 +63,7 @@ class Specification(BaseMixin, db.Model):
filter = db.Column(db.Text)
weights = association_proxy("specification_conditions", "weight")
conditions = association_proxy("specification_conditions", "condition")
database_studyset = db.Column(db.Text)
corrector = db.Column(db.JSON)
user_id = db.Column(db.Text, db.ForeignKey("users.external_id"))
user = relationship("User", backref=backref("specifications"))
Expand Down
32 changes: 26 additions & 6 deletions compose/neurosynth_compose/resources/analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from collections import ChainMap
import pathlib
from operator import itemgetter
Expand Down Expand Up @@ -119,7 +118,10 @@ def update_or_create(cls, data, id=None, commit=True):
only_ids = set(data.keys()) - set(["id"]) == set()

if cls._model is Condition:
record = cls._model.query.filter_by(name=data.get('name')).first() or cls._model()
record = (
cls._model.query.filter_by(name=data.get("name")).first()
or cls._model()
)
if id is None:
record = cls._model()
record.user = current_user
Expand Down Expand Up @@ -149,7 +151,8 @@ def update_or_create(cls, data, id=None, commit=True):

# get nested attributes
nested_keys = [
item for key in cls._nested.keys()
item
for key in cls._nested.keys()
for item in (key if isinstance(key, tuple) else (key,))
]

Expand All @@ -164,25 +167,42 @@ def update_or_create(cls, data, id=None, commit=True):
# Update nested attributes recursively
for field, res_name in cls._nested.items():
field = (field,) if not isinstance(field, tuple) else field
if set(data.keys()).issubset(field):
field = (list(data.keys())[0],)

try:
rec_data = itemgetter(*field)(data)
except KeyError:
rec_data = None

ResCls = globals()[res_name]

if rec_data is not None:
if isinstance(rec_data, tuple):
rec_data = [dict(ChainMap(*rc)) for rc in zip(*rec_data)]
# get ids of existing nested attributes
existing_nested = None
if cls._attribute_name:
existing_nested = getattr(record, cls._attribute_name, None)

if existing_nested and len(existing_nested) == len(rec_data):
_ = [
rd.update({"id": ns.id})
for rd, ns in zip(
rec_data, getattr(record, cls._attribute_name)
)
]
if isinstance(rec_data, list):
nested = [
ResCls.update_or_create(rec, commit=False)
for rec in rec_data
ResCls.update_or_create(rec, commit=False) for rec in rec_data
]
to_commit.extend(nested)
else:
nested = ResCls.update_or_create(rec_data, commit=False)
to_commit.append(nested)
update_field = field if len(field) == 1 else (cls._attribute_name,)
update_field = (
field if not cls._attribute_name else (cls._attribute_name,)
)
for f in update_field:
setattr(record, f, nested)

Expand Down
22 changes: 8 additions & 14 deletions compose/neurosynth_compose/schemas/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,7 @@ class ConditionSchema(Schema):
description = PGSQLString()


class SpecificationConditionSchema(Schema):
id = PGSQLString()
created_at = fields.DateTime()
updated_at = fields.DateTime(allow_none=True)
class SpecificationConditionSchema(BaseSchema):
condition = fields.Pluck(ConditionSchema, "name")
weight = fields.Number()

Expand All @@ -152,7 +149,7 @@ class StudysetReferenceSchema(Schema):
exclude=("snapshot",),
metadata={"pluck": "id"},
many=True,
dump_only=True
dump_only=True,
)


Expand All @@ -165,6 +162,7 @@ class SpecificationSchema(BaseSchema):
mask = PGSQLString(allow_none=True)
transformer = PGSQLString(allow_none=True)
estimator = fields.Nested("EstimatorSchema")
database_studyset = PGSQLString(allow_none=True)
contrast = PGSQLString(allow_none=True)
filter = PGSQLString(allow_none=True)
corrector = fields.Dict(allow_none=True)
Expand All @@ -178,11 +176,7 @@ class SpecificationSchema(BaseSchema):
data_key="conditions",
)
conditions = fields.Pluck(
ConditionSchema,
"name",
many=True,
allow_none=True,
dump_only=True
ConditionSchema, "name", many=True, allow_none=True, dump_only=True
)
weights = fields.List(
fields.Float(),
Expand Down Expand Up @@ -213,7 +207,7 @@ def to_bool(self, data, **kwargs):
output_conditions[i] = True
elif cond.lower() == "false":
output_conditions[i] = False
data['conditions'] = conditions
data["conditions"] = conditions

return data

Expand All @@ -224,10 +218,10 @@ def to_string(self, data, **kwargs):
output_conditions = conditions[:]
for i, cond in enumerate(conditions):
if cond is True:
output_conditions[i] = 'true'
output_conditions[i] = "true"
elif cond is False:
output_conditions[i] = 'false'
data['conditions'] = output_conditions
output_conditions[i] = "false"
data["conditions"] = output_conditions

return data

Expand Down
43 changes: 42 additions & 1 deletion compose/neurosynth_compose/tests/api/test_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_get_specification(session, app, auth_client, user_data):
"corrector": {"type": "FDRCorrector"},
"filter": "eyes",
},
]
],
)
def test_create_and_get_spec(session, app, auth_client, user_data, specification_data):
create_spec = auth_client.post("/api/specifications", data=specification_data)
Expand All @@ -35,3 +35,44 @@ def test_create_and_get_spec(session, app, auth_client, user_data, specification
view_spec = auth_client.get(f"/api/specifications/{create_spec.json['id']}")

assert create_spec.json == view_spec.json


@pytest.mark.parametrize(
"attribute,value",
[
("estimator", {"type": "MKDA"}),
("type", "ibma"),
("conditions", ["yes", "no"]),
("weights", [1, 1]),
("corrector", {"type": "FWECorrector"}),
("filter", "bunny"),
("database_studyset", "neurostore"),
],
)
def test_update_spec(session, app, auth_client, user_data, attribute, value):
specification_data = {
"estimator": {"type": "ALE"},
"type": "cbma",
"conditions": ["open", "closed"],
"weights": [1, -1],
"corrector": {"type": "FDRCorrector"},
"filter": "eyes",
}
create_spec = auth_client.post("/api/specifications", data=specification_data)

assert create_spec.status_code == 200

spec_id = create_spec.json["id"]

update_spec = auth_client.put(
f"/api/specifications/{spec_id}", data={attribute: value}
)
assert update_spec.status_code == 200

get_spec = auth_client.get(f"/api/specifications/{spec_id}")
assert get_spec.status_code == 200

if isinstance(value, list):
assert set(get_spec.json[attribute]) == set(value)
else:
assert get_spec.json[attribute] == value
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@

def test_studyset_references(session, app, auth_client, user_data):
nonnested = auth_client.get("/api/studyset-references?nested=false")
nested = auth_client.get("/api/studyset-references?nested=true")

assert nonnested.status_code == nested.status_code == 200
assert isinstance(nonnested.json['results'][0]['studysets'][0], str)
assert isinstance(nested.json['results'][0]['studysets'][0], dict)
assert isinstance(nonnested.json["results"][0]["studysets"][0], str)
assert isinstance(nested.json["results"][0]["studysets"][0], dict)

0 comments on commit ece81ce

Please sign in to comment.