Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support synthclip #129

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion clip_benchmark/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import torch
from .open_clip import load_open_clip
from .japanese_clip import load_japanese_clip
from .synthclip import load_synthclip

# loading function must return (model, transform, tokenizer)
TYPE2FUNC = {
"open_clip": load_open_clip,
"ja_clip": load_japanese_clip
"ja_clip": load_japanese_clip,
"synthclip": load_synthclip,
}
MODEL_TYPES = list(TYPE2FUNC.keys())

Expand Down
227 changes: 227 additions & 0 deletions clip_benchmark/models/synthclip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Modified from github.com/openai/CLIP
from collections import OrderedDict

import numpy as np
import timm
import torch
import torchvision.transforms as transforms
import open_clip
from torch import nn

# import losses


class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""

def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)


class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()

self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model)),
]
)
)
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask

def attention(self, x: torch.Tensor):
self.attn_mask = (
self.attn_mask.to(dtype=x.dtype, device=x.device)
if self.attn_mask is not None
else None
)
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x


class Transformer(nn.Module):
def __init__(
self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(
*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
)

def forward(self, x: torch.Tensor):
return self.resblocks(x)


class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
# vision
vision_width: int,
vision_model: nn.Module,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
**kwargs,
):
super().__init__()

self.context_length = context_length
self.vision_width = vision_width

self.visual = vision_model

self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
)

self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width)
)
self.ln_final = LayerNorm(transformer_width)

self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

self.initialize_parameters()

def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)

proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers) ** -0.5
)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

nn.init.normal_(self.image_projection, std=self.vision_width**-0.5)
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)

def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask

def encode_image(self, image):
x = self.visual(image)
x = x @ self.image_projection

return x

def encode_text(self, text):
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)

# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

return x

def forward(self, image, text):
image_embed = self.encode_image(image)
text_embed = self.encode_text(text)

return {
"image_embed": image_embed,
"text_embed": text_embed,
"logit_scale": self.logit_scale.exp(),
}


def get_loss(gather_with_grad=False):
return losses.CLIPLoss(gather_with_grad=gather_with_grad)


def get_metric_names():
return ["loss", "clip_loss", "clip_acc"]


def CLIP_VITB16(pretrained: str = None, cache_dir: str = None, **kwargs):
vision_model = timm.create_model("vit_base_patch16_224", num_classes=0,
pretrained=pretrained, cache_dir=cache_dir)
model = CLIP(
embed_dim=512,
vision_width=768,
vision_model=vision_model,
context_length=77,
vocab_size=49408,
transformer_width=512,
transformer_heads=8,
transformer_layers=12,
**kwargs,
)

return model


def load_synthclip(
model: str = "ViT-B-16",
pretrained: str = "./checkpoints/synthclip-30m/checkpoint_best.pt",
device="cpu", **kwargs):
model = CLIP_VITB16()
# Taken from
# https://github.com/hammoudhasan/SynthCLIP/blob/02ef69764d8dc921650bcac4a98bd0f477790787/Training/main.py#L240
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ColorJitter(0.4, 0.4, 0.4),
transforms.ToTensor(),
lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x, # force RGB
normalize,
]
)
model = model.to(device)
tokenizer = open_clip.get_tokenizer("ViT-B-16")
return model, transform, tokenizer
Loading