Skip to content

Commit

Permalink
refactor: moved types to types.py
Browse files Browse the repository at this point in the history
started working on replacing modification via a generator

created some new helper methods with polymorphic  views

Signed-off-by: Jason <git@jasoncameron.dev>
  • Loading branch information
JasonLovesDoggo committed Feb 18, 2024
1 parent 3f3ba4d commit 94186d4
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 53 deletions.
36 changes: 30 additions & 6 deletions core/api/utils/polymorphism.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

from functools import lru_cache
from json import JSONDecodeError
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Protocol, Set
from typing import Any, Callable, Dict, Iterable, List, Optional, Protocol, Set

from django.core.exceptions import BadRequest
from django.db.models import Model, Q
Expand All @@ -13,7 +12,7 @@

from core.api.v3.objects import *
from core.api.v3.objects.base import BaseProvider
from core.utils.types import APIObjOperations
from core.utils.types import APIObjOperations, APISerializerOperations, ProviderDetails

type IgnoredKey = str | Iterable[str]
type SerializerItems = Dict[str, BaseSerializer]
Expand Down Expand Up @@ -54,6 +53,14 @@ def split_dict(dictionary: SerializerItems) -> SplitDictResult:
splitter = split_dict_wrapper("_") # ignore key for serializers


def get_path_by_provider(provider: BaseProvider) -> str:
return [
provider_key
for provider_key, provider in providers.items()
if provider == provider
][0]


providers: Dict[str, BaseProvider] = (
{ # k = request type (param passed in url), v = provider class
"announcement": AnnouncementProvider,
Expand Down Expand Up @@ -94,22 +101,39 @@ def get_provider(provider_name: provider_keys) -> Callable:

def get_providers_by_operation(
operation: APIObjOperations, return_provider: Optional[bool] = False
) -> List[str]:
) -> List[str] | List[BaseProvider]:
"""
returns a list of provider path names that support the given operation.
Example:
>>> get_providers_by_operation("single")
["announcement", "blog-post", "exhibit", "event", "organization", "flatpage", "user", "tag", "term", "timetable", "comment", "like", "course"]
"""

operation = operation.lower()
return [
(prov if return_provider else key)
for key, prov in providers.items()
if getattr(prov, f"allow_{operation}", True) == True
]


def get_operations_by_provider(provider: BaseProvider) -> ProviderDetails:
"""
Returns a list of operations supported by the given provider.
"""
options = set()
for operation in provider.supported_operations():
data = APISerializerOperations(
operation=operation,
serializer=provider.raw_serializers.get(
operation, provider.raw_serializers.get("_")
),
)
options.add(data)

return options


class ObjectAPIView(generics.GenericAPIView):
def initial(self, *args, **kwargs):
super().initial(*args, **kwargs)
Expand Down Expand Up @@ -264,7 +288,7 @@ def get_serializer_class(self):
class Provider(Protocol):
allow_list: bool
allow_new: bool
kind: Literal["list", "new", "single", "retrieve"]
kind: APIObjOperations
listing_filters_ignore: List[str]

serializers: SplitDictResult
Expand Down
10 changes: 9 additions & 1 deletion core/api/v3/objects/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import Dict, Final, List
from typing import Dict, Final, List, Tuple

from django.db.models.base import ModelBase
from rest_framework.serializers import BaseSerializer
Expand Down Expand Up @@ -70,3 +70,11 @@ def __new__(cls, request):

def __init__(self, request):
self.request = request

@classmethod
def supported_operations(cls) -> Tuple[str]:
if not issubclass(cls, BaseProvider):
raise TypeError("This method can only be ran on subclasses of BaseProvider")
if "_" in cls.raw_serializers.keys():
return "list", "new", "single", "retrieve"
return tuple(cls.raw_serializers.keys())
99 changes: 55 additions & 44 deletions core/schema.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import dataclasses
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type

from drf_spectacular.drainage import set_override
from drf_spectacular.generators import SchemaGenerator
from drf_spectacular.utils import F, OpenApiExample
from memoization import cached
from rest_framework.serializers import BaseSerializer, Serializer
from rest_framework.serializers import Serializer

from core.api.utils.polymorphism import (
get_operations_by_provider,
get_path_by_provider,
get_provider,
get_providers_by_operation,
providers,
)
from core.api.v3.objects import BaseProvider
from core.utils.types import APIObjOperations
from core.utils.types import (
ObjectModificationData,
ProviderDetails,
SingleOperationData,
)


def metro_extend_schema_serializer( # modified version of drf_spectacular.utils.extend_schema_serializer
Expand Down Expand Up @@ -80,6 +86,11 @@ def wrapped_view(*args, **kwargs):

@cached
def run_fixers(result, generator, request, public):
"""
Run fixers on the schema to ensure that the API docs are properly formatted.
Note: this should ALWAYS return the same result for the same input as it's cached to improve performance.
"""
fixers = [Api3ObjSpliter]
if fixers is None:
raise ValueError("No fixers found, API3 obj docs will be broken.")
Expand All @@ -90,45 +101,6 @@ def run_fixers(result, generator, request, public):
return result


@dataclass
class SingleOperationData:
providers: List[BaseProvider]
operation: APIObjOperations
data: dict


@dataclass
class APISerializerOperations:
operation: APIObjOperations
serializer: BaseSerializer
# tags?


@dataclass
class ProviderDetails:
provider: BaseProvider
operations_supported: List[APISerializerOperations]
data: dict


@dataclass
class ObjectModificationData:
retrieve: Optional[SingleOperationData] = None
single: Optional[SingleOperationData] = None
list: Optional[SingleOperationData] = None
new: Optional[SingleOperationData] = None

def __iter__(self):
return iter(
[
("retrieve", self.retrieve),
("single", self.single),
("list", self.list),
("new", self.new),
]
)


class Api3ObjSpliter:
"""
Split the API3 schema into the different paths based on the object type and provider
Expand All @@ -139,7 +111,7 @@ class Api3ObjSpliter:
}

def __init__(self, schema):
self.operation_data: ObjectModificationData = ObjectModificationData()
self.operation_data = ObjectModificationData()
self.keys_to_delete: Tuple = ()
self.schema = schema
self._provider_details: Dict[str, ProviderDetails] = dict()
Expand All @@ -150,7 +122,8 @@ def run(self):
self.set_obj_paths(paths)

for _, provider in providers.items():
print(self._get_data_from_provider(provider))
...
# print(self._get_data_from_provider(provider))
for operation in dataclasses.fields(self.operation_data):
self.create_obj_views(operation)

Expand Down Expand Up @@ -230,3 +203,41 @@ def get_providers_from_name(enum: List[str]) -> List[BaseProvider]:
return [get_provider(key) for key in enum]

def create_obj_views(self, operation: SingleOperationData): ...


class MetroSchemaGenerator(SchemaGenerator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _get_paths_and_endpoints(self):
"""
Generate (path, method, view) given (path, method, callback) for paths.
"""
obj3 = set()
view_endpoints = super()._get_paths_and_endpoints()
for path, subpath, method, view in view_endpoints:
if path.startswith("/api/v3/obj/"):
name = view.__class__.__name__.lstrip("Object").casefold()
print(f"Found path: {name}")
for provider in get_providers_by_operation(name, return_provider=True):
provider: BaseProvider
data = Api3ObjSpliter._get_data_from_provider(provider) # noqa
obj3.add(
ProviderDetails(
provider=provider,
operations_supported=get_operations_by_provider(provider),
data=data,
url=path.replace(
"{type}",
get_path_by_provider(provider),
),
)
)

print(f"obj3: {obj3}")
formatted_obj3 = list()
for provider in obj3:
...

view_endpoints.extend(formatted_obj3)
return view_endpoints
52 changes: 51 additions & 1 deletion core/utils/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,58 @@
from typing import TYPE_CHECKING, Final, Literal, NamedTuple
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Final, List, Literal, NamedTuple, Optional

from rest_framework.serializers import BaseSerializer

if TYPE_CHECKING:
from core.api.v3.objects import BaseProvider

type APIObjOperations = Final[Literal["single", "new", "list", "retrieve"]]

type PathData = NamedTuple[str, "BaseProvider", dict]


@dataclass
class APISerializerOperations:
operation: APIObjOperations
serializer: BaseSerializer

def __hash__(self):
return hash(self.operation.__class__.__name__)

# tags?


@dataclass
class SingleOperationData:
providers: List["BaseProvider"]
operation: APIObjOperations
data: dict


@dataclass
class ProviderDetails:
provider: "BaseProvider"
operations_supported: List[APISerializerOperations]
url: Optional[str] = None
data: Optional[Dict] = None

def __hash__(self):
return hash(self.provider.__class__.__name__)


@dataclass
class ObjectModificationData:
retrieve: Optional[SingleOperationData] = None
single: Optional[SingleOperationData] = None
list: Optional[SingleOperationData] = None
new: Optional[SingleOperationData] = None

def __iter__(self):
return iter(
[
("retrieve", self.retrieve),
("single", self.single),
("list", self.list),
("new", self.new),
]
)
3 changes: 2 additions & 1 deletion metropolis/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,9 @@
"OAUTH2_AUTHORIZATION_URL": "/authorize",
"OAUTH2_TOKEN_URL": "/api/auth/token",
"OAUTH2_REFRESH_URL": "/api/auth/token/refresh",
"DEFAULT_GENERATOR_CLASS": "core.schema.MetroSchemaGenerator",
"OAUTH2_SCOPES": OAUTH2_PROVIDER["SCOPES"].keys(),
"POSTPROCESSING_HOOKS": ["core.schema.run_fixers"],
# "POSTPROCESSING_HOOKS": ["core.schema.run_fixers"],
#'SERVE_INCLUDE_SCHEMA': False,
# OTHER SETTINGS
}
Expand Down

0 comments on commit 94186d4

Please sign in to comment.