diff --git a/cacholote/cache.py b/cacholote/cache.py index a553d92..bf4bcdc 100644 --- a/cacholote/cache.py +++ b/cacholote/cache.py @@ -82,7 +82,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return _decode_and_update(session, cache_entry, settings) except decode.DecodeError as ex: warnings.warn(str(ex), UserWarning) - clean._delete_cache_entry(session, cache_entry) + clean._delete_cache_entries(session, cache_entry) result = func(*args, **kwargs) cache_entry = database.CacheEntry( diff --git a/cacholote/clean.py b/cacholote/clean.py index 578e8d8..62cd7ec 100644 --- a/cacholote/clean.py +++ b/cacholote/clean.py @@ -20,6 +20,7 @@ import posixpath from typing import Any, Callable, Literal, Optional +import fsspec import pydantic import sqlalchemy as sa import sqlalchemy.orm @@ -35,7 +36,9 @@ ) -def _get_files_from_cache_entry(cache_entry: database.CacheEntry) -> dict[str, str]: +def _get_files_from_cache_entry( + cache_entry: database.CacheEntry, key: str | None +) -> dict[str, Any]: result = cache_entry.result if not isinstance(result, (list, tuple, set)): result = [result] @@ -48,27 +51,57 @@ def _get_files_from_cache_entry(cache_entry: database.CacheEntry) -> dict[str, s and obj["callable"] in FILE_RESULT_CALLABLES ): fs, urlpath = extra_encoders._get_fs_and_urlpath(*obj["args"][:2]) - files[fs.unstrip_protocol(urlpath)] = obj["args"][0]["type"] + value = obj["args"][0] + if key is not None: + value = value[key] + files[fs.unstrip_protocol(urlpath)] = value return files -def _delete_cache_entry( - session: sa.orm.Session, cache_entry: database.CacheEntry +def _remove_files( + fs: fsspec.AbstractFileSystem, + files: list[str], + max_tries: int = 10, + **kwargs: Any, ) -> None: - fs, _ = utils.get_cache_files_fs_dirname() - files_to_delete = _get_files_from_cache_entry(cache_entry) - logger = config.get().logger + assert max_tries >= 1 + if not files: + return + + config.get().logger.info("deleting files", n_files_to_delete=len(files), **kwargs) + + n_tries = 0 + while files: + n_tries += 1 + try: + fs.rm(files, **kwargs) + return + except FileNotFoundError: + # Another concurrent process might have deleted files + if n_tries >= max_tries: + raise + files = [file for file in files if fs.exists(file)] - # First, delete database entry - logger.info("deleting cache entry", cache_entry=cache_entry) - session.delete(cache_entry) + +def _delete_cache_entries( + session: sa.orm.Session, *cache_entries: database.CacheEntry +) -> None: + fs, _ = utils.get_cache_files_fs_dirname() + files_to_delete = [] + dirs_to_delete = [] + for cache_entry in cache_entries: + session.delete(cache_entry) + + files = _get_files_from_cache_entry(cache_entry, key="type") + for file, file_type in files.items(): + if file_type == "application/vnd+zarr": + dirs_to_delete.append(file) + else: + files_to_delete.append(file) database._commit_or_rollback(session) - # Then, delete files - for urlpath, file_type in files_to_delete.items(): - if fs.exists(urlpath): - logger.info("deleting cache file", urlpath=urlpath) - fs.rm(urlpath, recursive=file_type == "application/vnd+zarr") + _remove_files(fs, files_to_delete, recursive=False) + _remove_files(fs, dirs_to_delete, recursive=True) def delete(func_to_del: str | Callable[..., Any], *args: Any, **kwargs: Any) -> None: @@ -88,25 +121,25 @@ def delete(func_to_del: str | Callable[..., Any], *args: Any, **kwargs: Any) -> for cache_entry in session.scalars( sa.select(database.CacheEntry).filter(database.CacheEntry.key == hexdigest) ): - _delete_cache_entry(session, cache_entry) + _delete_cache_entries(session, cache_entry) class _Cleaner: - def __init__(self) -> None: + def __init__(self, depth: int, use_database: bool) -> None: self.logger = config.get().logger self.fs, self.dirname = utils.get_cache_files_fs_dirname() - urldir = self.fs.unstrip_protocol(self.dirname) + self.urldir = self.fs.unstrip_protocol(self.dirname) self.logger.info("getting disk usage") self.file_sizes: dict[str, int] = collections.defaultdict(int) - for path, size in self.fs.du(self.dirname, total=False).items(): + du = self.known_files if use_database else self.fs.du(self.dirname, total=False) + for path, size in du.items(): # Group dirs urlpath = self.fs.unstrip_protocol(path) - basename, *_ = urlpath.replace(urldir, "", 1).strip("/").split("/") - if basename: - self.file_sizes[posixpath.join(urldir, basename)] += size - + parts = urlpath.replace(self.urldir, "", 1).strip("/").split("/") + if parts: + self.file_sizes[posixpath.join(self.urldir, *parts[:depth])] += size self.disk_usage = sum(self.file_sizes.values()) self.log_disk_usage() @@ -121,6 +154,16 @@ def log_disk_usage(self) -> None: def stop_cleaning(self, maxsize: int) -> bool: return self.disk_usage <= maxsize + @property + def known_files(self) -> dict[str, int]: + known_files: dict[str, int] = {} + with config.get().instantiated_sessionmaker() as session: + for cache_entry in session.scalars(sa.select(database.CacheEntry)): + known_files.update( + _get_files_from_cache_entry(cache_entry, key="file:size") + ) + return known_files + def get_unknown_files(self, lock_validity_period: float | None) -> set[str]: self.logger.info("getting unknown files") @@ -138,14 +181,7 @@ def get_unknown_files(self, lock_validity_period: float | None) -> set[str]: locked_files.add(urlpath) locked_files.add(urlpath.rsplit(".lock", 1)[0]) - if unknown_files := (set(self.file_sizes) - locked_files): - with config.get().instantiated_sessionmaker() as session: - for cache_entry in session.scalars(sa.select(database.CacheEntry)): - for known_file in _get_files_from_cache_entry(cache_entry): - unknown_files.discard(known_file) - if not unknown_files: - break - return unknown_files + return set(self.file_sizes) - locked_files - set(self.known_files) def delete_unknown_files( self, lock_validity_period: float | None, recursive: bool @@ -153,10 +189,7 @@ def delete_unknown_files( unknown_files = self.get_unknown_files(lock_validity_period) for urlpath in unknown_files: self.pop_file_size(urlpath) - self.remove_files( - list(unknown_files), - recursive=recursive, - ) + _remove_files(self.fs, list(unknown_files), recursive=recursive) self.log_disk_usage() @staticmethod @@ -208,30 +241,6 @@ def _get_method_sorters( sorters.append(database.CacheEntry.expiration) return sorters - def remove_files( - self, - files: list[str], - max_tries: int = 10, - **kwargs: Any, - ) -> None: - assert max_tries >= 1 - if not files: - return - - self.logger.info("deleting files", n_files_to_delete=len(files), **kwargs) - - n_tries = 0 - while files: - n_tries += 1 - try: - self.fs.rm(files, **kwargs) - return - except FileNotFoundError: - # Another concurrent process might have deleted files - if n_tries >= max_tries: - raise - files = [file for file in files if self.fs.exists(file)] - def delete_cache_files( self, maxsize: int, @@ -245,37 +254,27 @@ def delete_cache_files( if self.stop_cleaning(maxsize): return - files_to_delete = [] - dirs_to_delete = [] + entries_to_delete = [] self.logger.info("getting cache entries to delete") - n_entries_to_delete = 0 with config.get().instantiated_sessionmaker() as session: for cache_entry in session.scalars( sa.select(database.CacheEntry).filter(*filters).order_by(*sorters) ): - files = _get_files_from_cache_entry(cache_entry) - if files: - n_entries_to_delete += 1 - session.delete(cache_entry) - - for file, file_type in files.items(): - self.pop_file_size(file) - if file_type == "application/vnd+zarr": - dirs_to_delete.append(file) - else: - files_to_delete.append(file) + files = _get_files_from_cache_entry(cache_entry, key="file:size") + if any(file.startswith(self.urldir) for file in files): + entries_to_delete.append(cache_entry) + for file in files: + self.pop_file_size(file) if self.stop_cleaning(maxsize): break - if n_entries_to_delete: + if entries_to_delete: self.logger.info( - "deleting cache entries", n_entries_to_delete=n_entries_to_delete + "deleting cache entries", n_entries_to_delete=len(entries_to_delete) ) - database._commit_or_rollback(session) + _delete_cache_entries(session, *entries_to_delete) - self.remove_files(files_to_delete, recursive=False) - self.remove_files(dirs_to_delete, recursive=True) self.log_disk_usage() if not self.stop_cleaning(maxsize): @@ -296,6 +295,8 @@ def clean_cache_files( lock_validity_period: float | None = None, tags_to_clean: list[str | None] | None = None, tags_to_keep: list[str | None] | None = None, + depth: int = 1, + use_database: bool = False, ) -> None: """Clean cache files. @@ -316,8 +317,17 @@ def clean_cache_files( Tags to clean/keep. If None, delete all cache entries. To delete/keep untagged entries, add None in the list (e.g., [None, 'tag1', ...]). tags_to_clean and tags_to_keep are mutually exclusive. + depth: int, default: 1 + depth for grouping cache files + use_database: bool, default: False + Whether to infer disk usage from the cacholote database """ - cleaner = _Cleaner() + if use_database and delete_unknown_files: + raise ValueError( + "'use_database' and 'delete_unknown_files' are mutually exclusive" + ) + + cleaner = _Cleaner(depth=depth, use_database=use_database) if delete_unknown_files: cleaner.delete_unknown_files(lock_validity_period, recursive) @@ -350,7 +360,7 @@ def clean_invalid_cache_entries( for cache_entry in session.scalars( sa.select(database.CacheEntry).filter(*filters) ): - _delete_cache_entry(session, cache_entry) + _delete_cache_entries(session, cache_entry) if try_decode: with config.get().instantiated_sessionmaker() as session: @@ -358,13 +368,14 @@ def clean_invalid_cache_entries( try: decode.loads(cache_entry._result_as_string) except decode.DecodeError: - _delete_cache_entry(session, cache_entry) + _delete_cache_entries(session, cache_entry) def expire_cache_entries( tags: list[str] | None = None, before: datetime.datetime | None = None, after: datetime.date | None = None, + delete: bool = False, ) -> int: now = utils.utcnow() @@ -376,12 +387,14 @@ def expire_cache_entries( if after is not None: filters.append(database.CacheEntry.created_at > after) - count = 0 with config.get().instantiated_sessionmaker() as session: - for cache_entry in session.scalars( - sa.select(database.CacheEntry).filter(*filters) - ): - count += 1 - cache_entry.expiration = now - database._commit_or_rollback(session) - return count + cache_entries = list( + session.scalars(sa.select(database.CacheEntry).filter(*filters)) + ) + if delete: + _delete_cache_entries(session, *cache_entries) + else: + for cache_entry in cache_entries: + cache_entry.expiration = now + database._commit_or_rollback(session) + return len(cache_entries) diff --git a/tests/test_60_clean.py b/tests/test_60_clean.py index c27a8f4..3deca80 100644 --- a/tests/test_60_clean.py +++ b/tests/test_60_clean.py @@ -40,32 +40,40 @@ def cached_now() -> datetime.datetime: @pytest.mark.parametrize("method", ["LRU", "LFU"]) @pytest.mark.parametrize("set_cache", ["file", "cads"], indirect=True) +@pytest.mark.parametrize("folder,depth", [("", 1), ("", 2), ("foo", 2)]) +@pytest.mark.parametrize("use_database", [True, False]) def test_clean_cache_files( tmp_path: pathlib.Path, set_cache: str, method: Literal["LRU", "LFU"], + folder: str, + depth: int, + use_database: bool, ) -> None: con = config.get().engine.raw_connection() cur = con.cursor() - fs, dirname = utils.get_cache_files_fs_dirname() - # Create files - for algorithm in ("LRU", "LFU"): - filename = tmp_path / f"{algorithm}.txt" - fsspec.filesystem("file").pipe_file(filename, ONE_BYTE) + cache_files_urlpath = os.path.join(config.get().cache_files_urlpath, folder) + with config.set(cache_files_urlpath=cache_files_urlpath): + fs, dirname = utils.get_cache_files_fs_dirname() - # Copy to cache - (lru_path,) = {open_url(tmp_path / "LRU.txt").path for _ in range(2)} - lfu_path = open_url(tmp_path / "LFU.txt").path - assert set(fs.ls(dirname)) == {lru_path, lfu_path} + # Create files + for algorithm in ("LRU", "LFU"): + filename = tmp_path / f"{algorithm}.txt" + fsspec.filesystem("file").pipe_file(filename, ONE_BYTE) + + # Copy to cache + (lru_path,) = {open_url(tmp_path / "LRU.txt").path for _ in range(2)} + lfu_path = open_url(tmp_path / "LFU.txt").path + assert set(fs.ls(dirname)) == {lru_path, lfu_path} # Do not clean - clean.clean_cache_files(2, method=method) + clean.clean_cache_files(2, method=method, depth=depth, use_database=use_database) cur.execute("SELECT COUNT(*) FROM cache_entries", ()) assert cur.fetchone() == (fs.du(dirname),) == (2,) # Delete one file - clean.clean_cache_files(1, method=method) + clean.clean_cache_files(1, method=method, depth=depth, use_database=use_database) cur.execute("SELECT COUNT(*) FROM cache_entries", ()) assert cur.fetchone() == (fs.du(dirname),) == (1,) assert not fs.exists(lru_path if method == "LRU" else lfu_path) @@ -320,21 +328,33 @@ def test_clean_multiple_files(tmp_path: pathlib.Path) -> None: (["foo"], TOMORROW, YESTERDAY), ], ) +@pytest.mark.parametrize("delete,n_entries", [(True, 0), (False, 1)]) def test_expire_cache_entries( tags: None | list[str], before: None | datetime.datetime, after: None | datetime.datetime, + delete: bool, + n_entries: int, ) -> None: + con = config.get().engine.raw_connection() + cur = con.cursor() + with config.set(tag="foo"): now = cached_now() # Do not expire - count = clean.expire_cache_entries(tags=["bar"], before=YESTERDAY, after=TOMORROW) + count = clean.expire_cache_entries( + tags=["bar"], before=YESTERDAY, after=TOMORROW, delete=delete + ) assert count == 0 assert now == cached_now() # Expire - count = clean.expire_cache_entries(tags=tags, before=before, after=after) + count = clean.expire_cache_entries( + tags=tags, before=before, after=after, delete=delete + ) + cur.execute("SELECT COUNT(*) FROM cache_entries", ()) + assert cur.fetchone() == (n_entries,) assert count == 1 assert now != cached_now() @@ -349,3 +369,22 @@ def test_expire_cache_entries_created_at() -> None: assert clean.expire_cache_entries(after=toc) == 0 assert clean.expire_cache_entries(before=toc) == 1 assert clean.expire_cache_entries(after=tic) == 1 + + +def test_multiple(tmp_path: pathlib.Path) -> None: + oldpath = tmp_path / "old.txt" + oldpath.write_bytes(ONE_BYTE) + with config.set(cache_files_urlpath=str(tmp_path / "old")): + cached_oldpath = pathlib.Path(open_url(oldpath).path) + assert cached_oldpath.exists() + + newpath = tmp_path / "new.txt" + newpath.write_bytes(ONE_BYTE) + with config.set(cache_files_urlpath=str(tmp_path / "new")): + cached_newpath = pathlib.Path(open_url(newpath).path) + assert cached_newpath.exists() + + with config.set(cache_files_urlpath=str(tmp_path / "new")): + clean.clean_cache_files(0) + assert not cached_newpath.exists() + assert cached_oldpath.exists()