diff --git a/ci/apiv2/hashtopolis.py b/ci/apiv2/hashtopolis.py index 66b7aa95..b29d4fe3 100644 --- a/ci/apiv2/hashtopolis.py +++ b/ci/apiv2/hashtopolis.py @@ -291,6 +291,19 @@ def delete(self, obj): # TODO: Cleanup object to allow re-creation + def count(self, filter): + self.authenticate() + uri = self._api_endpoint + self._model_uri + "/count" + headers = self._headers + payload = {} + if filter: + for k, v in filter.items(): + payload[f"filter[{k}]"] = v + + logger.debug("Sending GET payload: %s to %s", json.dumps(payload), uri) + r = requests.get(uri, headers=headers, params=payload) + self.validate_status_code(r, [200], "Getting count failed") + return self.resp_to_json(r)['meta'] # Build Django ORM style django.query interface class QuerySet(): @@ -434,6 +447,11 @@ def get_first(cls): @classmethod def get(cls, **filters): return QuerySet(cls, filters=filters).get() + + @classmethod + def count(cls, **filters): + return cls.get_conn().count(filter=filters) + @classmethod def paginate(cls, **pages): @@ -912,7 +930,7 @@ def import_cracked_hashes(self, hashlist, source_data, separator): 'separator': separator, } response = self._helper_request("importCrackedHashes", payload) - return response['data'] + return response['meta'] def get_file(self, file, range=None): payload = { @@ -925,14 +943,14 @@ def recount_file_lines(self, file): 'fileId': file.id, } response = self._helper_request("recountFileLines", payload) - return File(**response['data']) + return File(**response['meta']) def unassign_agent(self, agent): payload = { 'agentId': agent.id, } response = self._helper_request("unassignAgent", payload) - return response['data'] + return response['meta'] def assign_agent(self, agent, task): payload = { @@ -940,4 +958,4 @@ def assign_agent(self, agent, task): 'taskId': task.id, } response = self._helper_request("assignAgent", payload) - return response['data'] + return response['meta'] diff --git a/ci/apiv2/test_count.py b/ci/apiv2/test_count.py new file mode 100644 index 00000000..82c0312d --- /dev/null +++ b/ci/apiv2/test_count.py @@ -0,0 +1,23 @@ +from hashtopolis import HashType +from utils import BaseTest + + +class CountTest(BaseTest): + model_class = HashType + + def create_test_objects(self, **kwargs): + objs = [] + for i in range(90000, 90100, 10): + obj = HashType(hashTypeId=i, + description=f"Dummy HashType {i}", + isSalted=(i < 90050), + isSlowHash=False).save() + objs.append(obj) + self.delete_after_test(obj) + return objs + + def test_count(self): + model_objs = self.create_test_objects() + model_count = len(model_objs) + api_count = HashType.objects.count(hashTypeId__gte=90000, hashTypeId__lte=91000)['count'] + self.assertEqual(model_count, api_count) diff --git a/src/dba/AbstractModelFactory.class.php b/src/dba/AbstractModelFactory.class.php index 1a0612a0..b7d2bc16 100755 --- a/src/dba/AbstractModelFactory.class.php +++ b/src/dba/AbstractModelFactory.class.php @@ -471,6 +471,10 @@ public function countFilter($options) { $query = $query . " FROM " . $this->getModelTable(); $vals = array(); + + if (array_key_exists('join', $options)) { + $query .= $this->applyJoins($options['join']); + } if (array_key_exists("filter", $options)) { $query .= $this->applyFilters($vals, $options['filter']); @@ -750,6 +754,21 @@ private function applyOrder($orders) { return " ORDER BY " . implode(", ", $orderQueries); } + private function applyJoins($joins) { + $query = ""; + foreach ($joins as $join) { + $joinFactory = $join->getOtherFactory(); + $localFactory = $this; + if ($join->getOverrideOwnFactory() != null) { + $localFactory = $join->getOverrideOwnFactory(); + } + $match1 = $join->getMatch1(); + $match2 = $join->getMatch2(); + $query .= " INNER JOIN " . $joinFactory->getModelTable() . " ON " . $localFactory->getModelTable() . "." . $match1 . "=" . $joinFactory->getModelTable() . "." . $match2 . " "; + } + return $query; + } + //applylimit is slightly different than the other apply functions, since you can only limit by a single value //the $limit argument is a single object LimitFilter object instead of an array of objects. private function applyLimit($limit) { diff --git a/src/inc/apiv2/common/AbstractBaseAPI.class.php b/src/inc/apiv2/common/AbstractBaseAPI.class.php index 1a15e1df..87d0ccb5 100644 --- a/src/inc/apiv2/common/AbstractBaseAPI.class.php +++ b/src/inc/apiv2/common/AbstractBaseAPI.class.php @@ -15,6 +15,7 @@ use DBA\AgentStat; use DBA\Assignment; use DBA\Chunk; +use DBA\ComparisonFilter; use DBA\Config; use DBA\ConfigSection; use DBA\CrackerBinary; @@ -875,14 +876,19 @@ protected function getPrimaryKey(): string } } + function getFilters(Request $request) { + return $this->getQueryParameterFamily($request, 'filter'); + } + /** * Check for valid filter parameters and build QueryFilter */ - protected function makeFilter(Request $request, array $features): array + // protected function makeFilter(Request $request, array $features): array + protected function makeFilter(array $filters, object $apiClass): array { - $qFs = []; - - $filters = $this->getQueryParameterFamily($request, 'filter'); + $qFs = []; + $features = $apiClass->getAliasedFeatures(); + $factory = $apiClass->getFactory(); foreach ($filters as $filter => $value) { if (preg_match('/^(?P[_a-zA-Z0-9]+?)(?|__eq|__ne|__lt|__lte|__gt|__gte|__contains|__startswith|__endswith|__icontains|__istartswith|__iendswith)$/', $filter, $matches) == 0) { @@ -919,44 +925,52 @@ protected function makeFilter(Request $request, array $features): array switch($matches['operator']) { case '': case '__eq': - array_push($qFs, new QueryFilter($remappedKey, $val, '=')); + $operator = '='; break; case '__ne': - array_push($qFs, new QueryFilter($remappedKey, $val, '!=')); + $operator = '!='; break; case '__lt': - array_push($qFs, new QueryFilter($remappedKey, $val, '<')); + $operator = '<'; break; case '__lte': - array_push($qFs, new QueryFilter($remappedKey, $val, '<=')); + $operator = '<='; break; case '__gt': - array_push($qFs, new QueryFilter($remappedKey, $val, '>')); + $operator = '>'; break; case '__gte': - array_push($qFs, new QueryFilter($remappedKey, $val, '>=')); + $operator = '>='; break; case '__contains': - array_push($qFs, new LikeFilter($remappedKey, "%" . $val . "%")); + array_push($qFs, new LikeFilter($remappedKey, "%" . $val . "%", $factory)); break; case '__startswith': - array_push($qFs, new LikeFilter($remappedKey, $val . "%")); + array_push($qFs, new LikeFilter($remappedKey, $val . "%", $factory)); break; case '__endswith': - array_push($qFs, new LikeFilter($remappedKey, "%" . $val)); + array_push($qFs, new LikeFilter($remappedKey, "%" . $val, $factory)); break; case '__icontains': - array_push($qFs, new LikeFilterInsensitive($remappedKey, "%" . $val . "%")); + array_push($qFs, new LikeFilterInsensitive($remappedKey, "%" . $val . "%", $factory)); break; case '__istartswith': - array_push($qFs, new LikeFilterInsensitive($remappedKey, $val . "%")); + array_push($qFs, new LikeFilterInsensitive($remappedKey, $val . "%", $factory)); break; case '__iendswith': - array_push($qFs, new LikeFilterInsensitive($remappedKey, "%" . $val)); + array_push($qFs, new LikeFilterInsensitive($remappedKey, "%" . $val, $factory)); break; default: assert(False, "Operator '" . $matches['operator'] . "' not implemented"); } + + if ($operator) { + if (array_key_exists($val, $features)) { + array_push($qFs, new ComparisonFilter($remappedKey, $val, $operator, $factory)); + } else { + array_push($qFs, new QueryFilter($remappedKey, $val, $operator, $factory)); + } + } } return $qFs; } @@ -1257,7 +1271,7 @@ protected static function getOneResource(object $apiClass, object $object, Reque //Meta response for helper functions that do not respond with resource records protected static function getMetaResponse(array $meta, Request $request, Response $response, int $statusCode=200) { - $ret = self::createJsonResponse($meta=$meta); + $ret = self::createJsonResponse(meta: $meta); $body = $response->getBody(); $body->write(self::ret2json($ret)); diff --git a/src/inc/apiv2/common/AbstractModelAPI.class.php b/src/inc/apiv2/common/AbstractModelAPI.class.php index 9a07c6ec..4caea634 100644 --- a/src/inc/apiv2/common/AbstractModelAPI.class.php +++ b/src/inc/apiv2/common/AbstractModelAPI.class.php @@ -392,7 +392,8 @@ public static function getManyResources(object $apiClass, Request $request, Resp $aFs = []; /* Generate filters */ - $qFs_Filter = $apiClass->makeFilter($request, $aliasedfeatures); + $filters = $apiClass->getFilters($request); + $qFs_Filter = $apiClass->makeFilter($filters, $apiClass); $qFs_ACL = $apiClass->getFilterACL(); $qFs = array_merge($qFs_ACL, $qFs_Filter); if (count($qFs) > 0) { @@ -559,6 +560,108 @@ public function get(Request $request, Response $response, array $args): Response return self::getManyResources($this, $request, $response); } + /** + * Maps filters to the appropiate models based on their feautures. + * + * Helper function to get valid filters for the models. This is usefull when multiple objects + * have been included and the correct filters need to be mapped to the correct objects. + * Currently used to make complex filters for counting objects + * + * @param array $filters An associative array of filters where the key is the filter + * name and the value is the filter value. Filters should match + * the pattern ``, where `` can be + * one of the supported suffixes (e.g., `__eq`, `__ne`). + * @param array $models An array of model objects. Each model must have a `getFeatures()` + * method that returns an associative array of model features. + * The features should map filter keys to their respective + * attributes or aliases. + * + * @return array An associative array mapping model classes to their respective valid filters. + * The structure is: + * [ + * ModelClassName => [ + * 'filter' => 'value', + * ... + * ], + * ... + * ] + * + * @throws HTException If a filter key does not match the expected format or is invalid. + */ + public function filterObjectMap(array $filters, array $models) { + + $modelFilterMap = []; + foreach ($filters as $filter => $value) { + if (preg_match('/^(?P[_a-zA-Z0-9]+?)(?|__eq|__ne|__lt|__lte|__gt|__gte|__contains|__startswith|__endswith|__icontains|__istartswith|__iendswith)$/', $filter, $matches) == 0) { + throw new HTException("Filter parameter '" . $filter . "' is not valid"); + } + + foreach($models as $model) { + $features = $model->getFeatures(); + // Special filtering of _id to use for uniform access to model primary key + $cast_key = $matches['key'] == '_id' ? array_column($features, 'alias', 'dbname')[$this->getPrimaryKey()] : $matches['key']; + if (array_key_exists($cast_key, $features) == false) { + continue; //not a valid filter for current model + }; + $modelFilterMap[$model::class][$filter] = $value; + break; //filter has been found for current model, so break to go to next filter + } + } + return $modelFilterMap; + } + + /** + * API entry point for retrieving count information of data + */ + public function count(Request $request, Response $response, array $args): Response + { + $this->preCommon($request); + $factory = $this->getFactory(); + + //resolve all expandables + $validExpandables = $this::getExpandables(); + $expands = $this->makeExpandables($request, $validExpandables); + + $objects = [$factory->getNullObject()]; + //build join filters + foreach ($expands as $expand) { + $relation = $this->getToManyRelationships()[$expand]; + $objects[] = $this->getModelFactory($relation["relationType"])->getNullObject(); + $otherFactory = $this->getModelFactory($relation["relationType"]); + $primaryKey = $this->getPrimaryKey(); + $aFs[Factory::JOIN][] = new JoinFilter($otherFactory, $relation["relationKey"], $primaryKey, $factory); + } + + $filters = $this->getFilters($request); + $filterObjectMap = $this->filterObjectMap($filters, $objects); + $qFs = []; + foreach($filterObjectMap as $class => $cur_filters) { + $relationApiClass = new ($this->container->get('classMapper')->get($class))($this->container); + $current_qFs = $this->makeFilter($cur_filters, $relationApiClass); + $qFs = array_merge($qFs, $current_qFs); + } + + if (count($qFs) > 0) { + $aFs[Factory::FILTER] = $qFs; + } + + $count = $factory->countFilter($aFs); + $meta = ["count" => $count]; + + $include_total = $request->getQueryParams()['include_total']; + if ($include_total == "true") { + $meta["total_count"] = $factory->countFilter([]); + } + + $ret = self::createJsonResponse(meta: $meta); + + $body = $response->getBody(); + $body->write($this->ret2json($ret)); + + return $response->withStatus(200) + ->withHeader("Content-Type", 'application/vnd.api+json'); + } + /** * Get input field names valid for creation of object */ @@ -1106,6 +1209,7 @@ static public function register($app): void if (in_array("GET", $available_methods)) { $app->get($baseUri, $me . ':get')->setname($me . ':get'); + $app->get($baseUri . "/count", $me . ':count')->setname($me . ':count'); } foreach ($me::getToOneRelationships() as $name => $relationship) {