Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ref/speed up ingestion #616

Merged
merged 20 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions store/neurostore/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from pathlib import Path
from werkzeug.middleware.profiler import ProfilerMiddleware

from connexion.middleware import MiddlewarePosition
from starlette.middleware.cors import CORSMiddleware
Expand All @@ -10,11 +9,45 @@
# from connexion.json_schema import default_handlers as json_schema_handlers
from connexion.resolver import MethodResolver
from flask_caching import Cache
import sqltap.wsgi

from .or_json import ORJSONDecoder, ORJSONEncoder
from .database import init_db

# from datetime import datetime

# import sqltap.wsgi
# import sqltap
# import yappi

# class SQLTapMiddleware:
# def __init__(self, app):
# self.app = app

# async def __call__(self, scope, receive, send):
# profiler = sqltap.start()
# await self.app(scope, receive, send)
# statistics = profiler.collect()
# sqltap.report(statistics, "report.txt", report_format="text")


# class LineProfilerMiddleware:
# def __init__(self, app):
# self.app = app

# async def __call__(self, scope, receive, send):
# yappi.start()
# await self.app(scope, receive, send)
# yappi.stop()
# filename = (
# scope["path"].lstrip("/").rstrip("/").replace("/", "-")
# + "-"
# + scope["method"].lower()
# + str(datetime.now())
# + ".prof"
# )
# stats = yappi.get_func_stats()
# stats.save(filename, type="pstat")


connexion_app = connexion.FlaskApp(__name__, specification_dir="openapi/")

Expand Down Expand Up @@ -45,6 +78,16 @@
allow_headers=["*"],
)

# add sqltap
# connexion_app.add_middleware(
# SQLTapMiddleware,
# )

# add profiling
# connexion_app.add_middleware(
# LineProfilerMiddleware
# )

connexion_app.add_api(
openapi_file,
base_path="/api",
Expand All @@ -68,9 +111,5 @@
},
)

if app.debug:
app.wsgi_app = sqltap.wsgi.SQLTapMiddleware(app.wsgi_app, path="/api/__sqltap__")
app = ProfilerMiddleware(app)

app.json_encoder = ORJSONEncoder
app.json_decoder = ORJSONDecoder
3 changes: 1 addition & 2 deletions store/neurostore/models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ class BaseStudy(BaseMixin, db.Model):

user = relationship("User", backref=backref("base_studies"))
# retrieve versions of same study
versions = relationship(
"Study", backref=backref("base_study"))
versions = relationship("Study", backref=backref("base_study"))

def update_has_images_and_points(self):
# Calculate has_images and has_coordinates for the BaseStudy
Expand Down
81 changes: 74 additions & 7 deletions store/neurostore/models/event_listeners.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import inspect
from sqlalchemy.orm import joinedload
from flask_sqlalchemy.session import Session
from sqlalchemy import event
from .data import (
AnnotationAnalysis,
Annotation,
Studyset,
StudysetStudy,
BaseStudy,
Study,
Analysis,
Point,
Image,
_check_type,
)

from ..database import db


Expand Down Expand Up @@ -64,6 +68,27 @@ def create_blank_notes(studyset, annotation, initiator):


def add_annotation_analyses_studyset(studyset, studies, collection_adapter):
if not (inspect(studyset).pending or inspect(studyset).transient):
studyset = (
Studyset.query.filter_by(id=studyset.id)
.options(
joinedload(Studyset.studies).options(joinedload(Study.analyses)),
joinedload(Studyset.annotations),
)
.one()
)
all_studies = set(studyset.studies + studies)
existing_studies = [
s for s in all_studies if not (inspect(s).pending or inspect(s).transient)
]
study_query = (
Study.query.filter(Study.id.in_([s.id for s in existing_studies]))
.options(joinedload(Study.analyses))
.all()
)

all_studies.union(set(study_query))

all_analyses = [analysis for study in studies for analysis in study.analyses]
existing_analyses = [
analysis for study in studyset.studies for analysis in study.analyses
Expand Down Expand Up @@ -91,6 +116,17 @@ def add_annotation_analyses_studyset(studyset, studies, collection_adapter):


def add_annotation_analyses_study(study, analyses, collection_adapter):
if not (inspect(study).pending or inspect(study).transient):
study = (
Study.query.filter_by(id=study.id)
.options(
joinedload(Study.analyses),
joinedload(Study.studyset_studies)
.joinedload(StudysetStudy.studyset)
.joinedload(Studyset.annotations),
)
.one()
)
new_analyses = set(analyses) - set([a for a in study.analyses])

all_annotations = set(
Expand Down Expand Up @@ -150,14 +186,31 @@ def get_nested_attr(obj, nested_attr):

def get_base_study(obj):
base_study = None

if isinstance(obj, (Point, Image)):
base_study = get_nested_attr(obj, "analysis.study.base_study")
if isinstance(obj, Analysis):
base_study = get_nested_attr(obj, "study.base_study")
if isinstance(obj, Study):
base_study = obj.base_study
if isinstance(obj, BaseStudy):
base_study = obj
if obj in session.new or session.deleted:
base_study = get_nested_attr(obj, "analysis.study.base_study")
elif isinstance(obj, Analysis):
relevant_attrs = ("study", "points", "images")
for attr in relevant_attrs:
attr_history = get_nested_attr(inspect(obj), f"attrs.{attr}.history")
if attr_history.added or attr_history.deleted:
base_study = get_nested_attr(obj, "study.base_study")
break
elif isinstance(obj, Study):
relevant_attrs = ("base_study", "analyses")
for attr in relevant_attrs:
attr_history = get_nested_attr(inspect(obj), f"attrs.{attr}.history")
if attr_history.added or attr_history.deleted:
base_study = obj.base_study
break
elif isinstance(obj, BaseStudy):
relevant_attrs = ("versions",)
for attr in relevant_attrs:
attr_history = get_nested_attr(inspect(obj), f"attrs.{attr}.history")
if attr_history.added or attr_history.deleted:
base_study = obj
break

return base_study

Expand All @@ -169,4 +222,18 @@ def get_base_study(obj):

# Update the has_images and has_points for each unique BaseStudy
for base_study in unique_base_studies:
if (
inspect(base_study).attrs.versions.history.added
and base_study.has_coordinates is True
and base_study.has_images is True
):
continue

if (
inspect(base_study).attrs.versions.history.deleted
and base_study.has_coordinates is False
and base_study.has_images is False
):
continue

base_study.update_has_images_and_points()
2 changes: 1 addition & 1 deletion store/neurostore/openapi
116 changes: 87 additions & 29 deletions store/neurostore/resources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,18 @@ def post_nested_record_update(record):
"""
return record

def after_update_or_create(self, record):
"""
Processing of a record after updating or creating (defined in specific classes).
"""
return record

@classmethod
def update_or_create(cls, data, id=None, commit=True):
def load_nested_records(cls, data, record=None):
return data

@classmethod
def update_or_create(cls, data, id=None, user=None, record=None, commit=True):
"""
scenerios:
1. cloning a study
Expand All @@ -91,7 +101,7 @@ def update_or_create(cls, data, id=None, commit=True):
# Store all models so we can atomically update in one commit
to_commit = []

current_user = get_current_user()
current_user = user or get_current_user()
if not current_user:
current_user = create_user()

Expand All @@ -104,31 +114,35 @@ def update_or_create(cls, data, id=None, commit=True):

# allow compose bot to make changes
compose_bot = current_app.config["COMPOSE_AUTH0_CLIENT_ID"] + "@clients"
if id is None:
if id is None and record is None:
record = cls._model()
record.user = current_user
else:
elif record is None:
record = cls._model.query.filter_by(id=id).first()
if record is None:
abort(422)
elif (
record.user_id != current_user.external_id
and not only_ids
and current_user.external_id != compose_bot
):
abort(403)
elif only_ids:
to_commit.append(record)

if commit:
db.session.add_all(to_commit)
try:
db.session.commit()
except SQLAlchemyError:
db.session.rollback()
abort(400)

return record

data = cls.load_nested_records(data, record)

if (
not sa.inspect(record).pending
and record.user != current_user
and not only_ids
and current_user.external_id != compose_bot
):
abort(403)
elif only_ids:
to_commit.append(record)

if commit:
db.session.add_all(to_commit)
try:
db.session.commit()
except SQLAlchemyError:
db.session.rollback()
abort(400)

return record

# Update all non-nested attributes
for k, v in data.items():
Expand All @@ -151,7 +165,13 @@ def update_or_create(cls, data, id=None, commit=True):
}
else:
query_args = {"id": v["id"]}
v = LnCls._model.query.filter_by(**query_args).first()

if v.get("preloaded_data"):
v = v["preloaded_data"]
else:
q = LnCls._model.query.filter_by(**query_args)
v = q.first()

if v is None:
abort(400)

Expand All @@ -171,13 +191,40 @@ def update_or_create(cls, data, id=None, commit=True):
ResCls = getattr(viewdata, res_name)
if data.get(field) is not None:
if isinstance(data.get(field), list):
nested = [
ResCls.update_or_create(rec, commit=False)
for rec in data.get(field)
]
nested = []
for rec in data.get(field):
id = None
if isinstance(rec, dict) and rec.get("id"):
id = rec.get("id")
elif isinstance(rec, str):
id = rec
if data.get("preloaded_studies") and id:
nested_record = data["preloaded_studies"].get(id)
else:
nested_record = None
nested.append(
ResCls.update_or_create(
rec,
user=current_user,
record=nested_record,
commit=False,
)
)
to_commit.extend(nested)
else:
nested = ResCls.update_or_create(data.get(field), commit=False)
id = None
rec = data.get(field)
if isinstance(rec, dict) and rec.get("id"):
id = rec.get("id")
elif isinstance(rec, str):
id = rec
if data.get("preloaded_studies") and id:
nested_record = data["preloaded_studies"].get(id)
else:
nested_record = None
nested = ResCls.update_or_create(
rec, user=current_user, record=nested_record, commit=False
)
to_commit.append(nested)

setattr(record, field, nested)
Expand Down Expand Up @@ -298,7 +345,15 @@ def get(self, id):
q = self._model.query
if args["nested"] or self._model is Annotation:
q = q.options(nested_load(self))

if self._model is Annotation:
q = q.options(
joinedload(Annotation.annotation_analyses).options(
joinedload(AnnotationAnalysis.analysis),
joinedload(AnnotationAnalysis.studyset_study).options(
joinedload(StudysetStudy.study)
),
)
)
record = q.filter_by(id=id).first_or_404()
if self._model is Studyset and args["nested"]:
snapshot = StudysetSnapshot()
Expand All @@ -319,6 +374,7 @@ def put(self, id):
with db.session.no_autoflush:
record = self.__class__.update_or_create(data, id)

record = self.after_update_or_create(record)
# clear relevant caches
clear_cache(self.__class__, record, request.path)

Expand Down Expand Up @@ -481,6 +537,8 @@ def post(self):
with db.session.no_autoflush:
record = self.__class__.update_or_create(data)

record = self.after_update_or_create(record)

# clear the cache for this endpoint
clear_cache(self.__class__, record, request.path)

Expand Down
Loading
Loading