From b50547e9e9ea109405d230a42eea54bd2540e492 Mon Sep 17 00:00:00 2001 From: dennemark Date: Wed, 23 Oct 2024 10:43:11 +0200 Subject: [PATCH] fix: :bug: fix aggregate and groupBy queries --- src/filterQueryResults.ts | 16 +++++++--- src/helpers.ts | 6 ++-- src/index.ts | 2 +- test/extension.test.ts | 67 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 7 deletions(-) diff --git a/src/filterQueryResults.ts b/src/filterQueryResults.ts index 86e8681..7894f4f 100644 --- a/src/filterQueryResults.ts +++ b/src/filterQueryResults.ts @@ -2,10 +2,10 @@ import { AbilityTuple, PureAbility } from "@casl/ability"; import { PrismaQuery } from "@casl/prisma"; import { Prisma } from "@prisma/client"; import { CreationTree } from "./convertCreationTreeToSelect"; -import { getPermittedFields, getSubject, isSubset, PrismaExtensionCaslOptions, relationFieldsByModel } from "./helpers"; +import { caslOperationDict, getPermittedFields, getSubject, isSubset, PrismaCaslOperation, PrismaExtensionCaslOptions, relationFieldsByModel } from "./helpers"; import { storePermissions } from "./storePermissions"; -export function filterQueryResults(result: any, mask: any, creationTree: CreationTree | undefined, abilities: PureAbility, model: string, opts?: PrismaExtensionCaslOptions) { +export function filterQueryResults(result: any, mask: any, creationTree: CreationTree | undefined, abilities: PureAbility, model: string, operation: PrismaCaslOperation, opts?: PrismaExtensionCaslOptions) { if (typeof result === 'number') { return result } @@ -13,6 +13,7 @@ export function filterQueryResults(result: any, mask: any, creationTree: Creatio if (!prismaModel) { throw new Error(`Model ${model} does not exist on Prisma Client`) } + const operationFields = caslOperationDict[operation].operationFields const filterPermittedFields = (entry: any) => { if (!entry) { return null } @@ -52,11 +53,18 @@ export function filterQueryResults(result: any, mask: any, creationTree: Creatio const permittedFields = getPermittedFields(abilities, 'read', model, entry) let hasKeys = false - Object.keys(entry).filter((field) => field !== opts?.permissionField).forEach((field) => { + Object.keys(entry).filter((field) => { + if (operationFields?.includes(field)) { + hasKeys = true + return false + } else { + return field !== opts?.permissionField + } + }).forEach((field) => { const relationField = relationFieldsByModel[model][field] if (relationField) { const nestedCreationTree = creationTree && field in creationTree.children ? creationTree.children[field] : undefined - const res = filterQueryResults(entry[field], mask?.[field], nestedCreationTree, abilities, relationField.type) + const res = filterQueryResults(entry[field], mask?.[field], nestedCreationTree, abilities, relationField.type, operation) entry[field] = Array.isArray(res) ? res.length > 0 ? res : null : res } if ((!permittedFields.includes(field) && !relationField) || mask?.[field] === true) { diff --git a/src/helpers.ts b/src/helpers.ts index e7449e7..8a599cf 100644 --- a/src/helpers.ts +++ b/src/helpers.ts @@ -46,6 +46,8 @@ export const caslOperationDict: Record< dataQuery: boolean whereQuery: boolean includeSelectQuery: boolean + // optional fields for certain actions that should be allowed to access + operationFields?: string[] } > = { create: { action: 'create', dataQuery: true, whereQuery: false, includeSelectQuery: true }, @@ -57,9 +59,9 @@ export const caslOperationDict: Record< findMany: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: true }, findUnique: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: true }, findUniqueOrThrow: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: true }, - aggregate: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: false }, + aggregate: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: false, operationFields: ['_min', '_max', '_avg', '_count', '_sum'] }, count: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: false }, - groupBy: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: false }, + groupBy: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: false, operationFields: ['_min', '_max', '_avg', '_count', '_sum'] }, update: { action: 'update', dataQuery: true, whereQuery: true, includeSelectQuery: true }, updateMany: { action: 'update', dataQuery: true, whereQuery: true, includeSelectQuery: false }, delete: { action: 'delete', dataQuery: false, whereQuery: true, includeSelectQuery: true }, diff --git a/src/index.ts b/src/index.ts index 5cbcd43..b4c57d7 100644 --- a/src/index.ts +++ b/src/index.ts @@ -111,7 +111,7 @@ export function useCaslAbilities(getAbilityFactory: () => AbilityBuilder { }) }) + describe('aggregate', () => { + it('can aggregate data', async () => { + function builderFactory() { + const builder = abilityBuilder() + const { can, cannot } = builder + can('read', 'User') + + return builder + } + const client = seedClient.$extends( + useCaslAbilities(builderFactory) + ) + const result = await client.user.aggregate({ + _avg: { id: true }, + _count: { id: true }, + _min: { id: true }, + _max: { id: true }, + _sum: { id: true } + }) + expect(result).toEqual({ + _avg: { id: 0.5 }, + _count: { id: 2 }, + _min: { id: 0 }, + _max: { id: 1 }, + _sum: { id: 1 }, + }) + }) + }) + describe('count', () => { + it('can count data', async () => { + function builderFactory() { + const builder = abilityBuilder() + const { can, cannot } = builder + can('read', 'User') + + return builder + } + const client = seedClient.$extends( + useCaslAbilities(builderFactory) + ) + const result = await client.user.count() + expect(result).toEqual(2) + }) + }) + describe('groupBy', () => { + it('can groupBy data', async () => { + function builderFactory() { + const builder = abilityBuilder() + const { can, cannot } = builder + can('read', 'User') + + return builder + } + const client = seedClient.$extends( + useCaslAbilities(builderFactory) + ) + const result = await client.user.groupBy({ + by: ['email'], + _avg: { id: true }, + _count: { id: true }, + _min: { id: true }, + _max: { id: true }, + _sum: { id: true } + }) + expect(result).toEqual([{ _avg: { id: 0 }, _count: { id: 1 }, _max: { id: 0 }, _min: { id: 0 }, _sum: { id: 0 }, email: "0" }, { _avg: { id: 1 }, _count: { id: 1 }, _max: { id: 1 }, _min: { id: 1 }, _sum: { id: 1 }, email: "1" }]) + }) + }) describe('fluent api queries', () => { it('can do chained queries if abilities exist', async () => { function builderFactory() {