diff --git a/comfy/samplers.py b/comfy/samplers.py index c05e3e084af..89464a42ac6 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,12 +1,13 @@ from __future__ import annotations from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, NamedTuple if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel from comfy.controlnet import ControlBase import torch +from functools import partial import collections from comfy import model_management import math @@ -920,31 +921,37 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) -SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic", "kl_optimal"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] -def calculate_sigmas(model_sampling, scheduler_name, steps): - if scheduler_name == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) - elif scheduler_name == "exponential": - sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) - elif scheduler_name == "normal": - sigmas = normal_scheduler(model_sampling, steps) - elif scheduler_name == "simple": - sigmas = simple_scheduler(model_sampling, steps) - elif scheduler_name == "ddim_uniform": - sigmas = ddim_scheduler(model_sampling, steps) - elif scheduler_name == "sgm_uniform": - sigmas = normal_scheduler(model_sampling, steps, sgm=True) - elif scheduler_name == "beta": - sigmas = beta_scheduler(model_sampling, steps) - elif scheduler_name == "linear_quadratic": - sigmas = linear_quadratic_schedule(model_sampling, steps) - elif scheduler_name == "kl_optimal": - sigmas = kl_optimal_scheduler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) - else: - logging.error("error invalid scheduler {}".format(scheduler_name)) - return sigmas +class SchedulerHandler(NamedTuple): + handler: Callable[..., torch.Tensor] + # Boolean indicates whether to call the handler like: + # scheduler_function(model_sampling, steps) or + # scheduler_function(n, sigma_min: float, sigma_max: float) + use_ms: bool = True + +SCHEDULER_HANDLERS = { + "normal": SchedulerHandler(normal_scheduler), + "karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False), + "exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False), + "sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)), + "simple": SchedulerHandler(simple_scheduler), + "ddim_uniform": SchedulerHandler(ddim_scheduler), + "beta": SchedulerHandler(beta_scheduler), + "linear_quadratic": SchedulerHandler(linear_quadratic_schedule), + "kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False), +} +SCHEDULER_NAMES = list(SCHEDULER_HANDLERS) + +def calculate_sigmas(model_sampling: object, scheduler_name: str, steps: int) -> torch.Tensor: + handler = SCHEDULER_HANDLERS.get(scheduler_name) + if handler is None: + err = f"error invalid scheduler {scheduler_name}" + logging.error(err) + raise ValueError(err) + if handler.use_ms: + return handler.handler(model_sampling, steps) + return handler.handler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) def sampler_object(name): if name == "uni_pc":