Skip to content

Commit

Permalink
feat(improve-api-endpoints): Added Datasets and Annotation APIs
Browse files Browse the repository at this point in the history
- Added AnnotationReplyActionApi, AnnotationListApi, AnnotationUpdateDeleteApi
- Added GET and PATCH endpoints to DatasetApi
- Applied current_user in validate_app_token just like what it is done in validate_dataset_token

Related to: #11727
  • Loading branch information
jasonfish568 committed Dec 30, 2024
1 parent fc29f20 commit a47b514
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 5 deletions.
2 changes: 1 addition & 1 deletion api/controllers/service_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
api = ExternalApi(bp)

from . import index
from .app import app, audio, completion, conversation, file, message, workflow
from .app import annotation, app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, hit_testing, segment
94 changes: 94 additions & 0 deletions api/controllers/service_api/app/annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@

from flask import request
from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden

from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.annotation_fields import (
annotation_fields,
)
from libs.login import current_user
from models.model import App, EndUser
from services.annotation_service import AppAnnotationService


class AnnotationReplyActionApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, app_id, action):
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
args = parser.parse_args()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200


class AnnotationListApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, app_id):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)

app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
response = {
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200

@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
def post(self, app_model: App, end_user: EndUser, app_id):
print(current_user)
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation


class AnnotationUpdateDeleteApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
def post(self, app_model: App, end_user: EndUser, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()

app_id = str(app_id)
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation

@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def delete(self, app_model: App, end_user: EndUser, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()

app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
return {"result": "success"}, 200


api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>")
api.add_resource(AnnotationListApi, "/apps/<uuid:app_id>/annotations")
api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
150 changes: 148 additions & 2 deletions api/controllers/service_api/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from flask import request
from flask_restful import marshal, reqparse # type: ignore
from werkzeug.exceptions import NotFound
from werkzeug.exceptions import Forbidden, NotFound

import services.dataset_service
from controllers.service_api import api
Expand All @@ -11,7 +11,7 @@
from fields.dataset_fields import dataset_detail_fields
from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetService
from services.dataset_service import DatasetPermissionService, DatasetService


def _validate_name(name):
Expand All @@ -20,6 +20,12 @@ def _validate_name(name):
return name


def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description


class DatasetListApi(DatasetApiResource):
"""Resource for datasets."""

Expand Down Expand Up @@ -132,6 +138,145 @@ def post(self, tenant_id):
class DatasetApi(DatasetApiResource):
"""Resource for dataset."""

def get(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})

# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)

embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)

model_names = []
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")

if data["indexing_technique"] == "high_quality":
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data["embedding_available"] = True
else:
data["embedding_available"] = False
else:
data["embedding_available"] = True

if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})

return data, 200

def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")

parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
parser.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
parser.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")

parser.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)

parser.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)

parser.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
args = parser.parse_args()
data = request.get_json()

# check embedding model setting
if data.get("indexing_technique") == "high_quality":
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)

# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list")
)

dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)

if dataset is None:
raise NotFound("Dataset not found.")

result_data = marshal(dataset, dataset_detail_fields)
tenant_id = current_user.current_tenant_id

if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)

partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({"partial_member_list": partial_member_list})

return result_data, 200

def delete(self, _, dataset_id):
"""
Deletes a dataset given its ID.
Expand All @@ -152,6 +297,7 @@ def delete(self, _, dataset_id):

try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {"result": "success"}, 204
else:
raise NotFound("Dataset not found.")
Expand Down
25 changes: 23 additions & 2 deletions api/controllers/service_api/wraps.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,27 @@ def decorated_view(*args, **kwargs):
if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.")

tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = Account.query.filter_by(id=ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")

kwargs["app_model"] = app_model

if fetch_user_arg:
Expand Down Expand Up @@ -193,7 +214,7 @@ def validate_and_get_api_token(scope=None):
.filter(
ApiToken.token == auth_token,
ApiToken.type == scope,
)
)
.first()
)

Expand All @@ -220,7 +241,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == "service_api",
)
)
.first()
)

Expand Down

0 comments on commit a47b514

Please sign in to comment.