diff --git a/clip_benchmark/models/__init__.py b/clip_benchmark/models/__init__.py index 96a6129..2ab62f5 100644 --- a/clip_benchmark/models/__init__.py +++ b/clip_benchmark/models/__init__.py @@ -2,11 +2,13 @@ import torch from .open_clip import load_open_clip from .japanese_clip import load_japanese_clip +from .synthclip import load_synthclip # loading function must return (model, transform, tokenizer) TYPE2FUNC = { "open_clip": load_open_clip, - "ja_clip": load_japanese_clip + "ja_clip": load_japanese_clip, + "synthclip": load_synthclip, } MODEL_TYPES = list(TYPE2FUNC.keys()) diff --git a/clip_benchmark/models/synthclip.py b/clip_benchmark/models/synthclip.py new file mode 100644 index 0000000..7274fb0 --- /dev/null +++ b/clip_benchmark/models/synthclip.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from github.com/openai/CLIP +from collections import OrderedDict + +import numpy as np +import timm +import torch +import torchvision.transforms as transforms +import open_clip +from torch import nn + +# import losses + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = ( + self.attn_mask.to(dtype=x.dtype, device=x.device) + if self.attn_mask is not None + else None + ) + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential( + *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] + ) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class CLIP(nn.Module): + def __init__( + self, + embed_dim: int, + # vision + vision_width: int, + vision_model: nn.Module, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + **kwargs, + ): + super().__init__() + + self.context_length = context_length + self.vision_width = vision_width + + self.visual = vision_model + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width) + ) + self.ln_final = LayerNorm(transformer_width) + + self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers) ** -0.5 + ) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + nn.init.normal_(self.image_projection, std=self.vision_width**-0.5) + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_image(self, image): + x = self.visual(image) + x = x @ self.image_projection + + return x + + def encode_text(self, text): + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_embed = self.encode_image(image) + text_embed = self.encode_text(text) + + return { + "image_embed": image_embed, + "text_embed": text_embed, + "logit_scale": self.logit_scale.exp(), + } + + +def get_loss(gather_with_grad=False): + return losses.CLIPLoss(gather_with_grad=gather_with_grad) + + +def get_metric_names(): + return ["loss", "clip_loss", "clip_acc"] + + +def CLIP_VITB16(pretrained: str = None, cache_dir: str = None, **kwargs): + vision_model = timm.create_model("vit_base_patch16_224", num_classes=0, + pretrained=pretrained, cache_dir=cache_dir) + model = CLIP( + embed_dim=512, + vision_width=768, + vision_model=vision_model, + context_length=77, + vocab_size=49408, + transformer_width=512, + transformer_heads=8, + transformer_layers=12, + **kwargs, + ) + + return model + + +def load_synthclip( + model: str = "ViT-B-16", + pretrained: str = "./checkpoints/synthclip-30m/checkpoint_best.pt", + device="cpu", **kwargs): + model = CLIP_VITB16() + # Taken from + # https://github.com/hammoudhasan/SynthCLIP/blob/02ef69764d8dc921650bcac4a98bd0f477790787/Training/main.py#L240 + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + transform = transforms.Compose( + [ + transforms.Resize((224, 224)), + transforms.ColorJitter(0.4, 0.4, 0.4), + transforms.ToTensor(), + lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x, # force RGB + normalize, + ] + ) + model = model.to(device) + tokenizer = open_clip.get_tokenizer("ViT-B-16") + return model, transform, tokenizer