diff --git a/core/api/utils/polymorphism.py b/core/api/utils/polymorphism.py index 56236648..da0aa2db 100644 --- a/core/api/utils/polymorphism.py +++ b/core/api/utils/polymorphism.py @@ -1,8 +1,7 @@ from __future__ import annotations -from functools import lru_cache from json import JSONDecodeError -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Protocol, Set +from typing import Any, Callable, Dict, Iterable, List, Optional, Protocol, Set from django.core.exceptions import BadRequest from django.db.models import Model, Q @@ -13,7 +12,7 @@ from core.api.v3.objects import * from core.api.v3.objects.base import BaseProvider -from core.utils.types import APIObjOperations +from core.utils.types import APIObjOperations, APISerializerOperations, ProviderDetails type IgnoredKey = str | Iterable[str] type SerializerItems = Dict[str, BaseSerializer] @@ -54,6 +53,14 @@ def split_dict(dictionary: SerializerItems) -> SplitDictResult: splitter = split_dict_wrapper("_") # ignore key for serializers +def get_path_by_provider(provider: BaseProvider) -> str: + return [ + provider_key + for provider_key, provider in providers.items() + if provider == provider + ][0] + + providers: Dict[str, BaseProvider] = ( { # k = request type (param passed in url), v = provider class "announcement": AnnouncementProvider, @@ -94,7 +101,7 @@ def get_provider(provider_name: provider_keys) -> Callable: def get_providers_by_operation( operation: APIObjOperations, return_provider: Optional[bool] = False -) -> List[str]: +) -> List[str] | List[BaseProvider]: """ returns a list of provider path names that support the given operation. @@ -102,7 +109,7 @@ def get_providers_by_operation( >>> get_providers_by_operation("single") ["announcement", "blog-post", "exhibit", "event", "organization", "flatpage", "user", "tag", "term", "timetable", "comment", "like", "course"] """ - + operation = operation.lower() return [ (prov if return_provider else key) for key, prov in providers.items() @@ -110,6 +117,23 @@ def get_providers_by_operation( ] +def get_operations_by_provider(provider: BaseProvider) -> ProviderDetails: + """ + Returns a list of operations supported by the given provider. + """ + options = set() + for operation in provider.supported_operations(): + data = APISerializerOperations( + operation=operation, + serializer=provider.raw_serializers.get( + operation, provider.raw_serializers.get("_") + ), + ) + options.add(data) + + return options + + class ObjectAPIView(generics.GenericAPIView): def initial(self, *args, **kwargs): super().initial(*args, **kwargs) @@ -264,7 +288,7 @@ def get_serializer_class(self): class Provider(Protocol): allow_list: bool allow_new: bool - kind: Literal["list", "new", "single", "retrieve"] + kind: APIObjOperations listing_filters_ignore: List[str] serializers: SplitDictResult diff --git a/core/api/v3/objects/base.py b/core/api/v3/objects/base.py index 5360faa2..2bae1ffa 100644 --- a/core/api/v3/objects/base.py +++ b/core/api/v3/objects/base.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Dict, Final, List +from typing import Dict, Final, List, Tuple from django.db.models.base import ModelBase from rest_framework.serializers import BaseSerializer @@ -70,3 +70,11 @@ def __new__(cls, request): def __init__(self, request): self.request = request + + @classmethod + def supported_operations(cls) -> Tuple[str]: + if not issubclass(cls, BaseProvider): + raise TypeError("This method can only be ran on subclasses of BaseProvider") + if "_" in cls.raw_serializers.keys(): + return "list", "new", "single", "retrieve" + return tuple(cls.raw_serializers.keys()) diff --git a/core/schema.py b/core/schema.py index 2a944ad6..40ff1b1d 100644 --- a/core/schema.py +++ b/core/schema.py @@ -1,20 +1,26 @@ import dataclasses -from dataclasses import dataclass from functools import wraps from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type from drf_spectacular.drainage import set_override +from drf_spectacular.generators import SchemaGenerator from drf_spectacular.utils import F, OpenApiExample from memoization import cached -from rest_framework.serializers import BaseSerializer, Serializer +from rest_framework.serializers import Serializer from core.api.utils.polymorphism import ( + get_operations_by_provider, + get_path_by_provider, get_provider, get_providers_by_operation, providers, ) from core.api.v3.objects import BaseProvider -from core.utils.types import APIObjOperations +from core.utils.types import ( + ObjectModificationData, + ProviderDetails, + SingleOperationData, +) def metro_extend_schema_serializer( # modified version of drf_spectacular.utils.extend_schema_serializer @@ -80,6 +86,11 @@ def wrapped_view(*args, **kwargs): @cached def run_fixers(result, generator, request, public): + """ + Run fixers on the schema to ensure that the API docs are properly formatted. + + Note: this should ALWAYS return the same result for the same input as it's cached to improve performance. + """ fixers = [Api3ObjSpliter] if fixers is None: raise ValueError("No fixers found, API3 obj docs will be broken.") @@ -90,45 +101,6 @@ def run_fixers(result, generator, request, public): return result -@dataclass -class SingleOperationData: - providers: List[BaseProvider] - operation: APIObjOperations - data: dict - - -@dataclass -class APISerializerOperations: - operation: APIObjOperations - serializer: BaseSerializer - # tags? - - -@dataclass -class ProviderDetails: - provider: BaseProvider - operations_supported: List[APISerializerOperations] - data: dict - - -@dataclass -class ObjectModificationData: - retrieve: Optional[SingleOperationData] = None - single: Optional[SingleOperationData] = None - list: Optional[SingleOperationData] = None - new: Optional[SingleOperationData] = None - - def __iter__(self): - return iter( - [ - ("retrieve", self.retrieve), - ("single", self.single), - ("list", self.list), - ("new", self.new), - ] - ) - - class Api3ObjSpliter: """ Split the API3 schema into the different paths based on the object type and provider @@ -139,7 +111,7 @@ class Api3ObjSpliter: } def __init__(self, schema): - self.operation_data: ObjectModificationData = ObjectModificationData() + self.operation_data = ObjectModificationData() self.keys_to_delete: Tuple = () self.schema = schema self._provider_details: Dict[str, ProviderDetails] = dict() @@ -150,7 +122,8 @@ def run(self): self.set_obj_paths(paths) for _, provider in providers.items(): - print(self._get_data_from_provider(provider)) + ... + # print(self._get_data_from_provider(provider)) for operation in dataclasses.fields(self.operation_data): self.create_obj_views(operation) @@ -230,3 +203,41 @@ def get_providers_from_name(enum: List[str]) -> List[BaseProvider]: return [get_provider(key) for key in enum] def create_obj_views(self, operation: SingleOperationData): ... + + +class MetroSchemaGenerator(SchemaGenerator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _get_paths_and_endpoints(self): + """ + Generate (path, method, view) given (path, method, callback) for paths. + """ + obj3 = set() + view_endpoints = super()._get_paths_and_endpoints() + for path, subpath, method, view in view_endpoints: + if path.startswith("/api/v3/obj/"): + name = view.__class__.__name__.lstrip("Object").casefold() + print(f"Found path: {name}") + for provider in get_providers_by_operation(name, return_provider=True): + provider: BaseProvider + data = Api3ObjSpliter._get_data_from_provider(provider) # noqa + obj3.add( + ProviderDetails( + provider=provider, + operations_supported=get_operations_by_provider(provider), + data=data, + url=path.replace( + "{type}", + get_path_by_provider(provider), + ), + ) + ) + + print(f"obj3: {obj3}") + formatted_obj3 = list() + for provider in obj3: + ... + + view_endpoints.extend(formatted_obj3) + return view_endpoints diff --git a/core/utils/types.py b/core/utils/types.py index 7716eaf2..96768241 100644 --- a/core/utils/types.py +++ b/core/utils/types.py @@ -1,4 +1,7 @@ -from typing import TYPE_CHECKING, Final, Literal, NamedTuple +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, Final, List, Literal, NamedTuple, Optional + +from rest_framework.serializers import BaseSerializer if TYPE_CHECKING: from core.api.v3.objects import BaseProvider @@ -6,3 +9,50 @@ type APIObjOperations = Final[Literal["single", "new", "list", "retrieve"]] type PathData = NamedTuple[str, "BaseProvider", dict] + + +@dataclass +class APISerializerOperations: + operation: APIObjOperations + serializer: BaseSerializer + + def __hash__(self): + return hash(self.operation.__class__.__name__) + + # tags? + + +@dataclass +class SingleOperationData: + providers: List["BaseProvider"] + operation: APIObjOperations + data: dict + + +@dataclass +class ProviderDetails: + provider: "BaseProvider" + operations_supported: List[APISerializerOperations] + url: Optional[str] = None + data: Optional[Dict] = None + + def __hash__(self): + return hash(self.provider.__class__.__name__) + + +@dataclass +class ObjectModificationData: + retrieve: Optional[SingleOperationData] = None + single: Optional[SingleOperationData] = None + list: Optional[SingleOperationData] = None + new: Optional[SingleOperationData] = None + + def __iter__(self): + return iter( + [ + ("retrieve", self.retrieve), + ("single", self.single), + ("list", self.list), + ("new", self.new), + ] + ) diff --git a/metropolis/settings.py b/metropolis/settings.py index 79000ed0..2be549f5 100644 --- a/metropolis/settings.py +++ b/metropolis/settings.py @@ -306,8 +306,9 @@ "OAUTH2_AUTHORIZATION_URL": "/authorize", "OAUTH2_TOKEN_URL": "/api/auth/token", "OAUTH2_REFRESH_URL": "/api/auth/token/refresh", + "DEFAULT_GENERATOR_CLASS": "core.schema.MetroSchemaGenerator", "OAUTH2_SCOPES": OAUTH2_PROVIDER["SCOPES"].keys(), - "POSTPROCESSING_HOOKS": ["core.schema.run_fixers"], + # "POSTPROCESSING_HOOKS": ["core.schema.run_fixers"], #'SERVE_INCLUDE_SCHEMA': False, # OTHER SETTINGS }