Skip to content

Commit

Permalink
move some tests into the adapters test folder
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Sep 9, 2024
1 parent a51d695 commit 4595133
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 20 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Iterator

import pytest
import torch

Expand All @@ -10,48 +8,63 @@
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter


@pytest.fixture(scope="module", params=[True, False])
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet | SDXLUNet]:
xl: bool = request.param
unet = SDXLUNet(in_channels=4) if xl else SD1UNet(in_channels=4)
yield unet
@pytest.fixture(scope="module")
def unet(
refiners_unet: SD1UNet | SDXLUNet,
) -> SD1UNet | SDXLUNet:
return refiners_unet


def test_freeu_adapter(unet: SD1UNet | SDXLUNet) -> None:
def test_inject_eject_freeu(
unet: SD1UNet | SDXLUNet,
) -> None:
initial_repr = repr(unet)
freeu = SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9])

assert len(list(unet.walk(FreeUResidualConcatenator))) == 0
assert unet.parent is None
assert unet.find(FreeUResidualConcatenator) is None
assert repr(unet) == initial_repr

freeu.inject()
assert unet.parent is not None
assert unet.find(FreeUResidualConcatenator) is not None
assert repr(unet) != initial_repr

with pytest.raises(AssertionError) as exc:
freeu.eject()
assert "could not find" in str(exc.value)
freeu.eject()
assert unet.parent is None
assert unet.find(FreeUResidualConcatenator) is None
assert repr(unet) == initial_repr

freeu.inject()
assert len(list(unet.walk(FreeUResidualConcatenator))) == 2
assert unet.parent is not None
assert unet.find(FreeUResidualConcatenator) is not None
assert repr(unet) != initial_repr

freeu.eject()
assert len(list(unet.walk(FreeUResidualConcatenator))) == 0
assert unet.parent is None
assert unet.find(FreeUResidualConcatenator) is None
assert repr(unet) == initial_repr


def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
num_blocks = len(unet.layer("UpBlocks", Chain))

with pytest.raises(AssertionError):
SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1))


def test_freeu_adapter_inconsistent_scales(unet: SD1UNet | SDXLUNet) -> None:
with pytest.raises(AssertionError):
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9, 0.9])
with pytest.raises(AssertionError):
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2, 1.2], skip_scales=[0.9, 0.9])


def test_freeu_identity_scales() -> None:
def test_freeu_identity_scales(unet: SD1UNet | SDXLUNet) -> None:
manual_seed(0)
text_embedding = torch.randn(1, 77, 768)
timestep = torch.randint(0, 999, size=(1, 1))
x = torch.randn(1, 4, 32, 32)
text_embedding = torch.randn(1, 77, 768, dtype=unet.dtype, device=unet.device)
timestep = torch.randint(0, 999, size=(1, 1), device=unet.device)
x = torch.randn(1, 4, 32, 32, dtype=unet.dtype, device=unet.device)

unet = SD1UNet(in_channels=4)
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s

with no_grad():
Expand All @@ -65,5 +78,7 @@ def test_freeu_identity_scales() -> None:
unet.set_timestep(timestep=timestep)
y_2 = unet(x.clone())

freeu.eject()

# The FFT -> inverse FFT sequence (skip features) introduces small numerical differences
assert torch.allclose(y_1, y_2, atol=1e-5)
File renamed without changes.

0 comments on commit 4595133

Please sign in to comment.