Skip to content

Commit

Permalink
Merge pull request #80 from graphql-python/feat-generic-reference-field
Browse files Browse the repository at this point in the history
Feat generic reference field
  • Loading branch information
abawchen authored Apr 22, 2019
2 parents f6a60e2 + ac2ebbf commit d8a2f46
Show file tree
Hide file tree
Showing 14 changed files with 334 additions and 253 deletions.
8 changes: 4 additions & 4 deletions graphene_mongo/advanced_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ def _resolve_type_coordinates(self, info):
return self['coordinates']


class _CoordinatesField(graphene.ObjectType):
class _TypeField(graphene.ObjectType):

type = graphene.String()

def resolve_type(self, info):
return self['type']


class PointFieldType(_CoordinatesField):
class PointFieldType(_TypeField):

coordinates = graphene.List(
graphene.Float, resolver=_resolve_type_coordinates)


class PolygonFieldType(_CoordinatesField):
class PolygonFieldType(_TypeField):

coordinates = graphene.List(
graphene.List(
Expand All @@ -34,7 +34,7 @@ class PolygonFieldType(_CoordinatesField):
)


class MultiPolygonFieldType(_CoordinatesField):
class MultiPolygonFieldType(_TypeField):

coordinates = graphene.List(
graphene.List(
Expand Down
38 changes: 35 additions & 3 deletions graphene_mongo/converter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import mongoengine
import uuid
from graphene import (
ID,
Boolean,
Expand All @@ -9,14 +11,16 @@
List,
NonNull,
String,
Union,
is_node
)
from graphene.types.json import JSONString

import mongoengine
from mongoengine.base import get_document

from . import advanced_types
from .utils import import_single_dispatch, get_field_description
from .utils import (
import_single_dispatch, get_field_description,
)

singledispatch = import_single_dispatch()

Expand Down Expand Up @@ -110,6 +114,34 @@ def convert_field_to_list(field, registry=None):
return List(base_type, description=get_field_description(field, registry), required=field.required)


@convert_mongoengine_field.register(mongoengine.GenericReferenceField)
def convert_field_to_union(field, registry=None):

_types = []
for choice in field.choices:
_field = mongoengine.ReferenceField(get_document(choice))
_field = convert_mongoengine_field(_field, registry)
_type = _field.get_type()
if _type:
_types.append(_type.type)
else:
# TODO: Register type auto-matically here.
pass

if len(_types) == 0:
return None

# XXX: Use uuid to avoid duplicate name
name = '{}_{}_union_{}'.format(
field._owner_document.__name__,
field.db_field,
str(uuid.uuid1()).replace('-', '')
)
Meta = type('Meta', (object, ), {'types': tuple(_types)})
_union = type(name, (Union, ), {'Meta': Meta})
return Field(_union)


@convert_mongoengine_field.register(mongoengine.EmbeddedDocumentField)
@convert_mongoengine_field.register(mongoengine.ReferenceField)
def convert_field_to_dynamic(field, registry=None):
Expand Down
17 changes: 13 additions & 4 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from collections import OrderedDict
from functools import partial, reduce

import graphene
import mongoengine
from graphene import PageInfo
from graphene.relay import ConnectionField
from graphene.types.argument import to_arguments
from graphene.types.dynamic import Dynamic
Expand Down Expand Up @@ -65,6 +65,13 @@ def args(self, args):

def _field_args(self, items):
def is_filterable(k):
"""
Args:
k (str): field name.
Returns:
bool
"""

if not hasattr(self.model, k):
return False
if isinstance(getattr(self.model, k), property):
Expand All @@ -75,8 +82,10 @@ def is_filterable(k):
return False
if isinstance(converted, (ConnectionField, Dynamic)):
return False
if callable(getattr(converted, 'type', None)) and isinstance(converted.type(),
(PointFieldType, MultiPolygonFieldType)):
if callable(getattr(converted, 'type', None)) \
and isinstance(
converted.type(),
(PointFieldType, MultiPolygonFieldType, graphene.Union)):
return False
return True

Expand Down Expand Up @@ -158,7 +167,7 @@ def default_resolver(self, _root, info, **args):
list_length=list_length,
connection_type=self.type,
edge_type=self.type.Edge,
pageinfo_type=PageInfo,
pageinfo_type=graphene.PageInfo,
)
connection.iterable = objs
connection.list_length = list_length
Expand Down
162 changes: 85 additions & 77 deletions graphene_mongo/tests/models.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
import mongoengine
from datetime import datetime
from mongoengine import (
connect, Document, EmbeddedDocument
)
from mongoengine.fields import (
DateTimeField, EmailField, EmbeddedDocumentField,
FloatField, EmbeddedDocumentListField, ListField, LazyReferenceField,
MapField, MultiPolygonField, PointField, PolygonField,
ReferenceField, StringField,
)

connect('graphene-mongo-test', host='mongomock://localhost', alias='default')
mongoengine.connect('graphene-mongo-test', host='mongomock://localhost', alias='default')


class Publisher(Document):
class Publisher(mongoengine.Document):

meta = {'collection': 'test_publisher'}
name = StringField()
name = mongoengine.StringField()

@property
def legal_name(self):
Expand All @@ -25,129 +17,145 @@ def bad_field(self):
return None


class Editor(Document):
class Editor(mongoengine.Document):
"""
An Editor of a publication.
"""

meta = {'collection': 'test_editor'}
id = StringField(primary_key=True)
first_name = StringField(required=True, help_text="Editor's first name.", db_field='fname')
last_name = StringField(required=True, help_text="Editor's last name.")
metadata = MapField(field=StringField(), help_text="Arbitrary metadata.")
company = LazyReferenceField(Publisher)
id = mongoengine.StringField(primary_key=True)
first_name = mongoengine.StringField(required=True, help_text="Editor's first name.", db_field='fname')
last_name = mongoengine.StringField(required=True, help_text="Editor's last name.")
metadata = mongoengine.MapField(field=mongoengine.StringField(), help_text="Arbitrary metadata.")
company = mongoengine.LazyReferenceField(Publisher)


class Article(Document):
class Article(mongoengine.Document):

meta = {'collection': 'test_article'}
headline = StringField(required=True, help_text="The article headline.")
pub_date = DateTimeField(default=datetime.now,
verbose_name="publication date",
help_text="The date of first press.")
editor = ReferenceField(Editor)
reporter = ReferenceField('Reporter')
headline = mongoengine.StringField(required=True, help_text="The article headline.")
pub_date = mongoengine.DateTimeField(
default=datetime.now,
verbose_name="publication date",
help_text="The date of first press.")
editor = mongoengine.ReferenceField(Editor)
reporter = mongoengine.ReferenceField('Reporter')
# Will not convert this field cause no chioces
generic_reference = mongoengine.GenericReferenceField()


class EmbeddedArticle(EmbeddedDocument):
class EmbeddedArticle(mongoengine.EmbeddedDocument):

meta = {'collection': 'test_embedded_article'}
headline = StringField(required=True)
pub_date = DateTimeField(default=datetime.now)
editor = ReferenceField(Editor)
reporter = ReferenceField('Reporter')
headline = mongoengine.StringField(required=True)
pub_date = mongoengine.DateTimeField(default=datetime.now)
editor = mongoengine.ReferenceField(Editor)
reporter = mongoengine.ReferenceField('Reporter')


class Reporter(Document):
class Reporter(mongoengine.Document):

meta = {'collection': 'test_reporter'}
id = StringField(primary_key=True)
first_name = StringField(required=True)
last_name = StringField(required=True)
email = EmailField()
awards = ListField(StringField())
articles = ListField(ReferenceField(Article))
embedded_articles = ListField(EmbeddedDocumentField(EmbeddedArticle))
embedded_list_articles = EmbeddedDocumentListField(EmbeddedArticle)


class Player(Document):
id = mongoengine.StringField(primary_key=True)
first_name = mongoengine.StringField(required=True)
last_name = mongoengine.StringField(required=True)
email = mongoengine.EmailField()
awards = mongoengine.ListField(mongoengine.StringField())
articles = mongoengine.ListField(mongoengine.ReferenceField(Article))
embedded_articles = mongoengine.ListField(mongoengine.EmbeddedDocumentField(EmbeddedArticle))
embedded_list_articles = mongoengine.EmbeddedDocumentListField(EmbeddedArticle)
id = mongoengine.StringField(primary_key=True)
first_name = mongoengine.StringField(required=True)
last_name = mongoengine.StringField(required=True)
email = mongoengine.EmailField()
awards = mongoengine.ListField(mongoengine.StringField())
articles = mongoengine.ListField(mongoengine.ReferenceField(Article))
embedded_articles = mongoengine.ListField(mongoengine.EmbeddedDocumentField(EmbeddedArticle))
embedded_list_articles = mongoengine.EmbeddedDocumentListField(EmbeddedArticle)
generic_reference = mongoengine.GenericReferenceField(
choices=[Article, Editor, ]
)


class Player(mongoengine.Document):

meta = {'collection': 'test_player'}
first_name = StringField(required=True)
last_name = StringField(required=True)
opponent = ReferenceField('Player')
players = ListField(ReferenceField('Player'))
articles = ListField(ReferenceField('Article'))
embedded_list_articles = EmbeddedDocumentListField(EmbeddedArticle)
first_name = mongoengine.StringField(required=True)
last_name = mongoengine.StringField(required=True)
opponent = mongoengine.ReferenceField('Player')
players = mongoengine.ListField(mongoengine.ReferenceField('Player'))
articles = mongoengine.ListField(mongoengine.ReferenceField('Article'))
embedded_list_articles = mongoengine.EmbeddedDocumentListField(EmbeddedArticle)


class Parent(Document):
class Parent(mongoengine.Document):

meta = {
'collection': 'test_parent',
'allow_inheritance': True
}
bar = StringField()
loc = MultiPolygonField()
bar = mongoengine.StringField()
loc = mongoengine.MultiPolygonField()


class CellTower(Document):
class CellTower(mongoengine.Document):

meta = {
'collection': 'test_cell_tower',
}
code = StringField()
base = PolygonField()
coverage_area = MultiPolygonField()
code = mongoengine.StringField()
base = mongoengine.PolygonField()
coverage_area = mongoengine.MultiPolygonField()


class Child(Parent):

meta = {'collection': 'test_child'}
baz = StringField()
loc = PointField()
baz = mongoengine.StringField()
loc = mongoengine.PointField()


class ProfessorMetadata(EmbeddedDocument):
class ProfessorMetadata(mongoengine.EmbeddedDocument):

meta = {'collection': 'test_professor_metadata'}
id = StringField(primary_key=False)
first_name = StringField()
last_name = StringField()
departments = ListField(StringField())
id = mongoengine.StringField(primary_key=False)
first_name = mongoengine.StringField()
last_name = mongoengine.StringField()
departments = mongoengine.ListField(mongoengine.StringField())


class ProfessorVector(Document):
class ProfessorVector(mongoengine.Document):

meta = {'collection': 'test_professor_vector'}
vec = ListField(FloatField())
metadata = EmbeddedDocumentField(ProfessorMetadata)
vec = mongoengine.ListField(mongoengine.FloatField())
metadata = mongoengine.EmbeddedDocumentField(ProfessorMetadata)


class ParentWithRelationship(Document):
class ParentWithRelationship(mongoengine.Document):

meta = {'collection': 'test_parent_reference'}
before_child = ListField(ReferenceField("ChildRegisteredBefore"))
after_child = ListField(ReferenceField("ChildRegisteredAfter"))
name = StringField()
before_child = mongoengine.ListField(
mongoengine.ReferenceField('ChildRegisteredBefore'))
after_child = mongoengine.ListField(
mongoengine.ReferenceField('ChildRegisteredAfter'))
name = mongoengine.StringField()


class ChildRegisteredBefore(Document):
class ChildRegisteredBefore(mongoengine.Document):

meta = {'collection': 'test_child_before_reference'}
parent = ReferenceField(ParentWithRelationship)
name = StringField()
parent = mongoengine.ReferenceField(ParentWithRelationship)
name = mongoengine.StringField()


class ChildRegisteredAfter(Document):
class ChildRegisteredAfter(mongoengine.Document):

meta = {'collection': 'test_child_after_reference'}
parent = ReferenceField(ParentWithRelationship)
name = StringField()
parent = mongoengine.ReferenceField(ParentWithRelationship)
name = mongoengine.StringField()


class ErroneousModel(Document):
class ErroneousModel(mongoengine.Document):
meta = {'collection': 'test_colliding_objects_model'}

objects = ListField(StringField())
objects = mongoengine.ListField(mongoengine.StringField())
Loading

0 comments on commit d8a2f46

Please sign in to comment.