Skip to content

Commit

Permalink
deprecate LatentDiffusionAutoencoder's decode_latents
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Oct 15, 2024
1 parent 9b19272 commit ffb486e
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 20 deletions.
10 changes: 5 additions & 5 deletions docs/guides/adapting_sdxl/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ with no_grad(): # Disable gradient calculation for memory-efficient inference
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)

predicted_image.save("vanilla_sdxl.png")

Expand Down Expand Up @@ -145,7 +145,7 @@ predicted_image.save("vanilla_sdxl.png")
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)

predicted_image.save("vanilla_sdxl.png")

Expand Down Expand Up @@ -318,7 +318,7 @@ manager.add_loras("pixel-art-lora", load_from_safetensors("pixel-art-xl-v1.1.saf
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)

predicted_image.save("scifi_pixel_sdxl.png")

Expand Down Expand Up @@ -453,7 +453,7 @@ with torch.no_grad():
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)

predicted_image.save("scifi_pixel_IP_sdxl.png")

Expand Down Expand Up @@ -591,7 +591,7 @@ with torch.no_grad():
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)

predicted_image.save("scifi_pixel_IP_T2I_sdxl.png")

Expand Down
5 changes: 0 additions & 5 deletions src/refiners/foundationals/latent_diffusion/auto_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,6 @@ def images_to_latents(self, images: list[Image.Image]) -> Tensor:
x = 2 * x - 1
return self.encode(x)

# backward-compatibility alias
# TODO: deprecate this method
def decode_latents(self, x: Tensor) -> Image.Image:
return self.latents_to_image(x)

def latents_to_image(self, x: Tensor) -> Image.Image:
"""
Decode latents to an image.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class StableDiffusion_1(LatentDiffusionModel):
for step in sd15.steps:
x = sd15(x, step=step, clip_text_embedding=clip_text_embedding)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
predicted_image.save("output.png")
```
"""
Expand Down
10 changes: 5 additions & 5 deletions tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ def test_diffusion_sdxl_control_lora(
)

# decode latent to image
predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)

# ensure the predicted image is similar to the expected image
ensure_similar_images(
Expand Down Expand Up @@ -1935,7 +1935,7 @@ def test_diffusion_ip_adapter_multi(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)

ensure_similar_images(predicted_image, expected_image_ip_adapter_multi, min_psnr=43, min_ssim=0.98)

Expand Down Expand Up @@ -2245,7 +2245,7 @@ def test_diffusion_sdxl_sliced_attention(
condition_scale=5,
)

predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)


Expand Down Expand Up @@ -2279,7 +2279,7 @@ def test_diffusion_sdxl_euler_deterministic(
condition_scale=5,
)

predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image)


Expand Down Expand Up @@ -2604,7 +2604,7 @@ def test_style_aligned(
)

# decode latents
predicted_images = [sdxl.lda.decode_latents(latent.unsqueeze(0)) for latent in x]
predicted_images = sdxl.lda.latents_to_images(x)

# tile all images horizontally
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/test_doc_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_guide_adapting_sdxl_vanilla(
time_ids=time_ids,
)

predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)


Expand Down Expand Up @@ -151,7 +151,7 @@ def test_guide_adapting_sdxl_single_lora(
time_ids=time_ids,
)

predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98)


Expand Down Expand Up @@ -195,7 +195,7 @@ def test_guide_adapting_sdxl_multiple_loras(
time_ids=time_ids,
)

predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98)


Expand Down Expand Up @@ -255,7 +255,7 @@ def test_guide_adapting_sdxl_loras_ip_adapter(
time_ids=time_ids,
)

predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=29, min_ssim=0.98)


Expand Down

0 comments on commit ffb486e

Please sign in to comment.