Skip to content

Commit

Permalink
prevent submodule registration with settatr for the Chain class
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Oct 9, 2023
1 parent d0164e3 commit cb05b2e
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 59 deletions.
5 changes: 4 additions & 1 deletion src/refiners/fluxion/layers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ def __init__(self, *args: Module | Iterable[Module]) -> None:

def __setattr__(self, name: str, value: Any) -> None:
if isinstance(value, torch.nn.Module):
raise ValueError("Chain does not support setting modules by attribute. Instead, use the append method.")
raise ValueError(
"Chain does not support setting modules by attribute. Instead, use a mutation method like `append` or"
" wrap it within a single element list to prevent pytorch from registering it as a submodule."
)
super().__setattr__(name, value)

@property
Expand Down
139 changes: 81 additions & 58 deletions src/refiners/foundationals/clip/concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,64 +10,6 @@
import re


class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
"""
Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique.
Example:
import torch
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.fluxion.utils import load_from_safetensors
encoder = CLIPTextEncoderL(device="cuda")
tensors = load_from_safetensors("CLIPTextEncoderL.safetensors")
encoder.load_state_dict(tensors)
cat_embedding = torch.load("cat_embedding.bin")["<this-cat>"]
dog_embedding = torch.load("dog_embedding.bin")["<that-dog>"]
extender = ConceptExtender(encoder)
extender.add_concept(token="<this-cat>", embedding=cat_embedding)
extender.inject()
# New concepts can be added at any time
extender.add_concept(token="<that-dog>", embedding=dog_embedding)
# Now the encoder can be used with the new concepts
"""

def __init__(self, target: CLIPTextEncoder) -> None:
with self.setup_adapter(target):
super().__init__(target)

try:
token_encoder, self.token_encoder_parent = next(target.walk(TokenEncoder))
except StopIteration:
raise RuntimeError("TokenEncoder not found.")

try:
clip_tokenizer, self.clip_tokenizer_parent = next(target.walk(CLIPTokenizer))
except StopIteration:
raise RuntimeError("Tokenizer not found.")

self.embedding_extender = EmbeddingExtender(token_encoder)
self.token_extender = TokenExtender(clip_tokenizer)

def add_concept(self, token: str, embedding: Tensor) -> None:
self.embedding_extender.add_embedding(embedding)
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)

def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender":
self.embedding_extender.inject(self.token_encoder_parent)
self.token_extender.inject(self.clip_tokenizer_parent)
return super().inject(parent)

def eject(self) -> None:
self.embedding_extender.eject()
self.token_extender.eject()
super().eject()


class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
old_weight: Parameter
new_weight: Parameter
Expand Down Expand Up @@ -122,3 +64,84 @@ def add_token(self, token: str, token_id: int) -> None:
tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE)
# Define the keyword as its own smallest subtoken
tokenizer.byte_pair_encoding_cache[token] = token


class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
"""
Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique.
Example:
import torch
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.fluxion.utils import load_from_safetensors
encoder = CLIPTextEncoderL(device="cuda")
tensors = load_from_safetensors("CLIPTextEncoderL.safetensors")
encoder.load_state_dict(tensors)
cat_embedding = torch.load("cat_embedding.bin")["<this-cat>"]
dog_embedding = torch.load("dog_embedding.bin")["<that-dog>"]
extender = ConceptExtender(encoder)
extender.add_concept(token="<this-cat>", embedding=cat_embedding)
extender.inject()
# New concepts can be added at any time
extender.add_concept(token="<that-dog>", embedding=dog_embedding)
# Now the encoder can be used with the new concepts
"""

def __init__(self, target: CLIPTextEncoder) -> None:
with self.setup_adapter(target):
super().__init__(target)

try:
token_encoder, token_encoder_parent = next(target.walk(TokenEncoder))
self._token_encoder_parent = [token_encoder_parent]

except StopIteration:
raise RuntimeError("TokenEncoder not found.")

try:
clip_tokenizer, clip_tokenizer_parent = next(target.walk(CLIPTokenizer))
self._clip_tokenizer_parent = [clip_tokenizer_parent]
except StopIteration:
raise RuntimeError("Tokenizer not found.")

self._embedding_extender = [EmbeddingExtender(token_encoder)]
self._token_extender = [TokenExtender(clip_tokenizer)]

@property
def embedding_extender(self) -> EmbeddingExtender:
assert len(self._embedding_extender) == 1, "EmbeddingExtender not found."
return self._embedding_extender[0]

@property
def token_extender(self) -> TokenExtender:
assert len(self._token_extender) == 1, "TokenExtender not found."
return self._token_extender[0]

@property
def token_encoder_parent(self) -> fl.Chain:
assert len(self._token_encoder_parent) == 1, "TokenEncoder parent not found."
return self._token_encoder_parent[0]

@property
def clip_tokenizer_parent(self) -> fl.Chain:
assert len(self._clip_tokenizer_parent) == 1, "Tokenizer parent not found."
return self._clip_tokenizer_parent[0]

def add_concept(self, token: str, embedding: Tensor) -> None:
self.embedding_extender.add_embedding(embedding)
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)

def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender":
self.embedding_extender.inject(self.token_encoder_parent)
self.token_extender.inject(self.clip_tokenizer_parent)
return super().inject(parent)

def eject(self) -> None:
self.embedding_extender.eject()
self.token_extender.eject()
super().eject()

0 comments on commit cb05b2e

Please sign in to comment.