diff --git a/core/api/urls.py b/core/api/urls.py index 171650d9..0086cee4 100644 --- a/core/api/urls.py +++ b/core/api/urls.py @@ -3,7 +3,7 @@ from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView from .views import * -from .views.objects import * +from .views.objects.main import ObjectList, ObjectNew, ObjectRetrieve, ObjectSingle router = SimpleRouter() diff --git a/core/api/views/objects/__init__.py b/core/api/views/objects/__init__.py index 15b6a64b..896e50c1 100644 --- a/core/api/views/objects/__init__.py +++ b/core/api/views/objects/__init__.py @@ -1 +1,11 @@ -from .main import * +from .announcement import * +from .blog_post import * +from .post_interactions import * +from .event import * +from .courses import * +from .flatpage import * +from .organization import * +from .tag import * +from .term import * +from .timetable import * +from .user import * diff --git a/core/api/views/objects/main.py b/core/api/views/objects/main.py index f580012f..9b4f9d04 100644 --- a/core/api/views/objects/main.py +++ b/core/api/views/objects/main.py @@ -1,53 +1,33 @@ from __future__ import annotations -import os from json import JSONDecodeError -from typing import Dict, Callable, List, Tuple, Set, Final +from typing import Dict, Callable, List, Tuple, Set, Literal -from django.conf import settings from django.core.exceptions import ObjectDoesNotExist, BadRequest from django.db.models import Model, Q, QuerySet from django.http import QueryDict from django.shortcuts import get_object_or_404 -from django.urls import reverse, NoReverseMatch -from rest_framework import generics, permissions -from rest_framework.response import Response +from django.urls import NoReverseMatch +from drf_spectacular.utils import extend_schema, OpenApiParameter +from rest_framework import generics -from .base import BaseProvider +from .__init__ import * +from .event import EventProvider +from .exhibit import ExhibitProvider from ...utils import GenericAPIViewWithDebugInfo, GenericAPIViewWithLastModified __all__ = ["ObjectList", "ObjectSingle", "ObjectRetrieve", "ObjectNew"] -def gen_get_provider(mapping: Dict[str, str]): - for file in os.listdir(os.path.dirname(__file__)): - if file.endswith(".py") and file not in ["__init__.py", "base.py", "main.py"]: - __import__(f"core.api.views.objects.{file[:-3]}", fromlist=["*"]) - - provClasses = BaseProvider.__subclasses__() - try: - ProvReqNames = [ - mapping[cls.__name__.rsplit("Provider")[0].lower()] for cls in provClasses - ] - except KeyError as e: - raise NotImplementedError( - f"Provider class {e} is missing a request name. Please add it to the mapping." - ) from e - provClassMapping = {key: value for key, value in zip(ProvReqNames, provClasses)} - - def get_provider(provider_name: str): - """ - Gets a provider by type name. - """ - if provider_name not in ProvReqNames: - raise BadRequest( - "Object type not found. Valid types are: " - + ", ".join(ProvReqNames) - + "." - ) - return provClassMapping[provider_name] - - return get_provider +def get_provider(provider_name: str): + """ + Gets a provider by type name. + """ + if provider_name not in providers: + raise BadRequest( + "Object type not found. Valid types are: " + ", ".join(providers) + "." + ) + return providers[provider_name] get_provider = gen_get_provider( # k = Provider class name e.g. comment in CommentProvider, v = request name @@ -67,6 +47,23 @@ def get_provider(provider_name: str): "course": "course", } ) +providers = { # k = request type (param passed in url), v = provider class + "announcement": AnnouncementProvider, + "blog-post": BlogPostProvider, + "exhibit": ExhibitProvider, + "event": EventProvider, + "organization": OrganizationProvider, + "flatpage": FlatPageProvider, + "user": UserProvider, + "tag": TagProvider, + "term": TermProvider, + "timetable": TimetableProvider, + "comment": CommentProvider, + "like": LikeProvider, + "course": CourseProvider, +} + + class ObjectAPIView(generics.GenericAPIView):