diff --git a/oidc/provider.py b/oidc/provider.py index e270113..e845d95 100644 --- a/oidc/provider.py +++ b/oidc/provider.py @@ -1,8 +1,17 @@ +from __future__ import annotations + +from collections.abc import Callable + +from django.http import HttpRequest + import time import requests from sentry.auth.provider import MigratingIdentityId from sentry.auth.providers.oauth2 import OAuth2Callback, OAuth2Login, OAuth2Provider +from sentry.auth.services.auth.model import RpcAuthProvider +from sentry.organizations.services.organization.model import RpcOrganization +from sentry.plugins.base.response import DeferredResponse from .constants import ( AUTHORIZATION_ENDPOINT, @@ -14,7 +23,7 @@ TOKEN_ENDPOINT, USERINFO_ENDPOINT, ) -from .views import FetchUser, OIDCConfigureView +from .views import FetchUser, oidc_configure_view class OIDCLogin(OAuth2Login): @@ -63,8 +72,10 @@ def get_client_id(self): def get_client_secret(self): return CLIENT_SECRET - def get_configure_view(self): - return OIDCConfigureView.as_view() + def get_configure_view( + self, + ) -> Callable[[HttpRequest, RpcOrganization, RpcAuthProvider], DeferredResponse]: + return oidc_configure_view def get_auth_pipeline(self): return [ diff --git a/oidc/views.py b/oidc/views.py index 2c9293a..8d3c9d0 100644 --- a/oidc/views.py +++ b/oidc/views.py @@ -1,7 +1,15 @@ +from __future__ import annotations + import logging -from sentry.auth.view import AuthView, ConfigureView +from django.http import HttpRequest +from rest_framework.response import Response + +from sentry.auth.services.auth.model import RpcAuthProvider +from sentry.auth.view import AuthView from sentry.utils import json +from sentry.organizations.services.organization.model import RpcOrganization +from sentry.plugins.base.response import DeferredResponse from sentry.utils.signing import urlsafe_b64decode from .constants import ERR_INVALID_RESPONSE, ISSUER @@ -15,7 +23,7 @@ def __init__(self, domains, version, *args, **kwargs): self.version = version super().__init__(*args, **kwargs) - def dispatch(self, request, helper): + def dispatch(self, request: HttpRequest, helper) -> Response: # type: ignore data = helper.fetch_state("data") try: @@ -52,17 +60,18 @@ def dispatch(self, request, helper): return helper.next_step() -class OIDCConfigureView(ConfigureView): - def dispatch(self, request, organization, auth_provider): - config = auth_provider.config - if config.get("domain"): - domains = [config["domain"]] - else: - domains = config.get("domains") - return self.render( - "oidc/configure.html", - {"provider_name": ISSUER or "", "domains": domains or []}, - ) +def oidc_configure_view( + request: HttpRequest, organization: RpcOrganization, auth_provider: RpcAuthProvider +) -> DeferredResponse: + config = auth_provider.config + if config.get("domain"): + domains: list[str] | None + domains = [config["domain"]] + else: + domains = config.get("domains") + return DeferredResponse( + "oidc/configure.html", {"provider_name": ISSUER or "", "domains": domains or []} + ) def extract_domain(email):