Skip to content

Commit

Permalink
Add schema update time verification to insert and upsert so as to use…
Browse files Browse the repository at this point in the history
… the cache

Signed-off-by: Xianhui.Lin <xianhui.lin@zilliz.com>

add default param

Signed-off-by: Xianhui.Lin <xianhui.lin@zilliz.com>

improve

Signed-off-by: Xianhui.Lin <xianhui.lin@zilliz.com>

improve

Signed-off-by: Xianhui.Lin <xianhui.lin@zilliz.com>
  • Loading branch information
JsDove committed Jan 10, 2025
1 parent 1b555d3 commit ee31cef
Show file tree
Hide file tree
Showing 10 changed files with 516 additions and 441 deletions.
5 changes: 3 additions & 2 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(self, raw: Any):
self.num_shards = 0
self.num_partitions = 0
self.enable_dynamic_field = False

self.update_timestamp = 0
if self._raw:
self.__pack(self._raw)

Expand All @@ -209,7 +209,7 @@ def __pack(self, raw: Any):
# for kv in raw.extra_params:

self.fields = [FieldSchema(f) for f in raw.schema.fields]

self.update_timestamp = raw.update_timestamp
self.functions = [FunctionSchema(f) for f in raw.schema.functions]
function_output_field_names = [f for fn in self.functions for f in fn.output_field_names]
for field in self.fields:
Expand Down Expand Up @@ -247,6 +247,7 @@ def dict(self):
"properties": self.properties,
"num_partitions": self.num_partitions,
"enable_dynamic_field": self.enable_dynamic_field,
"update_timestamp": self.update_timestamp,
}
self._rewrite_schema_dict(_dict)
return _dict
Expand Down
64 changes: 60 additions & 4 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
self._setup_db_interceptor(kwargs.get("db_name"))
self._setup_grpc_channel()
self.callbacks = []
self.schema_cache = {}

def register_state_change_callback(self, callback: Callable):
self.callbacks.append(callback)
Expand Down Expand Up @@ -161,6 +162,7 @@ def close(self):
self._channel.close()

def reset_db_name(self, db_name: str):
self.schema_cache.clear()
self._setup_db_interceptor(db_name)
self._setup_grpc_channel()
self._setup_identifier_interceptor(self._user)
Expand Down Expand Up @@ -526,10 +528,28 @@ def insert_rows(
collection_name, entities, partition_name, schema, timeout, **kwargs
)
resp = self._stub.Insert(request=request, timeout=timeout)
if resp.status.error_code == common_pb2.SchemaMismatch:
schema = self.update_schema(collection_name, timeout)
request = self._prepare_row_insert_request(
collection_name, entities, partition_name, schema, timeout, **kwargs
)
resp = self._stub.Insert(request=request, timeout=timeout)
check_status(resp.status)
ts_utils.update_collection_ts(collection_name, resp.timestamp)
return MutationResult(resp)

def update_schema(self, collection_name: str, timeout: Optional[float] = None):
self.schema_cache.pop(collection_name, None)
schema = self.describe_collection(collection_name, timeout=timeout)
schema_timestamp = schema.get("update_timestamp", 0)

self.schema_cache[collection_name] = {
"schema": schema,
"schema_timestamp": schema_timestamp,
}

return schema

def _prepare_row_insert_request(
self,
collection_name: str,
Expand All @@ -542,9 +562,9 @@ def _prepare_row_insert_request(
if isinstance(entity_rows, dict):
entity_rows = [entity_rows]

if not isinstance(schema, dict):
schema = self.describe_collection(collection_name, timeout=timeout)

schema, schema_timestamp = self._get_schema_from_cache_or_remote(
collection_name, schema, timeout
)
fields_info = schema.get("fields")
enable_dynamic = schema.get("enable_dynamic_field", False)

Expand All @@ -554,8 +574,33 @@ def _prepare_row_insert_request(
partition_name,
fields_info,
enable_dynamic=enable_dynamic,
schema_timestamp=schema_timestamp,
)

def _get_schema_from_cache_or_remote(
self, collection_name: str, schema: Optional[dict] = None, timeout: Optional[float] = None
):
"""
checks the cache for the schema. If not found, it fetches it remotely and updates the cache
"""
if collection_name in self.schema_cache:
# Use the cached schema and timestamp
schema = self.schema_cache[collection_name]["schema"]
schema_timestamp = self.schema_cache[collection_name]["schema_timestamp"]
else:
# Fetch the schema remotely if not in cache
if not isinstance(schema, dict):
schema = self.describe_collection(collection_name, timeout=timeout)
schema_timestamp = schema.get("update_timestamp", 0)

# Cache the fetched schema and timestamp
self.schema_cache[collection_name] = {
"schema": schema,
"schema_timestamp": schema_timestamp,
}

return schema, schema_timestamp

def _prepare_batch_insert_request(
self,
collection_name: str,
Expand Down Expand Up @@ -723,13 +768,18 @@ def _prepare_row_upsert_request(
if not isinstance(rows, list):
raise ParamError(message="'rows' must be a list, please provide valid row data.")

fields_info, enable_dynamic = self._get_info(collection_name, timeout, **kwargs)
schema, schema_timestamp = self._get_schema_from_cache_or_remote(
collection_name, timeout=timeout
)
fields_info = schema.get("fields")
enable_dynamic = schema.get("enable_dynamic_field", False)
return Prepare.row_upsert_param(
collection_name,
rows,
partition_name,
fields_info,
enable_dynamic=enable_dynamic,
schema_timestamp=schema_timestamp,
)

@retry_on_rpc_failure()
Expand All @@ -748,6 +798,12 @@ def upsert_rows(
)
rf = self._stub.Upsert.future(request, timeout=timeout)
response = rf.result()
if response.status.error_code == common_pb2.SchemaMismatch:
schema = self.update_schema(collection_name, timeout)
request = self._prepare_row_insert_request(
collection_name, entities, partition_name, schema, timeout, **kwargs
)
response = self._stub.Insert(request=request, timeout=timeout)
check_status(response.status)
m = MutationResult(response)
ts_utils.update_collection_ts(collection_name, m.timestamp)
Expand Down
4 changes: 4 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def row_insert_param(
entities: List,
partition_name: str,
fields_info: Dict,
schema_timestamp: int = 0,
enable_dynamic: bool = False,
):
if not fields_info:
Expand All @@ -617,6 +618,7 @@ def row_insert_param(
collection_name=collection_name,
partition_name=p_name,
num_rows=len(entities),
schema_timestamp=schema_timestamp,
)

return cls._parse_row_request(request, fields_info, enable_dynamic, entities)
Expand All @@ -629,6 +631,7 @@ def row_upsert_param(
partition_name: str,
fields_info: Any,
enable_dynamic: bool = False,
schema_timestamp: int = 0,
):
if not fields_info:
raise ParamError(message="Missing collection meta to validate entities")
Expand All @@ -639,6 +642,7 @@ def row_upsert_param(
collection_name=collection_name,
partition_name=p_name,
num_rows=len(entities),
schema_timestamp=schema_timestamp,
)

return cls._parse_upsert_row_request(request, fields_info, enable_dynamic, entities)
Expand Down
Loading

0 comments on commit ee31cef

Please sign in to comment.