diff --git a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py index b7298291c..08d56c46a 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py @@ -1,18 +1,35 @@ import dataclasses from collections import deque +from typing import NamedTuple import numpy as np import torch -from torch import Generator, Tensor, device as Device, dtype as Dtype from refiners.foundationals.latent_diffusion.solvers.solver import ( BaseSolverParams, ModelPredictionType, + NoiseSchedule, Solver, TimestepSpacing, ) +def safe_log(x: torch.Tensor, lower_bound: float = 1e-6) -> torch.Tensor: + """Compute the log of a tensor with a lower bound.""" + return torch.log(torch.maximum(x, torch.tensor(lower_bound))) + + +def safe_sqrt(x: torch.Tensor) -> torch.Tensor: + """Compute the square root of a tensor ensuring that the input is non-negative""" + return torch.sqrt(torch.maximum(x, torch.tensor(0))) + + +class SolverTensors(NamedTuple): + cumulative_scale_factors: torch.Tensor + noise_std: torch.Tensor + signal_to_noise_ratios: torch.Tensor + + class DPMSolver(Solver): """Diffusion probabilistic models (DPMs) solver. @@ -37,9 +54,9 @@ def __init__( first_inference_step: int = 0, params: BaseSolverParams | None = None, last_step_first_order: bool = False, - device: Device | str = "cpu", - dtype: Dtype = torch.float32, - ): + device: torch.device | str = "cpu", + dtype: torch.dtype = torch.float32, + ) -> None: """Initializes a new DPM solver. Args: @@ -64,6 +81,14 @@ def __init__( ) self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2) self.last_step_first_order = last_step_first_order + sigmas = self.noise_std / self.cumulative_scale_factors + self.sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule) + sigma_min = sigmas[0:1] # corresponds to `final_sigmas_type="sigma_min" in diffusers` + self.sigmas = torch.cat([self.sigmas, sigma_min]) + self.cumulative_scale_factors, self.noise_std, self.signal_to_noise_ratios = self._solver_tensors_from_sigmas( + self.sigmas + ) + self.timesteps = self._timesteps_from_sigmas(sigmas) def rebuild( self: "DPMSolver", @@ -83,7 +108,7 @@ def rebuild( r.last_step_first_order = self.last_step_first_order return r - def _generate_timesteps(self) -> Tensor: + def _generate_timesteps(self) -> torch.Tensor: if self.params.timesteps_spacing != TimestepSpacing.CUSTOM: return super()._generate_timesteps() @@ -96,9 +121,75 @@ def _generate_timesteps(self) -> Tensor: np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:] return torch.tensor(np_space).flip(0) + def _generate_sigmas(self) -> tuple[torch.Tensor, torch.Tensor]: + """Generate the sigmas used by the solver.""" + assert self.params.sigma_schedule is not None, "sigma_schedule must be set for the DPM solver" + sigmas = self.noise_std / self.cumulative_scale_factors + sigmas = sigmas.flip(0) + rescaled_sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule) + rescaled_sigmas = torch.cat([rescaled_sigmas, torch.tensor([0.0])]) + return sigmas, rescaled_sigmas + + def _rescale_sigmas(self, sigmas: torch.Tensor, sigma_schedule: NoiseSchedule | None) -> torch.Tensor: + """Rescale the sigmas according to the sigma schedule.""" + match sigma_schedule: + case NoiseSchedule.UNIFORM: + rho = 1 + case NoiseSchedule.QUADRATIC: + rho = 2 + case NoiseSchedule.KARRAS: + rho = 7 + case None: + return torch.tensor( + np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()), + device=self.device, + ) + + linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device) + first_sigma = sigmas[0] + last_sigma = sigmas[-1] + rescaled_sigmas = ( + first_sigma ** (1 / rho) + linear_schedule * (last_sigma ** (1 / rho) - first_sigma ** (1 / rho)) + ) ** rho + return rescaled_sigmas.flip(0) + + def _timesteps_from_sigmas(self, sigmas: torch.Tensor) -> torch.Tensor: + """Generate the timesteps from the sigmas.""" + log_sigmas = safe_log(sigmas) + timesteps: list[torch.Tensor] = [] + for sigma in self.sigmas[:-1]: + log_sigma = safe_log(sigma) + distance_matrix = log_sigma - log_sigmas.unsqueeze(1) + + # Determine the range of sigma indices + low_indices = (distance_matrix >= 0).cumsum(dim=0).argmax(dim=0).clip(max=sigmas.size(0) - 2) + high_indices = low_indices + 1 + + low_log_sigma = log_sigmas[low_indices] + high_log_sigma = log_sigmas[high_indices] + + # Interpolate sigma values + interpolation_weights = (low_log_sigma - log_sigma) / (low_log_sigma - high_log_sigma) + interpolation_weights = torch.clamp(interpolation_weights, 0, 1) + timestep = (1 - interpolation_weights) * low_indices + interpolation_weights * high_indices + timesteps.append(timestep) + + return torch.cat(timesteps).round() + + def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors: + """Generate the tensors from the sigmas.""" + cumulative_scale_factors = 1 / torch.sqrt(sigmas**2 + 1) + noise_std = sigmas * cumulative_scale_factors + signal_to_noise_ratios = safe_log(cumulative_scale_factors) - safe_log(noise_std) + return SolverTensors( + cumulative_scale_factors=cumulative_scale_factors, + noise_std=noise_std, + signal_to_noise_ratios=signal_to_noise_ratios, + ) + def dpm_solver_first_order_update( - self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None - ) -> Tensor: + self, x: torch.Tensor, noise: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None + ) -> torch.Tensor: """Applies a first-order backward Euler update to the input data `x`. Args: @@ -109,32 +200,29 @@ def dpm_solver_first_order_update( Returns: The denoised version of the input data `x`. """ - current_timestep = self.timesteps[step] - previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0]) + current_ratio = self.signal_to_noise_ratios[step] + next_ratio = self.signal_to_noise_ratios[step + 1] - previous_ratio = self.signal_to_noise_ratios[previous_timestep] - current_ratio = self.signal_to_noise_ratios[current_timestep] + next_scale_factor = self.cumulative_scale_factors[step + 1] - previous_scale_factor = self.cumulative_scale_factors[previous_timestep] + next_noise_std = self.noise_std[step + 1] + current_noise_std = self.noise_std[step] - previous_noise_std = self.noise_std[previous_timestep] - current_noise_std = self.noise_std[current_timestep] - - ratio_delta = current_ratio - previous_ratio + ratio_delta = current_ratio - next_ratio if sde_noise is None: - return (previous_noise_std / current_noise_std) * x + ( - 1.0 - torch.exp(ratio_delta) - ) * previous_scale_factor * noise + return (next_noise_std / current_noise_std) * x + (1.0 - torch.exp(ratio_delta)) * next_scale_factor * noise factor = 1.0 - torch.exp(2.0 * ratio_delta) return ( - (previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x - + previous_scale_factor * factor * noise - + previous_noise_std * torch.sqrt(factor) * sde_noise + (next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x + + next_scale_factor * factor * noise + + next_noise_std * safe_sqrt(factor) * sde_noise ) - def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noise: Tensor | None = None) -> Tensor: + def multistep_dpm_solver_second_order_update( + self, x: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None + ) -> torch.Tensor: """Applies a second-order backward Euler update to the input data `x`. Args: @@ -144,43 +232,41 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noi Returns: The denoised version of the input data `x`. """ - previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0]) - current_timestep = self.timesteps[step] - next_timestep = self.timesteps[step - 1] - current_data_estimation = self.estimated_data[-1] - next_data_estimation = self.estimated_data[-2] + previous_data_estimation = self.estimated_data[-2] - previous_ratio = self.signal_to_noise_ratios[previous_timestep] - current_ratio = self.signal_to_noise_ratios[current_timestep] - next_ratio = self.signal_to_noise_ratios[next_timestep] + next_ratio = self.signal_to_noise_ratios[step + 1] + current_ratio = self.signal_to_noise_ratios[step] + previous_ratio = self.signal_to_noise_ratios[step - 1] - previous_scale_factor = self.cumulative_scale_factors[previous_timestep] - previous_noise_std = self.noise_std[previous_timestep] - current_noise_std = self.noise_std[current_timestep] + next_scale_factor = self.cumulative_scale_factors[step + 1] + next_noise_std = self.noise_std[step + 1] + current_noise_std = self.noise_std[step] - estimation_delta = (current_data_estimation - next_data_estimation) / ( - (current_ratio - next_ratio) / (previous_ratio - current_ratio) + estimation_delta = (current_data_estimation - previous_data_estimation) / ( + (current_ratio - previous_ratio) / (next_ratio - current_ratio) ) - ratio_delta = current_ratio - previous_ratio + ratio_delta = current_ratio - next_ratio if sde_noise is None: factor = 1.0 - torch.exp(ratio_delta) return ( - (previous_noise_std / current_noise_std) * x - + previous_scale_factor * factor * current_data_estimation - + 0.5 * previous_scale_factor * factor * estimation_delta + (next_noise_std / current_noise_std) * x + + next_scale_factor * factor * current_data_estimation + + 0.5 * next_scale_factor * factor * estimation_delta ) factor = 1.0 - torch.exp(2.0 * ratio_delta) return ( - (previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x - + previous_scale_factor * factor * current_data_estimation - + 0.5 * previous_scale_factor * factor * estimation_delta - + previous_noise_std * torch.sqrt(factor) * sde_noise + (next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x + + next_scale_factor * factor * current_data_estimation + + 0.5 * next_scale_factor * factor * estimation_delta + + next_noise_std * safe_sqrt(factor) * sde_noise ) - def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: + def __call__( + self, x: torch.Tensor, predicted_noise: torch.Tensor, step: int, generator: torch.Generator | None = None + ) -> torch.Tensor: """Apply one step of the backward diffusion process. Note: @@ -199,9 +285,8 @@ def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Gen """ assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" - current_timestep = self.timesteps[step] - scale_factor = self.cumulative_scale_factors[current_timestep] - noise_ratio = self.noise_std[current_timestep] + scale_factor = self.cumulative_scale_factors[step] + noise_ratio = self.noise_std[step] estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor self.estimated_data.append(estimated_denoised_data) variance = self.params.sde_variance diff --git a/src/refiners/foundationals/latent_diffusion/solvers/solver.py b/src/refiners/foundationals/latent_diffusion/solvers/solver.py index 088a14934..dabe5d0a0 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/solver.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/solver.py @@ -67,6 +67,7 @@ class BaseSolverParams: initial_diffusion_rate: float | None final_diffusion_rate: float | None noise_schedule: NoiseSchedule | None + sigma_schedule: NoiseSchedule | None model_prediction_type: ModelPredictionType | None sde_variance: float @@ -91,6 +92,7 @@ class SolverParams(BaseSolverParams): initial_diffusion_rate: float | None = None final_diffusion_rate: float | None = None noise_schedule: NoiseSchedule | None = None + sigma_schedule: NoiseSchedule | None = None model_prediction_type: ModelPredictionType | None = None sde_variance: float = 0.0 @@ -103,6 +105,7 @@ class ResolvedSolverParams(BaseSolverParams): initial_diffusion_rate: float final_diffusion_rate: float noise_schedule: NoiseSchedule + sigma_schedule: NoiseSchedule | None model_prediction_type: ModelPredictionType sde_variance: float @@ -140,6 +143,7 @@ class Solver(fl.Module, ABC): initial_diffusion_rate=8.5e-4, final_diffusion_rate=1.2e-2, noise_schedule=NoiseSchedule.QUADRATIC, + sigma_schedule=None, model_prediction_type=ModelPredictionType.NOISE, sde_variance=0.0, ) @@ -404,14 +408,12 @@ def sample_noise_schedule(self) -> Tensor: A tensor representing the noise schedule. """ match self.params.noise_schedule: - case "uniform": + case NoiseSchedule.UNIFORM: return 1 - self.sample_power_distribution(1) - case "quadratic": + case NoiseSchedule.QUADRATIC: return 1 - self.sample_power_distribution(2) - case "karras": + case NoiseSchedule.KARRAS: return 1 - self.sample_power_distribution(7) - case _: - raise ValueError(f"Unknown noise schedule: {self.params.noise_schedule}") def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver": """Move the solver to the specified device and data type. diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index c8320dcf9..dbcd16ebc 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -97,6 +97,11 @@ def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB") +@pytest.fixture +def expected_image_std_sde_karras_random_init(ref_path: Path) -> Image.Image: + return _img_open(ref_path / "expected_std_sde_karras_random_init.png").convert("RGB") + + @pytest.fixture def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB") @@ -913,6 +918,39 @@ def test_diffusion_std_sde_random_init( ensure_similar_images(predicted_image, expected_image_std_sde_random_init) +@no_grad() +def test_diffusion_std_sde_karras_random_init( + sd15_std_sde: StableDiffusion_1, expected_image_std_sde_karras_random_init: Image.Image, test_device: torch.device +): + sd15 = sd15_std_sde + + prompt = "a cute cat, detailed high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + + sd15.solver = DPMSolver( + num_inference_steps=18, + last_step_first_order=True, + params=SolverParams(sde_variance=1.0, sigma_schedule=NoiseSchedule.KARRAS), + device=test_device, + ) + + manual_seed(2) + x = sd15.init_latents((512, 512)) + + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + + predicted_image = sd15.lda.latents_to_image(x) + + ensure_similar_images(predicted_image, expected_image_std_sde_karras_random_init) + + @no_grad() def test_diffusion_batch2(sd15_std: StableDiffusion_1): sd15 = sd15_std diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index a0ef33776..fa6742dc1 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -97,6 +97,29 @@ manual_seed(2) image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=7.5).images[0] ``` +- `expected_std_sde_karras_random_init.png` is generated with the following code (diffusers 0.30.2): + +```python +import torch +from diffusers import StableDiffusionPipeline +from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler +from refiners.fluxion.utils import manual_seed + +model_id = "botp/stable-diffusion-v1-5" +pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32) +pipe = pipe.to("cuda:1") + +config = {**pipe.scheduler.config} +config["use_karras_sigmas"] = True +config["algorithm_type"] = "sde-dpmsolver++" +pipe.scheduler = DPMSolverMultistepScheduler.from_config(config) + +prompt = "a cute cat, detailed high-quality professional image" +negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" +manual_seed(2) +image = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=18, guidance_scale=7.5).images[0] +``` + - `kitchen_mask.png` is made manually. - Controlnet guides have been manually generated (x) using open source software and models, namely: diff --git a/tests/e2e/test_diffusion_ref/expected_std_sde_karras_random_init.png b/tests/e2e/test_diffusion_ref/expected_std_sde_karras_random_init.png new file mode 100644 index 000000000..e51f2d9ec Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_std_sde_karras_random_init.png differ diff --git a/tests/foundationals/latent_diffusion/test_solvers.py b/tests/foundationals/latent_diffusion/test_solvers.py index 5522027b2..a6146ddbd 100644 --- a/tests/foundationals/latent_diffusion/test_solvers.py +++ b/tests/foundationals/latent_diffusion/test_solvers.py @@ -1,3 +1,4 @@ +import itertools from typing import cast from warnings import warn @@ -29,38 +30,11 @@ def test_ddpm_diffusers(): assert equal(diffusers_scheduler.timesteps, solver.timesteps) -@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)]) -def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool): - from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore - - manual_seed(0) - - diffusers_scheduler = DiffuserScheduler( - beta_schedule="scaled_linear", - beta_start=0.00085, - beta_end=0.012, - lower_order_final=False, - euler_at_final=last_step_first_order, - final_sigmas_type="sigma_min", # default before Diffusers 0.26.0 - ) - diffusers_scheduler.set_timesteps(n_steps) - solver = DPMSolver( - num_inference_steps=n_steps, - last_step_first_order=last_step_first_order, - ) - assert equal(solver.timesteps, diffusers_scheduler.timesteps) - - sample = randn(1, 3, 32, 32) - predicted_noise = randn(1, 3, 32, 32) - - for step, timestep in enumerate(diffusers_scheduler.timesteps): - diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore - refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step) - assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" - - -@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)]) -def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool): +@pytest.mark.parametrize( + "n_steps, last_step_first_order, sde_variance, use_karras_sigmas", + list(itertools.product([5, 30], [False, True], [0.0, 1.0], [False, True])), +) +def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool, sde_variance: float, use_karras_sigmas: bool): from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore manual_seed(0) @@ -72,13 +46,17 @@ def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool): lower_order_final=False, euler_at_final=last_step_first_order, final_sigmas_type="sigma_min", # default before Diffusers 0.26.0 - algorithm_type="sde-dpmsolver++", + algorithm_type="sde-dpmsolver++" if sde_variance == 1.0 else "dpmsolver++", + use_karras_sigmas=use_karras_sigmas, ) diffusers_scheduler.set_timesteps(n_steps) solver = DPMSolver( num_inference_steps=n_steps, last_step_first_order=last_step_first_order, - params=SolverParams(sde_variance=1.0), + params=SolverParams( + sde_variance=sde_variance, + sigma_schedule=NoiseSchedule.KARRAS if use_karras_sigmas else None, + ), ) assert equal(solver.timesteps, diffusers_scheduler.timesteps) @@ -94,8 +72,9 @@ def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool): manual_seed(37) refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)] + atol = 1e-4 if use_karras_sigmas else 1e-6 for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)): - assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=1e-6), f"outputs differ at step {step}" + assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=atol), f"outputs differ at step {step}" def test_ddim_diffusers():