Skip to content

Commit

Permalink
scheduler: add remove noise
Browse files Browse the repository at this point in the history
aka original sample prediction (or predict x0)

E.g. useful for methods like self-attention guidance (see equation (2)
in https://arxiv.org/pdf/2210.00939.pdf)
  • Loading branch information
deltheil committed Oct 5, 2023
1 parent 665bcdc commit 81285a7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,20 @@ def add_noise(
step: int,
) -> Tensor:
timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep].unsqueeze(-1).unsqueeze(-1)
noise_stds = self.noise_std[timestep].unsqueeze(-1).unsqueeze(-1)
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep]
noised_x = cumulative_scale_factors * x + noise_stds * noise
return noised_x

def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep]
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf. Useful to preview progress or for guidance like
# in https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x

def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
if device is not None:
self.device = Device(device)
Expand Down
25 changes: 25 additions & 0 deletions tests/foundationals/latent_diffusion/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,31 @@ def test_ddim_solver_diffusers():
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"


def test_scheduler_remove_noise():
from diffusers import DDIMScheduler # type: ignore

diffusers_scheduler = DDIMScheduler(
beta_end=0.012,
beta_schedule="scaled_linear",
beta_start=0.00085,
num_train_timesteps=1000,
set_alpha_to_one=False,
steps_offset=1,
clip_sample=False,
)
diffusers_scheduler.set_timesteps(30)
refiners_scheduler = DDIM(num_inference_steps=30)

sample = randn(1, 4, 32, 32)
noise = randn(1, 4, 32, 32)

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore
refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step)

assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"


def test_scheduler_device(test_device: Device):
if test_device.type == "cpu":
warn("not running on CPU, skipping")
Expand Down

0 comments on commit 81285a7

Please sign in to comment.