Skip to content

Commit

Permalink
simplified get_function methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Tansito committed Dec 20, 2024
1 parent d50a4d2 commit 1fb6ba9
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 106 deletions.
71 changes: 17 additions & 54 deletions gateway/api/repositories/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class FunctionRepository:
The main objective of this class is to manage the access to the model
"""

# This repository should be in the use case implementatio
# This repository should be in the use case implementation
# but this class is not ready yet so it will live here
# in the meantime
user_repository = UserRepository()
Expand Down Expand Up @@ -109,7 +109,7 @@ def get_provider_functions_with_run_permissions(self, author) -> List[Function]:

return result_queryset

def get_user_function_by_title(self, author, title: str) -> Function | None:
def get_user_function(self, author, title: str) -> Function | None:
"""
Returns the user function associated to a title:
Expand Down Expand Up @@ -137,8 +137,8 @@ def get_user_function_by_title(self, author, title: str) -> Function | None:

return result_queryset

def get_provider_function_by_title_with_view_permissions(
self, author, title: str, provider_name: str
def get_provider_function_by_permission(
self, author, permission_name: str, title: str, provider_name: str
) -> Function | None:
"""
Returns the provider function associated to:
Expand All @@ -160,7 +160,7 @@ def get_provider_function_by_title_with_view_permissions(
# have it implemented yet we will do the check by now in the
# repository call
view_groups = self.user_repository.get_groups_by_permissions(
user=author, permission_name=VIEW_PROGRAM_PERMISSION
user=author, permission_name=permission_name
)
author_groups_with_view_permissions_criteria = Q(instances__in=view_groups)
author_criteria = Q(author=author)
Expand All @@ -181,52 +181,12 @@ def get_provider_function_by_title_with_view_permissions(

return result_queryset

def get_provider_function_by_title_with_run_permissions(
self, author, title: str, provider_name: str
) -> Function | None:
"""
Returns the provider function associated to:
- A Function title
- A Provider
- Author must have run permission to execute it or be the author
Args:
author: Django author from who retrieve the function
title: Title that the function must have to find it
provider: Provider associated to the function
Returns:
Program | None: provider function with the specific
title and provider
"""

# This access should be checked in the use-case but how we don't
# have it implemented yet we will do the check by now in the
# repository call
run_groups = self.user_repository.get_groups_by_permissions(
user=author, permission_name=RUN_PROGRAM_PERMISSION
)
author_groups_with_run_permissions_criteria = Q(instances__in=run_groups)
author_criteria = Q(author=author)
title_criteria = Q(title=title, provider__name=provider_name)

result_queryset = Function.objects.filter(
(author_criteria | author_groups_with_run_permissions_criteria)
& title_criteria
).first()

if result_queryset is None:
logger.warning(
"Function [%s/%s] was not found or author [%s] doesn't have access to it",
provider_name,
title,
author.id,
)

return result_queryset

def get_function_by_title_with_run_permissions(
self, user, function_title: str, provider_name: str | None
def get_function_by_permission(
self,
user,
permission_name: str,
function_title: str,
provider_name: str | None,
) -> None:
"""
This method returns the specified function if the user is
Expand All @@ -242,8 +202,11 @@ def get_function_by_title_with_run_permissions(
"""

if provider_name:
return self.get_provider_function_by_title_with_run_permissions(
author=user, title=function_title, provider_name=provider_name
return self.get_provider_function_by_permission(
author=user,
permission_name=permission_name,
title=function_title,
provider_name=provider_name,
)

return self.get_user_function_by_title(author=user, title=function_title)
return self.get_user_function(author=user, title=function_title)
89 changes: 41 additions & 48 deletions gateway/api/views/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from rest_framework.response import Response

from api.access_policies.providers import ProviderAccessPolicy
from api.models import RUN_PROGRAM_PERMISSION
from api.repositories.functions import FunctionRepository
from api.repositories.providers import ProviderRepository
from api.services.file_storage import FileStorage, WorkingDir
Expand Down Expand Up @@ -74,12 +75,11 @@ def list(self, request):
status=status.HTTP_400_BAD_REQUEST,
)

function = (
self.function_repository.get_function_by_title_with_run_permissions(
user=request.user,
function_title=function_title,
provider_name=provider_name,
)
function = self.function_repository.get_function_by_permission(
user=request.user,
permission_name=RUN_PROGRAM_PERMISSION,
function_title=function_title,
provider_name=provider_name,
)
if not function:
if provider_name:
Expand Down Expand Up @@ -137,12 +137,11 @@ def provider_list(self, request):
status=status.HTTP_404_NOT_FOUND,
)

function = (
self.function_repository.get_function_by_title_with_run_permissions(
user=request.user,
function_title=function_title,
provider_name=provider_name,
)
function = self.function_repository.get_function_by_permission(
user=request.user,
permission_name=RUN_PROGRAM_PERMISSION,
function_title=function_title,
provider_name=provider_name,
)
if not function:
return Response(
Expand Down Expand Up @@ -186,12 +185,11 @@ def download(self, request):
status=status.HTTP_400_BAD_REQUEST,
)

function = (
self.function_repository.get_function_by_title_with_run_permissions(
user=request.user,
function_title=function_title,
provider_name=provider_name,
)
function = self.function_repository.get_function_by_permission(
user=request.user,
permission_name=RUN_PROGRAM_PERMISSION,
function_title=function_title,
provider_name=provider_name,
)
if not function:
if provider_name:
Expand Down Expand Up @@ -265,12 +263,11 @@ def provider_download(self, request):
status=status.HTTP_404_NOT_FOUND,
)

function = (
self.function_repository.get_function_by_title_with_run_permissions(
user=request.user,
function_title=function_title,
provider_name=provider_name,
)
function = self.function_repository.get_function_by_permission(
user=request.user,
permission_name=RUN_PROGRAM_PERMISSION,
function_title=function_title,
provider_name=provider_name,
)
if not function:
return Response(
Expand Down Expand Up @@ -319,12 +316,11 @@ def delete(self, request):
status=status.HTTP_400_BAD_REQUEST,
)

function = (
self.function_repository.get_function_by_title_with_run_permissions(
user=request.user,
function_title=function_title,
provider_name=provider_name,
)
function = self.function_repository.get_function_by_permission(
user=request.user,
permission_name=RUN_PROGRAM_PERMISSION,
function_title=function_title,
provider_name=provider_name,
)

if not function:
Expand Down Expand Up @@ -386,12 +382,11 @@ def provider_delete(self, request):
status=status.HTTP_404_NOT_FOUND,
)

function = (
self.function_repository.get_function_by_title_with_run_permissions(
user=request.user,
function_title=function_title,
provider_name=provider_name,
)
function = self.function_repository.get_function_by_permission(
user=request.user,
permission_name=RUN_PROGRAM_PERMISSION,
function_title=function_title,
provider_name=provider_name,
)

if not function:
Expand Down Expand Up @@ -441,12 +436,11 @@ def upload(self, request):
status=status.HTTP_400_BAD_REQUEST,
)

function = (
self.function_repository.get_function_by_title_with_run_permissions(
user=request.user,
function_title=function_title,
provider_name=provider_name,
)
function = self.function_repository.get_function_by_permission(
user=request.user,
permission_name=RUN_PROGRAM_PERMISSION,
function_title=function_title,
provider_name=provider_name,
)
if not function:
if provider_name:
Expand Down Expand Up @@ -506,12 +500,11 @@ def provider_upload(self, request):
status=status.HTTP_404_NOT_FOUND,
)

function = (
self.function_repository.get_function_by_title_with_run_permissions(
user=request.user,
function_title=function_title,
provider_name=provider_name,
)
function = self.function_repository.get_function_by_permission(
user=request.user,
permission_name=RUN_PROGRAM_PERMISSION,
function_title=function_title,
provider_name=provider_name,
)
if not function:
return Response(
Expand Down
11 changes: 7 additions & 4 deletions gateway/api/views/programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
RunProgramSerializer,
UploadProgramSerializer,
)
from api.models import RUN_PROGRAM_PERMISSION, Program, Job
from api.models import RUN_PROGRAM_PERMISSION, VIEW_PROGRAM_PERMISSION, Program, Job
from api.views.enums.type_filter import TypeFilter

# pylint: disable=duplicate-code
Expand Down Expand Up @@ -309,11 +309,14 @@ def get_by_title(self, request, title):
)

if provider_name:
function = self.program_repository.get_provider_function_by_title_with_view_permissions(
author=author, title=function_title, provider_name=provider_name
function = self.program_repository.get_provider_function_by_permission(
author=author,
permission_name=VIEW_PROGRAM_PERMISSION,
title=function_title,
provider_name=provider_name,
)
else:
function = self.program_repository.get_user_function_by_title(
function = self.program_repository.get_user_function(
author=author, title=function_title
)

Expand Down

0 comments on commit 1fb6ba9

Please sign in to comment.