Skip to content

Commit

Permalink
rework compute_clip_image_embedding overloads + improve docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Sep 27, 2024
1 parent 1fc2ad3 commit 8212d54
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions src/refiners/foundationals/latent_diffusion/image_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,36 +455,49 @@ def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})

@overload
def compute_clip_image_embedding(self, image_prompt: Tensor, weights: list[float] | None = None) -> Tensor: ...
def compute_clip_image_embedding(
self,
image_prompt: Image.Image,
) -> Tensor: ...

@overload
def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor: ...
def compute_clip_image_embedding(
self,
image_prompt: Tensor,
weights: list[float] | None = None,
concat_batches: bool = True,
) -> Tensor: ...

@overload
def compute_clip_image_embedding(
self, image_prompt: list[Image.Image], weights: list[float] | None = None
self,
image_prompt: list[Image.Image],
weights: list[float] | None = None,
concat_batches: bool = True,
) -> Tensor: ...

def compute_clip_image_embedding(
self,
image_prompt: Tensor | Image.Image | list[Image.Image],
image_prompt: Image.Image | list[Image.Image] | Tensor,
weights: list[float] | None = None,
concat_batches: bool = True,
) -> Tensor:
"""Compute the CLIP image embedding.
"""Compute CLIP image embedding(s).
Args:
image_prompt: The image prompt to use.
weights: The scale to use for the image prompt.
concat_batches: Whether to concatenate the batches.
image_prompt: The image prompt(s) to use.
weights: The scale(s) to use for the image prompt(s).
concat_batches: Whether to concatenate the image embeddings along the feature dimension.
Returns:
The CLIP image embedding.
The CLIP image embedding(s).
"""
if isinstance(image_prompt, Image.Image):
image_prompt = self.preprocess_image(image_prompt)
elif isinstance(image_prompt, list):
assert all(isinstance(image, Image.Image) for image in image_prompt)
assert all(
isinstance(image, Image.Image) for image in image_prompt
), "All elements of `image_prompt` must be of PIL Images."
image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])

negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
Expand Down

0 comments on commit 8212d54

Please sign in to comment.