Skip to content

Commit

Permalink
move get func to extra encoders
Browse files Browse the repository at this point in the history
  • Loading branch information
malmans2 committed Oct 30, 2024
1 parent b642e16 commit 4921c5e
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 34 deletions.
35 changes: 4 additions & 31 deletions cacholote/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,6 @@

from . import config, database, decode, encode, extra_encoders, utils

FILE_RESULT_KEYS = ("type", "callable", "args", "kwargs")
FILE_RESULT_CALLABLES = (
"cacholote.extra_encoders:decode_xr_dataarray",
"cacholote.extra_encoders:decode_xr_dataset",
"cacholote.extra_encoders:decode_io_object",
)


def _get_files_from_cache_entry(
cache_entry: database.CacheEntry, key: str | None
) -> dict[str, Any]:
result = cache_entry.result
if not isinstance(result, (list, tuple, set)):
result = [result]

files = {}
for obj in result:
if (
isinstance(obj, dict)
and set(FILE_RESULT_KEYS) == set(obj)
and obj["callable"] in FILE_RESULT_CALLABLES
):
fs, urlpath = extra_encoders._get_fs_and_urlpath(*obj["args"][:2])
value = obj["args"][0]
if key is not None:
value = value[key]
files[fs.unstrip_protocol(urlpath)] = value
return files


def _remove_files(
fs: fsspec.AbstractFileSystem,
Expand Down Expand Up @@ -92,7 +63,7 @@ def _delete_cache_entries(
for cache_entry in cache_entries:
session.delete(cache_entry)

files = _get_files_from_cache_entry(cache_entry, key="type")
files = extra_encoders._get_files_from_cache_entry(cache_entry, key="type")
for file, file_type in files.items():
if file_type == "application/vnd+zarr":
dirs_to_delete.append(file)
Expand Down Expand Up @@ -261,7 +232,9 @@ def delete_cache_files(
for cache_entry in session.scalars(
sa.select(database.CacheEntry).filter(*filters).order_by(*sorters)
):
files = _get_files_from_cache_entry(cache_entry, key="file:size")
files = extra_encoders._get_files_from_cache_entry(
cache_entry, key="file:size"
)
if any(file.startswith(self.urldir) for file in files):
entries_to_delete.append(cache_entry)
for file in files:
Expand Down
4 changes: 2 additions & 2 deletions cacholote/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import sqlalchemy.orm
import sqlalchemy_utils

from . import clean, utils
from . import extra_encoders, utils

_DATETIME_MAX = datetime.datetime(
datetime.MAXYEAR, 12, 31, tzinfo=datetime.timezone.utc
Expand Down Expand Up @@ -129,7 +129,7 @@ def add_cache_files(
connection: sa.Connection,
target: CacheEntry,
) -> None:
for name, size in clean._get_files_from_cache_entry(
for name, size in extra_encoders._get_files_from_cache_entry(
target, key="file:size"
).items():
target.cache_files.add(CacheFile(name=name, size=size))
Expand Down
31 changes: 30 additions & 1 deletion cacholote/extra_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import fsspec.implementations.local
import pydantic

from . import config, encode, utils
from . import config, database, encode, utils

try:
import dask
Expand All @@ -68,6 +68,13 @@
fsspec.implementations.local.LocalFileOpener,
]

FILE_RESULT_KEYS = ("type", "callable", "args", "kwargs")
FILE_RESULT_CALLABLES = (
"cacholote.extra_encoders:decode_xr_dataarray",
"cacholote.extra_encoders:decode_xr_dataset",
"cacholote.extra_encoders:decode_io_object",
)


def _add_ext_to_mimetypes() -> None:
"""Add netcdf, grib, and zarr to mimetypes."""
Expand Down Expand Up @@ -129,6 +136,28 @@ def _logging_timer(event: str, **kwargs: Any) -> Generator[float, None, None]:
context.upload_log(f"end {event}. {_kwargs_to_str(**kwargs)}")


def _get_files_from_cache_entry(
cache_entry: database.CacheEntry, key: str | None
) -> dict[str, Any]:
result = cache_entry.result
if not isinstance(result, (list, tuple, set)):
result = [result]

files = {}
for obj in result:
if (
isinstance(obj, dict)
and set(FILE_RESULT_KEYS) == set(obj)
and obj["callable"] in FILE_RESULT_CALLABLES
):
fs, urlpath = _get_fs_and_urlpath(*obj["args"][:2])
value = obj["args"][0]
if key is not None:
value = value[key]
files[fs.unstrip_protocol(urlpath)] = value
return files


class FileInfoModel(pydantic.BaseModel):
type: str
href: str
Expand Down

0 comments on commit 4921c5e

Please sign in to comment.