Skip to content

Commit

Permalink
Dino loss
Browse files Browse the repository at this point in the history
  • Loading branch information
pschaldenbrand committed Oct 13, 2023
1 parent 43903d4 commit c3ea074
Show file tree
Hide file tree
Showing 4 changed files with 388 additions and 1 deletion.
377 changes: 377 additions & 0 deletions src/losses/dino_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,377 @@
import argparse
import torch
from torch import nn
from torchvision import transforms
import torch.nn.modules.utils as nn_utils
import math
import types
from pathlib import Path
from typing import Union, List, Tuple
from PIL import Image

"""
Adaptation from: https://github.com/ShirAmir/dino-vit-features
Updated to use Dino-v2
@article{amir2021deep,
author = {Shir Amir and Yossi Gandelsman and Shai Bagon and Tali Dekel},
title = {Deep ViT Features as Dense Visual Descriptors},
journal = {arXiv preprint arXiv:2112.05814},
year = {2021}
}
"""

class ViTExtractor:
""" This class facilitates extraction of features, descriptors, and saliency maps from a ViT.
We use the following notation in the documentation of the module's methods:
B - batch size
h - number of heads. usually takes place of the channel dimension in pytorch's convention BxCxHxW
p - patch size of the ViT. either 8 or 16.
t - number of tokens. equals the number of patches + 1, e.g. HW / p**2 + 1. Where H and W are the height and width
of the input image.
d - the embedding dimension in the ViT.
"""

def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Module = None, device: str = 'cuda'):
"""
:param model_type: A string specifying the type of model to extract from.
[dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 |
vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224]
:param stride: stride of first convolution layer. small stride -> higher resolution.
:param model: Optional parameter. The nn.Module to extract from instead of creating a new one in ViTExtractor.
should be compatible with model_type.
"""
self.model_type = model_type
self.device = device
if model is not None:
self.model = model
else:
self.model = ViTExtractor.create_model(model_type)

self.model = ViTExtractor.patch_vit_resolution(self.model, stride=stride)
self.model.eval()
self.model.to(self.device)
self.p = self.model.patch_embed.patch_size[0] ####
self.stride = self.model.patch_embed.proj.stride

self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5)
self.std = (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5)

self._feats = []
self.hook_handlers = []
self.load_size = None
self.num_patches = None

@staticmethod
def create_model(model_type: str) -> nn.Module:
"""
:param model_type: a string specifying which model to load. [dino_vits8 | dino_vits16 | dino_vitb8 |
dino_vitb16 | vit_small_patch8_224 | vit_small_patch16_224 | vit_base_patch8_224 |
vit_base_patch16_224]
:return: the model
"""
model = torch.hub.load('facebookresearch/dinov2', model_type)
return model

@staticmethod
def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]):
"""
Creates a method for position encoding interpolation.
:param patch_size: patch size of the model.
:param stride_hw: A tuple containing the new height and width stride respectively.
:return: the interpolation method
"""
def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor:
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
class_pos_embed = self.pos_embed[:, 0]
patch_pos_embed = self.pos_embed[:, 1:]
dim = x.shape[-1]
# compute number of tokens taking stride into account
w0 = 1 + (w - patch_size) // stride_hw[1]
h0 = 1 + (h - patch_size) // stride_hw[0]
assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and
stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}"""
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
align_corners=False, recompute_scale_factor=False
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

return interpolate_pos_encoding

@staticmethod
def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module:
"""
change resolution of model output by changing the stride of the patch extraction.
:param model: the model to change resolution for.
:param stride: the new stride parameter.
:return: the adjusted model
"""
patch_size = model.patch_embed.patch_size
print(stride, patch_size)
if stride == patch_size[0] and stride == patch_size[1]: # nothing to do
return model

stride = nn_utils._pair(stride)
print(stride, patch_size)
assert all([(patch_size[i] // stride[i]) * stride[i] == patch_size[i] for i in
range(len(stride))]), f'stride {stride} should divide patch_size {patch_size}'

# fix the stride
model.patch_embed.proj.stride = stride
# fix the positional encoding code
model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model)
return model

def preprocess(self, image_path: Union[str, Path],
load_size: Union[int, Tuple[int, int]] = None) -> Tuple[torch.Tensor, Image.Image]:
"""
Preprocesses an image before extraction.
:param image_path: path to image to be extracted.
:param load_size: optional. Size to resize image before the rest of preprocessing.
:return: a tuple containing:
(1) the preprocessed image as a tensor to insert the model of shape BxCxHxW.
(2) the pil image in relevant dimensions
"""
pil_image = Image.open(image_path).convert('RGB')
if load_size is not None:
pil_image = transforms.Resize(load_size, interpolation=transforms.InterpolationMode.LANCZOS)(pil_image)
prep = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std)
])
prep_img = prep(pil_image)[None, ...]
return prep_img, pil_image

def _get_hook(self, facet: str):
"""
generate a hook method for a specific block and facet.
"""
if facet in ['attn', 'token']:
def _hook(model, input, output):
self._feats.append(output)
return _hook

if facet == 'query':
facet_idx = 0
elif facet == 'key':
facet_idx = 1
elif facet == 'value':
facet_idx = 2
else:
raise TypeError(f"{facet} is not a supported facet.")

def _inner_hook(module, input, output):
input = input[0]
B, N, C = input.shape
qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
self._feats.append(qkv[facet_idx]) #Bxhxtxd
return _inner_hook

def _register_hooks(self, layers: List[int], facet: str) -> None:
"""
register hook to extract features.
:param layers: layers from which to extract features.
:param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
"""
for block_idx, block in enumerate(self.model.blocks):
if block_idx in layers:
if facet == 'token':
self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet)))
elif facet == 'attn':
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet)))
elif facet in ['key', 'query', 'value']:
self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet)))
else:
raise TypeError(f"{facet} is not a supported facet.")

def _unregister_hooks(self) -> None:
"""
unregisters the hooks. should be called after feature extraction.
"""
for handle in self.hook_handlers:
handle.remove()
self.hook_handlers = []

def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: str = 'key') -> List[torch.Tensor]:
"""
extract features from the model
:param batch: batch to extract features for. Has shape BxCxHxW.
:param layers: layer to extract. A number between 0 to 11.
:param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
:return : tensor of features.
if facet is 'key' | 'query' | 'value' has shape Bxhxtxd
if facet is 'attn' has shape Bxhxtxt
if facet is 'token' has shape Bxtxd
"""
B, C, H, W = batch.shape
self._feats = []
self._register_hooks(layers, facet)
_ = self.model(batch)
self._unregister_hooks()
self.load_size = (H, W)
self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1])
return self._feats

def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor:
"""
create a log-binned descriptor.
:param x: tensor of features. Has shape Bxhxtxd.
:param hierarchy: how many bin hierarchies to use.
"""
B = x.shape[0]
num_bins = 1 + 8 * hierarchy

bin_x = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1) # Bx(t-1)x(dxh)
bin_x = bin_x.permute(0, 2, 1)
bin_x = bin_x.reshape(B, bin_x.shape[1], self.num_patches[0], self.num_patches[1])
# Bx(dxh)xnum_patches[0]xnum_patches[1]
sub_desc_dim = bin_x.shape[1]

avg_pools = []
# compute bins of all sizes for all spatial locations.
for k in range(0, hierarchy):
# avg pooling with kernel 3**kx3**k
win_size = 3 ** k
avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False)
avg_pools.append(avg_pool(bin_x))

bin_x = torch.zeros((B, sub_desc_dim * num_bins, self.num_patches[0], self.num_patches[1])).to(self.device)
for y in range(self.num_patches[0]):
for x in range(self.num_patches[1]):
part_idx = 0
# fill all bins for a spatial location (y, x)
for k in range(0, hierarchy):
kernel_size = 3 ** k
for i in range(y - kernel_size, y + kernel_size + 1, kernel_size):
for j in range(x - kernel_size, x + kernel_size + 1, kernel_size):
if i == y and j == x and k != 0:
continue
if 0 <= i < self.num_patches[0] and 0 <= j < self.num_patches[1]:
bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
:, :, i, j]
else: # handle padding in a more delicate way than zero padding
temp_i = max(0, min(i, self.num_patches[0] - 1))
temp_j = max(0, min(j, self.num_patches[1] - 1))
bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
:, :, temp_i,
temp_j]
part_idx += 1
bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1)
# Bx1x(t-1)x(dxh)
return bin_x

def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = 'key',
bin: bool = False, include_cls: bool = False) -> torch.Tensor:
"""
extract descriptors from the model
:param batch: batch to extract descriptors for. Has shape BxCxHxW.
:param layers: layer to extract. A number between 0 to 11.
:param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token']
:param bin: apply log binning to the descriptor. default is False.
:return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors.
"""
assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors.
choose from ['key' | 'query' | 'value' | 'token'] """
self._extract_features(batch, [layer], facet)
x = self._feats[0]
if facet == 'token':
x.unsqueeze_(dim=1) #Bx1xtxd
if not include_cls:
x = x[:, :, 1:, :] # remove cls token
else:
assert not bin, "bin = True and include_cls = True are not supported together, set one of them False."
if not bin:
desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh)
else:
desc = self._log_bin(x)
return desc

def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor:
"""
extract saliency maps. The saliency maps are extracted by averaging several attention heads from the last layer
in of the CLS token. All values are then normalized to range between 0 and 1.
:param batch: batch to extract saliency maps for. Has shape BxCxHxW.
:return: a tensor of saliency maps. has shape Bxt-1
"""
assert self.model_type == "dino_vits8", f"saliency maps are supported only for dino_vits model_type."
self._extract_features(batch, [11], 'attn')
head_idxs = [0, 2, 4, 5]
curr_feats = self._feats[0] #Bxhxtxt
cls_attn_map = curr_feats[:, head_idxs, 0, 1:].mean(dim=1) #Bx(t-1)
temp_mins, temp_maxs = cls_attn_map.min(dim=1)[0], cls_attn_map.max(dim=1)[0]
cls_attn_maps = (cls_attn_map - temp_mins) / (temp_maxs - temp_mins) # normalize to range [0,1]
return cls_attn_maps

def chunk_cosine_sim(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
""" Computes cosine similarity between all possible pairs in two sets of vectors.
Operates on chunks so no large amount of GPU RAM is required.
:param x: an tensor of descriptors of shape Bx1x(t_x)xd' where d' is the dimensionality of the descriptors and t_x
is the number of tokens in x.
:param y: a tensor of descriptors of shape Bx1x(t_y)xd' where d' is the dimensionality of the descriptors and t_y
is the number of tokens in y.
:return: cosine similarity between all descriptors in x and all descriptors in y. Has shape of Bx1x(t_x)x(t_y) """
result_list = []
num_token_x = x.shape[2]
for token_idx in range(num_token_x):
token = x[:, :, token_idx, :].unsqueeze(dim=2) # Bx1x1xd'
result_list.append(torch.nn.CosineSimilarity(dim=3)(token, y)) # Bx1xt
return torch.stack(result_list, dim=2) # Bx1x(t_x)x(t_y)

""" taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse"""
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')


mean = (0.485, 0.456, 0.406) #if "dino" in model_type else (0.5, 0.5, 0.5)
std = (0.229, 0.224, 0.225) #if "dino" in model_type else (0.5, 0.5, 0.5)
prep = transforms.Compose([
# transforms.ToTensor(),
transforms.Resize((224,224)),
transforms.Normalize(mean=mean, std=std)
])

extractor = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cosine_sim = torch.nn.CosineSimilarity(dim=2)

def dino_loss(img0, img1, model_type='dinov2_vitb14',
stride=14, layer=3, facet='token', bin=False):
"""
:param img0: image tensor [batch,3,h,w]
:param img1: image tensor [batch,3,h,w]
:param layer: layer from which to extract features [0,11]
:param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token']
"""
global extractor
if extractor is None:
extractor = ViTExtractor(model_type, stride, device=device)
extractor.requires_grad = False

descs_a = extractor.extract_descriptors(prep(img0.to(device)),
layer, facet, bin, include_cls=True)

descs_b = extractor.extract_descriptors(prep(img1.to(device)),
layer, facet, bin, include_cls=True)

similarities = cosine_sim(descs_a, descs_b)
similarities = torch.mean(similarities, dim=2)

return 1 - similarities.mean()
3 changes: 2 additions & 1 deletion src/paint.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@

# Clean paint brush and/or get more paint
if not painter.opt.ink:
color_ind, _ = nearest_color(stroke.color_transform.detach().cpu().numpy(), color_palette)
color_ind, _ = nearest_color(stroke.color_transform.detach().cpu().numpy(),
color_palette.detach().cpu().numpy())
new_paint_color = color_ind != curr_color
if new_paint_color or consecutive_strokes_no_clean > 12:
painter.clean_paint_brush()
Expand Down
Loading

0 comments on commit c3ea074

Please sign in to comment.