diff --git a/slim_trees/__init__.py b/slim_trees/__init__.py index 221aedb..ec1ac75 100644 --- a/slim_trees/__init__.py +++ b/slim_trees/__init__.py @@ -13,7 +13,7 @@ import importlib.metadata import warnings from pathlib import Path -from typing import Any, BinaryIO, Optional, Union +from typing import Any, BinaryIO, Optional, Union, overload from slim_trees.pickling import ( dump_compressed, @@ -41,6 +41,22 @@ ] +@overload +def dump_sklearn_compressed( + model: Any, + file: BinaryIO, + compression: Union[str, dict], +): ... + + +@overload +def dump_sklearn_compressed( + model: Any, + file: Union[str, Path], + compression: Optional[Union[str, dict]] = None, +): ... + + def dump_sklearn_compressed( model: Any, file: Union[str, Path, BinaryIO], @@ -59,7 +75,7 @@ def dump_sklearn_compressed( """ from slim_trees.sklearn_tree import dump - dump_compressed(model, file, compression, dump) + dump_compressed(model, file, compression, dump) # type: ignore def dumps_sklearn_compressed( @@ -80,6 +96,22 @@ def dumps_sklearn_compressed( return dumps_compressed(model, compression, dumps) +@overload +def dump_lgbm_compressed( + model: Any, + file: BinaryIO, + compression: Union[str, dict], +): ... + + +@overload +def dump_lgbm_compressed( + model: Any, + file: Union[str, Path], + compression: Optional[Union[str, dict]] = None, +): ... + + def dump_lgbm_compressed( model: Any, file: Union[str, Path, BinaryIO], @@ -98,7 +130,7 @@ def dump_lgbm_compressed( """ from slim_trees.lgbm_booster import dump - dump_compressed(model, file, compression, dump) + dump_compressed(model, file, compression, dump) # type: ignore def dumps_lgbm_compressed( diff --git a/slim_trees/pickling.py b/slim_trees/pickling.py index a324cad..3bd5d55 100644 --- a/slim_trees/pickling.py +++ b/slim_trees/pickling.py @@ -5,7 +5,7 @@ import pathlib import pickle from collections.abc import Callable -from typing import Any, BinaryIO, Dict, Optional, Tuple, Union +from typing import Any, BinaryIO, Dict, Optional, Tuple, Union, overload class _NoCompression: @@ -68,6 +68,24 @@ def _unpack_compression_args( raise ValueError("File must be a path or compression must not be None.") +@overload +def dump_compressed( + obj: Any, + file: BinaryIO, + compression: Union[str, dict], + dump_function: Optional[Callable] = None, +): ... + + +@overload +def dump_compressed( + obj: Any, + file: Union[str, pathlib.Path], + compression: Optional[Union[str, dict]] = None, + dump_function: Optional[Callable] = None, +): ... + + def dump_compressed( obj: Any, file: Union[str, pathlib.Path, BinaryIO], @@ -123,6 +141,22 @@ def dumps_compressed( return _get_compression_library(compression_method).compress(data_uncompressed) +@overload +def load_compressed( + file: BinaryIO, + compression: Union[str, dict], + unpickler_class: type = pickle.Unpickler, +): ... + + +@overload +def load_compressed( + file: Union[str, pathlib.Path], + compression: Optional[Union[str, dict]] = None, + unpickler_class: type = pickle.Unpickler, +): ... + + def load_compressed( file: Union[str, pathlib.Path, BinaryIO], compression: Optional[Union[str, dict]] = None,