Skip to content

Commit

Permalink
fix use of dtypes in autoencoder tests
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell committed Nov 28, 2024
1 parent 79c9991 commit 41a81f9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 30 deletions.
20 changes: 0 additions & 20 deletions tests/foundationals/latent_diffusion/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pathlib import Path

import pytest
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
Expand Down Expand Up @@ -93,25 +92,6 @@ def refiners_sdxl(
)


@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
def refiners_autoencoder(
request: pytest.FixtureRequest,
refiners_sd15_autoencoder: SD1Autoencoder,
refiners_sdxl_autoencoder: SDXLAutoencoder,
test_dtype_fp32_bf16_fp16: torch.dtype,
) -> SD1Autoencoder | SDXLAutoencoder:
model_version = request.param
match (model_version, test_dtype_fp32_bf16_fp16):
case ("SD1.5", _):
return refiners_sd15_autoencoder
case ("SDXL", torch.float16):
return refiners_sdxl_autoencoder
case ("SDXL", _):
return refiners_sdxl_autoencoder
case _:
raise ValueError(f"Unknown model version: {model_version}")


@pytest.fixture(scope="module")
def diffusers_sd15_pipeline(
sd15_diffusers_runwayml_path: str,
Expand Down
31 changes: 21 additions & 10 deletions tests/foundationals/latent_diffusion/test_autoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from tests.utils import ensure_similar_images

from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion import (
LatentDiffusionAutoencoder,
SD1Autoencoder,
SDXLAutoencoder,
)


@pytest.fixture(scope="module")
Expand All @@ -16,25 +20,32 @@ def sample_image() -> Image.Image:
if not test_image.is_file():
warn(f"could not reference image at {test_image}, skipping")
pytest.skip(allow_module_level=True)
img = Image.open(test_image) # type: ignore
img = Image.open(test_image)
assert img.size == (512, 512)
return img


@pytest.fixture(scope="module")
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
def autoencoder(
refiners_autoencoder: LatentDiffusionAutoencoder,
request: pytest.FixtureRequest,
refiners_sd15_autoencoder: SD1Autoencoder,
refiners_sdxl_autoencoder: SDXLAutoencoder,
test_device: torch.device,
test_dtype_fp32_bf16_fp16: torch.dtype,
) -> LatentDiffusionAutoencoder:
return refiners_autoencoder.to(test_device)
model_version = request.param
if model_version == "SDXL" and test_dtype_fp32_bf16_fp16 == torch.float16:
pytest.skip("SDXL autoencoder does not support float16")
ae = refiners_sd15_autoencoder if model_version == "SD1.5" else refiners_sdxl_autoencoder
return ae.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)


@no_grad()
def test_encode_decode_image(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = autoencoder.image_to_latents(sample_image)
decoded = autoencoder.latents_to_image(encoded)

assert decoded.mode == "RGB" # type: ignore
assert decoded.mode == "RGB"

# Ensure no saturation. The green channel (band = 1) must not max out.
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
Expand All @@ -53,7 +64,7 @@ def test_encode_decode_images(autoencoder: LatentDiffusionAutoencoder, sample_im

@no_grad()
def test_tiled_autoencoder(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((2048, 2048)) # type: ignore
sample_image = sample_image.resize((2048, 2048))

with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)):
encoded = autoencoder.tiled_image_to_latents(sample_image)
Expand All @@ -64,7 +75,7 @@ def test_tiled_autoencoder(autoencoder: LatentDiffusionAutoencoder, sample_image

@no_grad()
def test_tiled_autoencoder_rectangular_tiles(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((2048, 2048)) # type: ignore
sample_image = sample_image.resize((2048, 2048))

with autoencoder.tiled_inference(sample_image, tile_size=(512, 1024)):
encoded = autoencoder.tiled_image_to_latents(sample_image)
Expand All @@ -75,7 +86,7 @@ def test_tiled_autoencoder_rectangular_tiles(autoencoder: LatentDiffusionAutoenc

@no_grad()
def test_tiled_autoencoder_large_tile(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((1024, 1024)) # type: ignore
sample_image = sample_image.resize((1024, 1024))

with autoencoder.tiled_inference(sample_image, tile_size=(2048, 2048)):
encoded = autoencoder.tiled_image_to_latents(sample_image)
Expand All @@ -87,7 +98,7 @@ def test_tiled_autoencoder_large_tile(autoencoder: LatentDiffusionAutoencoder, s
@no_grad()
def test_tiled_autoencoder_rectangular_image(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.crop((0, 0, 300, 500))
sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore
sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4))

with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)):
encoded = autoencoder.tiled_image_to_latents(sample_image)
Expand Down

0 comments on commit 41a81f9

Please sign in to comment.