Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve cleaner #135

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cacholote/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:

with settings.instantiated_sessionmaker() as session:
session.add(cache_entry)
cache_entry._add_cache_files()
return _decode_and_update(session, cache_entry, settings)

return cast(F, wrapper)
55 changes: 13 additions & 42 deletions cacholote/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,6 @@

from . import config, database, decode, encode, extra_encoders, utils

FILE_RESULT_KEYS = ("type", "callable", "args", "kwargs")
FILE_RESULT_CALLABLES = (
"cacholote.extra_encoders:decode_xr_dataarray",
"cacholote.extra_encoders:decode_xr_dataset",
"cacholote.extra_encoders:decode_io_object",
)


def _get_files_from_cache_entry(
cache_entry: database.CacheEntry, key: str | None
) -> dict[str, Any]:
result = cache_entry.result
if not isinstance(result, (list, tuple, set)):
result = [result]

files = {}
for obj in result:
if (
isinstance(obj, dict)
and set(FILE_RESULT_KEYS) == set(obj)
and obj["callable"] in FILE_RESULT_CALLABLES
):
fs, urlpath = extra_encoders._get_fs_and_urlpath(*obj["args"][:2])
value = obj["args"][0]
if key is not None:
value = value[key]
files[fs.unstrip_protocol(urlpath)] = value
return files


def _remove_files(
fs: fsspec.AbstractFileSystem,
Expand Down Expand Up @@ -90,14 +61,12 @@ def _delete_cache_entries(
files_to_delete = []
dirs_to_delete = []
for cache_entry in cache_entries:
session.delete(cache_entry)

files = _get_files_from_cache_entry(cache_entry, key="type")
for file, file_type in files.items():
if file_type == "application/vnd+zarr":
dirs_to_delete.append(file)
for cache_file in cache_entry.cache_files:
if cache_file.type == "application/vnd+zarr":
dirs_to_delete.append(cache_file.name)
else:
files_to_delete.append(file)
files_to_delete.append(cache_file.name)
session.delete(cache_entry)
database._commit_or_rollback(session)

_remove_files(fs, files_to_delete, recursive=False)
Expand Down Expand Up @@ -157,12 +126,12 @@ def stop_cleaning(self, maxsize: int) -> bool:
@property
def known_files(self) -> dict[str, int]:
known_files: dict[str, int] = {}
filters = [database.CacheFile.name.startswith(self.urldir)]
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(sa.select(database.CacheEntry)):
files = _get_files_from_cache_entry(cache_entry, key="file:size")
known_files.update(
{k: v for k, v in files.items() if k.startswith(self.urldir)}
)
for cache_file in session.scalars(
sa.select(database.CacheFile).filter(*filters)
):
known_files[cache_file.name] = cache_file.size
return known_files

def get_unknown_files(self, lock_validity_period: float | None) -> set[str]:
Expand Down Expand Up @@ -261,7 +230,9 @@ def delete_cache_files(
for cache_entry in session.scalars(
sa.select(database.CacheEntry).filter(*filters).order_by(*sorters)
):
files = _get_files_from_cache_entry(cache_entry, key="file:size")
files = extra_encoders._get_files_from_cache_entry(
cache_entry, key="file:size"
)
if any(file.startswith(self.urldir) for file in files):
entries_to_delete.append(cache_entry)
for file in files:
Expand Down
67 changes: 66 additions & 1 deletion cacholote/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,29 @@
import sqlalchemy.orm
import sqlalchemy_utils

from . import utils
from . import extra_encoders, utils

_DATETIME_MAX = datetime.datetime(
datetime.MAXYEAR, 12, 31, tzinfo=datetime.timezone.utc
)

Base = sa.orm.declarative_base()

association_table = sa.Table(
"association_table",
Base.metadata,
sa.Column(
"cache_entries_id",
sa.ForeignKey("cache_entries.id"),
primary_key=True,
),
sa.Column(
"cache_files_name",
sa.ForeignKey("cache_files.name"),
primary_key=True,
),
)


class CacheEntry(Base):
__tablename__ = "cache_entries"
Expand All @@ -48,6 +63,23 @@ class CacheEntry(Base):
updated_at = sa.Column(sa.DateTime, default=utils.utcnow, onupdate=utils.utcnow)
counter = sa.Column(sa.Integer)
tag = sa.Column(sa.String)
cache_files: sa.orm.Mapped[set[CacheFile]] = sa.orm.relationship(
secondary=association_table,
back_populates="cache_entries",
cascade="all, save-update",
)

def _add_cache_files(self) -> None:
for name, info in extra_encoders._get_files_from_cache_entry(
self, key=None
).items():
self.cache_files.add(
CacheFile(
name=name,
size=info["file:size"],
type=info["type"],
)
)

@property
def _result_as_string(self) -> str:
Expand All @@ -69,6 +101,39 @@ def __repr__(self) -> str:
return f"CacheEntry({public_attrs_repr})"


class CacheFile(Base):
__tablename__ = "cache_files"

name: str = sa.Column(sa.String(), primary_key=True)
size: int = sa.Column(sa.Integer())
type: str = sa.Column(sa.String())
cache_entries: sa.orm.Mapped[set[CacheEntry]] = sa.orm.relationship(
secondary=association_table,
back_populates="cache_files",
cascade="all, delete",
)

@property
def updated_at(self) -> datetime.datetime:
return max(
[
cache_entry.updated_at
for cache_entry in self.cache_entries
if cache_entry.updated_at
]
)

@property
def count(self) -> int:
return sum(
[
cache_entry.counter
for cache_entry in self.cache_entries
if cache_entry.counter
]
)


@sa.event.listens_for(CacheEntry, "before_insert")
def set_expiration_to_max(
mapper: sa.orm.Mapper[CacheEntry],
Expand Down
31 changes: 30 additions & 1 deletion cacholote/extra_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import fsspec.implementations.local
import pydantic

from . import config, encode, utils
from . import config, database, encode, utils

try:
import dask
Expand All @@ -68,6 +68,13 @@
fsspec.implementations.local.LocalFileOpener,
]

FILE_RESULT_KEYS = ("type", "callable", "args", "kwargs")
FILE_RESULT_CALLABLES = (
"cacholote.extra_encoders:decode_xr_dataarray",
"cacholote.extra_encoders:decode_xr_dataset",
"cacholote.extra_encoders:decode_io_object",
)


def _add_ext_to_mimetypes() -> None:
"""Add netcdf, grib, and zarr to mimetypes."""
Expand Down Expand Up @@ -129,6 +136,28 @@ def _logging_timer(event: str, **kwargs: Any) -> Generator[float, None, None]:
context.upload_log(f"end {event}. {_kwargs_to_str(**kwargs)}")


def _get_files_from_cache_entry(
cache_entry: database.CacheEntry, key: str | None
) -> dict[str, Any]:
result = cache_entry.result
if not isinstance(result, (list, tuple, set)):
result = [result]

files = {}
for obj in result:
if (
isinstance(obj, dict)
and set(FILE_RESULT_KEYS) == set(obj)
and obj["callable"] in FILE_RESULT_CALLABLES
):
fs, urlpath = _get_fs_and_urlpath(*obj["args"][:2])
value = obj["args"][0]
if key is not None:
value = value[key]
files[fs.unstrip_protocol(urlpath)] = value
return files


class FileInfoModel(pydantic.BaseModel):
type: str
href: str
Expand Down
26 changes: 25 additions & 1 deletion tests/test_50_io_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
import pytest
import pytest_httpserver
import pytest_structlog
import sqlalchemy as sa
import structlog

from cacholote import cache, config, decode, encode, extra_encoders, utils
from cacholote import cache, config, database, decode, encode, extra_encoders, utils


@cache.cacheable
Expand Down Expand Up @@ -235,3 +236,26 @@ def test_io_logging(
},
]
assert log.events == expected


def test_io_delete_cache_file(tmp_path: pathlib.Path) -> None:
# Create tmpfile and cache
tmpfile = tmp_path / "test.txt"
fsspec.filesystem("file").touch(tmpfile)

# Cache file
cached_open(tmpfile)
con = config.get().engine.raw_connection()
cur = con.cursor()
cur.execute("SELECT COUNT(*) FROM cache_entries", ())
assert cur.fetchone() == (1,)

# Delete cache file
with config.get().instantiated_sessionmaker() as session:
for cache_file in session.scalars(sa.select(database.CacheFile)):
session.delete(cache_file)
database._commit_or_rollback(session)

# cache-db must be empty
cur.execute("SELECT COUNT(*) FROM cache_entries", ())
assert cur.fetchone() == (0,)