Skip to content

Commit

Permalink
Merge pull request #431 from atlanhq/FT-814
Browse files Browse the repository at this point in the history
FT-814:Add projection support for groups endpoint
  • Loading branch information
Aryamanz29 authored Dec 3, 2024
2 parents edd5b46 + deb8fbe commit bd94635
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 6 deletions.
5 changes: 4 additions & 1 deletion pyatlan/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

ROLE_API = "roles"
GROUP_API = "groups"
GROUP_API_V2 = "v2/groups"
USER_API = "users"
QUERY_API = "query"
IMAGE_API = "images"
Expand All @@ -25,7 +26,9 @@
GET_ROLES = API(ROLE_API, HTTPMethod.GET, HTTPStatus.OK, endpoint=EndPoint.HERACLES)

# Group APIs
GET_GROUPS = API(GROUP_API, HTTPMethod.GET, HTTPStatus.OK, endpoint=EndPoint.HERACLES)
GET_GROUPS = API(
GROUP_API_V2, HTTPMethod.GET, HTTPStatus.OK, endpoint=EndPoint.HERACLES
)
CREATE_GROUP = API(
GROUP_API, HTTPMethod.POST, HTTPStatus.OK, endpoint=EndPoint.HERACLES
)
Expand Down
18 changes: 14 additions & 4 deletions pyatlan/client/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def get(
sort: Optional[str] = None,
count: bool = True,
offset: int = 0,
columns: Optional[List[str]] = None,
) -> GroupResponse:
"""
Retrieves a GroupResponse object which contains a list of the groups defined in Atlan.
Expand All @@ -107,11 +108,17 @@ def get(
:param sort: property by which to sort the results
:param count: whether to return the total number of records (True) or not (False)
:param offset: starting point for results to return, for paging
:param columns: provides columns projection support for groups endpoint
:returns: a GroupResponse object which contains a list of groups that match the provided criteria
:raises AtlanError: on any API communication issue
"""
request = GroupRequest(
post_filter=post_filter, limit=limit, sort=sort, count=count, offset=offset
post_filter=post_filter,
limit=limit,
sort=sort,
count=count,
offset=offset,
columns=columns,
)
endpoint = GET_GROUPS.format_path_with_params()
raw_json = self._client._call_api(
Expand All @@ -134,17 +141,20 @@ def get_all(
limit: int = 20,
offset: int = 0,
sort: Optional[str] = "name",
columns: Optional[List[str]] = None,
) -> List[AtlanGroup]:
"""
Retrieve all groups defined in Atlan.
:param limit: maximum number of results to be returned
:param offset: starting point for the list of groups when paging
:param sort: property by which to sort the results, by default : `name`
:param sort: property by which to sort the results, by default : name
:param columns: provides columns projection support for groups endpoint
:returns: a list of all the groups in Atlan
"""
response: GroupResponse = self.get(offset=offset, limit=limit, sort=sort)
return [group for group in response]
if response := self.get(offset=offset, limit=limit, sort=sort, columns=columns):
return response.records # type: ignore
return None # type: ignore

@validate_arguments
def get_by_name(
Expand Down
7 changes: 7 additions & 0 deletions pyatlan/model/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class Attributes(AtlanObject):
attributes: Optional[AtlanGroup.Attributes] = Field(
default=None, description="Detailed attributes of the group."
)
roles: Optional[List[str]] = Field(default=None, description="TBC")
decentralized_roles: Optional[List[Any]] = Field(default=None, description="TBC")
id: Optional[str] = Field(
default=None, description="Unique identifier for the group (GUID)."
Expand Down Expand Up @@ -192,6 +193,10 @@ class GroupRequest(AtlanObject):
default=20,
description="Maximum number of groups to return per page.",
)
columns: Optional[List[str]] = Field(
default=None,
description="List of specific fields to include in the response.",
)

@property
def query_params(self) -> dict:
Expand All @@ -203,6 +208,8 @@ def query_params(self) -> dict:
qp["count"] = self.count
qp["offset"] = self.offset
qp["limit"] = self.limit
if self.columns:
qp["columns"] = self.columns
return qp


Expand Down
57 changes: 56 additions & 1 deletion tests/integration/admin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_retrieve_all_groups(client: AtlanClient, group: CreateGroupResponse):


def test_group_get_pagination(client: AtlanClient, group: CreateGroupResponse):
response = client.group.get(limit=1)
response = client.group.get(limit=1, count=True)

assert response
assert response.total_record is not None
Expand Down Expand Up @@ -295,3 +295,58 @@ def test_retrieve_admin_logs(
if count >= 1000:
break
assert count > 0


def test_get_all_with_limit(client: AtlanClient, group: CreateGroupResponse):
limit = 2
groups = client.group.get_all(limit=limit)
assert groups
assert len(groups) == limit

for group1 in groups:
assert group1.id
assert group1.name
assert group1.path is not None


def test_get_all_with_columns(client: AtlanClient, group: CreateGroupResponse):
columns = ["path"]
groups = client.group.get_all(columns=columns)

assert groups
assert len(groups) >= 1

for group1 in groups:
assert group1.name
assert group1.path is not None
assert group1.attributes is None
assert group1.roles is None


def test_get_all_with_sorting(client: AtlanClient, group: CreateGroupResponse):
groups = client.group.get_all(sort="name")

assert groups
assert len(groups) >= 1

sorted_names = [group.name for group in groups if group.name is not None]
assert sorted_names == sorted(sorted_names)


def test_get_all_with_everything(client: AtlanClient, group: CreateGroupResponse):
limit = 2
columns = ["path", "attributes"]
sort = "name"

groups = client.group.get_all(limit=limit, columns=columns, sort=sort)

assert groups
assert len(groups) == limit
sorted_names = [group.name for group in groups if group.name is not None]
assert sorted_names == sorted(sorted_names)

for group1 in groups:
assert group1.name
assert group1.path is not None
assert group1.roles is None
assert group1.attributes is not None
75 changes: 75 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ def client():
return AtlanClient()


@pytest.fixture
def group_client(mock_api_caller):
return GroupClient(client=mock_api_caller)


@pytest.fixture
def mock_asset_client():
return Mock(AssetClient)
Expand Down Expand Up @@ -2219,3 +2224,73 @@ def test_atlan_client_headers(client: AtlanClient):
"x-atlan-client-origin": "product_sdk",
}
assert expected == client._session.headers


def test_get_all_pagation(group_client, mock_api_caller):
mock_page_1 = [
{"id": "1", "alias": "Group3"},
{"id": "2", "alias": "Group4"},
]
mock_api_caller._call_api.side_effect = [
{"records": mock_page_1},
]

groups = group_client.get_all(limit=2)

assert len(groups) == 2
assert groups[0].id == "1"
assert groups[1].id == "2"
assert mock_api_caller._call_api.call_count == 1
mock_api_caller.reset_mock()


def test_get_all_empty_response_with_raw_records(group_client, mock_api_caller):
mock_page_1 = []
mock_api_caller._call_api.side_effect = [
{"records": mock_page_1},
]

groups = group_client.get_all()
assert len(groups) == 0
mock_api_caller.reset_mock()


def test_get_all_with_columns(group_client, mock_api_caller):
mock_page_1 = [
{"id": "1", "alias": "Group1"},
{"id": "2", "alias": "Group2"},
]
mock_api_caller._call_api.side_effect = [
{"records": mock_page_1},
]

columns = ["alias"]
groups = group_client.get_all(limit=10, columns=columns)

assert len(groups) == 2
assert groups[0].id == "1"
assert groups[0].alias == "Group1"
mock_api_caller._call_api.assert_called_once()
query_params = mock_api_caller._call_api.call_args.kwargs["query_params"]
assert query_params["columns"] == columns
mock_api_caller.reset_mock()


def test_get_all_sorting(group_client, mock_api_caller):
mock_page_1 = [
{"id": "1", "alias": "Group1"},
{"id": "2", "alias": "Group2"},
]
mock_api_caller._call_api.side_effect = [
{"records": mock_page_1},
]

groups = group_client.get_all(limit=10, sort="alias")

assert len(groups) == 2
assert groups[0].id == "1"
assert groups[0].alias == "Group1"
mock_api_caller._call_api.assert_called_once()
query_params = mock_api_caller._call_api.call_args.kwargs["query_params"]
assert query_params["sort"] == "alias"
mock_api_caller.reset_mock()

0 comments on commit bd94635

Please sign in to comment.