Skip to content

Commit

Permalink
add karras sigmas to dpm solver
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Sep 6, 2024
1 parent cf247a1 commit b6240c6
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 88 deletions.
181 changes: 133 additions & 48 deletions src/refiners/foundationals/latent_diffusion/solvers/dpm.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
import dataclasses
from collections import deque
from typing import NamedTuple

import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype

from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
ModelPredictionType,
NoiseSchedule,
Solver,
TimestepSpacing,
)


def safe_log(x: torch.Tensor, lower_bound: float = 1e-6) -> torch.Tensor:
"""Compute the log of a tensor with a lower bound."""
return torch.log(torch.maximum(x, torch.tensor(lower_bound)))


def safe_sqrt(x: torch.Tensor) -> torch.Tensor:
"""Compute the square root of a tensor ensuring that the input is non-negative"""
return torch.sqrt(torch.maximum(x, torch.tensor(0)))


class SolverTensors(NamedTuple):
cumulative_scale_factors: torch.Tensor
noise_std: torch.Tensor
signal_to_noise_ratios: torch.Tensor


class DPMSolver(Solver):
"""Diffusion probabilistic models (DPMs) solver.
Expand All @@ -37,9 +54,9 @@ def __init__(
first_inference_step: int = 0,
params: BaseSolverParams | None = None,
last_step_first_order: bool = False,
device: Device | str = "cpu",
dtype: Dtype = torch.float32,
):
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32,
) -> None:
"""Initializes a new DPM solver.
Args:
Expand All @@ -64,6 +81,14 @@ def __init__(
)
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
self.last_step_first_order = last_step_first_order
sigmas = self.noise_std / self.cumulative_scale_factors
self.sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
sigma_min = sigmas[0:1] # corresponds to `final_sigmas_type="sigma_min" in diffusers`
self.sigmas = torch.cat([self.sigmas, sigma_min])
self.cumulative_scale_factors, self.noise_std, self.signal_to_noise_ratios = self._solver_tensors_from_sigmas(
self.sigmas
)
self.timesteps = self._timesteps_from_sigmas(sigmas)

def rebuild(
self: "DPMSolver",
Expand All @@ -83,7 +108,7 @@ def rebuild(
r.last_step_first_order = self.last_step_first_order
return r

def _generate_timesteps(self) -> Tensor:
def _generate_timesteps(self) -> torch.Tensor:
if self.params.timesteps_spacing != TimestepSpacing.CUSTOM:
return super()._generate_timesteps()

Expand All @@ -96,9 +121,75 @@ def _generate_timesteps(self) -> Tensor:
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
return torch.tensor(np_space).flip(0)

def _generate_sigmas(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Generate the sigmas used by the solver."""
assert self.params.sigma_schedule is not None, "sigma_schedule must be set for the DPM solver"
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = sigmas.flip(0)
rescaled_sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
rescaled_sigmas = torch.cat([rescaled_sigmas, torch.tensor([0.0])])
return sigmas, rescaled_sigmas

def _rescale_sigmas(self, sigmas: torch.Tensor, sigma_schedule: NoiseSchedule | None) -> torch.Tensor:
"""Rescale the sigmas according to the sigma schedule."""
match sigma_schedule:
case NoiseSchedule.UNIFORM:
rho = 1
case NoiseSchedule.QUADRATIC:
rho = 2
case NoiseSchedule.KARRAS:
rho = 7
case None:
return torch.tensor(
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
device=self.device,
)

linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)
first_sigma = sigmas[0]
last_sigma = sigmas[-1]
rescaled_sigmas = (
first_sigma ** (1 / rho) + linear_schedule * (last_sigma ** (1 / rho) - first_sigma ** (1 / rho))
) ** rho
return rescaled_sigmas.flip(0)

def _timesteps_from_sigmas(self, sigmas: torch.Tensor) -> torch.Tensor:
"""Generate the timesteps from the sigmas."""
log_sigmas = safe_log(sigmas)
timesteps: list[torch.Tensor] = []
for sigma in self.sigmas[:-1]:
log_sigma = safe_log(sigma)
distance_matrix = log_sigma - log_sigmas.unsqueeze(1)

# Determine the range of sigma indices
low_indices = (distance_matrix >= 0).cumsum(dim=0).argmax(dim=0).clip(max=sigmas.size(0) - 2)
high_indices = low_indices + 1

low_log_sigma = log_sigmas[low_indices]
high_log_sigma = log_sigmas[high_indices]

# Interpolate sigma values
interpolation_weights = (low_log_sigma - log_sigma) / (low_log_sigma - high_log_sigma)
interpolation_weights = torch.clamp(interpolation_weights, 0, 1)
timestep = (1 - interpolation_weights) * low_indices + interpolation_weights * high_indices
timesteps.append(timestep)

return torch.cat(timesteps).round()

def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors:
"""Generate the tensors from the sigmas."""
cumulative_scale_factors = 1 / torch.sqrt(sigmas**2 + 1)
noise_std = sigmas * cumulative_scale_factors
signal_to_noise_ratios = safe_log(cumulative_scale_factors) - safe_log(noise_std)
return SolverTensors(
cumulative_scale_factors=cumulative_scale_factors,
noise_std=noise_std,
signal_to_noise_ratios=signal_to_noise_ratios,
)

def dpm_solver_first_order_update(
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
) -> Tensor:
self, x: torch.Tensor, noise: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
) -> torch.Tensor:
"""Applies a first-order backward Euler update to the input data `x`.
Args:
Expand All @@ -109,32 +200,29 @@ def dpm_solver_first_order_update(
Returns:
The denoised version of the input data `x`.
"""
current_timestep = self.timesteps[step]
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
current_ratio = self.signal_to_noise_ratios[step]
next_ratio = self.signal_to_noise_ratios[step + 1]

previous_ratio = self.signal_to_noise_ratios[previous_timestep]
current_ratio = self.signal_to_noise_ratios[current_timestep]
next_scale_factor = self.cumulative_scale_factors[step + 1]

previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
next_noise_std = self.noise_std[step + 1]
current_noise_std = self.noise_std[step]

previous_noise_std = self.noise_std[previous_timestep]
current_noise_std = self.noise_std[current_timestep]

ratio_delta = current_ratio - previous_ratio
ratio_delta = current_ratio - next_ratio

if sde_noise is None:
return (previous_noise_std / current_noise_std) * x + (
1.0 - torch.exp(ratio_delta)
) * previous_scale_factor * noise
return (next_noise_std / current_noise_std) * x + (1.0 - torch.exp(ratio_delta)) * next_scale_factor * noise

factor = 1.0 - torch.exp(2.0 * ratio_delta)
return (
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ previous_scale_factor * factor * noise
+ previous_noise_std * torch.sqrt(factor) * sde_noise
(next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ next_scale_factor * factor * noise
+ next_noise_std * safe_sqrt(factor) * sde_noise
)

def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noise: Tensor | None = None) -> Tensor:
def multistep_dpm_solver_second_order_update(
self, x: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
) -> torch.Tensor:
"""Applies a second-order backward Euler update to the input data `x`.
Args:
Expand All @@ -144,43 +232,41 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noi
Returns:
The denoised version of the input data `x`.
"""
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
current_timestep = self.timesteps[step]
next_timestep = self.timesteps[step - 1]

current_data_estimation = self.estimated_data[-1]
next_data_estimation = self.estimated_data[-2]
previous_data_estimation = self.estimated_data[-2]

previous_ratio = self.signal_to_noise_ratios[previous_timestep]
current_ratio = self.signal_to_noise_ratios[current_timestep]
next_ratio = self.signal_to_noise_ratios[next_timestep]
next_ratio = self.signal_to_noise_ratios[step + 1]
current_ratio = self.signal_to_noise_ratios[step]
previous_ratio = self.signal_to_noise_ratios[step - 1]

previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_noise_std = self.noise_std[previous_timestep]
current_noise_std = self.noise_std[current_timestep]
next_scale_factor = self.cumulative_scale_factors[step + 1]
next_noise_std = self.noise_std[step + 1]
current_noise_std = self.noise_std[step]

estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
estimation_delta = (current_data_estimation - previous_data_estimation) / (
(current_ratio - previous_ratio) / (next_ratio - current_ratio)
)
ratio_delta = current_ratio - previous_ratio
ratio_delta = current_ratio - next_ratio

if sde_noise is None:
factor = 1.0 - torch.exp(ratio_delta)
return (
(previous_noise_std / current_noise_std) * x
+ previous_scale_factor * factor * current_data_estimation
+ 0.5 * previous_scale_factor * factor * estimation_delta
(next_noise_std / current_noise_std) * x
+ next_scale_factor * factor * current_data_estimation
+ 0.5 * next_scale_factor * factor * estimation_delta
)

factor = 1.0 - torch.exp(2.0 * ratio_delta)
return (
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ previous_scale_factor * factor * current_data_estimation
+ 0.5 * previous_scale_factor * factor * estimation_delta
+ previous_noise_std * torch.sqrt(factor) * sde_noise
(next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ next_scale_factor * factor * current_data_estimation
+ 0.5 * next_scale_factor * factor * estimation_delta
+ next_noise_std * safe_sqrt(factor) * sde_noise
)

def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
def __call__(
self, x: torch.Tensor, predicted_noise: torch.Tensor, step: int, generator: torch.Generator | None = None
) -> torch.Tensor:
"""Apply one step of the backward diffusion process.
Note:
Expand All @@ -199,9 +285,8 @@ def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Gen
"""
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"

current_timestep = self.timesteps[step]
scale_factor = self.cumulative_scale_factors[current_timestep]
noise_ratio = self.noise_std[current_timestep]
scale_factor = self.cumulative_scale_factors[step]
noise_ratio = self.noise_std[step]
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)
variance = self.params.sde_variance
Expand Down
12 changes: 7 additions & 5 deletions src/refiners/foundationals/latent_diffusion/solvers/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class BaseSolverParams:
initial_diffusion_rate: float | None
final_diffusion_rate: float | None
noise_schedule: NoiseSchedule | None
sigma_schedule: NoiseSchedule | None
model_prediction_type: ModelPredictionType | None
sde_variance: float

Expand All @@ -91,6 +92,7 @@ class SolverParams(BaseSolverParams):
initial_diffusion_rate: float | None = None
final_diffusion_rate: float | None = None
noise_schedule: NoiseSchedule | None = None
sigma_schedule: NoiseSchedule | None = None
model_prediction_type: ModelPredictionType | None = None
sde_variance: float = 0.0

Expand All @@ -103,6 +105,7 @@ class ResolvedSolverParams(BaseSolverParams):
initial_diffusion_rate: float
final_diffusion_rate: float
noise_schedule: NoiseSchedule
sigma_schedule: NoiseSchedule | None
model_prediction_type: ModelPredictionType
sde_variance: float

Expand Down Expand Up @@ -140,6 +143,7 @@ class Solver(fl.Module, ABC):
initial_diffusion_rate=8.5e-4,
final_diffusion_rate=1.2e-2,
noise_schedule=NoiseSchedule.QUADRATIC,
sigma_schedule=None,
model_prediction_type=ModelPredictionType.NOISE,
sde_variance=0.0,
)
Expand Down Expand Up @@ -404,14 +408,12 @@ def sample_noise_schedule(self) -> Tensor:
A tensor representing the noise schedule.
"""
match self.params.noise_schedule:
case "uniform":
case NoiseSchedule.UNIFORM:
return 1 - self.sample_power_distribution(1)
case "quadratic":
case NoiseSchedule.QUADRATIC:
return 1 - self.sample_power_distribution(2)
case "karras":
case NoiseSchedule.KARRAS:
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.params.noise_schedule}")

def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
"""Move the solver to the specified device and data type.
Expand Down
38 changes: 38 additions & 0 deletions tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")


@pytest.fixture
def expected_image_std_sde_karras_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_sde_karras_random_init.png").convert("RGB")


@pytest.fixture
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
Expand Down Expand Up @@ -913,6 +918,39 @@ def test_diffusion_std_sde_random_init(
ensure_similar_images(predicted_image, expected_image_std_sde_random_init)


@no_grad()
def test_diffusion_std_sde_karras_random_init(
sd15_std_sde: StableDiffusion_1, expected_image_std_sde_karras_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_std_sde

prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)

sd15.solver = DPMSolver(
num_inference_steps=18,
last_step_first_order=True,
params=SolverParams(sde_variance=1.0, sigma_schedule=NoiseSchedule.KARRAS),
device=test_device,
)

manual_seed(2)
x = sd15.init_latents((512, 512))

for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)

predicted_image = sd15.lda.latents_to_image(x)

ensure_similar_images(predicted_image, expected_image_std_sde_karras_random_init)


@no_grad()
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
sd15 = sd15_std
Expand Down
Loading

0 comments on commit b6240c6

Please sign in to comment.