Skip to content

Commit

Permalink
fix: 🐛 fix aggregate and groupBy queries
Browse files Browse the repository at this point in the history
  • Loading branch information
dennemark committed Oct 23, 2024
1 parent b8fc84e commit b50547e
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 7 deletions.
16 changes: 12 additions & 4 deletions src/filterQueryResults.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@ import { AbilityTuple, PureAbility } from "@casl/ability";
import { PrismaQuery } from "@casl/prisma";
import { Prisma } from "@prisma/client";
import { CreationTree } from "./convertCreationTreeToSelect";
import { getPermittedFields, getSubject, isSubset, PrismaExtensionCaslOptions, relationFieldsByModel } from "./helpers";
import { caslOperationDict, getPermittedFields, getSubject, isSubset, PrismaCaslOperation, PrismaExtensionCaslOptions, relationFieldsByModel } from "./helpers";
import { storePermissions } from "./storePermissions";

export function filterQueryResults(result: any, mask: any, creationTree: CreationTree | undefined, abilities: PureAbility<AbilityTuple, PrismaQuery>, model: string, opts?: PrismaExtensionCaslOptions) {
export function filterQueryResults(result: any, mask: any, creationTree: CreationTree | undefined, abilities: PureAbility<AbilityTuple, PrismaQuery>, model: string, operation: PrismaCaslOperation, opts?: PrismaExtensionCaslOptions) {
if (typeof result === 'number') {
return result
}
const prismaModel = model in relationFieldsByModel ? model as Prisma.ModelName : undefined
if (!prismaModel) {
throw new Error(`Model ${model} does not exist on Prisma Client`)
}
const operationFields = caslOperationDict[operation].operationFields

const filterPermittedFields = (entry: any) => {
if (!entry) { return null }
Expand Down Expand Up @@ -52,11 +53,18 @@ export function filterQueryResults(result: any, mask: any, creationTree: Creatio
const permittedFields = getPermittedFields(abilities, 'read', model, entry)

let hasKeys = false
Object.keys(entry).filter((field) => field !== opts?.permissionField).forEach((field) => {
Object.keys(entry).filter((field) => {
if (operationFields?.includes(field)) {
hasKeys = true
return false
} else {
return field !== opts?.permissionField
}
}).forEach((field) => {
const relationField = relationFieldsByModel[model][field]
if (relationField) {
const nestedCreationTree = creationTree && field in creationTree.children ? creationTree.children[field] : undefined
const res = filterQueryResults(entry[field], mask?.[field], nestedCreationTree, abilities, relationField.type)
const res = filterQueryResults(entry[field], mask?.[field], nestedCreationTree, abilities, relationField.type, operation)
entry[field] = Array.isArray(res) ? res.length > 0 ? res : null : res
}
if ((!permittedFields.includes(field) && !relationField) || mask?.[field] === true) {
Expand Down
6 changes: 4 additions & 2 deletions src/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ export const caslOperationDict: Record<
dataQuery: boolean
whereQuery: boolean
includeSelectQuery: boolean
// optional fields for certain actions that should be allowed to access
operationFields?: string[]
}
> = {
create: { action: 'create', dataQuery: true, whereQuery: false, includeSelectQuery: true },
Expand All @@ -57,9 +59,9 @@ export const caslOperationDict: Record<
findMany: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: true },
findUnique: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: true },
findUniqueOrThrow: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: true },
aggregate: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: false },
aggregate: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: false, operationFields: ['_min', '_max', '_avg', '_count', '_sum'] },
count: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: false },
groupBy: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: false },
groupBy: { action: 'read', dataQuery: false, whereQuery: true, includeSelectQuery: false, operationFields: ['_min', '_max', '_avg', '_count', '_sum'] },
update: { action: 'update', dataQuery: true, whereQuery: true, includeSelectQuery: true },
updateMany: { action: 'update', dataQuery: true, whereQuery: true, includeSelectQuery: false },
delete: { action: 'delete', dataQuery: false, whereQuery: true, includeSelectQuery: true },
Expand Down
2 changes: 1 addition & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ export function useCaslAbilities(getAbilityFactory: () => AbilityBuilder<PureAbi
// on fluent models we need to take mask of the relation
caslQuery.mask = fluentRelationModel && fluentRelationModel in caslQuery.mask ? caslQuery.mask[fluentRelationModel] : {}
}
const filteredResult = filterQueryResults(result, caslQuery.mask, caslQuery.creationTree, abilities, fluentModel, opts)
const filteredResult = filterQueryResults(result, caslQuery.mask, caslQuery.creationTree, abilities, fluentModel, op, opts)

if (perf) {
perf.mark('prisma-casl-extension-4')
Expand Down
67 changes: 67 additions & 0 deletions test/extension.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1811,6 +1811,73 @@ describe('prisma extension casl', () => {
})

})
describe('aggregate', () => {
it('can aggregate data', async () => {
function builderFactory() {
const builder = abilityBuilder()
const { can, cannot } = builder
can('read', 'User')

return builder
}
const client = seedClient.$extends(
useCaslAbilities(builderFactory)
)
const result = await client.user.aggregate({
_avg: { id: true },
_count: { id: true },
_min: { id: true },
_max: { id: true },
_sum: { id: true }
})
expect(result).toEqual({
_avg: { id: 0.5 },
_count: { id: 2 },
_min: { id: 0 },
_max: { id: 1 },
_sum: { id: 1 },
})
})
})
describe('count', () => {
it('can count data', async () => {
function builderFactory() {
const builder = abilityBuilder()
const { can, cannot } = builder
can('read', 'User')

return builder
}
const client = seedClient.$extends(
useCaslAbilities(builderFactory)
)
const result = await client.user.count()
expect(result).toEqual(2)
})
})
describe('groupBy', () => {
it('can groupBy data', async () => {
function builderFactory() {
const builder = abilityBuilder()
const { can, cannot } = builder
can('read', 'User')

return builder
}
const client = seedClient.$extends(
useCaslAbilities(builderFactory)
)
const result = await client.user.groupBy({
by: ['email'],
_avg: { id: true },
_count: { id: true },
_min: { id: true },
_max: { id: true },
_sum: { id: true }
})
expect(result).toEqual([{ _avg: { id: 0 }, _count: { id: 1 }, _max: { id: 0 }, _min: { id: 0 }, _sum: { id: 0 }, email: "0" }, { _avg: { id: 1 }, _count: { id: 1 }, _max: { id: 1 }, _min: { id: 1 }, _sum: { id: 1 }, email: "1" }])
})
})
describe('fluent api queries', () => {
it('can do chained queries if abilities exist', async () => {
function builderFactory() {
Expand Down

0 comments on commit b50547e

Please sign in to comment.