diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index 1c64f6cfa..e1e9bdbe6 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -145,7 +145,10 @@ def __init__(self, *args: Module | Iterable[Module]) -> None: def __setattr__(self, name: str, value: Any) -> None: if isinstance(value, torch.nn.Module): - raise ValueError("Chain does not support setting modules by attribute. Instead, use the append method.") + raise ValueError( + "Chain does not support setting modules by attribute. Instead, use a mutation method like `append` or" + " wrap it within a single element list to prevent pytorch from registering it as a submodule." + ) super().__setattr__(name, value) @property diff --git a/src/refiners/foundationals/clip/concepts.py b/src/refiners/foundationals/clip/concepts.py index 902112d40..3380e5a19 100644 --- a/src/refiners/foundationals/clip/concepts.py +++ b/src/refiners/foundationals/clip/concepts.py @@ -10,64 +10,6 @@ import re -class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]): - """ - Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique. - - Example: - import torch - from refiners.foundationals.clip.concepts import ConceptExtender - from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL - from refiners.fluxion.utils import load_from_safetensors - - encoder = CLIPTextEncoderL(device="cuda") - tensors = load_from_safetensors("CLIPTextEncoderL.safetensors") - encoder.load_state_dict(tensors) - - cat_embedding = torch.load("cat_embedding.bin")[""] - dog_embedding = torch.load("dog_embedding.bin")[""] - - extender = ConceptExtender(encoder) - extender.add_concept(token="", embedding=cat_embedding) - extender.inject() - # New concepts can be added at any time - extender.add_concept(token="", embedding=dog_embedding) - - # Now the encoder can be used with the new concepts - """ - - def __init__(self, target: CLIPTextEncoder) -> None: - with self.setup_adapter(target): - super().__init__(target) - - try: - token_encoder, self.token_encoder_parent = next(target.walk(TokenEncoder)) - except StopIteration: - raise RuntimeError("TokenEncoder not found.") - - try: - clip_tokenizer, self.clip_tokenizer_parent = next(target.walk(CLIPTokenizer)) - except StopIteration: - raise RuntimeError("Tokenizer not found.") - - self.embedding_extender = EmbeddingExtender(token_encoder) - self.token_extender = TokenExtender(clip_tokenizer) - - def add_concept(self, token: str, embedding: Tensor) -> None: - self.embedding_extender.add_embedding(embedding) - self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1) - - def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender": - self.embedding_extender.inject(self.token_encoder_parent) - self.token_extender.inject(self.clip_tokenizer_parent) - return super().inject(parent) - - def eject(self) -> None: - self.embedding_extender.eject() - self.token_extender.eject() - super().eject() - - class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]): old_weight: Parameter new_weight: Parameter @@ -122,3 +64,84 @@ def add_token(self, token: str, token_id: int) -> None: tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE) # Define the keyword as its own smallest subtoken tokenizer.byte_pair_encoding_cache[token] = token + + +class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]): + """ + Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique. + + Example: + import torch + from refiners.foundationals.clip.concepts import ConceptExtender + from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL + from refiners.fluxion.utils import load_from_safetensors + + encoder = CLIPTextEncoderL(device="cuda") + tensors = load_from_safetensors("CLIPTextEncoderL.safetensors") + encoder.load_state_dict(tensors) + + cat_embedding = torch.load("cat_embedding.bin")[""] + dog_embedding = torch.load("dog_embedding.bin")[""] + + extender = ConceptExtender(encoder) + extender.add_concept(token="", embedding=cat_embedding) + extender.inject() + # New concepts can be added at any time + extender.add_concept(token="", embedding=dog_embedding) + + # Now the encoder can be used with the new concepts + """ + + def __init__(self, target: CLIPTextEncoder) -> None: + with self.setup_adapter(target): + super().__init__(target) + + try: + token_encoder, token_encoder_parent = next(target.walk(TokenEncoder)) + self._token_encoder_parent = [token_encoder_parent] + + except StopIteration: + raise RuntimeError("TokenEncoder not found.") + + try: + clip_tokenizer, clip_tokenizer_parent = next(target.walk(CLIPTokenizer)) + self._clip_tokenizer_parent = [clip_tokenizer_parent] + except StopIteration: + raise RuntimeError("Tokenizer not found.") + + self._embedding_extender = [EmbeddingExtender(token_encoder)] + self._token_extender = [TokenExtender(clip_tokenizer)] + + @property + def embedding_extender(self) -> EmbeddingExtender: + assert len(self._embedding_extender) == 1, "EmbeddingExtender not found." + return self._embedding_extender[0] + + @property + def token_extender(self) -> TokenExtender: + assert len(self._token_extender) == 1, "TokenExtender not found." + return self._token_extender[0] + + @property + def token_encoder_parent(self) -> fl.Chain: + assert len(self._token_encoder_parent) == 1, "TokenEncoder parent not found." + return self._token_encoder_parent[0] + + @property + def clip_tokenizer_parent(self) -> fl.Chain: + assert len(self._clip_tokenizer_parent) == 1, "Tokenizer parent not found." + return self._clip_tokenizer_parent[0] + + def add_concept(self, token: str, embedding: Tensor) -> None: + self.embedding_extender.add_embedding(embedding) + self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1) + + def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender": + self.embedding_extender.inject(self.token_encoder_parent) + self.token_extender.inject(self.clip_tokenizer_parent) + return super().inject(parent) + + def eject(self) -> None: + self.embedding_extender.eject() + self.token_extender.eject() + super().eject()