diff --git a/src/losses/dino_loss.py b/src/losses/dino_loss.py new file mode 100644 index 0000000..f89644b --- /dev/null +++ b/src/losses/dino_loss.py @@ -0,0 +1,377 @@ +import argparse +import torch +from torch import nn +from torchvision import transforms +import torch.nn.modules.utils as nn_utils +import math +import types +from pathlib import Path +from typing import Union, List, Tuple +from PIL import Image + +""" +Adaptation from: https://github.com/ShirAmir/dino-vit-features +Updated to use Dino-v2 +@article{amir2021deep, + author = {Shir Amir and Yossi Gandelsman and Shai Bagon and Tali Dekel}, + title = {Deep ViT Features as Dense Visual Descriptors}, + journal = {arXiv preprint arXiv:2112.05814}, + year = {2021} +} +""" + +class ViTExtractor: + """ This class facilitates extraction of features, descriptors, and saliency maps from a ViT. + + We use the following notation in the documentation of the module's methods: + B - batch size + h - number of heads. usually takes place of the channel dimension in pytorch's convention BxCxHxW + p - patch size of the ViT. either 8 or 16. + t - number of tokens. equals the number of patches + 1, e.g. HW / p**2 + 1. Where H and W are the height and width + of the input image. + d - the embedding dimension in the ViT. + """ + + def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Module = None, device: str = 'cuda'): + """ + :param model_type: A string specifying the type of model to extract from. + [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | + vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224] + :param stride: stride of first convolution layer. small stride -> higher resolution. + :param model: Optional parameter. The nn.Module to extract from instead of creating a new one in ViTExtractor. + should be compatible with model_type. + """ + self.model_type = model_type + self.device = device + if model is not None: + self.model = model + else: + self.model = ViTExtractor.create_model(model_type) + + self.model = ViTExtractor.patch_vit_resolution(self.model, stride=stride) + self.model.eval() + self.model.to(self.device) + self.p = self.model.patch_embed.patch_size[0] #### + self.stride = self.model.patch_embed.proj.stride + + self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5) + self.std = (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5) + + self._feats = [] + self.hook_handlers = [] + self.load_size = None + self.num_patches = None + + @staticmethod + def create_model(model_type: str) -> nn.Module: + """ + :param model_type: a string specifying which model to load. [dino_vits8 | dino_vits16 | dino_vitb8 | + dino_vitb16 | vit_small_patch8_224 | vit_small_patch16_224 | vit_base_patch8_224 | + vit_base_patch16_224] + :return: the model + """ + model = torch.hub.load('facebookresearch/dinov2', model_type) + return model + + @staticmethod + def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]): + """ + Creates a method for position encoding interpolation. + :param patch_size: patch size of the model. + :param stride_hw: A tuple containing the new height and width stride respectively. + :return: the interpolation method + """ + def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + # compute number of tokens taking stride into account + w0 = 1 + (w - patch_size) // stride_hw[1] + h0 = 1 + (h - patch_size) // stride_hw[0] + assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and + stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}""" + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + align_corners=False, recompute_scale_factor=False + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return interpolate_pos_encoding + + @staticmethod + def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module: + """ + change resolution of model output by changing the stride of the patch extraction. + :param model: the model to change resolution for. + :param stride: the new stride parameter. + :return: the adjusted model + """ + patch_size = model.patch_embed.patch_size + print(stride, patch_size) + if stride == patch_size[0] and stride == patch_size[1]: # nothing to do + return model + + stride = nn_utils._pair(stride) + print(stride, patch_size) + assert all([(patch_size[i] // stride[i]) * stride[i] == patch_size[i] for i in + range(len(stride))]), f'stride {stride} should divide patch_size {patch_size}' + + # fix the stride + model.patch_embed.proj.stride = stride + # fix the positional encoding code + model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model) + return model + + def preprocess(self, image_path: Union[str, Path], + load_size: Union[int, Tuple[int, int]] = None) -> Tuple[torch.Tensor, Image.Image]: + """ + Preprocesses an image before extraction. + :param image_path: path to image to be extracted. + :param load_size: optional. Size to resize image before the rest of preprocessing. + :return: a tuple containing: + (1) the preprocessed image as a tensor to insert the model of shape BxCxHxW. + (2) the pil image in relevant dimensions + """ + pil_image = Image.open(image_path).convert('RGB') + if load_size is not None: + pil_image = transforms.Resize(load_size, interpolation=transforms.InterpolationMode.LANCZOS)(pil_image) + prep = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std) + ]) + prep_img = prep(pil_image)[None, ...] + return prep_img, pil_image + + def _get_hook(self, facet: str): + """ + generate a hook method for a specific block and facet. + """ + if facet in ['attn', 'token']: + def _hook(model, input, output): + self._feats.append(output) + return _hook + + if facet == 'query': + facet_idx = 0 + elif facet == 'key': + facet_idx = 1 + elif facet == 'value': + facet_idx = 2 + else: + raise TypeError(f"{facet} is not a supported facet.") + + def _inner_hook(module, input, output): + input = input[0] + B, N, C = input.shape + qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) + self._feats.append(qkv[facet_idx]) #Bxhxtxd + return _inner_hook + + def _register_hooks(self, layers: List[int], facet: str) -> None: + """ + register hook to extract features. + :param layers: layers from which to extract features. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] + """ + for block_idx, block in enumerate(self.model.blocks): + if block_idx in layers: + if facet == 'token': + self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet))) + elif facet == 'attn': + self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet))) + elif facet in ['key', 'query', 'value']: + self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet))) + else: + raise TypeError(f"{facet} is not a supported facet.") + + def _unregister_hooks(self) -> None: + """ + unregisters the hooks. should be called after feature extraction. + """ + for handle in self.hook_handlers: + handle.remove() + self.hook_handlers = [] + + def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: str = 'key') -> List[torch.Tensor]: + """ + extract features from the model + :param batch: batch to extract features for. Has shape BxCxHxW. + :param layers: layer to extract. A number between 0 to 11. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] + :return : tensor of features. + if facet is 'key' | 'query' | 'value' has shape Bxhxtxd + if facet is 'attn' has shape Bxhxtxt + if facet is 'token' has shape Bxtxd + """ + B, C, H, W = batch.shape + self._feats = [] + self._register_hooks(layers, facet) + _ = self.model(batch) + self._unregister_hooks() + self.load_size = (H, W) + self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1]) + return self._feats + + def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor: + """ + create a log-binned descriptor. + :param x: tensor of features. Has shape Bxhxtxd. + :param hierarchy: how many bin hierarchies to use. + """ + B = x.shape[0] + num_bins = 1 + 8 * hierarchy + + bin_x = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1) # Bx(t-1)x(dxh) + bin_x = bin_x.permute(0, 2, 1) + bin_x = bin_x.reshape(B, bin_x.shape[1], self.num_patches[0], self.num_patches[1]) + # Bx(dxh)xnum_patches[0]xnum_patches[1] + sub_desc_dim = bin_x.shape[1] + + avg_pools = [] + # compute bins of all sizes for all spatial locations. + for k in range(0, hierarchy): + # avg pooling with kernel 3**kx3**k + win_size = 3 ** k + avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False) + avg_pools.append(avg_pool(bin_x)) + + bin_x = torch.zeros((B, sub_desc_dim * num_bins, self.num_patches[0], self.num_patches[1])).to(self.device) + for y in range(self.num_patches[0]): + for x in range(self.num_patches[1]): + part_idx = 0 + # fill all bins for a spatial location (y, x) + for k in range(0, hierarchy): + kernel_size = 3 ** k + for i in range(y - kernel_size, y + kernel_size + 1, kernel_size): + for j in range(x - kernel_size, x + kernel_size + 1, kernel_size): + if i == y and j == x and k != 0: + continue + if 0 <= i < self.num_patches[0] and 0 <= j < self.num_patches[1]: + bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ + :, :, i, j] + else: # handle padding in a more delicate way than zero padding + temp_i = max(0, min(i, self.num_patches[0] - 1)) + temp_j = max(0, min(j, self.num_patches[1] - 1)) + bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ + :, :, temp_i, + temp_j] + part_idx += 1 + bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1) + # Bx1x(t-1)x(dxh) + return bin_x + + def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = 'key', + bin: bool = False, include_cls: bool = False) -> torch.Tensor: + """ + extract descriptors from the model + :param batch: batch to extract descriptors for. Has shape BxCxHxW. + :param layers: layer to extract. A number between 0 to 11. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token'] + :param bin: apply log binning to the descriptor. default is False. + :return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors. + """ + assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors. + choose from ['key' | 'query' | 'value' | 'token'] """ + self._extract_features(batch, [layer], facet) + x = self._feats[0] + if facet == 'token': + x.unsqueeze_(dim=1) #Bx1xtxd + if not include_cls: + x = x[:, :, 1:, :] # remove cls token + else: + assert not bin, "bin = True and include_cls = True are not supported together, set one of them False." + if not bin: + desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) + else: + desc = self._log_bin(x) + return desc + + def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor: + """ + extract saliency maps. The saliency maps are extracted by averaging several attention heads from the last layer + in of the CLS token. All values are then normalized to range between 0 and 1. + :param batch: batch to extract saliency maps for. Has shape BxCxHxW. + :return: a tensor of saliency maps. has shape Bxt-1 + """ + assert self.model_type == "dino_vits8", f"saliency maps are supported only for dino_vits model_type." + self._extract_features(batch, [11], 'attn') + head_idxs = [0, 2, 4, 5] + curr_feats = self._feats[0] #Bxhxtxt + cls_attn_map = curr_feats[:, head_idxs, 0, 1:].mean(dim=1) #Bx(t-1) + temp_mins, temp_maxs = cls_attn_map.min(dim=1)[0], cls_attn_map.max(dim=1)[0] + cls_attn_maps = (cls_attn_map - temp_mins) / (temp_maxs - temp_mins) # normalize to range [0,1] + return cls_attn_maps + +def chunk_cosine_sim(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ Computes cosine similarity between all possible pairs in two sets of vectors. + Operates on chunks so no large amount of GPU RAM is required. + :param x: an tensor of descriptors of shape Bx1x(t_x)xd' where d' is the dimensionality of the descriptors and t_x + is the number of tokens in x. + :param y: a tensor of descriptors of shape Bx1x(t_y)xd' where d' is the dimensionality of the descriptors and t_y + is the number of tokens in y. + :return: cosine similarity between all descriptors in x and all descriptors in y. Has shape of Bx1x(t_x)x(t_y) """ + result_list = [] + num_token_x = x.shape[2] + for token_idx in range(num_token_x): + token = x[:, :, token_idx, :].unsqueeze(dim=2) # Bx1x1xd' + result_list.append(torch.nn.CosineSimilarity(dim=3)(token, y)) # Bx1xt + return torch.stack(result_list, dim=2) # Bx1x(t_x)x(t_y) + +""" taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse""" +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +mean = (0.485, 0.456, 0.406) #if "dino" in model_type else (0.5, 0.5, 0.5) +std = (0.229, 0.224, 0.225) #if "dino" in model_type else (0.5, 0.5, 0.5) +prep = transforms.Compose([ + # transforms.ToTensor(), + transforms.Resize((224,224)), + transforms.Normalize(mean=mean, std=std) +]) + +extractor = None +device = 'cuda' if torch.cuda.is_available() else 'cpu' +cosine_sim = torch.nn.CosineSimilarity(dim=2) + +def dino_loss(img0, img1, model_type='dinov2_vitb14', + stride=14, layer=3, facet='token', bin=False): + """ + :param img0: image tensor [batch,3,h,w] + :param img1: image tensor [batch,3,h,w] + :param layer: layer from which to extract features [0,11] + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token'] + """ + global extractor + if extractor is None: + extractor = ViTExtractor(model_type, stride, device=device) + extractor.requires_grad = False + + descs_a = extractor.extract_descriptors(prep(img0.to(device)), + layer, facet, bin, include_cls=True) + + descs_b = extractor.extract_descriptors(prep(img1.to(device)), + layer, facet, bin, include_cls=True) + + similarities = cosine_sim(descs_a, descs_b) + similarities = torch.mean(similarities, dim=2) + + return 1 - similarities.mean() \ No newline at end of file diff --git a/src/paint.py b/src/paint.py index bea2790..b21de3f 100755 --- a/src/paint.py +++ b/src/paint.py @@ -89,7 +89,8 @@ # Clean paint brush and/or get more paint if not painter.opt.ink: - color_ind, _ = nearest_color(stroke.color_transform.detach().cpu().numpy(), color_palette) + color_ind, _ = nearest_color(stroke.color_transform.detach().cpu().numpy(), + color_palette.detach().cpu().numpy()) new_paint_color = color_ind != curr_color if new_paint_color or consecutive_strokes_no_clean > 12: painter.clean_paint_brush() diff --git a/src/paint_utils3.py b/src/paint_utils3.py index 9971b74..40e91b6 100644 --- a/src/paint_utils3.py +++ b/src/paint_utils3.py @@ -188,6 +188,12 @@ def discretize_color(brush_stroke, discrete_colors): return discrete_colors[argmin].clone() +def compare_images(img1, img2): + ''' Pixel wise comparison ''' + # Input images are Lab + delta_E = colour.delta_E(img1, img2) + return delta_E + def nearest_color(color, discrete_colors): ''' Get the most similar color to a given color (np.array([3])) ''' #dist = np.mean(np.abs(discrete_colors - color[None,:])**2, axis=1) diff --git a/src/painting_optimization.py b/src/painting_optimization.py index cc3d68b..1700026 100644 --- a/src/painting_optimization.py +++ b/src/painting_optimization.py @@ -20,6 +20,7 @@ from losses.stable_diffusion.stable_diffusion_loss2 import stable_diffusion_loss, encode_text_stable_diffusion from losses.speech2emotion.speech2emotion import speech2emotion, speech2text +from losses.dino_loss import dino_loss from losses.clip_loss import clip_conv_loss, clip_model, clip_text_loss, clip_fc_loss import clip @@ -59,6 +60,8 @@ def parse_objective(objective_type, objective_data, p, weight=1.0, num_augs=30): return compute_style_loss(p, objective_data) * weight elif objective_type == 'clip_conv_loss': return clip_conv_loss(p, objective_data) * weight + elif objective_type == 'dino': + return dino_loss(p, objective_data) * weight elif objective_type == 'l2': return ((p - objective_data)**2).mean() * weight elif objective_type == 'clip_fc_loss':