diff --git a/compose/neurosynth_compose/models/analysis.py b/compose/neurosynth_compose/models/analysis.py index ca08a96ab..98d23d8b7 100644 --- a/compose/neurosynth_compose/models/analysis.py +++ b/compose/neurosynth_compose/models/analysis.py @@ -48,7 +48,11 @@ class SpecificationCondition(BaseMixin, db.Model): db.Text, db.ForeignKey("conditions.id"), index=True, primary_key=True ) condition = relationship("Condition", backref=backref("specification_conditions")) - specification = relationship("Specification", backref=backref("specification_conditions")) + specification = relationship( + "Specification", backref=backref("specification_conditions") + ) + user_id = db.Column(db.Text, db.ForeignKey("users.external_id")) + user = relationship("User", backref=backref("specification_conditions")) class Specification(BaseMixin, db.Model): @@ -59,6 +63,7 @@ class Specification(BaseMixin, db.Model): filter = db.Column(db.Text) weights = association_proxy("specification_conditions", "weight") conditions = association_proxy("specification_conditions", "condition") + database_studyset = db.Column(db.Text) corrector = db.Column(db.JSON) user_id = db.Column(db.Text, db.ForeignKey("users.external_id")) user = relationship("User", backref=backref("specifications")) diff --git a/compose/neurosynth_compose/resources/analysis.py b/compose/neurosynth_compose/resources/analysis.py index aa21c9451..6507e6bba 100644 --- a/compose/neurosynth_compose/resources/analysis.py +++ b/compose/neurosynth_compose/resources/analysis.py @@ -1,4 +1,3 @@ - from collections import ChainMap import pathlib from operator import itemgetter @@ -119,7 +118,10 @@ def update_or_create(cls, data, id=None, commit=True): only_ids = set(data.keys()) - set(["id"]) == set() if cls._model is Condition: - record = cls._model.query.filter_by(name=data.get('name')).first() or cls._model() + record = ( + cls._model.query.filter_by(name=data.get("name")).first() + or cls._model() + ) if id is None: record = cls._model() record.user = current_user @@ -149,7 +151,8 @@ def update_or_create(cls, data, id=None, commit=True): # get nested attributes nested_keys = [ - item for key in cls._nested.keys() + item + for key in cls._nested.keys() for item in (key if isinstance(key, tuple) else (key,)) ] @@ -164,25 +167,42 @@ def update_or_create(cls, data, id=None, commit=True): # Update nested attributes recursively for field, res_name in cls._nested.items(): field = (field,) if not isinstance(field, tuple) else field + if set(data.keys()).issubset(field): + field = (list(data.keys())[0],) + try: rec_data = itemgetter(*field)(data) except KeyError: rec_data = None ResCls = globals()[res_name] + if rec_data is not None: if isinstance(rec_data, tuple): rec_data = [dict(ChainMap(*rc)) for rc in zip(*rec_data)] + # get ids of existing nested attributes + existing_nested = None + if cls._attribute_name: + existing_nested = getattr(record, cls._attribute_name, None) + + if existing_nested and len(existing_nested) == len(rec_data): + _ = [ + rd.update({"id": ns.id}) + for rd, ns in zip( + rec_data, getattr(record, cls._attribute_name) + ) + ] if isinstance(rec_data, list): nested = [ - ResCls.update_or_create(rec, commit=False) - for rec in rec_data + ResCls.update_or_create(rec, commit=False) for rec in rec_data ] to_commit.extend(nested) else: nested = ResCls.update_or_create(rec_data, commit=False) to_commit.append(nested) - update_field = field if len(field) == 1 else (cls._attribute_name,) + update_field = ( + field if not cls._attribute_name else (cls._attribute_name,) + ) for f in update_field: setattr(record, f, nested) diff --git a/compose/neurosynth_compose/schemas/analysis.py b/compose/neurosynth_compose/schemas/analysis.py index 3e6c4d076..bda8aa5b6 100644 --- a/compose/neurosynth_compose/schemas/analysis.py +++ b/compose/neurosynth_compose/schemas/analysis.py @@ -130,10 +130,7 @@ class ConditionSchema(Schema): description = PGSQLString() -class SpecificationConditionSchema(Schema): - id = PGSQLString() - created_at = fields.DateTime() - updated_at = fields.DateTime(allow_none=True) +class SpecificationConditionSchema(BaseSchema): condition = fields.Pluck(ConditionSchema, "name") weight = fields.Number() @@ -152,7 +149,7 @@ class StudysetReferenceSchema(Schema): exclude=("snapshot",), metadata={"pluck": "id"}, many=True, - dump_only=True + dump_only=True, ) @@ -165,6 +162,7 @@ class SpecificationSchema(BaseSchema): mask = PGSQLString(allow_none=True) transformer = PGSQLString(allow_none=True) estimator = fields.Nested("EstimatorSchema") + database_studyset = PGSQLString(allow_none=True) contrast = PGSQLString(allow_none=True) filter = PGSQLString(allow_none=True) corrector = fields.Dict(allow_none=True) @@ -178,11 +176,7 @@ class SpecificationSchema(BaseSchema): data_key="conditions", ) conditions = fields.Pluck( - ConditionSchema, - "name", - many=True, - allow_none=True, - dump_only=True + ConditionSchema, "name", many=True, allow_none=True, dump_only=True ) weights = fields.List( fields.Float(), @@ -213,7 +207,7 @@ def to_bool(self, data, **kwargs): output_conditions[i] = True elif cond.lower() == "false": output_conditions[i] = False - data['conditions'] = conditions + data["conditions"] = conditions return data @@ -224,10 +218,10 @@ def to_string(self, data, **kwargs): output_conditions = conditions[:] for i, cond in enumerate(conditions): if cond is True: - output_conditions[i] = 'true' + output_conditions[i] = "true" elif cond is False: - output_conditions[i] = 'false' - data['conditions'] = output_conditions + output_conditions[i] = "false" + data["conditions"] = output_conditions return data diff --git a/compose/neurosynth_compose/tests/api/test_specification.py b/compose/neurosynth_compose/tests/api/test_specification.py index 7620cc90d..f79551b2d 100644 --- a/compose/neurosynth_compose/tests/api/test_specification.py +++ b/compose/neurosynth_compose/tests/api/test_specification.py @@ -25,7 +25,7 @@ def test_get_specification(session, app, auth_client, user_data): "corrector": {"type": "FDRCorrector"}, "filter": "eyes", }, - ] + ], ) def test_create_and_get_spec(session, app, auth_client, user_data, specification_data): create_spec = auth_client.post("/api/specifications", data=specification_data) @@ -35,3 +35,44 @@ def test_create_and_get_spec(session, app, auth_client, user_data, specification view_spec = auth_client.get(f"/api/specifications/{create_spec.json['id']}") assert create_spec.json == view_spec.json + + +@pytest.mark.parametrize( + "attribute,value", + [ + ("estimator", {"type": "MKDA"}), + ("type", "ibma"), + ("conditions", ["yes", "no"]), + ("weights", [1, 1]), + ("corrector", {"type": "FWECorrector"}), + ("filter", "bunny"), + ("database_studyset", "neurostore"), + ], +) +def test_update_spec(session, app, auth_client, user_data, attribute, value): + specification_data = { + "estimator": {"type": "ALE"}, + "type": "cbma", + "conditions": ["open", "closed"], + "weights": [1, -1], + "corrector": {"type": "FDRCorrector"}, + "filter": "eyes", + } + create_spec = auth_client.post("/api/specifications", data=specification_data) + + assert create_spec.status_code == 200 + + spec_id = create_spec.json["id"] + + update_spec = auth_client.put( + f"/api/specifications/{spec_id}", data={attribute: value} + ) + assert update_spec.status_code == 200 + + get_spec = auth_client.get(f"/api/specifications/{spec_id}") + assert get_spec.status_code == 200 + + if isinstance(value, list): + assert set(get_spec.json[attribute]) == set(value) + else: + assert get_spec.json[attribute] == value diff --git a/compose/neurosynth_compose/tests/api/test_studyset_reference.py b/compose/neurosynth_compose/tests/api/test_studyset_reference.py index bd2752989..8f53d2a7e 100644 --- a/compose/neurosynth_compose/tests/api/test_studyset_reference.py +++ b/compose/neurosynth_compose/tests/api/test_studyset_reference.py @@ -1,8 +1,7 @@ - def test_studyset_references(session, app, auth_client, user_data): nonnested = auth_client.get("/api/studyset-references?nested=false") nested = auth_client.get("/api/studyset-references?nested=true") assert nonnested.status_code == nested.status_code == 200 - assert isinstance(nonnested.json['results'][0]['studysets'][0], str) - assert isinstance(nested.json['results'][0]['studysets'][0], dict) + assert isinstance(nonnested.json["results"][0]["studysets"][0], str) + assert isinstance(nested.json["results"][0]["studysets"][0], dict)