Skip to content

Commit

Permalink
improve image_to_tensor and tensor_to_image utils
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Oct 17, 2023
1 parent 585c7ad commit 8ae5a00
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
44 changes: 40 additions & 4 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,49 @@ def default_sigma(kernel_size: int) -> float:


def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(
0
)
"""
Convert a PIL Image to a Tensor.
If the image is in mode `RGB` the tensor will have shape `[3, H, W]`, otherwise
`[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.
Values are clamped to the range `[0, 1]`.
"""
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)

match image.mode:
case "L":
image_tensor = image_tensor.unsqueeze(0)
case "RGBA" | "RGB":
image_tensor = image_tensor.permute(2, 0, 1)
case _:
raise ValueError(f"Unsupported image mode: {image.mode}")

return image_tensor.unsqueeze(0)


def tensor_to_image(tensor: Tensor) -> Image.Image:
return Image.fromarray((tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")) # type: ignore
"""
Convert a Tensor to a PIL Image.
The tensor must have shape `[1, channels, height, width]` where the number of
channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).
Expected values are in the range `[0, 1]` and are clamped to this range.
"""
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
num_channels = tensor.shape[1]
tensor = tensor.clamp(0, 1).squeeze(0)

match num_channels:
case 1:
tensor = tensor.squeeze(0)
case 3 | 4:
tensor = tensor.permute(1, 2, 0)
case _:
raise ValueError(f"Unsupported number of channels: {num_channels}")

return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]


def safe_open(
Expand Down
17 changes: 16 additions & 1 deletion tests/fluxion/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
from torch import device as Device, dtype as DType
from PIL import Image
import pytest
import torch

from refiners.fluxion.utils import gaussian_blur, manual_seed
from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, tensor_to_image


@dataclass
Expand Down Expand Up @@ -47,3 +48,17 @@ def test_gaussian_blur(test_device: Device, blur_input: BlurInput) -> None:
our_blur = gaussian_blur(tensor, blur_input.kernel_size, blur_input.sigma)

assert torch.equal(our_blur, ref_blur)


def test_image_to_tensor() -> None:
image = Image.new("RGB", (512, 512))

assert image_to_tensor(image).shape == (1, 3, 512, 512)
assert image_to_tensor(image.convert("L")).shape == (1, 1, 512, 512)
assert image_to_tensor(image.convert("RGBA")).shape == (1, 4, 512, 512)


def test_tensor_to_image() -> None:
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"

0 comments on commit 8ae5a00

Please sign in to comment.