Skip to content

Commit

Permalink
implement _add_noise for dpm solver
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Sep 9, 2024
1 parent a51d695 commit 5ca6ca2
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion src/refiners/foundationals/latent_diffusion/solvers/dpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,30 @@ def _timesteps_from_sigmas(self, sigmas: torch.Tensor) -> torch.Tensor:
timestep = (1 - interpolation_weights) * low_indices + interpolation_weights * high_indices
timesteps.append(timestep)

return torch.cat(timesteps).round()
return torch.cat(timesteps).round().int()

def _add_noise(
self,
x: torch.Tensor,
noise: torch.Tensor,
step: int,
) -> torch.Tensor:
"""Add noise to the input tensor using the solver's parameters.
Args:
x: The input tensor to add noise to.
noise: The noise tensor to add to the input tensor.
step: The current step of the diffusion process.
Returns:
The input tensor with added noise.
"""
cumulative_scale_factors = self.cumulative_scale_factors[step]
noise_stds = self.noise_std[step]

# noisify the latents, arXiv:2006.11239 Eq. 4
noised_x = cumulative_scale_factors * x + noise_stds * noise
return noised_x

def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors:
"""Generate the tensors from the sigmas."""
Expand Down

0 comments on commit 5ca6ca2

Please sign in to comment.