From 4921c5ece7cb7c5de620fec31bb28fd745bb847f Mon Sep 17 00:00:00 2001 From: malmans2 Date: Wed, 30 Oct 2024 08:45:55 +0100 Subject: [PATCH] move get func to extra encoders --- cacholote/clean.py | 35 ++++------------------------------- cacholote/database.py | 4 ++-- cacholote/extra_encoders.py | 31 ++++++++++++++++++++++++++++++- 3 files changed, 36 insertions(+), 34 deletions(-) diff --git a/cacholote/clean.py b/cacholote/clean.py index 40e1589..54ad1ad 100644 --- a/cacholote/clean.py +++ b/cacholote/clean.py @@ -28,35 +28,6 @@ from . import config, database, decode, encode, extra_encoders, utils -FILE_RESULT_KEYS = ("type", "callable", "args", "kwargs") -FILE_RESULT_CALLABLES = ( - "cacholote.extra_encoders:decode_xr_dataarray", - "cacholote.extra_encoders:decode_xr_dataset", - "cacholote.extra_encoders:decode_io_object", -) - - -def _get_files_from_cache_entry( - cache_entry: database.CacheEntry, key: str | None -) -> dict[str, Any]: - result = cache_entry.result - if not isinstance(result, (list, tuple, set)): - result = [result] - - files = {} - for obj in result: - if ( - isinstance(obj, dict) - and set(FILE_RESULT_KEYS) == set(obj) - and obj["callable"] in FILE_RESULT_CALLABLES - ): - fs, urlpath = extra_encoders._get_fs_and_urlpath(*obj["args"][:2]) - value = obj["args"][0] - if key is not None: - value = value[key] - files[fs.unstrip_protocol(urlpath)] = value - return files - def _remove_files( fs: fsspec.AbstractFileSystem, @@ -92,7 +63,7 @@ def _delete_cache_entries( for cache_entry in cache_entries: session.delete(cache_entry) - files = _get_files_from_cache_entry(cache_entry, key="type") + files = extra_encoders._get_files_from_cache_entry(cache_entry, key="type") for file, file_type in files.items(): if file_type == "application/vnd+zarr": dirs_to_delete.append(file) @@ -261,7 +232,9 @@ def delete_cache_files( for cache_entry in session.scalars( sa.select(database.CacheEntry).filter(*filters).order_by(*sorters) ): - files = _get_files_from_cache_entry(cache_entry, key="file:size") + files = extra_encoders._get_files_from_cache_entry( + cache_entry, key="file:size" + ) if any(file.startswith(self.urldir) for file in files): entries_to_delete.append(cache_entry) for file in files: diff --git a/cacholote/database.py b/cacholote/database.py index 922fa3e..9041220 100644 --- a/cacholote/database.py +++ b/cacholote/database.py @@ -28,7 +28,7 @@ import sqlalchemy.orm import sqlalchemy_utils -from . import clean, utils +from . import extra_encoders, utils _DATETIME_MAX = datetime.datetime( datetime.MAXYEAR, 12, 31, tzinfo=datetime.timezone.utc @@ -129,7 +129,7 @@ def add_cache_files( connection: sa.Connection, target: CacheEntry, ) -> None: - for name, size in clean._get_files_from_cache_entry( + for name, size in extra_encoders._get_files_from_cache_entry( target, key="file:size" ).items(): target.cache_files.add(CacheFile(name=name, size=size)) diff --git a/cacholote/extra_encoders.py b/cacholote/extra_encoders.py index 2509951..676401a 100644 --- a/cacholote/extra_encoders.py +++ b/cacholote/extra_encoders.py @@ -41,7 +41,7 @@ import fsspec.implementations.local import pydantic -from . import config, encode, utils +from . import config, database, encode, utils try: import dask @@ -68,6 +68,13 @@ fsspec.implementations.local.LocalFileOpener, ] +FILE_RESULT_KEYS = ("type", "callable", "args", "kwargs") +FILE_RESULT_CALLABLES = ( + "cacholote.extra_encoders:decode_xr_dataarray", + "cacholote.extra_encoders:decode_xr_dataset", + "cacholote.extra_encoders:decode_io_object", +) + def _add_ext_to_mimetypes() -> None: """Add netcdf, grib, and zarr to mimetypes.""" @@ -129,6 +136,28 @@ def _logging_timer(event: str, **kwargs: Any) -> Generator[float, None, None]: context.upload_log(f"end {event}. {_kwargs_to_str(**kwargs)}") +def _get_files_from_cache_entry( + cache_entry: database.CacheEntry, key: str | None +) -> dict[str, Any]: + result = cache_entry.result + if not isinstance(result, (list, tuple, set)): + result = [result] + + files = {} + for obj in result: + if ( + isinstance(obj, dict) + and set(FILE_RESULT_KEYS) == set(obj) + and obj["callable"] in FILE_RESULT_CALLABLES + ): + fs, urlpath = _get_fs_and_urlpath(*obj["args"][:2]) + value = obj["args"][0] + if key is not None: + value = value[key] + files[fs.unstrip_protocol(urlpath)] = value + return files + + class FileInfoModel(pydantic.BaseModel): type: str href: str