Skip to content

Commit

Permalink
Count (#1145)
Browse files Browse the repository at this point in the history
* FEAT added count endpoint

* Fixed bug in count endpoint when no filters have been provided

* Fixed helper tests

* FEAT made test for count endpoint

* FEAT added possibility to join in count filters

* FEAT added the possibility to create complex filters in count endpoint. By adding the possibility to expand and filter on expanded objects

* Removed commented out code
  • Loading branch information
jessevz authored Dec 6, 2024
1 parent bfd39df commit eed46ff
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 22 deletions.
26 changes: 22 additions & 4 deletions ci/apiv2/hashtopolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,19 @@ def delete(self, obj):

# TODO: Cleanup object to allow re-creation

def count(self, filter):
self.authenticate()
uri = self._api_endpoint + self._model_uri + "/count"
headers = self._headers
payload = {}
if filter:
for k, v in filter.items():
payload[f"filter[{k}]"] = v

logger.debug("Sending GET payload: %s to %s", json.dumps(payload), uri)
r = requests.get(uri, headers=headers, params=payload)
self.validate_status_code(r, [200], "Getting count failed")
return self.resp_to_json(r)['meta']

# Build Django ORM style django.query interface
class QuerySet():
Expand Down Expand Up @@ -434,6 +447,11 @@ def get_first(cls):
@classmethod
def get(cls, **filters):
return QuerySet(cls, filters=filters).get()

@classmethod
def count(cls, **filters):
return cls.get_conn().count(filter=filters)


@classmethod
def paginate(cls, **pages):
Expand Down Expand Up @@ -912,7 +930,7 @@ def import_cracked_hashes(self, hashlist, source_data, separator):
'separator': separator,
}
response = self._helper_request("importCrackedHashes", payload)
return response['data']
return response['meta']

def get_file(self, file, range=None):
payload = {
Expand All @@ -925,19 +943,19 @@ def recount_file_lines(self, file):
'fileId': file.id,
}
response = self._helper_request("recountFileLines", payload)
return File(**response['data'])
return File(**response['meta'])

def unassign_agent(self, agent):
payload = {
'agentId': agent.id,
}
response = self._helper_request("unassignAgent", payload)
return response['data']
return response['meta']

def assign_agent(self, agent, task):
payload = {
'agentId': agent.id,
'taskId': task.id,
}
response = self._helper_request("assignAgent", payload)
return response['data']
return response['meta']
23 changes: 23 additions & 0 deletions ci/apiv2/test_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from hashtopolis import HashType
from utils import BaseTest


class CountTest(BaseTest):
model_class = HashType

def create_test_objects(self, **kwargs):
objs = []
for i in range(90000, 90100, 10):
obj = HashType(hashTypeId=i,
description=f"Dummy HashType {i}",
isSalted=(i < 90050),
isSlowHash=False).save()
objs.append(obj)
self.delete_after_test(obj)
return objs

def test_count(self):
model_objs = self.create_test_objects()
model_count = len(model_objs)
api_count = HashType.objects.count(hashTypeId__gte=90000, hashTypeId__lte=91000)['count']
self.assertEqual(model_count, api_count)
19 changes: 19 additions & 0 deletions src/dba/AbstractModelFactory.class.php
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ public function countFilter($options) {
$query = $query . " FROM " . $this->getModelTable();

$vals = array();

if (array_key_exists('join', $options)) {
$query .= $this->applyJoins($options['join']);
}

if (array_key_exists("filter", $options)) {
$query .= $this->applyFilters($vals, $options['filter']);
Expand Down Expand Up @@ -750,6 +754,21 @@ private function applyOrder($orders) {
return " ORDER BY " . implode(", ", $orderQueries);
}

private function applyJoins($joins) {
$query = "";
foreach ($joins as $join) {
$joinFactory = $join->getOtherFactory();
$localFactory = $this;
if ($join->getOverrideOwnFactory() != null) {
$localFactory = $join->getOverrideOwnFactory();
}
$match1 = $join->getMatch1();
$match2 = $join->getMatch2();
$query .= " INNER JOIN " . $joinFactory->getModelTable() . " ON " . $localFactory->getModelTable() . "." . $match1 . "=" . $joinFactory->getModelTable() . "." . $match2 . " ";
}
return $query;
}

//applylimit is slightly different than the other apply functions, since you can only limit by a single value
//the $limit argument is a single object LimitFilter object instead of an array of objects.
private function applyLimit($limit) {
Expand Down
48 changes: 31 additions & 17 deletions src/inc/apiv2/common/AbstractBaseAPI.class.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use DBA\AgentStat;
use DBA\Assignment;
use DBA\Chunk;
use DBA\ComparisonFilter;
use DBA\Config;
use DBA\ConfigSection;
use DBA\CrackerBinary;
Expand Down Expand Up @@ -875,14 +876,19 @@ protected function getPrimaryKey(): string
}
}

function getFilters(Request $request) {
return $this->getQueryParameterFamily($request, 'filter');
}

/**
* Check for valid filter parameters and build QueryFilter
*/
protected function makeFilter(Request $request, array $features): array
// protected function makeFilter(Request $request, array $features): array
protected function makeFilter(array $filters, object $apiClass): array
{
$qFs = [];

$filters = $this->getQueryParameterFamily($request, 'filter');
$qFs = [];
$features = $apiClass->getAliasedFeatures();
$factory = $apiClass->getFactory();
foreach ($filters as $filter => $value) {

if (preg_match('/^(?P<key>[_a-zA-Z0-9]+?)(?<operator>|__eq|__ne|__lt|__lte|__gt|__gte|__contains|__startswith|__endswith|__icontains|__istartswith|__iendswith)$/', $filter, $matches) == 0) {
Expand Down Expand Up @@ -919,44 +925,52 @@ protected function makeFilter(Request $request, array $features): array
switch($matches['operator']) {
case '':
case '__eq':
array_push($qFs, new QueryFilter($remappedKey, $val, '='));
$operator = '=';
break;
case '__ne':
array_push($qFs, new QueryFilter($remappedKey, $val, '!='));
$operator = '!=';
break;
case '__lt':
array_push($qFs, new QueryFilter($remappedKey, $val, '<'));
$operator = '<';
break;
case '__lte':
array_push($qFs, new QueryFilter($remappedKey, $val, '<='));
$operator = '<=';
break;
case '__gt':
array_push($qFs, new QueryFilter($remappedKey, $val, '>'));
$operator = '>';
break;
case '__gte':
array_push($qFs, new QueryFilter($remappedKey, $val, '>='));
$operator = '>=';
break;
case '__contains':
array_push($qFs, new LikeFilter($remappedKey, "%" . $val . "%"));
array_push($qFs, new LikeFilter($remappedKey, "%" . $val . "%", $factory));
break;
case '__startswith':
array_push($qFs, new LikeFilter($remappedKey, $val . "%"));
array_push($qFs, new LikeFilter($remappedKey, $val . "%", $factory));
break;
case '__endswith':
array_push($qFs, new LikeFilter($remappedKey, "%" . $val));
array_push($qFs, new LikeFilter($remappedKey, "%" . $val, $factory));
break;
case '__icontains':
array_push($qFs, new LikeFilterInsensitive($remappedKey, "%" . $val . "%"));
array_push($qFs, new LikeFilterInsensitive($remappedKey, "%" . $val . "%", $factory));
break;
case '__istartswith':
array_push($qFs, new LikeFilterInsensitive($remappedKey, $val . "%"));
array_push($qFs, new LikeFilterInsensitive($remappedKey, $val . "%", $factory));
break;
case '__iendswith':
array_push($qFs, new LikeFilterInsensitive($remappedKey, "%" . $val));
array_push($qFs, new LikeFilterInsensitive($remappedKey, "%" . $val, $factory));
break;
default:
assert(False, "Operator '" . $matches['operator'] . "' not implemented");
}

if ($operator) {
if (array_key_exists($val, $features)) {
array_push($qFs, new ComparisonFilter($remappedKey, $val, $operator, $factory));
} else {
array_push($qFs, new QueryFilter($remappedKey, $val, $operator, $factory));
}
}
}
return $qFs;
}
Expand Down Expand Up @@ -1257,7 +1271,7 @@ protected static function getOneResource(object $apiClass, object $object, Reque

//Meta response for helper functions that do not respond with resource records
protected static function getMetaResponse(array $meta, Request $request, Response $response, int $statusCode=200) {
$ret = self::createJsonResponse($meta=$meta);
$ret = self::createJsonResponse(meta: $meta);
$body = $response->getBody();
$body->write(self::ret2json($ret));

Expand Down
106 changes: 105 additions & 1 deletion src/inc/apiv2/common/AbstractModelAPI.class.php
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ public static function getManyResources(object $apiClass, Request $request, Resp
$aFs = [];

/* Generate filters */
$qFs_Filter = $apiClass->makeFilter($request, $aliasedfeatures);
$filters = $apiClass->getFilters($request);
$qFs_Filter = $apiClass->makeFilter($filters, $apiClass);
$qFs_ACL = $apiClass->getFilterACL();
$qFs = array_merge($qFs_ACL, $qFs_Filter);
if (count($qFs) > 0) {
Expand Down Expand Up @@ -559,6 +560,108 @@ public function get(Request $request, Response $response, array $args): Response
return self::getManyResources($this, $request, $response);
}

/**
* Maps filters to the appropiate models based on their feautures.
*
* Helper function to get valid filters for the models. This is usefull when multiple objects
* have been included and the correct filters need to be mapped to the correct objects.
* Currently used to make complex filters for counting objects
*
* @param array $filters An associative array of filters where the key is the filter
* name and the value is the filter value. Filters should match
* the pattern `<field><operator>`, where `<operator>` can be
* one of the supported suffixes (e.g., `__eq`, `__ne`).
* @param array $models An array of model objects. Each model must have a `getFeatures()`
* method that returns an associative array of model features.
* The features should map filter keys to their respective
* attributes or aliases.
*
* @return array An associative array mapping model classes to their respective valid filters.
* The structure is:
* [
* ModelClassName => [
* 'filter' => 'value',
* ...
* ],
* ...
* ]
*
* @throws HTException If a filter key does not match the expected format or is invalid.
*/
public function filterObjectMap(array $filters, array $models) {

$modelFilterMap = [];
foreach ($filters as $filter => $value) {
if (preg_match('/^(?P<key>[_a-zA-Z0-9]+?)(?<operator>|__eq|__ne|__lt|__lte|__gt|__gte|__contains|__startswith|__endswith|__icontains|__istartswith|__iendswith)$/', $filter, $matches) == 0) {
throw new HTException("Filter parameter '" . $filter . "' is not valid");
}

foreach($models as $model) {
$features = $model->getFeatures();
// Special filtering of _id to use for uniform access to model primary key
$cast_key = $matches['key'] == '_id' ? array_column($features, 'alias', 'dbname')[$this->getPrimaryKey()] : $matches['key'];
if (array_key_exists($cast_key, $features) == false) {
continue; //not a valid filter for current model
};
$modelFilterMap[$model::class][$filter] = $value;
break; //filter has been found for current model, so break to go to next filter
}
}
return $modelFilterMap;
}

/**
* API entry point for retrieving count information of data
*/
public function count(Request $request, Response $response, array $args): Response
{
$this->preCommon($request);
$factory = $this->getFactory();

//resolve all expandables
$validExpandables = $this::getExpandables();
$expands = $this->makeExpandables($request, $validExpandables);

$objects = [$factory->getNullObject()];
//build join filters
foreach ($expands as $expand) {
$relation = $this->getToManyRelationships()[$expand];
$objects[] = $this->getModelFactory($relation["relationType"])->getNullObject();
$otherFactory = $this->getModelFactory($relation["relationType"]);
$primaryKey = $this->getPrimaryKey();
$aFs[Factory::JOIN][] = new JoinFilter($otherFactory, $relation["relationKey"], $primaryKey, $factory);
}

$filters = $this->getFilters($request);
$filterObjectMap = $this->filterObjectMap($filters, $objects);
$qFs = [];
foreach($filterObjectMap as $class => $cur_filters) {
$relationApiClass = new ($this->container->get('classMapper')->get($class))($this->container);
$current_qFs = $this->makeFilter($cur_filters, $relationApiClass);
$qFs = array_merge($qFs, $current_qFs);
}

if (count($qFs) > 0) {
$aFs[Factory::FILTER] = $qFs;
}

$count = $factory->countFilter($aFs);
$meta = ["count" => $count];

$include_total = $request->getQueryParams()['include_total'];
if ($include_total == "true") {
$meta["total_count"] = $factory->countFilter([]);
}

$ret = self::createJsonResponse(meta: $meta);

$body = $response->getBody();
$body->write($this->ret2json($ret));

return $response->withStatus(200)
->withHeader("Content-Type", 'application/vnd.api+json');
}

/**
* Get input field names valid for creation of object
*/
Expand Down Expand Up @@ -1106,6 +1209,7 @@ static public function register($app): void

if (in_array("GET", $available_methods)) {
$app->get($baseUri, $me . ':get')->setname($me . ':get');
$app->get($baseUri . "/count", $me . ':count')->setname($me . ':count');
}

foreach ($me::getToOneRelationships() as $name => $relationship) {
Expand Down

0 comments on commit eed46ff

Please sign in to comment.