Skip to content

Commit

Permalink
Ref/speed up ingestion (#616)
Browse files Browse the repository at this point in the history
* preload more attributes

* wip: speed up PUT

* add sqltap profiling agsi

* do not update has_coordinates or has_images if irrelevant attribute updated

* make openapi more permissive and style

* remove unused import

* be more selective when updating has_coordinates and has_images

* refactor how records are looked up

* preload analyses

* handle loading of annotations

* preload the correct attributes for annotations

* catch more custom annotation loading

* fix annotation loading attempt #1

* attempt #2

* attempt #3

* reassign q

* remove extraneous command, and load studyset

* style fixed

* comment out unused bits
  • Loading branch information
jdkent authored Nov 21, 2023
1 parent 77d3ac4 commit 340d225
Show file tree
Hide file tree
Showing 11 changed files with 361 additions and 65 deletions.
51 changes: 45 additions & 6 deletions store/neurostore/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from pathlib import Path
from werkzeug.middleware.profiler import ProfilerMiddleware

from connexion.middleware import MiddlewarePosition
from starlette.middleware.cors import CORSMiddleware
Expand All @@ -10,11 +9,45 @@
# from connexion.json_schema import default_handlers as json_schema_handlers
from connexion.resolver import MethodResolver
from flask_caching import Cache
import sqltap.wsgi

from .or_json import ORJSONDecoder, ORJSONEncoder
from .database import init_db

# from datetime import datetime

# import sqltap.wsgi
# import sqltap
# import yappi

# class SQLTapMiddleware:
# def __init__(self, app):
# self.app = app

# async def __call__(self, scope, receive, send):
# profiler = sqltap.start()
# await self.app(scope, receive, send)
# statistics = profiler.collect()
# sqltap.report(statistics, "report.txt", report_format="text")


# class LineProfilerMiddleware:
# def __init__(self, app):
# self.app = app

# async def __call__(self, scope, receive, send):
# yappi.start()
# await self.app(scope, receive, send)
# yappi.stop()
# filename = (
# scope["path"].lstrip("/").rstrip("/").replace("/", "-")
# + "-"
# + scope["method"].lower()
# + str(datetime.now())
# + ".prof"
# )
# stats = yappi.get_func_stats()
# stats.save(filename, type="pstat")


connexion_app = connexion.FlaskApp(__name__, specification_dir="openapi/")

Expand Down Expand Up @@ -45,6 +78,16 @@
allow_headers=["*"],
)

# add sqltap
# connexion_app.add_middleware(
# SQLTapMiddleware,
# )

# add profiling
# connexion_app.add_middleware(
# LineProfilerMiddleware
# )

connexion_app.add_api(
openapi_file,
base_path="/api",
Expand All @@ -68,9 +111,5 @@
},
)

if app.debug:
app.wsgi_app = sqltap.wsgi.SQLTapMiddleware(app.wsgi_app, path="/api/__sqltap__")
app = ProfilerMiddleware(app)

app.json_encoder = ORJSONEncoder
app.json_decoder = ORJSONDecoder
3 changes: 1 addition & 2 deletions store/neurostore/models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ class BaseStudy(BaseMixin, db.Model):

user = relationship("User", backref=backref("base_studies"))
# retrieve versions of same study
versions = relationship(
"Study", backref=backref("base_study"))
versions = relationship("Study", backref=backref("base_study"))

def update_has_images_and_points(self):
# Calculate has_images and has_coordinates for the BaseStudy
Expand Down
81 changes: 74 additions & 7 deletions store/neurostore/models/event_listeners.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import inspect
from sqlalchemy.orm import joinedload
from flask_sqlalchemy.session import Session
from sqlalchemy import event
from .data import (
AnnotationAnalysis,
Annotation,
Studyset,
StudysetStudy,
BaseStudy,
Study,
Analysis,
Point,
Image,
_check_type,
)

from ..database import db


Expand Down Expand Up @@ -64,6 +68,27 @@ def create_blank_notes(studyset, annotation, initiator):


def add_annotation_analyses_studyset(studyset, studies, collection_adapter):
if not (inspect(studyset).pending or inspect(studyset).transient):
studyset = (
Studyset.query.filter_by(id=studyset.id)
.options(
joinedload(Studyset.studies).options(joinedload(Study.analyses)),
joinedload(Studyset.annotations),
)
.one()
)
all_studies = set(studyset.studies + studies)
existing_studies = [
s for s in all_studies if not (inspect(s).pending or inspect(s).transient)
]
study_query = (
Study.query.filter(Study.id.in_([s.id for s in existing_studies]))
.options(joinedload(Study.analyses))
.all()
)

all_studies.union(set(study_query))

all_analyses = [analysis for study in studies for analysis in study.analyses]
existing_analyses = [
analysis for study in studyset.studies for analysis in study.analyses
Expand Down Expand Up @@ -91,6 +116,17 @@ def add_annotation_analyses_studyset(studyset, studies, collection_adapter):


def add_annotation_analyses_study(study, analyses, collection_adapter):
if not (inspect(study).pending or inspect(study).transient):
study = (
Study.query.filter_by(id=study.id)
.options(
joinedload(Study.analyses),
joinedload(Study.studyset_studies)
.joinedload(StudysetStudy.studyset)
.joinedload(Studyset.annotations),
)
.one()
)
new_analyses = set(analyses) - set([a for a in study.analyses])

all_annotations = set(
Expand Down Expand Up @@ -150,14 +186,31 @@ def get_nested_attr(obj, nested_attr):

def get_base_study(obj):
base_study = None

if isinstance(obj, (Point, Image)):
base_study = get_nested_attr(obj, "analysis.study.base_study")
if isinstance(obj, Analysis):
base_study = get_nested_attr(obj, "study.base_study")
if isinstance(obj, Study):
base_study = obj.base_study
if isinstance(obj, BaseStudy):
base_study = obj
if obj in session.new or session.deleted:
base_study = get_nested_attr(obj, "analysis.study.base_study")
elif isinstance(obj, Analysis):
relevant_attrs = ("study", "points", "images")
for attr in relevant_attrs:
attr_history = get_nested_attr(inspect(obj), f"attrs.{attr}.history")
if attr_history.added or attr_history.deleted:
base_study = get_nested_attr(obj, "study.base_study")
break
elif isinstance(obj, Study):
relevant_attrs = ("base_study", "analyses")
for attr in relevant_attrs:
attr_history = get_nested_attr(inspect(obj), f"attrs.{attr}.history")
if attr_history.added or attr_history.deleted:
base_study = obj.base_study
break
elif isinstance(obj, BaseStudy):
relevant_attrs = ("versions",)
for attr in relevant_attrs:
attr_history = get_nested_attr(inspect(obj), f"attrs.{attr}.history")
if attr_history.added or attr_history.deleted:
base_study = obj
break

return base_study

Expand All @@ -169,4 +222,18 @@ def get_base_study(obj):

# Update the has_images and has_points for each unique BaseStudy
for base_study in unique_base_studies:
if (
inspect(base_study).attrs.versions.history.added
and base_study.has_coordinates is True
and base_study.has_images is True
):
continue

if (
inspect(base_study).attrs.versions.history.deleted
and base_study.has_coordinates is False
and base_study.has_images is False
):
continue

base_study.update_has_images_and_points()
2 changes: 1 addition & 1 deletion store/neurostore/openapi
116 changes: 87 additions & 29 deletions store/neurostore/resources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,18 @@ def post_nested_record_update(record):
"""
return record

def after_update_or_create(self, record):
"""
Processing of a record after updating or creating (defined in specific classes).
"""
return record

@classmethod
def update_or_create(cls, data, id=None, commit=True):
def load_nested_records(cls, data, record=None):
return data

@classmethod
def update_or_create(cls, data, id=None, user=None, record=None, commit=True):
"""
scenerios:
1. cloning a study
Expand All @@ -91,7 +101,7 @@ def update_or_create(cls, data, id=None, commit=True):
# Store all models so we can atomically update in one commit
to_commit = []

current_user = get_current_user()
current_user = user or get_current_user()
if not current_user:
current_user = create_user()

Expand All @@ -104,31 +114,35 @@ def update_or_create(cls, data, id=None, commit=True):

# allow compose bot to make changes
compose_bot = current_app.config["COMPOSE_AUTH0_CLIENT_ID"] + "@clients"
if id is None:
if id is None and record is None:
record = cls._model()
record.user = current_user
else:
elif record is None:
record = cls._model.query.filter_by(id=id).first()
if record is None:
abort(422)
elif (
record.user_id != current_user.external_id
and not only_ids
and current_user.external_id != compose_bot
):
abort(403)
elif only_ids:
to_commit.append(record)

if commit:
db.session.add_all(to_commit)
try:
db.session.commit()
except SQLAlchemyError:
db.session.rollback()
abort(400)

return record

data = cls.load_nested_records(data, record)

if (
not sa.inspect(record).pending
and record.user != current_user
and not only_ids
and current_user.external_id != compose_bot
):
abort(403)
elif only_ids:
to_commit.append(record)

if commit:
db.session.add_all(to_commit)
try:
db.session.commit()
except SQLAlchemyError:
db.session.rollback()
abort(400)

return record

# Update all non-nested attributes
for k, v in data.items():
Expand All @@ -151,7 +165,13 @@ def update_or_create(cls, data, id=None, commit=True):
}
else:
query_args = {"id": v["id"]}
v = LnCls._model.query.filter_by(**query_args).first()

if v.get("preloaded_data"):
v = v["preloaded_data"]
else:
q = LnCls._model.query.filter_by(**query_args)
v = q.first()

if v is None:
abort(400)

Expand All @@ -171,13 +191,40 @@ def update_or_create(cls, data, id=None, commit=True):
ResCls = getattr(viewdata, res_name)
if data.get(field) is not None:
if isinstance(data.get(field), list):
nested = [
ResCls.update_or_create(rec, commit=False)
for rec in data.get(field)
]
nested = []
for rec in data.get(field):
id = None
if isinstance(rec, dict) and rec.get("id"):
id = rec.get("id")
elif isinstance(rec, str):
id = rec
if data.get("preloaded_studies") and id:
nested_record = data["preloaded_studies"].get(id)
else:
nested_record = None
nested.append(
ResCls.update_or_create(
rec,
user=current_user,
record=nested_record,
commit=False,
)
)
to_commit.extend(nested)
else:
nested = ResCls.update_or_create(data.get(field), commit=False)
id = None
rec = data.get(field)
if isinstance(rec, dict) and rec.get("id"):
id = rec.get("id")
elif isinstance(rec, str):
id = rec
if data.get("preloaded_studies") and id:
nested_record = data["preloaded_studies"].get(id)
else:
nested_record = None
nested = ResCls.update_or_create(
rec, user=current_user, record=nested_record, commit=False
)
to_commit.append(nested)

setattr(record, field, nested)
Expand Down Expand Up @@ -298,7 +345,15 @@ def get(self, id):
q = self._model.query
if args["nested"] or self._model is Annotation:
q = q.options(nested_load(self))

if self._model is Annotation:
q = q.options(
joinedload(Annotation.annotation_analyses).options(
joinedload(AnnotationAnalysis.analysis),
joinedload(AnnotationAnalysis.studyset_study).options(
joinedload(StudysetStudy.study)
),
)
)
record = q.filter_by(id=id).first_or_404()
if self._model is Studyset and args["nested"]:
snapshot = StudysetSnapshot()
Expand All @@ -319,6 +374,7 @@ def put(self, id):
with db.session.no_autoflush:
record = self.__class__.update_or_create(data, id)

record = self.after_update_or_create(record)
# clear relevant caches
clear_cache(self.__class__, record, request.path)

Expand Down Expand Up @@ -481,6 +537,8 @@ def post(self):
with db.session.no_autoflush:
record = self.__class__.update_or_create(data)

record = self.after_update_or_create(record)

# clear the cache for this endpoint
clear_cache(self.__class__, record, request.path)

Expand Down
Loading

0 comments on commit 340d225

Please sign in to comment.