Skip to content

Commit

Permalink
Add create, update and delete subscriptions
Browse files Browse the repository at this point in the history
Use graphene-luna for testing purposes.

Also add a generic "DjangoSignalSubscription" type that allows you to subscribe to any Django signal.

Signed-off-by: Tormod Haugland <tormod.haugland@gmail.com>
  • Loading branch information
tOgg1 committed Sep 16, 2024
1 parent d8f6a7c commit b6b98e7
Show file tree
Hide file tree
Showing 9 changed files with 694 additions and 2 deletions.
2 changes: 2 additions & 0 deletions graphene_django_cud/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@

USE_ID_SUFFIXES_FOR_FK_SETTINGS_KEY = "GRAPHENE_DJANGO_CUD_USE_ID_SUFFIXES_FOR_FK"
USE_ID_SUFFIXES_FOR_M2M_SETTINGS_KEY = "GRAPHENE_DJANGO_CUD_USE_ID_SUFFIXES_FOR_M2M"

USE_MUTATION_SIGNALS_FOR_SUBSCRIPTIONS_KEY = "GRAPHENE_DJANGO_CUD_USE_MUTATION_SIGNALS_FOR_SUBSCRIPTIONS"
58 changes: 58 additions & 0 deletions graphene_django_cud/subscriptions/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import graphene
from graphql import GraphQLError


class SubscriptionField(graphene.Field):
"""
This is an extension of the graphene.Field class that exists
to allow our DjangoCudSubscriptionBase classes to pass a subscribe
method to the Field instantiation, which we use here in the
`wrap_subscribe` method. `wrap_subscribe` is called internally in graphene
to figure out which resolver to use for a subscription field.
"""

def __init__(self, *args, subscribe=None, **kwargs):
self.subscribe = subscribe
super().__init__(*args, **kwargs)

def wrap_subscribe(self, parent_subscribe):
return self.subscribe


class DjangoCudSubscriptionBase(graphene.ObjectType):
"""Base class for DjangoCud subscriptions"""

@classmethod
def get_permissions(cls, root, info, *args, **kwargs):
return cls._meta.permissions

@classmethod
def check_permissions(cls, root, info, *args, **kwargs) -> None:
get_permissions = getattr(cls, "get_permissions", None)
if not callable(get_permissions):
raise TypeError("The `get_permissions` attribute of a subscription must be callable.")

permissions = cls.get_permissions(root, info, *args, **kwargs)

if permissions and len(permissions) > 0:
if not info.context.user.has_perms(permissions):
raise GraphQLError("Not permitted to access this subscription.")

@classmethod
def Field(cls, name=None, description=None, deprecation_reason=None, required=False):
"""Create a field for the subscription that automatically creates a subscription resolver"""
return SubscriptionField(
cls._meta.output,
resolver=cls._meta.resolver,
subscribe=cls._meta.subscribe,
name=name,
description=description or cls._meta.description,
deprecation_reason=deprecation_reason,
required=required,
)

@classmethod
async def subscribe(cls, *args, **kwargs):
"""Dummy subscribe method. Must be implemented by subclasses"""
raise NotImplementedError("`subscribe` must be implemented by the implementing subclass. "
"This is likely a bug in graphene-django-cud.")
129 changes: 129 additions & 0 deletions graphene_django_cud/subscriptions/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import asyncio
from collections import OrderedDict
from typing import Optional

import graphene
from asgiref.sync import async_to_sync
from django.conf import settings
from django.db.models.signals import post_save
from django.dispatch import Signal
from graphene.types.objecttype import ObjectTypeOptions
from graphene_django.registry import get_global_registry

from graphene_django_cud.consts import USE_MUTATION_SIGNALS_FOR_SUBSCRIPTIONS_KEY
from graphene_django_cud.signals import post_create_mutation
from graphene_django_cud.subscriptions.core import DjangoCudSubscriptionBase
from graphene_django_cud.util import to_snake_case


class DjangoCreateSubscriptionOptions(ObjectTypeOptions):
model = None
return_field_name = None
permissions = None
signal: Optional[Signal] = None


class DjangoCreateSubscription(DjangoCudSubscriptionBase):
# All active subscriptions are stored in this centralized dictionary.
# We need to do this to keep track of which subscriptions are listening to
# which signals.
subscribers = {}

@classmethod
def __init_subclass_with_meta__(
cls,
_meta=None,
model=None,
permissions=None,
return_field_name=None,
signal=post_create_mutation if getattr(
settings,
USE_MUTATION_SIGNALS_FOR_SUBSCRIPTIONS_KEY,
False
) else post_save,
**kwargs,
):
registry = get_global_registry()
model_type = registry.get_type_for_model(model)

if not _meta:
_meta = DjangoCreateSubscriptionOptions(cls)

if not return_field_name:
return_field_name = to_snake_case(model.__name__)

output_fields = OrderedDict()
output_fields[return_field_name] = graphene.Field(model_type)

_meta.model = model
_meta.model_type = model_type
_meta.fields = output_fields
_meta.output = cls
_meta.permissions = permissions

# Importantly, this needs to be set to either nothing or the identity.
# Internally in graphene it will be defaulted to the identity function. If it
# isn't, graphene will try to pass the value resolve from the "subscribe" method
# through this resolver. If it is also set to "subscribe", we will get an issue with
# graphene trying to return an AsyncIterator.
_meta.resolver = None

# This is set to be the subscription resolver in the SubscriptionField class.
_meta.subscribe = cls.subscribe
_meta.return_field_name = return_field_name

# Connect to the model's post_save (or your custom) signal
signal.connect(cls._model_created_handler, sender=model)

super().__init_subclass_with_meta__(_meta=_meta, **kwargs)

@classmethod
def _model_created_handler(cls, sender, instance, created=None, **kwargs):
"""Handle model creation and notify subscribers"""
if created or created is None:
print(sender, instance, created, kwargs)
new_instance = cls.handle_object_created(sender, instance, **kwargs)

assert new_instance is None or isinstance(new_instance, cls._meta.model)

if new_instance:
instance = new_instance

# Notify all subscribers for the model
for subscriber in cls.subscribers.get(sender, []):
async_to_sync(subscriber)(instance)

@classmethod
def handle_object_created(cls, sender, instance, **kwargs):
"""Handle and modify any instance created"""
pass

@classmethod
def check_permissions(cls, root, info, *args, **kwargs) -> None:
return super().check_permissions(root, info, *args, **kwargs)

@classmethod
async def subscribe(cls, root, info, *args, **kwargs):
"""Subscribe to the model creation events asynchronously"""

cls.check_permissions(root, info, *args, **kwargs)

model = cls._meta.model
queue = asyncio.Queue()

# Ensure there's a list of subscribers for the model
if model not in cls.subscribers:
cls.subscribers[model] = []

# Add the queue's put method to the subscribers for this model
cls.subscribers[model].append(queue.put)

try:
while True:
# Wait for the next model instance to be created
instance = await queue.get()
data = {cls._meta.return_field_name: instance}
yield cls(**data)
finally:
# Clean up the subscriber when the subscription ends
cls.subscribers[model].remove(queue.put)
152 changes: 152 additions & 0 deletions graphene_django_cud/subscriptions/delete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import asyncio
from collections import OrderedDict
from typing import Optional

import graphene
from asgiref.sync import async_to_sync
from django.db.models.signals import post_save, post_delete
from graphene.types.objecttype import ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs
from graphene_django.registry import get_global_registry
from requests import delete

from graphene_django_cud.subscriptions.core import DjangoCudSubscriptionBase
from graphene_django_cud.util import to_snake_case

from graphene_django_cud.util.dict import get_any_of
import logging

logger = logging.getLogger(__name__)


class DjangoDeleteSubscriptionOptions(ObjectTypeOptions):
model = None
return_field_name = None
permissions = None
signal = None


class DjangoDeleteSubscription(DjangoCudSubscriptionBase):
# All active subscriptions are stored in this centralized dictionary.
# We need to do this to keep track of which subscriptions are listening to
# which signals.
subscribers = {}

@classmethod
def __init_subclass_with_meta__(
cls,
_meta=None,
model=None,
permissions=None,
return_field_name=None,
signal=post_delete,
**kwargs,
):
registry = get_global_registry()
model_type = registry.get_type_for_model(model)

if not _meta:
_meta = DjangoDeleteSubscriptionOptions(cls)

if not return_field_name:
return_field_name = to_snake_case(model.__name__)

output_fields = OrderedDict()
output_fields["id"] = graphene.String()

_meta.model = model
_meta.model_type = model_type
_meta.fields = yank_fields_from_attrs(output_fields, _as=graphene.Field)
_meta.output = cls
_meta.permissions = permissions

# Importantly, this needs to be set to either nothing or the identity.
# Internally in graphene it will be defaulted to the identity function.
_meta.resolver = None

# This is set to be the subscription resolver in the SubscriptionField class.
_meta.subscribe = cls.subscribe
_meta.return_field_name = return_field_name

# Connect to the model's post_save signal
signal.connect(cls._model_deleted_handler, sender=model)

super().__init_subclass_with_meta__(_meta=_meta, **kwargs)

@classmethod
def _model_deleted_handler(cls, sender, *args, **kwargs):
"""Handle model updating and notify subscribers"""

Model = cls._meta.model

instance: Optional[Model] = kwargs.get("instance", None) or next(filter(
lambda x: isinstance(x, Model), args
), None)

deleted_id = get_any_of(
kwargs,
[
"pk",
"raw_id",
"input_id",
"id"
]
) if not instance else get_any_of(
instance,
[
"pk",
"id",
]
)

print(kwargs, args, deleted_id)

if deleted_id is None:
logger.warning("Received a delete signal for a model without an instance or an id being passed to the "
"signal handler. Are you using a compatible signal? Read the documentation for "
"graphene-django-cud for more information.")
return

new_deleted_id = cls.handle_object_deleted(sender, deleted_id, **kwargs)

if new_deleted_id is not None:
deleted_id = new_deleted_id

# Notify all subscribers for the model
for subscriber in cls.subscribers.get(sender, []):
async_to_sync(subscriber)(deleted_id)

@classmethod
def handle_object_deleted(cls, sender, deleted_id, **kwargs):
"""Handle and modify any instance created"""
pass

@classmethod
def check_permissions(cls, root, info, *args, **kwargs) -> None:
return super().check_permissions(root, info, *args, **kwargs)

@classmethod
async def subscribe(cls, root, info, *args, **kwargs):
"""Subscribe to the model creation events asynchronously"""

cls.check_permissions(root, info, *args, **kwargs)

model = cls._meta.model
queue = asyncio.Queue()

# Ensure there's a list of subscribers for the model
if model not in cls.subscribers:
cls.subscribers[model] = []

# Add the queue's put method to the subscribers for this model
cls.subscribers[model].append(queue.put)

try:
while True:
# Wait for the next model instance to be deleted
_id = await queue.get()

yield cls(id=_id)
finally:
# Clean up the subscriber when the subscription ends
cls.subscribers[model].remove(queue.put)
Loading

0 comments on commit b6b98e7

Please sign in to comment.