Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify implementation of load_from_safetensors #438

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ authors = [{ name = "The Finegrain Team", email = "bonjour@lagon.tech" }]
license = { text = "MIT License" }
dependencies = [
"torch>=2.1.1",
"safetensors>=0.4.0",
"safetensors>=0.4.5",
"pillow>=10.4.0",
"jaxtyping>=0.2.23",
"packaging>=23.2",
Expand Down
4 changes: 3 additions & 1 deletion requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ rpds-py==0.19.1
# via referencing
s3transfer==0.10.2
# via boto3
safetensors==0.4.3
safetensors==0.4.5
# via diffusers
# via refiners
# via timm
Expand All @@ -347,6 +347,8 @@ segment-anything-hq==0.3
# via refiners
segment-anything-py==1.0.1
# via refiners
sentencepiece==0.2.0
# via refiners
sentry-sdk==2.12.0
# via wandb
setproctitle==1.3.3
Expand Down
36 changes: 3 additions & 33 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import warnings
from pathlib import Path
from typing import Any, Iterable, Literal, TypeVar, cast
from typing import Any, Iterable, TypeVar, cast

import torch
from jaxtyping import Float
from numpy import array, float32
from PIL import Image
from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from safetensors.torch import load_file as _load_file, save_file as _save_file # type: ignore
from torch import Tensor, device as Device, dtype as DType
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore

Expand Down Expand Up @@ -186,34 +185,6 @@ def tensor_to_image(tensor: Tensor) -> Image.Image:
return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]


def safe_open(
path: Path | str,
framework: Literal["pytorch", "tensorflow", "flax", "numpy"],
device: Device | str = "cpu",
) -> dict[str, Tensor]:
"""Open a SafeTensor file from disk.

Args:
path: The path to the file.
framework: The framework used to save the file.
device: The device to use for the tensors.

Returns:
The loaded tensors.
"""
framework_mapping = {
"pytorch": "pt",
"tensorflow": "tf",
"flax": "flax",
"numpy": "numpy",
}
return _safe_open(
str(path),
framework=framework_mapping[framework],
device=str(device),
) # type: ignore


def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
"""Load tensors from a file saved with `torch.save` from disk.

Expand Down Expand Up @@ -247,8 +218,7 @@ def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dic
Returns:
The loaded tensors.
"""
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
return _load_file(path, str(device))


def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
Expand Down