From 81285a76d6e0fc9f9c3240f9348db892904a2486 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Thu, 5 Oct 2023 16:44:38 +0200 Subject: [PATCH] scheduler: add remove noise aka original sample prediction (or predict x0) E.g. useful for methods like self-attention guidance (see equation (2) in https://arxiv.org/pdf/2210.00939.pdf) --- .../latent_diffusion/schedulers/scheduler.py | 13 ++++++++-- .../latent_diffusion/test_schedulers.py | 25 +++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py index b5413f55d..6e4767239 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py @@ -78,11 +78,20 @@ def add_noise( step: int, ) -> Tensor: timestep = self.timesteps[step] - cumulative_scale_factors = self.cumulative_scale_factors[timestep].unsqueeze(-1).unsqueeze(-1) - noise_stds = self.noise_std[timestep].unsqueeze(-1).unsqueeze(-1) + cumulative_scale_factors = self.cumulative_scale_factors[timestep] + noise_stds = self.noise_std[timestep] noised_x = cumulative_scale_factors * x + noise_stds * noise return noised_x + def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + timestep = self.timesteps[step] + cumulative_scale_factors = self.cumulative_scale_factors[timestep] + noise_stds = self.noise_std[timestep] + # See equation (15) from https://arxiv.org/pdf/2006.11239.pdf. Useful to preview progress or for guidance like + # in https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance) + denoised_x = (x - noise_stds * noise) / cumulative_scale_factors + return denoised_x + def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore if device is not None: self.device = Device(device) diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_schedulers.py index 123ef18fc..391c109c3 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_schedulers.py @@ -71,6 +71,31 @@ def test_ddim_solver_diffusers(): assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" +def test_scheduler_remove_noise(): + from diffusers import DDIMScheduler # type: ignore + + diffusers_scheduler = DDIMScheduler( + beta_end=0.012, + beta_schedule="scaled_linear", + beta_start=0.00085, + num_train_timesteps=1000, + set_alpha_to_one=False, + steps_offset=1, + clip_sample=False, + ) + diffusers_scheduler.set_timesteps(30) + refiners_scheduler = DDIM(num_inference_steps=30) + + sample = randn(1, 4, 32, 32) + noise = randn(1, 4, 32, 32) + + for step, timestep in enumerate(diffusers_scheduler.timesteps): + diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore + refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step) + + assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" + + def test_scheduler_device(test_device: Device): if test_device.type == "cpu": warn("not running on CPU, skipping")