diff --git a/README.md b/README.md index 5c23997..27792f6 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,177 @@ -# SOMA -[ICCV' 23] Novel Scenes & Classes: Towards Adaptive Open-set Object Detection +# [Novel Scenes & Classes: Towards Adaptive Open-set Object Detection (ICCV-23 ORAL)](assets/paper.pdf) + +By [Wuyang Li](https://wymancv.github.io/wuyang.github.io/) + +Paper link will be updated after the CVF open access. + +
+ +
+ +Domain Adaptive Object Detection (DAOD) strongly assumes a shared class space between the two domains. + +This work breaks the assumption and formulates Adaptive Open-set Object Detection (AOOD), by allowing the target domain with novel-class objects. + +The object detector uses the base-class labels in the source domain for training, and aims to detect base-class objects and identify novel-class objects as unknown in the target domain. + +If you have any ideas and problems hope to discuss, you can reach me out via [E-mail](mailto:wuyangli2-c@my.cityu.edu.hk). + +# 💡 Preparation + +## Setp 1: Clone and Install the Project + +### Clone the repository + +```bash +git clone https://github.com/CityU-AIM-Group/SOMA.git +``` + +### Install the project following [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR) + +Note that the following is in line with our experimental environments, which is silightly different from the official one. + +``` +# Linux, CUDA>=9.2, GCC>=5.4 +# (ours) CUDA=10.2, GCC=8.4, NVIDIA V100 +# Establish the conda environment + +conda create -n aood python=3.7 pip +conda activate aood +conda install pytorch=1.5.1 torchvision=0.6.1 cudatoolkit=10.2 -c pytorch +pip install -r requirements.txt + +# Compile the project +cd ./models/ops +sh ./make.sh + +# unit test (should see all checking is True) +python test.py + +# NOTE: If you meet the permission denied issue when starting the training +cd ../../ +chmod -R 777 ./ +``` + +## Setp 2: Download Necessary Resources + +### Download pre-processed datasets (VOC format) from the following links + +| | (Foggy) Cityscapes | Pascal VOC | Clipart | BDD100K | +| :------------: | :------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------: | +| Official Links | [Imgs](https://www.cityscapes-dataset.com/login/) | [Imgs+Labels](https://pjreddie.com/projects/pascal-voc-dataset-mirror/) | - | - | +| Our Links | [Labels](https://portland-my.sharepoint.com/:u:/g/personal/wuyangli2-c_my_cityu_edu_hk/EVNAjK2JkG9ChREzzqdqJkYBLoZ_VOqkMdhWasN_BETGWw?e=fP9Ae4) | - | [Imgs+Labels](https://portland-my.sharepoint.com/:u:/g/personal/wuyangli2-c_my_cityu_edu_hk/Edz2YcXHuStIqwM_NA7k8FMBGLeyAGQcSjdSR-vYaVx_vw?e=es6KDW) | [Imgs+Labels](https://portland-my.sharepoint.com/:u:/g/personal/wuyangli2-c_my_cityu_edu_hk/EeiO6O36QgZKnTcUZMInACIB0dfWEg4OFyoEZnZCkibKHA?e=6byqBX) | + +### Download DINO-pretrained ResNet-50 from this [link](https://portland-my.sharepoint.com/:u:/g/personal/wuyangli2-c_my_cityu_edu_hk/EVnK9IPi91ZPuNmwpeSWGHABqhSFQK52I7xGzroXKeuyzA?e=EnlwgO) + +## Setp 3: Change the Path + +### Change the data path as follows. + +``` +[DATASET_PATH] +└─ Cityscapes + └─ AOOD_Annotations + └─ AOOD_Main + └─ train_source.txt + └─ train_target.txt + └─ val_source.txt + └─ val_target.txt + └─ leftImg8bit + └─ train + └─ val + └─ leftImg8bit_foggy + └─ train + └─ val +└─ bdd_daytime + └─ Annotations + └─ ImageSets + └─ JPEGImages +└─ clipart + └─ Annotations + └─ ImageSets + └─ JPEGImages +└─ VOCdevkit + └─ VOC2007 + └─ VOC2012 +``` + +### Change the data root folder in config files + +Replace the DATASET.COCO_PATH in all yaml files in [config](configs) by your data root $DATASET_PATH, e.g., Line 22 of [soma_aood_city_to_foggy_r50.yaml](configs/soma_aood_city_to_foggy_r50.yaml) + +### Change the path of DINO-pretrained backbone + +Replace the backbone loading path at Line 107 of [backbone.py](models/backbone.py). + +# 🔥 Start Training + +We use two GPUs for training with 2 source images and 2 target images as input. + +```bash +GPUS_PER_NODE=2 +./tools/run_dist_launch.sh 2 python main.py --config_file {CONFIG_FILE} --opts DATASET.AOOD_SETTING 1 +``` + +We provide some scripts in our experiments in [run.sh](./run.sh). After "--opts", the settings will overwrite the default config file as the maskrcnn-benchmark framework. + +# 📦 Well-trained models + +Will be provided later + + + + +# 💬 Notification + +- The core idea is to select informative motifs (which can be trated as the mix-up of object queries) for self-training. +- You can try the DA version of [OW-DETR](https://github.com/akshitac8/OW-DETR) in this repository by setting: +``` +-opts AOOD.OW_DETR_ON True +``` +- Adopting SAM to address AOOD may be a good direction. +- To visualize unknown boxes, post-processing is needed in Line736 of [PostProcess](models/motif_detr.py). + +# 📝 Citation + +If you think this work is helpful for your project, please give it a star and citation. We sincerely appreciate your acknowledgment. + +```BibTeX +@InProceedings{li2023novel, + title={Novel Scenes & Classes: Towards Adaptive Open-set Object Detection}, + author={Li, Wuyang and Guo, Xiaoqing and Yuan, Yixuan}, + booktitle={ICCV}, + year={2023} +} +``` + +Relevant project: + +Exploring the similar issue for the classifictaion task. [[link]](https://openaccess.thecvf.com/content/CVPR2023/html/Li_Adjustment_and_Alignment_for_Unbiased_Open_Set_Domain_Adaptation_CVPR_2023_paper.html) + +```BibTeX +@InProceedings{Li_2023_CVPR, + author = {Li, Wuyang and Liu, Jie and Han, Bo and Yuan, Yixuan}, + title = {Adjustment and Alignment for Unbiased Open Set Domain Adaptation}, + booktitle = {CVPR}, + year = {2023}, +} +``` + +# 🤞 Acknowledgements + +We greatly appreciate the tremendous effort for the following works. + +- This work is based on DAOD framework [AQT](https://github.com/weii41392/AQT). +- Our work is highly inspired by [OW-DETR](https://github.com/akshitac8/OW-DETR) and [OpenDet](https://github.com/csuhan/opendet2). +- The implementation of the basic detector is based on [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR). + +# 📒 Abstract + +Domain Adaptive Object Detection (DAOD) transfers an object detector to a novel domain free of labels. However, in the real world, besides encountering novel scenes, novel domains always contain novel-class objects de facto, which are ignored in existing research. Thus, we formulate and study a more practical setting, Adaptive Open-set Object Detection (AOOD), considering both novel scenes and classes. Directly combing off-the-shelled cross-domain and open-set approaches is sub-optimal since their low-order dependence, such as the confidence score, is insufficient for the AOOD with two dimensions of novel information. To address this, we propose a novel Structured Motif Matching (SOMA) framework for AOOD, which models the high-order relation with motifs, \ie, statistically significant subgraphs, and formulates AOOD solution as motif matching to learn with high-order patterns. In a nutshell, SOMA consists of Structure-aware Novel-class Learning (SNL) and Structure-aware Transfer Learning (STL). As for SNL, we establish an instance-oriented graph to capture the class-independent object feature hidden in different base classes. Then, a high-order metric is proposed to match the most significant motif as high-order patterns, serving for motif-guided novel-class learning. In STL, we set up a semantic-oriented graph to model the class-dependent relation across domains, and match unlabelled objects with high-order motifs to align the cross-domain distribution with structural awareness. Extensive experiments demonstrate that the proposed SOMA achieves state-of-the-art performance. + +![image](./assets/overall.png) diff --git a/assets/mot.png b/assets/mot.png new file mode 100644 index 0000000..72a5962 Binary files /dev/null and b/assets/mot.png differ diff --git a/assets/overall.png b/assets/overall.png new file mode 100644 index 0000000..a0cd42b Binary files /dev/null and b/assets/overall.png differ diff --git a/assets/paper.pdf b/assets/paper.pdf new file mode 100644 index 0000000..1f1f4ba Binary files /dev/null and b/assets/paper.pdf differ diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 0000000..93702da --- /dev/null +++ b/benchmark.py @@ -0,0 +1,69 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +""" +Benchmark inference speed of Deformable DETR. +""" +import os +import time +import argparse + +import torch + +from main import get_args_parser as get_main_args_parser +from models import build_model +from datasets import build_dataset +from util.misc import nested_tensor_from_tensor_list + + +def get_benckmark_arg_parser(): + parser = argparse.ArgumentParser('Benchmark inference speed of Deformable DETR.') + parser.add_argument('--num_iters', type=int, default=300, help='total iters to benchmark speed') + parser.add_argument('--warm_iters', type=int, default=5, help='ignore first several iters that are very slow') + parser.add_argument('--batch_size', type=int, default=1, help='batch size in inference') + parser.add_argument('--resume', type=str, help='load the pre-trained checkpoint') + return parser + + +@torch.no_grad() +def measure_average_inference_time(model, inputs, num_iters=100, warm_iters=5): + ts = [] + for iter_ in range(num_iters): + torch.cuda.synchronize() + t_ = time.perf_counter() + model(inputs) + torch.cuda.synchronize() + t = time.perf_counter() - t_ + if iter_ >= warm_iters: + ts.append(t) + print(ts) + return sum(ts) / len(ts) + + +def benchmark(): + args, _ = get_benckmark_arg_parser().parse_known_args() + main_args = get_main_args_parser().parse_args(_) + assert args.warm_iters < args.num_iters and args.num_iters > 0 and args.warm_iters >= 0 + assert args.batch_size > 0 + assert args.resume is None or os.path.exists(args.resume) + dataset = build_dataset('val', main_args) + model, _, _ = build_model(main_args) + model.cuda() + model.eval() + if args.resume is not None: + ckpt = torch.load(args.resume, map_location=lambda storage, loc: storage) + model.load_state_dict(ckpt['model']) + inputs = nested_tensor_from_tensor_list([dataset.__getitem__(0)[0].cuda() for _ in range(args.batch_size)]) + t = measure_average_inference_time(model, inputs, args.num_iters, args.warm_iters) + return 1.0 / t * args.batch_size + + +if __name__ == '__main__': + fps = benchmark() + print(f'Inference Speed: {fps:.1f} FPS') + diff --git a/config.py b/config.py new file mode 100644 index 0000000..93f574c --- /dev/null +++ b/config.py @@ -0,0 +1,167 @@ +from yacs.config import CfgNode as CN +from numpy import pi + +_C = CN() + +# ------------------------------------------------------------------------ +# Training +# ------------------------------------------------------------------------ +_C.TRAIN = CN() +_C.TRAIN.LR = 2e-4 +_C.TRAIN.LR_BACKBONE_NAMES = ["backbone.0"] +_C.TRAIN.LR_BACKBONE = 2e-5 +_C.TRAIN.LR_LINEAR_PROJ_NAMES = ['reference_points', 'sampling_offsets'] +_C.TRAIN.LR_LINEAR_PROJ_MULT = 0.1 +_C.TRAIN.BATCH_SIZE = 2 +_C.TRAIN.WEIGHT_DECAY = 1e-4 +_C.TRAIN.EPOCHS = 50 +_C.TRAIN.LR_DROP = 40 +_C.TRAIN.LR_DROP_EPOCHS = None +_C.TRAIN.CLIP_MAX_NORM = 0.1 # gradient clipping max norm +_C.TRAIN.SGD = False # AdamW is used when setting this false + + +# ------------------------------------------------------------------------ +# Model +# ------------------------------------------------------------------------ +_C.MODEL = CN() + +# Variants of Deformable DETR +_C.MODEL.WITH_BOX_REFINE = False +_C.MODEL.TWO_STAGE = False + +# Model parameters +_C.MODEL.FROZEN_WEIGHTS = None # Path to the pretrained model. If set, only the mask head will be trained + +# * Backbone +_C.MODEL.BACKBONE = 'resnet50' # Name of the convolutional backbone to use +_C.MODEL.DILATION = False # If true, we replace stride with dilation in the last convolutional block (DC5) +_C.MODEL.POSITION_EMBEDDING = 'sine' # ('sine', 'learned') Type of positional embedding to use on top of the image features +_C.MODEL.POSITION_EMBEDDING_SCALE = 2 * pi # position / size * scale +_C.MODEL.NUM_FEATURE_LEVELS = 4 # number of feature levels + +# * Transformer +_C.MODEL.ENC_LAYERS = 6 # Number of encoding layers in the transformer +_C.MODEL.DEC_LAYERS = 6 # Number of decoding layers in the transformer +_C.MODEL.DIM_FEEDFORWARD = 1024 # Intermediate size of the feedforward layers in the transformer blocks +_C.MODEL.HIDDEN_DIM = 256 # Size of the embeddings (dimension of the transformer) +_C.MODEL.DROPOUT = 0.1 # Dropout applied in the transformer +_C.MODEL.NHEADS = 8 # Number of attention heads inside the transformer's attentions +_C.MODEL.NUM_QUERIES = 300 # Number of query slots +_C.MODEL.DEC_N_POINTS = 4 +_C.MODEL.ENC_N_POINTS = 4 + +# * Segmentation +_C.MODEL.MASKS = False # Train segmentation head if the flag is provided + +# * Domain Adaptation +_C.MODEL.BACKBONE_ALIGN = False +_C.MODEL.SPACE_ALIGN = False +_C.MODEL.CHANNEL_ALIGN = False +_C.MODEL.INSTANCE_ALIGN = False + +# ------------------------------------------------------------------------ +# Deformable DETR baseline Loss +# ------------------------------------------------------------------------ +_C.LOSS = CN() +_C.LOSS.AUX_LOSS = True # auxiliary decoding losses (loss at each layer) + +# * Matcher +_C.LOSS.SET_COST_CLASS = 2. # Class coefficient in the matching cost +_C.LOSS.SET_COST_BBOX = 5. # L1 box coefficient in the matching cost +_C.LOSS.SET_COST_GIOU = 2. # giou box coefficient in the matching cost + +# * Loss coefficients +_C.LOSS.MASK_LOSS_COEF = 1. +_C.LOSS.DICE_LOSS_COEF = 1. +_C.LOSS.CLS_LOSS_COEF = 2. +_C.LOSS.BBOX_LOSS_COEF = 5. +_C.LOSS.GIOU_LOSS_COEF = 2. + +_C.LOSS.SPACE_QUERY_LOSS_COEF = 0.1 +_C.LOSS.CHANNEL_QUERY_LOSS_COEF = 0.1 +_C.LOSS.INSTANCE_QUERY_LOSS_COEF = 0.1 +_C.LOSS.FOCAL_ALPHA = 0.25 +_C.LOSS.DA_GAMMA = 0 + + +# ------------------------------------------------------------------------ +# dataset parameters +# ------------------------------------------------------------------------ +_C.DATASET = CN() +_C.DATASET.DA_MODE = 'source_only' # ('source_only', 'uda', 'oracle') +_C.DATASET.NUM_CLASSES = 9 # This should be set as max_class_id + 1 +_C.DATASET.DATASET_FILE = 'cityscapes_to_foggy_cityscapes' +_C.DATASET.COCO_PATH = '../datasets' +_C.DATASET.COCO_PANOPTIC_PATH = None +_C.DATASET.REMOVE_DIFFICULT = False + + +# ------------------------------------------------------------------------ +# Distributed +# ------------------------------------------------------------------------ +_C.DIST = CN() +_C.DIST.DISTRIBUTED = False +_C.DIST.RANK = None +_C.DIST.WORLD_SIZE = None +_C.DIST.GPU = None +_C.DIST.DIST_URL = None +_C.DIST.DIST_BACKEND = None + +# ------------------------------------------------------------------------ +# Miscellaneous +# ------------------------------------------------------------------------ +_C.OUTPUT_DIR = '' # path where to save, empty for no saving +_C.DEVICE = 'cuda' # device to use for training / testing +_C.RESUME = '' # resume from checkpoint +_C.START_EPOCH = 0 # start epoch +_C.EVAL = False +_C.NUM_WORKERS = 2 +_C.CACHE_MODE = False # whether to cache images on memory +_C.SEED = 42 # Note this this cannot strictly control the same results. I don not know why + +# ------------------------------------------------------------------------ +# Adaptive Open-set Object Detection (AOOD) +# ------------------------------------------------------------------------ +_C.AOOD = CN() +_C.AOOD.OPEN_SET = CN() +_C.AOOD.CROSS_DOMAIN = CN() + +_C.AOOD.OW_DETR_ON = False +_C.AOOD.MOTIF_ON = False +_C.AOOD.OPENDET_DETR_ON = False +_C.DATASET.AOOD_SETTING = 1 +_C.DATASET.AOOD_TASK = 4 +_C.DATASET.AOOD_SCENE = 'cityscapes' +# _C.DATASET.AOOD_SCENE = 'pascal' + +# global alignment + def-detr baseline +_C.AOOD.CROSS_DOMAIN.BACKBONE_LAMBDA = 1.0 +_C.LOSS.BACKBONE_LOSS_COEF = 0.1 +_C.LOAD_OPTIMIZER = True +_C.EVAL_EPOCH = 29 + +# For novel-class +_C.AOOD.OPEN_SET.MOTIF_ON = False +_C.AOOD.OPEN_SET.KNN = 5 +_C.AOOD.OPEN_SET.TH = 0.5 +_C.AOOD.OPEN_SET.MOTIF_LOSS_COEF = 1.0 +_C.AOOD.OPEN_SET.WITH_SELF_LABELING = False +_C.AOOD.OPEN_SET.UNK_PROB = 0.0 +_C.AOOD.OPEN_SET.WARM_UP = -1 # -1 indicates no warm-up +_C.AOOD.OPEN_SET.MOTIF_UPDATE = True +_C.AOOD.OPEN_SET.ALPHA = 0.01 + +# For novel-scene +_C.AOOD.CROSS_DOMAIN.WARM_UP = -1 +_C.AOOD.CROSS_DOMAIN.MOTIF_ON = False +_C.AOOD.CROSS_DOMAIN.MOTIF_LOSS_COEF = 0.01 +_C.AOOD.CROSS_DOMAIN.KNN = 5 +_C.AOOD.CROSS_DOMAIN.BETA = 1.0 + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() diff --git a/configs/soma_aood_city_to_bdd100k_r50.yaml b/configs/soma_aood_city_to_bdd100k_r50.yaml new file mode 100644 index 0000000..e9ee55a --- /dev/null +++ b/configs/soma_aood_city_to_bdd100k_r50.yaml @@ -0,0 +1,93 @@ +CACHE_MODE: False +AOOD: + MOTIF_ON: True + # novel-class + OPEN_SET: + WARM_UP: 8 + TH: 0.5 + MOTIF_ON: True + ALPHA: 0.01 + KNN: 5 + UNK_PROB: 0.5 + MOTIF_LOSS_COEF: 0.05 + # novel-scene + CROSS_DOMAIN: + MOTIF_ON: True + BETA: 1.0 # std scaling +DATASET: + DA_MODE: aood + AOOD_SETTING: 1 # different class splittings 1-4 + AOOD_SCENE: 'cityscapes' + COCO_PANOPTIC_PATH: None + COCO_PATH: /home/wuyangli2/data/ + DATASET_FILE: cityscapes_to_bdd_daytime + NUM_CLASSES: 4 # 3 known + 1 + REMOVE_DIFFICULT: False +DEVICE: cuda +DIST: + DISTRIBUTED: False + DIST_BACKEND: nccl + DIST_URL: env:// + GPU: 0 + RANK: 0 + WORLD_SIZE: 4 +EVAL: False +LOSS: + AUX_LOSS: True + BACKBONE_LOSS_COEF: 0.1 + BBOX_LOSS_COEF: 5.0 + CHANNEL_QUERY_LOSS_COEF: 0.1 + CLS_LOSS_COEF: 2.0 + DA_GAMMA: 0 + DICE_LOSS_COEF: 1.0 + FOCAL_ALPHA: 0.25 + GIOU_LOSS_COEF: 2.0 + INSTANCE_QUERY_LOSS_COEF: 0.1 + MASK_LOSS_COEF: 1.0 + SET_COST_BBOX: 5.0 + SET_COST_CLASS: 2.0 + SET_COST_GIOU: 2.0 + SPACE_QUERY_LOSS_COEF: 0.1 +MODEL: + BACKBONE: resnet50 + BACKBONE_ALIGN: True + CHANNEL_ALIGN: False + DEC_LAYERS: 3 + DEC_N_POINTS: 4 + DILATION: False + DIM_FEEDFORWARD: 1024 + DROPOUT: 0.1 + ENC_LAYERS: 3 + ENC_N_POINTS: 4 + FROZEN_WEIGHTS: None + HIDDEN_DIM: 256 + INSTANCE_ALIGN: False + MASKS: False + NHEADS: 8 + NUM_FEATURE_LEVELS: 4 + NUM_QUERIES: 100 + POSITION_EMBEDDING: sine + POSITION_EMBEDDING_SCALE: 6.283185307179586 + SPACE_ALIGN: False + TWO_STAGE: False + WITH_BOX_REFINE: False +NUM_WORKERS: 2 +OUTPUT_DIR: exps/r50_aood_c2b +RESUME: +LOAD_OPTIMIZER: False +SEED: 42 # This cannot strictly control the same results. Sorry, I do not know why. +START_EPOCH: 0 +EVAL_EPOCH: 9 +TRAIN: + BATCH_SIZE: 4 # each gpu, 2 for source and 2 for target + CLIP_MAX_NORM: 0.1 + EPOCHS: 20 + LR: 0.0002 + LR_BACKBONE: 2e-05 + LR_BACKBONE_NAMES: ['backbone.0'] + LR_DROP: 15 + LR_DROP_EPOCHS: None + LR_LINEAR_PROJ_MULT: 0.1 + LR_LINEAR_PROJ_NAMES: ['reference_points', 'sampling_offsets'] + SGD: False + WEIGHT_DECAY: 0.0001 \ No newline at end of file diff --git a/configs/soma_aood_city_to_foggy_r50.yaml b/configs/soma_aood_city_to_foggy_r50.yaml new file mode 100644 index 0000000..4af00ff --- /dev/null +++ b/configs/soma_aood_city_to_foggy_r50.yaml @@ -0,0 +1,93 @@ +CACHE_MODE: False +AOOD: + MOTIF_ON: True + # novel-class + OPEN_SET: + WARM_UP: 9 + TH: 0.5 + MOTIF_ON: True + ALPHA: 0.01 + KNN: 5 + UNK_PROB: 0.5 + MOTIF_LOSS_COEF: 0.1 + # novel-scene + CROSS_DOMAIN: + MOTIF_ON: True + BETA: 1.0 # std scaling +DATASET: + DA_MODE: aood + AOOD_SETTING: 1 # different class splittings 1-4 + AOOD_SCENE: 'cityscapes' + COCO_PANOPTIC_PATH: None + COCO_PATH: /home/wuyangli2/data/ + DATASET_FILE: cityscapes_to_foggy_cityscapes + NUM_CLASSES: 4 # 3 known + 1 + REMOVE_DIFFICULT: False +DEVICE: cuda +DIST: + DISTRIBUTED: False + DIST_BACKEND: nccl + DIST_URL: env:// + GPU: 0 + RANK: 0 + WORLD_SIZE: 4 +EVAL: False +LOSS: + AUX_LOSS: True + BACKBONE_LOSS_COEF: 0.1 + BBOX_LOSS_COEF: 5.0 + CHANNEL_QUERY_LOSS_COEF: 0.1 + CLS_LOSS_COEF: 2.0 + DA_GAMMA: 0 + DICE_LOSS_COEF: 1.0 + FOCAL_ALPHA: 0.25 + GIOU_LOSS_COEF: 2.0 + INSTANCE_QUERY_LOSS_COEF: 0.1 + MASK_LOSS_COEF: 1.0 + SET_COST_BBOX: 5.0 + SET_COST_CLASS: 2.0 + SET_COST_GIOU: 2.0 + SPACE_QUERY_LOSS_COEF: 0.1 +MODEL: + BACKBONE: resnet50 + BACKBONE_ALIGN: True + CHANNEL_ALIGN: False + DEC_LAYERS: 3 + DEC_N_POINTS: 4 + DILATION: False + DIM_FEEDFORWARD: 1024 + DROPOUT: 0.1 + ENC_LAYERS: 3 + ENC_N_POINTS: 4 + FROZEN_WEIGHTS: None + HIDDEN_DIM: 256 + INSTANCE_ALIGN: False + MASKS: False + NHEADS: 8 + NUM_FEATURE_LEVELS: 4 + NUM_QUERIES: 100 + POSITION_EMBEDDING: sine + POSITION_EMBEDDING_SCALE: 6.283185307179586 + SPACE_ALIGN: False + TWO_STAGE: False + WITH_BOX_REFINE: False +NUM_WORKERS: 2 +OUTPUT_DIR: exps/r50_aood_c2f +RESUME: +LOAD_OPTIMIZER: False +SEED: 42 # This cannot strictly control the same results. Sorry, I do not know why. +START_EPOCH: 0 +EVAL_EPOCH: 29 +TRAIN: + BATCH_SIZE: 4 # each gpu, 2 for source and 2 for target + CLIP_MAX_NORM: 0.1 + EPOCHS: 65 + LR: 0.0002 + LR_BACKBONE: 2e-05 + LR_BACKBONE_NAMES: ['backbone.0'] + LR_DROP: 40 + LR_DROP_EPOCHS: None + LR_LINEAR_PROJ_MULT: 0.1 + LR_LINEAR_PROJ_NAMES: ['reference_points', 'sampling_offsets'] + SGD: False + WEIGHT_DECAY: 0.0001 \ No newline at end of file diff --git a/configs/soma_aood_pascal_to_clipart_r50.yaml b/configs/soma_aood_pascal_to_clipart_r50.yaml new file mode 100644 index 0000000..7599bce --- /dev/null +++ b/configs/soma_aood_pascal_to_clipart_r50.yaml @@ -0,0 +1,93 @@ +EVAL_EPOCH: 1 +CACHE_MODE: False +AOOD: + MOTIF_ON: True + # novel-class + OPEN_SET: + WARM_UP: 9 + TH: 1.0 + MOTIF_ON: True + ALPHA: 0.01 + KNN: 5 + UNK_PROB: 0.5 + MOTIF_LOSS_COEF: 0.01 + # novel-scene + CROSS_DOMAIN: + MOTIF_ON: True + BETA: 2.0 # std scaling +DATASET: + DA_MODE: aood + AOOD_SETTING: 1 + AOOD_SCENE: 'pascal' + COCO_PANOPTIC_PATH: None + COCO_PATH: /home/wuyangli2/data/ + DATASET_FILE: pascal_to_clipart + NUM_CLASSES: 11 # 10 known + 1 + REMOVE_DIFFICULT: False +DEVICE: cuda +DIST: + DISTRIBUTED: False + DIST_BACKEND: nccl + DIST_URL: env:// + GPU: 0 + RANK: 0 + WORLD_SIZE: 4 +EVAL: False +LOSS: + AUX_LOSS: True + BACKBONE_LOSS_COEF: 0.1 + BBOX_LOSS_COEF: 5.0 + CHANNEL_QUERY_LOSS_COEF: 0.1 + CLS_LOSS_COEF: 2.0 + DA_GAMMA: 0 + DICE_LOSS_COEF: 1.0 + FOCAL_ALPHA: 0.25 + GIOU_LOSS_COEF: 2.0 + INSTANCE_QUERY_LOSS_COEF: 0.1 + MASK_LOSS_COEF: 1.0 + SET_COST_BBOX: 5.0 + SET_COST_CLASS: 2.0 + SET_COST_GIOU: 2.0 + SPACE_QUERY_LOSS_COEF: 0.1 +MODEL: + BACKBONE: resnet50 + BACKBONE_ALIGN: True + CHANNEL_ALIGN: False + DEC_LAYERS: 3 + DEC_N_POINTS: 4 + DILATION: False + DIM_FEEDFORWARD: 1024 + DROPOUT: 0.1 + ENC_LAYERS: 3 + ENC_N_POINTS: 4 + FROZEN_WEIGHTS: None + HIDDEN_DIM: 256 + INSTANCE_ALIGN: False + MASKS: False + NHEADS: 8 + NUM_FEATURE_LEVELS: 4 + NUM_QUERIES: 100 + POSITION_EMBEDDING: sine + POSITION_EMBEDDING_SCALE: 6.283185307179586 + SPACE_ALIGN: False + TWO_STAGE: False + WITH_BOX_REFINE: False +NUM_WORKERS: 2 +OUTPUT_DIR: exps/r50_aood_p2c +RESUME: +LOAD_OPTIMIZER: False +SEED: 42 # This cannot strictly control the same results. Sorry, I do not know why. +START_EPOCH: 0 +TRAIN: + BATCH_SIZE: 4 # each gpu, 2 for source and 2 for target + CLIP_MAX_NORM: 0.1 + EPOCHS: 30 + LR: 0.0002 + LR_BACKBONE: 2e-05 + LR_BACKBONE_NAMES: ['backbone.0'] + LR_DROP: 25 + LR_DROP_EPOCHS: None + LR_LINEAR_PROJ_MULT: 0.1 + LR_LINEAR_PROJ_NAMES: ['reference_points', 'sampling_offsets'] + SGD: False + WEIGHT_DECAY: 0.0001 \ No newline at end of file diff --git a/datasets/DAOD.py b/datasets/DAOD.py new file mode 100644 index 0000000..337b1c4 --- /dev/null +++ b/datasets/DAOD.py @@ -0,0 +1,226 @@ +# ------------------------------------------------------------------------ +# Novel Scenes & Classes: Towards Adaptive Open-set Object Detection +# Modified by Wuyang Li +# ---------------------------------------------- +# Created by Wei-Jie Huang +# ---------------------------------------------- +from pathlib import Path +from torch.utils.data import Dataset +from datasets.coco import CocoDetection, make_coco_transforms +from datasets.aood import AOODDetection +from util.misc import get_local_rank, get_local_size, nested_tensor_from_tensor_list + +def get_paths(root): + root = Path(root) + return { + 'cityscapes': { + 'train_img': root / 'Cityscapes/leftImg8bit/train', + 'val_img': root / 'Cityscapes/leftImg8bit/val', + 'train_anno': root / 'Cityscapes/cocoAnnotations/cityscapes_train_cocostyle.json', + 'val_img': root / 'Cityscapes/leftImg8bit/val', + 'val_anno': root / 'Cityscapes/cocoAnnotations/cityscapes_foggy_val_cocostyle.json', + + 'train_xml': root / 'Cityscapes/AOOD_Annotations', + 'val_xml': root / 'Cityscapes/AOOD_Annotations', + 'train_data_list': root / 'Cityscapes/AOOD_Main/train_source.txt', + 'val_data_list': root / 'Cityscapes/AOOD_Main/val_source.txt', + }, + 'cityscapes_caronly': { + 'train_img': root / 'Cityscapes/leftImg8bit/train', + 'train_anno': root / 'Cityscapes/annotations/cityscapes_caronly_train.json', + 'val_img': root / 'Cityscapes/leftImg8bit/val', + 'val_anno': root / 'Cityscapes/annotations/cityscapes_caronly_val.json', + }, + 'foggy_cityscapes': { + 'train_img': root / 'Cityscapes/leftImg8bit_foggy/train', + 'train_anno': root / 'Cityscapes/cocoAnnotations/cityscapes_foggy_train_cocostyle.json', + 'val_img': root / 'Cityscapes/leftImg8bit_foggy/val', + 'val_anno': root / 'Cityscapes/cocoAnnotations/cityscapes_foggy_val_cocostyle.json', + + 'train_xml': root / 'Cityscapes/AOOD_Annotations', + 'train_data_list': root / 'Cityscapes/AOOD_Main/train_target.txt', + + 'val_xml': root / 'Cityscapes/AOOD_Annotations', + 'val_data_list': root / 'Cityscapes/AOOD_Main/val_target.txt', + + }, + 'sim10k': { + 'train_img': root / 'sim10k/VOC2012/JPEGImages', + 'train_anno': root / 'sim10k/annotations/sim10k_caronly.json', + }, + 'bdd_daytime': { + 'train_img': root / 'bdd_daytime/JPEGImages', + 'val_img': root / 'bdd_daytime/JPEGImages', + 'train_xml': root / 'bdd_daytime/Annotations', + 'train_data_list': root / 'bdd_daytime/ImageSets/Main/train.txt', + 'val_xml': root / 'bdd_daytime/Annotations', + 'val_data_list': root / 'bdd_daytime/ImageSets/Main/val.txt', + + }, + 'pascal': { + 'train_img': root / 'VOCdevkit/VOC2012/JPEGImages', + 'train_xml': root / 'VOCdevkit/VOC2012/Annotations', + 'train_data_list': root / 'VOCdevkit/VOC2012/ImageSets/Main/trainval.txt', + 'val_img': root / 'VOCdevkit/VOC2012/JPEGImages', + 'val_xml': root / 'VOCdevkit/VOC2012/Annotations', + 'val_data_list': root / 'VOCdevkit/VOC2012/ImageSets/Main/trainval.txt', + }, + 'clipart': { + 'train_img': root / 'clipart/JPEGImages', + 'train_xml': root / 'clipart/Annotations', + 'train_data_list': root / 'clipart/ImageSets/Main/all.txt', + 'val_img': root / 'clipart/JPEGImages', + 'val_xml': root / 'clipart/Annotations', + 'val_data_list': root / 'clipart/ImageSets/Main/all.txt', + }, + } + +class AOODDataset(Dataset): + def __init__(self, source_img_folder, source_ann_folder, source_data_list, target_img_folder, target_ann_folder, target_data_list, + transforms, setting, scene): + self.source = AOODDetection( + img_folder=source_img_folder, + ann_folder=source_ann_folder, + data_list = source_data_list, + remove_unk = True, + transforms=transforms, + setting=setting, + scene = scene[0], + ) + + self.target = AOODDetection( + img_folder=target_img_folder, + ann_folder=target_ann_folder, + data_list=target_data_list, + transforms=transforms, + remove_unk=False, + setting=setting, + scene = scene[1], + + ) + + def __len__(self): + return max(len(self.source), len(self.target)) + # return min(len(self.source), len(self.target)) + + def __getitem__(self, idx): + source_img, source_target = self.source[idx % len(self.source)] + target_img, _ = self.target[idx % len(self.target)] + return source_img, target_img, source_target + +class DADataset(Dataset): + def __init__(self, source_img_folder, source_ann_file, target_img_folder, target_ann_file, + transforms, return_masks, cache_mode=False, local_rank=0, local_size=1): + self.source = CocoDetection( + img_folder=source_img_folder, + ann_file=source_ann_file, + transforms=transforms, + return_masks=return_masks, + cache_mode=cache_mode, + local_rank=local_rank, + local_size=local_size + ) + + self.target = CocoDetection( + img_folder=target_img_folder, + ann_file=target_ann_file, + transforms=transforms, + return_masks=return_masks, + cache_mode=cache_mode, + local_rank=local_rank, + local_size=local_size + ) + + def __len__(self): + return max(len(self.source), len(self.target)) + + def __getitem__(self, idx): + source_img, source_target = self.source[idx % len(self.source)] + target_img, _ = self.target[idx % len(self.target)] + return source_img, target_img, source_target + + +def collate_fn(batch): + source_imgs, target_imgs, source_targets = list(zip(*batch)) + samples = nested_tensor_from_tensor_list(source_imgs + target_imgs) + return samples, source_targets + + +def build(image_set, cfg, multi_task_eval_id=4): + paths = get_paths(cfg.DATASET.COCO_PATH) + source_domain, target_domain = cfg.DATASET.DATASET_FILE.split('_to_') + if image_set == 'val': + if cfg.DATASET.DA_MODE == 'aood': + return AOODDetection( + img_folder=paths[target_domain]['val_img'], + ann_folder=paths[target_domain]['val_xml'], + data_list=paths[target_domain]['val_data_list'], + transforms=make_coco_transforms(image_set), + remove_unk=False, + setting= cfg.DATASET.AOOD_SETTING, + scene = target_domain, + multi_task_eval_id = multi_task_eval_id, + is_eval =True, + + ) + else: + return CocoDetection( + img_folder=paths[target_domain]['val_img'], + ann_file=paths[target_domain]['val_anno'], + transforms=make_coco_transforms(image_set), + return_masks=cfg.MODEL.MASKS, + cache_mode=cfg.CACHE_MODE, + local_rank=get_local_rank(), + local_size=get_local_size() + ) + elif image_set == 'train': + if cfg.DATASET.DA_MODE == 'source_only': + return CocoDetection( + img_folder=paths[source_domain]['train_img'], + ann_file=paths[source_domain]['train_anno'], + transforms=make_coco_transforms(image_set), + return_masks=cfg.MODEL.MASKS, + cache_mode=cfg.CACHE_MODE, + local_rank=get_local_rank(), + local_size=get_local_size(), + ) + elif cfg.DATASET.DA_MODE == 'oracle': + return CocoDetection( + img_folder=paths[target_domain]['train_img'], + ann_file=paths[target_domain]['train_anno'], + transforms=make_coco_transforms(image_set), + return_masks=cfg.MODEL.MASKS, + cache_mode=cfg.CACHE_MODE, + local_rank=get_local_rank(), + local_size=get_local_size() + ) + elif cfg.DATASET.DA_MODE == 'uda': + return DADataset( + source_img_folder=paths[source_domain]['train_img'], + source_ann_file=paths[source_domain]['train_anno'], + target_img_folder=paths[target_domain]['train_img'], + target_ann_file=paths[target_domain]['train_anno'], + transforms=make_coco_transforms(image_set), + return_masks=cfg.MODEL.MASKS, + cache_mode=cfg.CACHE_MODE, + local_rank=get_local_rank(), + local_size=get_local_size() + ) + + elif cfg.DATASET.DA_MODE == 'aood': + return AOODDataset( + source_img_folder=paths[source_domain]['train_img'], + source_ann_folder=paths[source_domain]['train_xml'], + source_data_list=paths[source_domain]['train_data_list'], + + target_img_folder=paths[target_domain]['train_img'], + target_ann_folder=paths[target_domain]['train_xml'], + target_data_list=paths[target_domain]['train_data_list'], + + transforms=make_coco_transforms(image_set), + setting=cfg.DATASET.AOOD_SETTING, + scene = [source_domain, target_domain] + ) + else: + raise ValueError(f'Unknown argument cfg.DATASET.DA_MODE {cfg.DATASET.DA_MODE}') + raise ValueError(f'unknown image set {image_set}') diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..1b4787b --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1,46 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import torch.utils.data +from .torchvision_datasets import CocoDetection + +from .coco import build as build_coco + + +def get_coco_api_from_dataset(dataset): + for _ in range(10): + # if isinstance(dataset, torchvision.datasets.CocoDetection): + # break + if isinstance(dataset, torch.utils.data.Subset): + dataset = dataset.dataset + if isinstance(dataset, CocoDetection): + return dataset.coco + else: + return dataset + + +def build_dataset(image_set, cfg,multi_task_eval_id=4): + if cfg.DATASET.DATASET_FILE == 'coco': + return build_coco(image_set, cfg) + if cfg.DATASET.DATASET_FILE == 'coco_panoptic': + # to avoid making panopticapi required for coco + from .coco_panoptic import build as build_coco_panoptic + return build_coco_panoptic(image_set, cfg) + DAOD_dataset = [ + 'cityscapes_to_foggy_cityscapes', + 'sim10k_to_cityscapes_caronly', + 'cityscapes_to_bdd_daytime', + 'pascal_to_clipart', + ] + if cfg.DATASET.DATASET_FILE in DAOD_dataset: + from .DAOD import build + return build(image_set, cfg, multi_task_eval_id=multi_task_eval_id) + raise ValueError(f'dataset {cfg.DATASET.DATASET_FILE} not supported') diff --git a/datasets/aood.py b/datasets/aood.py new file mode 100644 index 0000000..5a8f707 --- /dev/null +++ b/datasets/aood.py @@ -0,0 +1,339 @@ +# ------------------------------------------------------------------------ +# Novel Scenes & Classes: Towards Adaptive Open-set Object Detection +# Modified by Wuyang Li +# ------------------------------------------------------------------------ +import functools +import torch +import os +import tarfile +import collections +import logging +import copy +from torchvision.datasets import VisionDataset +import itertools +import util.misc as utils +import xml.etree.ElementTree as ET +from PIL import Image +import datasets.transforms as T + +class AOODDetection(VisionDataset): + + def get_aood_settings_cityscapes(self, setting=2): + + NMES = ['person', 'car', 'train', 'rider', 'truck', 'motorcycle', 'bicycle', 'bus'] + UNK = ["unknown"] + + if setting == 1: # different semantics + BASE_CLASSES = ['car', 'truck', 'bus'] + NOVEL_CLASSES = ['person','motorcycle','train', 'bicycle' , 'rider'] + elif setting == 2: # similar semantics + BASE_CLASSES = ['person', 'bicycle', 'bus'] + NOVEL_CLASSES = ['car', 'truck','train', 'motorcycle', 'rider' ] + elif setting == 3: # frequency down + BASE_CLASSES = ['person', 'car', 'rider'] + NOVEL_CLASSES = [ 'bicycle', 'train', 'truck', 'motorcycle', 'bus'] + elif setting == 4: # frequency top + BASE_CLASSES = [ 'motorcycle', 'truck', 'bus'] + NOVEL_CLASSES = ['person', 'train', 'car','bicycle', 'rider'] + + ALL_CLASSES= tuple(itertools.chain(BASE_CLASSES, NOVEL_CLASSES)) + CLASS_NAMES= tuple(itertools.chain(BASE_CLASSES, UNK)) + + return BASE_CLASSES, NOVEL_CLASSES, ALL_CLASSES, CLASS_NAMES + + + def get_aood_settings_pascal_voc(self, setting=1): + PASCAL_CLASSES = [ + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor"] + UNK = ["unknown"] + + # if setting == 1: + # BASE_CLASSES = ALL_CLASSES[:19] + # NOVEL_CLASSES = ALL_CLASSES[19:] + # elif setting == 2: + # BASE_CLASSES = ALL_CLASSES[:15] + # NOVEL_CLASSES = ALL_CLASSES[15:] + # elif setting == 3: + # BASE_CLASSES = ALL_CLASSES[:10] + # NOVEL_CLASSES = ALL_CLASSES[10:] + # elif setting == 4: + # BASE_CLASSES = ALL_CLASSES[:5] + # NOVEL_CLASSES = ALL_CLASSES[5:] + + BASE_CLASSES = PASCAL_CLASSES[:10] + NOVEL_CLASSES = PASCAL_CLASSES[10:] + + ALL_CLASSES= tuple(itertools.chain(PASCAL_CLASSES)) + BASE_CLASSES= tuple(itertools.chain(BASE_CLASSES)) + NOVEL_CLASSES= tuple(itertools.chain(NOVEL_CLASSES)) + + CLASS_NAMES = tuple(itertools.chain(BASE_CLASSES, UNK)) + return BASE_CLASSES, NOVEL_CLASSES, ALL_CLASSES, CLASS_NAMES + + + def __init__(self, + # args, + img_folder, + ann_folder, + data_list, + transforms=None, + remove_unk=False, + setting=1, + scene='pascal', + multi_task_eval_id = 4, + is_eval=False + ): + super(AOODDetection, self).__init__(img_folder) + + self.images = [] + self.annotations = [] + self.imgids = [] + self.imgid2annotations = {} + self.image_set = [] + self.transforms = transforms + self.remove_unk = remove_unk + self.is_eval = is_eval + + self.id_to_task_nme = { + 1: 'het-sem', + 2: 'hom-sem', + 3: 'freq-dec', + 4: 'freq-inc' + } + + self.scene = scene + if self.scene == 'cityscapes' or self.scene == 'bdd_daytime' or self.scene == 'foggy_cityscapes': + self.BASE_CLASSES, self.NOVEL_CLASSES, self.ALL_CLASSES, self.CLASS_NAMES = self.get_aood_settings_cityscapes(setting) + elif self.scene == 'pascal' or self.scene == 'clipart': + self.BASE_CLASSES, self.NOVEL_CLASSES, self.ALL_CLASSES, self.CLASS_NAMES = self.get_aood_settings_pascal_voc(setting) + else: + raise KeyError('undefined aood scenes') + + self.num_classes = len(self.CLASS_NAMES) # K+1 for model training + self.num_base = len(self.BASE_CLASSES) # K + self.unk_id = self.num_classes - 1 + + self.NOVEL_CLASSES_PER_TASK = self.NOVEL_CLASSES[:multi_task_eval_id] + + num_novel_per_task = len(self.NOVEL_CLASSES_PER_TASK ) + + all_classes_id = range(len(self.ALL_CLASSES)) + + self.bdd2city={ + 'bike':'bicycle', + 'motor': 'motorcycle', + } + + self.base_id = all_classes_id[:self.num_base] # 0-k + self.novel_id = all_classes_id[self.num_base:] # k- all + self.per_task_novel_id = all_classes_id[self.num_base:self.num_base + num_novel_per_task] # k- sel + + with open(data_list, "r") as f: + file_names = [x.strip() for x in f.readlines()] + + if self.remove_unk: # remove images without base-class objects + if utils.is_main_process(): + print(''.join(80*['='])) + print('source domain training set:') + if self.scene != 'pascal': + print('AOOD Task: {}'.format(self.id_to_task_nme[setting])) + print("BASE_CLASSES: {}".format(self.BASE_CLASSES)) + print("REALLOCATED CLASSES: {}".format(self.CLASS_NAMES)) + file_names = self.filter_imgs_without_base_objects(ann_folder, file_names) + elif is_eval: # inference: remove images without base-class objects + if utils.is_main_process(): + print(''.join(80*['='])) + print('target domain test set (task {}):'.format(multi_task_eval_id)) + print("BASE_CLASSES: {}".format(self.BASE_CLASSES)) + print("NOVEL_CLASSES: {}".format(self.NOVEL_CLASSES_PER_TASK)) + file_names = self.filter_imgs_without_base_novel_objects(ann_folder, file_names) + else: # target domain training set: preserve all images + if utils.is_main_process(): + print(''.join(80*['='])) + print('target domain training set:') + print('num images: {}'.format(len(file_names))) + print("BASE_CLASSES: {}".format(self.BASE_CLASSES)) + print("NOVEL_CLASSES: {}".format(self.NOVEL_CLASSES)) + + self.image_set.extend(file_names) + + suffix = ".png" if self.scene == 'cityscapes' or self.scene == 'foggy_cityscapes' else '.jpg' + self.images.extend([os.path.join(img_folder, x + suffix) for x in file_names]) + self.annotations.extend([os.path.join(ann_folder, x + ".xml") for x in file_names]) + + self.imgids = list(range(len(file_names))) + self.imgids2img = dict(zip(self.imgids, file_names)) + self.imgid2annotations.update(dict(zip(self.imgids, self.annotations))) + + assert (len(self.images) == len(self.annotations) == len(self.imgids)) + + def filter_imgs_without_base_novel_objects(self, ann_folder, file_names): + new_file_names = [] + for x in file_names: + anno = os.path.join(ann_folder, x + ".xml") + tree = ET.parse(anno) + target = self.parse_voc_xml(tree.getroot()) + flag=True + for obj in target['annotation']['object']: + cls = obj["name"] + # if cls in self.bdd2city.keys(): + if cls in self.bdd2city.keys() and self.scene == 'bdd_daytime': + cls = self.bdd2city[cls] + if cls not in self.BASE_CLASSES and cls not in self.NOVEL_CLASSES_PER_TASK: + flag=False + break + if flag: + new_file_names.append(x) + + print('original images: {}, after removing images without base and novel objects: {}.'.format(len(file_names), len(new_file_names))) + return new_file_names + + def filter_imgs_without_base_objects(self, ann_folder, file_names): + new_file_names = [] + for x in file_names: + anno = os.path.join(ann_folder, x + ".xml") + tree = ET.parse(anno) + target = self.parse_voc_xml(tree.getroot()) + + for obj in target['annotation']['object']: + cls = obj["name"] + if cls in self.bdd2city.keys(): + cls = self.bdd2city[cls] + + if cls in self.BASE_CLASSES: + new_file_names.append(x) + break + print('original images: {}, after removing images without base objects: {}.'.format(len(file_names), len(new_file_names))) + return new_file_names + + @functools.lru_cache(maxsize=None) + def load_instances(self, img_id): + tree = ET.parse(self.imgid2annotations[img_id]) + target = self.parse_voc_xml(tree.getroot()) + instances = [] + for obj in target['annotation']['object']: + cls = obj["name"] + + if cls in self.bdd2city.keys(): + cls = self.bdd2city[cls] + bbox = obj["bndbox"] + bbox = [float(bbox[x]) for x in ["xmin", "ymin", "xmax", "ymax"]] + bbox[0] -= 1.0 + bbox[1] -= 1.0 + instance = dict( + category_id=self.ALL_CLASSES.index(cls), + bbox=bbox, + area=(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]), + image_id=img_id + ) + instances.append(instance) + return target, instances + + def remove_novel_instances(self, target): + # for the labelled training + entry = copy.copy(target) + for annotation in copy.copy(entry): + if annotation["category_id"] not in self.base_id: + entry.remove(annotation) + return entry + + def label_all_novel_instances_as_unk(self, target): + # for the unlabelled training + entry = copy.copy(target) + for annotation in copy.copy(entry): + # for annotation in entry: + if annotation["category_id"] not in self.base_id: + annotation["category_id"] = self.unk_id + + return entry + + def label_per_task_novel_instances_as_unk(self, target): + # for the unlabelled training + entry = copy.copy(target) + for annotation in copy.copy(entry): + # for annotation in entry: + if annotation["category_id"] in self.base_id: + continue + elif annotation["category_id"] in self.per_task_novel_id: + annotation["category_id"] = self.unk_id + else: + entry.remove(annotation) + return entry + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is a dictionary of the XML tree. + """ + + img = Image.open(self.images[index]).convert('RGB') + target, instances = self.load_instances(self.imgids[index]) + + if self.remove_unk: + instances = self.remove_novel_instances(instances) + elif self.is_eval: + instances = self.label_per_task_novel_instances_as_unk(instances) + else: + instances = self.label_all_novel_instances_as_unk(instances) + + w, h = map(target['annotation']['size'].get, ['width', 'height']) + target = dict( + image_id=torch.tensor([self.imgids[index]], dtype=torch.int64), + labels=torch.tensor([i['category_id'] for i in instances], dtype=torch.int64), + area=torch.tensor([i['area'] for i in instances], dtype=torch.float32), + boxes=torch.as_tensor([i['bbox'] for i in instances], dtype=torch.float32), + orig_size=torch.as_tensor([int(h), int(w)]), + size=torch.as_tensor([int(h), int(w)]), + iscrowd=torch.zeros(len(instances), dtype=torch.uint8) + ) + + if self.transforms is not None: + img, target = self.transforms(img, target) + return img, target + + def __len__(self): + return len(self.images) + + def parse_voc_xml(self, node): + voc_dict = {} + children = list(node) + if children: + def_dic = collections.defaultdict(list) + for dc in map(self.parse_voc_xml, children): + for ind, v in dc.items(): + def_dic[ind].append(v) + if node.tag == 'annotation': + def_dic['object'] = [def_dic['object']] + voc_dict = { + node.tag: + {ind: v[0] if len(v) == 1 else v + for ind, v in def_dic.items()} + } + if node.text: + text = node.text.strip() + if not children: + voc_dict[node.tag] = text + return voc_dict + diff --git a/datasets/aood_eval.py b/datasets/aood_eval.py new file mode 100644 index 0000000..8417c3f --- /dev/null +++ b/datasets/aood_eval.py @@ -0,0 +1,633 @@ +# ------------------------------------------------------------------------ +# Novel Scenes & Classes: Towards Adaptive Open-set Object Detection +# Modified by Wuyang Li +# ------------------------------------------------------------------------ +import os +import shutil +import datetime +import functools +import subprocess +import xml.etree.ElementTree as ET +import numpy as np +import torch +import logging +from util.misc import all_gather +from numpy import * +from collections import OrderedDict, defaultdict + +class AOODEvaluator: + def __init__(self, voc_gt, iou_types, args=None, use_07_metric=True, ovthresh=list(range(50, 100, 5))): + assert tuple(iou_types) == ('bbox',) + self.use_07_metric = use_07_metric + self.ovthresh = ovthresh + self.voc_gt = voc_gt + self.eps = torch.finfo(torch.float64).eps + self.num_classes = len(self.voc_gt.CLASS_NAMES) # base1, base2 ,.. base n, unk + self._class_names = self.voc_gt.CLASS_NAMES + self.AP = torch.zeros(self.num_classes, 1) + self.all_recs = defaultdict(list) + self.all_precs = defaultdict(list) + self.recs = defaultdict(list) + self.precs = defaultdict(list) + self.num_unks = defaultdict(list) + self.unk_det_as_knowns = defaultdict(list) + self.tp_plus_fp_cs = defaultdict(list) + self.fp_os = defaultdict(list) + self.coco_eval = dict(bbox=lambda: None) + # self.coco_eval['bbox'].stats = torch.tensor([]) + self.coco_eval['bbox'].stats = dict() + self.coco_eval['bbox'].eval = dict() + + self.img_ids = [] + self.lines = [] + self.lines_cls = [] + + self.total_num_class = self.voc_gt.num_classes + self.unknown_class_index = self.voc_gt.unk_id + self.num_seen_classes = self.voc_gt.num_base + self.known_classes = self.voc_gt.BASE_CLASSES + self.unknown_classes = self.voc_gt.NOVEL_CLASSES_PER_TASK + + def update(self, predictions): + for img_id, pred in predictions.items(): + pred_boxes, pred_labels, pred_scores = [pred[k].cpu() for k in ['boxes', 'labels', 'scores']] + # print(img_id) + # image_id = self.voc_gt.convert_image_id(int(img_id), to_string=True) + img_name = self.voc_gt.imgids2img[int(img_id)] + self.img_ids.append(img_id) + classes = pred_labels.tolist() + for (xmin, ymin, xmax, ymax), cls, score in zip(pred_boxes.tolist(), classes , pred_scores.tolist()): + xmin += 1 + ymin += 1 + self.lines.append(f"{img_name} {score:.3f} {xmin:.1f} {ymin:.1f} {xmax:.1f} {ymax:.1f}") + self.lines_cls.append(cls) + + def compute_avg_precision_at_many_recall_level_for_unk(self, precisions, recalls): + precs = {} + for r in range(1, 10): + r = r/10 + p = self.compute_avg_precision_at_a_recall_level_for_unk(precisions, recalls, recall_level=r) + precs[r] = p + return precs + + def compute_avg_precision_at_a_recall_level_for_unk(self, precisions, recalls, recall_level=0.5): + precs = {} + for iou, recall in recalls.items(): + prec = [] + for cls_id, rec in enumerate(recall): + if cls_id == self.unknown_class_index and len(rec)>0: + p = precisions[iou][cls_id][min(range(len(rec)), key=lambda i: abs(rec[i] - recall_level))] + prec.append(p) + if len(prec) > 0: + precs[iou] = np.mean(prec) + else: + precs[iou] = 0 + return precs + + def compute_WI_at_many_recall_level(self, recalls, tp_plus_fp_cs, fp_os): + wi_at_recall = {} + for r in range(1, 10): + r = r/10 + wi = self.compute_WI_at_a_recall_level(recalls, tp_plus_fp_cs, fp_os, recall_level=r) + wi_at_recall[r] = wi + return wi_at_recall + + def compute_WI_at_a_recall_level(self, recalls, tp_plus_fp_cs, fp_os, recall_level=0.5): + wi_at_iou = {} + for iou, recall in recalls.items(): + tp_plus_fps = [] + fps = [] + for cls_id, rec in enumerate(recall): + if cls_id in range(self.num_seen_classes) and len(rec) > 0: + index = min(range(len(rec)), key=lambda i: abs(rec[i] - recall_level)) + tp_plus_fp = tp_plus_fp_cs[iou][cls_id][index] + tp_plus_fps.append(tp_plus_fp) + fp = fp_os[iou][cls_id][index] + fps.append(fp) + if len(tp_plus_fps) > 0: + wi_at_iou[iou] = np.mean(fps) / np.mean(tp_plus_fps) + else: + wi_at_iou[iou] = 0 + return wi_at_iou + + def synchronize_between_processes(self): + self.img_ids = torch.tensor(self.img_ids, dtype=torch.int64) + self.lines_cls = torch.tensor(self.lines_cls, dtype=torch.int64) + self.img_ids, self.lines, self.lines_cls = self.merge(self.img_ids, self.lines, self.lines_cls) + + def merge(self, img_ids, lines, lines_cls): + flatten = lambda ls: [s for l in ls for s in l] + all_img_ids = torch.cat(all_gather(img_ids)) + all_lines_cls = torch.cat(all_gather(lines_cls)) + all_lines = flatten(all_gather(lines)) + return all_img_ids, all_lines, all_lines_cls + + def accumulate(self): + for class_label_ind, class_label in enumerate(self.voc_gt.CLASS_NAMES): + + # lines_by_class = [l + '\n' for l, c in zip(self.lines, self.lines_cls.tolist()) if c == class_label_ind+1] + lines_by_class = [l + '\n' for l, c in zip(self.lines, self.lines_cls.tolist()) if c == class_label_ind] # coco start from 1 + if len(lines_by_class) == 0: + lines_by_class = [] + print(class_label + " has " + str(len(lines_by_class)) + " predictions.") + ovthresh = 50 + ovthresh_ind, _ = map(self.ovthresh.index, [50, 75]) + self.rec, self.prec, self.AP[class_label_ind, ovthresh_ind], self.unk_det_as_known, \ + self.num_unk, self.tp_plus_fp_closed_set, self.fp_open_set = voc_eval( + lines_by_class, + self.voc_gt.annotations, + self.voc_gt.image_set, + class_label, + ovthresh=ovthresh / 100.0, + use_07_metric=self.use_07_metric, + known_classes=self.known_classes, + unknown_classes=self.unknown_classes) #[-1] + + self.AP[class_label_ind, ovthresh_ind] = self.AP[class_label_ind, ovthresh_ind] * 100 + self.all_recs[ovthresh].append(self.rec) + self.all_precs[ovthresh].append(self.prec) + self.num_unks[ovthresh].append(self.num_unk) + self.unk_det_as_knowns[ovthresh].append(self.unk_det_as_known) + self.tp_plus_fp_cs[ovthresh].append(self.tp_plus_fp_closed_set) + self.fp_os[ovthresh].append(self.fp_open_set) + try: + self.recs[ovthresh].append(self.rec[-1] * 100) + self.precs[ovthresh].append(self.prec[-1] * 100) + except: + self.recs[ovthresh].append(0.) + self.precs[ovthresh].append(0.) + + def summarize(self, fmt='{:.06f}'): + o50, _ = map(self.ovthresh.index, [50, 75]) + mAP = float(self.AP.mean()) + mAP50 = float(self.AP[:, o50].mean()) + # print('detection mAP50:', fmt.format(mAP50)) + # print('detection mAP:', fmt.format(mAP)) + # print('---AP50---') + wi = self.compute_WI_at_many_recall_level(self.all_recs, self.tp_plus_fp_cs, self.fp_os) + # print('Wilderness Impact: ' + str(wi)) + avg_precision_unk = self.compute_avg_precision_at_many_recall_level_for_unk(self.all_precs, self.all_recs) + # print('avg_precision: ' + str(avg_precision_unk)) + total_num_unk_det_as_known = {iou: np.sum(x) for iou, x in self.unk_det_as_knowns.items()} #torch.sum(self.unk_det_as_knowns[:, o50]) #[np.sum(x) for x in self.unk_det_as_knowns[:, o50]] + total_num_unk = self.num_unks[50][0] + # print('Absolute OSE (total_num_unk_det_as_known): ' + str(total_num_unk_det_as_known)) + # print('total_num_unk ' + str(total_num_unk)) + # print("AP50: " + str(['%.1f' % x for x in self.AP[:, o50]])) + # print("Precisions50: " + str(['%.1f' % x for x in self.precs[50]])) + # print("Recall50: " + str(['%.1f' % x for x in self.recs[50]])) + # print("Unknown AP50: " + str(self.AP[:, o50][-1])) + # print("Unknown Precisions50: " + str(self.precs[50][-1])) + # print("Unknown Recall50: " + str(self.recs[50][-1])) + + # import ipdb; ipdb.set_trace() + + + AP50 = [round(x.item(),2) for x in self.AP[:, o50]] + + base_AP50 = AP50[:-1] + # base_mAP = mean(base_AP50) + base_mAP = round(mean(base_AP50), 2) + novel_AP50 = round(AP50[-1], 5) + + novel_recall = round(self.recs[50][-1], 2) + wi_08 = round(wi[0.8][50] *100, 3) + + AOSE = total_num_unk_det_as_known[50] + + + class_name = list(self.voc_gt.CLASS_NAMES) + base_name = class_name[:-1] + + title = base_name + ['mAP_base','WI_08', 'AOSE', 'Recall_novel', 'nodel_AP50' ] + ap_map_wi_aose_ar = base_AP50 + [base_mAP]+[wi_08]+[AOSE]+[novel_recall]+[novel_AP50] + report_results = [str(base_mAP)]+[str(novel_recall)] + [str(wi_08)]+[str(round(AOSE))] + + ap_map_wi_aose_ar = [str(i) for i in ap_map_wi_aose_ar] + ap_map_wi_aose_ar=' & '.join(ap_map_wi_aose_ar) + + title = ' & '.join(title) + # "Wilderness Impact: ": str(wi), + # "avg_unk_precision: ": str(avg_precision_unk), + # for class_name, ap in zip(self.voc_gt.CLASS_NAMES, self.AP[:, o50].cpu().tolist()): + # print(class_name, fmt.format(ap)) + # for class_name, ap in zip(self.voc_gt.CLASS_NAMES, self.AP[:, o50].cpu().tolist()): + # print(class_name, fmt.format(ap)) + self.coco_eval['bbox'].stats = { + "detection_mAP50: ": float(self.AP[:, o50].mean()), + "Absolute OSE (total_num_unk_det_as_known): ": str(total_num_unk_det_as_known), + "total_num_unk: ": str(total_num_unk), + "base_mAP: ": base_mAP, + "AP50: ": str(['%.1f' % x for x in self.AP[:, o50]]), + "Recall50: ": str(['%.1f' % x for x in self.recs[50]]), + "Unknown AP50: ": str(self.AP[:, o50][-1]), + "Unknown Recall50: ": str(self.recs[50][-1]), + "ap_map_wi_aose_ar": ap_map_wi_aose_ar, + "report_results": report_results, + "title": title, + + } + print(self.coco_eval['bbox'].stats ) + +def voc_ap(rec, prec, use_07_metric=False): + """ ap = voc_ap(rec, prec, [use_07_metric]) + Compute VOC AP given precision and recall. + If use_07_metric is true, uses the + VOC 07 11 point method (default:False). + """ + if use_07_metric: + # 11 point metric + ap = 0. + for t in np.arange(0., 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11. + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + + +@functools.lru_cache(maxsize=None) +def parse_rec(filename, known_classes, unknown_classes): + """ Parse a PASCAL VOC xml file """ + VOC_CLASS_NAMES_COCOFIED = [ + "bike", "motor" + ] + BASE_VOC_CLASS_NAMES = [ + "bicycle", "motorcycle" + ] + + # VOC_CLASS_NAMES_COCOFIED = [ + # "airplane", "dining table", "motorcycle", + # "potted plant", "couch", "tv" + # ] + # BASE_VOC_CLASS_NAMES = [ + # "aeroplane", "diningtable", "motorbike", + # "pottedplant", "sofa", "tvmonitor" + # ] + + tree = ET.parse(filename) + + objects = [] + for obj in tree.findall('object'): + obj_struct = {} + cls_name = obj.find('name').text + + if cls_name in VOC_CLASS_NAMES_COCOFIED: + cls_name = BASE_VOC_CLASS_NAMES[VOC_CLASS_NAMES_COCOFIED.index(cls_name)] + + # if cls_name not in known_classes: + if cls_name in known_classes: + cls_name = cls_name + elif cls_name in unknown_classes: + cls_name = 'unknown' + else: + continue + obj_struct['name'] = cls_name + obj_struct['difficult'] = int(obj.find('difficult').text) + bbox = obj.find('bndbox') + obj_struct['bbox'] = [int(bbox.find('xmin').text), + int(bbox.find('ymin').text), + int(bbox.find('xmax').text), + int(bbox.find('ymax').text)] + objects.append(obj_struct) + + return objects + + +def voc_eval(detpath, + annopath, + imagesetfile, + classname, + ovthresh=0.5, + use_07_metric=False, + known_classes=None, + unknown_classes= None): + # -------------------------------------------------------- + # Fast/er R-CNN + # Licensed under The MIT License [see LICENSE for details] + # Written by Bharath Hariharan + # -------------------------------------------------------- + + """rec, prec, ap = voc_eval(detpath, + annopath, + imagesetfile, + classname, + [ovthresh], + [use_07_metric]) + Top level function that does the PASCAL VOC evaluation. + detpath: Path to detections + detpath.format(classname) should produce the detection results file. + annopath: Path to annotations + annopath.format(imagename) should be the xml annotations file. + imagesetfile: Text file containing the list of images, one image per line. + classname: Category name (duh) + cachedir: Directory for caching the annotations + [ovthresh]: Overlap threshold (default = 0.5) + [use_07_metric]: Whether to use VOC07's 11 point AP computation + (default False) + """ + + + def iou(BBGT, bb): + ixmin = np.maximum(BBGT[:, 0], bb[0]) + iymin = np.maximum(BBGT[:, 1], bb[1]) + ixmax = np.minimum(BBGT[:, 2], bb[2]) + iymax = np.minimum(BBGT[:, 3], bb[3]) + iw = np.maximum(ixmax - ixmin + 1., 0.) + ih = np.maximum(iymax - iymin + 1., 0.) + inters = iw * ih + + # union + uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) + + (BBGT[:, 2] - BBGT[:, 0] + 1.) * + (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters) + + overlaps = inters / uni + ovmax = np.max(overlaps) + jmax = np.argmax(overlaps) + return ovmax, jmax + + # assumes detections are in detpath.format(classname) + # assumes annotations are in annopath.format(imagename) + # assumes imagesetfile is a text file with each line an image name + # cachedir caches the annotations in a pickle file + + # read list of images + if isinstance(imagesetfile, list): + lines = imagesetfile + else: + with open(imagesetfile, 'r') as f: + lines = f.readlines() + # imagenames = [x.strip().split('/')[-1] for x in lines] + # imagenames = [x.strip() for x in lines] + imagenames = [x.split('/')[-1] for x in lines] + + # load annots + recs = {} + if isinstance(annopath, list): + for a in annopath: + imagename = os.path.splitext(os.path.basename(a))[0] + recs[imagename] = parse_rec(a, tuple(known_classes), tuple(unknown_classes)) + else: + for i, imagename in enumerate(imagenames): + recs[imagename] = parse_rec(annopath.format(imagename), tuple(known_classes), tuple(unknown_classes)) + + # extract gt objects for this class + class_recs = {} + npos = 0 + + for imagename in imagenames: + R = [obj for obj in recs[imagename] if obj['name'] == classname] + bbox = np.array([x['bbox'] for x in R]) + difficult = np.array([x['difficult'] for x in R]).astype(np.bool) + det = [False] * len(R) + npos = npos + sum(~difficult) + class_recs[imagename] = {'bbox': bbox, + 'difficult': difficult, + 'det': det} + + # read dets + if isinstance(detpath, list): + lines = detpath + else: + detfile = detpath.format(classname) + with open(detfile, 'r') as f: + lines = f.readlines() + + + splitlines = [x.strip().split(' ') for x in lines] + image_ids = [x[0] for x in splitlines] + confidence = np.array([float(x[1]) for x in splitlines]) + if len(splitlines) == 0: + BB = np.array([[float(z) for z in x[2:]] for x in splitlines]).reshape(-1, 4) + else: + BB = np.array([[float(z) for z in x[2:]] for x in splitlines])#.reshape(-1, 4) + + # if BB.size == 0: + # return 0, 0, 0 + + # sort by confidence + sorted_ind = np.argsort(-confidence) + BB = BB[sorted_ind, :] + + image_ids = [image_ids[x].split('/')[-1] for x in sorted_ind] + + # go down dets and mark TPs and FPs + nd = len(image_ids) + tp = np.zeros(nd) + fp = np.zeros(nd) + for d in range(nd): + R = class_recs[image_ids[d]] + bb = BB[d, :].astype(float) + ovmax = -np.inf + BBGT = R['bbox'].astype(float) + + if BBGT.size > 0: + ovmax, jmax = iou(BBGT, bb) + + if ovmax > ovthresh: + if not R['difficult'][jmax]: + if not R['det'][jmax]: + tp[d] = 1. + R['det'][jmax] = 1 + else: + fp[d] = 1. + else: + fp[d] = 1. + + # compute precision recall + fp = np.cumsum(fp) + tp = np.cumsum(tp) + rec = tp / float(npos) + # avoid divide by zero in case the first detection matches a difficult + # ground truth + prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) + ap = voc_ap(rec, prec, use_07_metric) + + ''' + Computing Absolute Open-Set Error (A-OSE) and Wilderness Impact (WI) + =========== + Absolute OSE = # of unknown objects classified as known objects of class 'classname' + WI = FP_openset / (TP_closed_set + FP_closed_set) + ''' + # logger = logging.getLogger(__name__) + + # Finding GT of unknown objects + unknown_class_recs = {} + n_unk = 0 + for imagename in imagenames: + R = [obj for obj in recs[imagename] if obj["name"] == 'unknown'] + bbox = np.array([x["bbox"] for x in R]) + difficult = np.array([x["difficult"] for x in R]).astype(np.bool) + det = [False] * len(R) + n_unk = n_unk + sum(~difficult) + unknown_class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det} + + if classname == 'unknown': + return rec, prec, ap, 0., n_unk, None, None + + # Go down each detection and see if it has an overlap with an unknown object. + # If so, it is an unknown object that was classified as known. + is_unk = np.zeros(nd) + for d in range(nd): + R = unknown_class_recs[image_ids[d]] + bb = BB[d, :].astype(float) + ovmax = -np.inf + BBGT = R["bbox"].astype(float) + + if BBGT.size > 0: + # compute overlaps + # intersection + ixmin = np.maximum(BBGT[:, 0], bb[0]) + iymin = np.maximum(BBGT[:, 1], bb[1]) + ixmax = np.minimum(BBGT[:, 2], bb[2]) + iymax = np.minimum(BBGT[:, 3], bb[3]) + iw = np.maximum(ixmax - ixmin + 1.0, 0.0) + ih = np.maximum(iymax - iymin + 1.0, 0.0) + inters = iw * ih + + # union + uni = ( + (bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0) + + (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0) + - inters + ) + + overlaps = inters / uni + ovmax = np.max(overlaps) + jmax = np.argmax(overlaps) + + if ovmax > ovthresh: + is_unk[d] = 1.0 + + is_unk_sum = np.sum(is_unk) + tp_plus_fp_closed_set = tp+fp + fp_open_set = np.cumsum(is_unk) + + + return rec, prec, ap, is_unk_sum, n_unk, tp_plus_fp_closed_set, fp_open_set + + +def bbox_nms(boxes, scores, overlap_threshold=0.4, score_threshold=0.0, mask=False): + def overlap(box1, box2=None, rectint=False, eps=1e-6): + area = lambda boxes=None, x1=None, y1=None, x2=None, y2=None: (boxes[..., 2] - boxes[..., 0]) * ( + boxes[..., 3] - boxes[..., 1]) if boxes is not None else (x2 - x1).clamp(min=0) * (y2 - y1).clamp( + min=0) + + if box2 is None and not isinstance(box1, list) and box1.dim() == 3: + return torch.stack(list(map(overlap, box1))) + b1, b2 = [(b if b.dim() == 2 else b.unsqueeze(0)).t().contiguous() for b in + [box1, (box2 if box2 is not None else box1)]] + + xx1 = torch.max(b1[0].unsqueeze(1), b2[0].unsqueeze(0)) + yy1 = torch.max(b1[1].unsqueeze(1), b2[1].unsqueeze(0)) + xx2 = torch.min(b1[2].unsqueeze(1), b2[2].unsqueeze(0)) + yy2 = torch.min(b1[3].unsqueeze(1), b2[3].unsqueeze(0)) + + inter = area(x1=xx1, y1=yy1, x2=xx2, y2=yy2) + return inter / (area(b1.t()).unsqueeze(1) + area(b2.t()).unsqueeze(0) - inter + eps) if not rectint else inter + + O = overlap(boxes) + I = scores.sort(0)[1] + M = scores.gather(0, I).ge(score_threshold) + M = M if M.any() else M.fill_(1) + pick = [] + + for i, m in zip(I.t(), M.t()): + p = [] + i = i[m] + while len(i) > 1: + p.append(i[-1]) + m = O[:, i[-1]][i].lt(overlap_threshold) + m[-1] = 0 + i = i[m] + pick.append(torch.tensor(p + i.tolist(), dtype=torch.int64)) + + return pick if not mask else torch.stack( + [torch.zeros(len(scores), dtype=torch.bool).scatter_(0, p, 1) for p in pick]) + + +def package_submission(out_dir, image_file_name, class_labels, VOCYEAR, SUBSET, TASK, tar=True, **kwargs): + def cls(file_path, class_label_ind, scores): + with open(file_path, 'w') as f: + f.writelines(map('{} {}\n'.format, image_file_name, scores[:, class_label_ind].tolist())) + + def det(file_path, class_label_ind, scores, proposals, keep): + zipped = [] + for example_idx, basename in enumerate(image_file_name): + I = keep[example_idx][class_label_ind] + zipped.extend((basename, s) + tuple(p) for s, p in zip(scores[example_idx][I, class_label_ind].tolist(), + proposals[example_idx][I, :4].add(1).tolist())) + with open(file_path, 'w') as f: + f.writelines(map('{} {} {:.0f} {:.0f} {:.0f} {:.0f} \n'.format, *zip(*zipped))) + + task_a, task_b = TASK.split('_') + resdir = os.path.join(out_dir, 'results') + respath = os.path.join(resdir, VOCYEAR, 'Main', '%s_{}_{}_%s.txt'.format(task_b, SUBSET)) + + if os.path.exists(resdir): + shutil.rmtree(resdir) + os.makedirs(os.path.join(resdir, VOCYEAR, 'Main')) + + for class_label_ind, class_label in enumerate(class_labels): + dict(det=det, cls=cls)[task_b](respath.replace('%s', '{}').format(task_a, class_label), class_label_ind, + **kwargs) + + if tar: + subprocess.check_call(['tar', '-czf', 'results-{}-{}-{}.tar.gz'.format(VOCYEAR, TASK, SUBSET), 'results'], + cwd=out_dir) + + return respath + + +def detection_mean_ap(out_dir, image_file_name, class_labels, VOCYEAR, SUBSET, VOC_DEVKIT_VOCYEAR, scores=None, + boxes=None, nms_score_threshold=1e-4, nms_overlap_threshold=0.4, tar=False, octave=False, + cmd='octave --eval', env=None, stdout_stderr=open(os.devnull, 'wb'), do_nms=True): + if scores is not None: + nms = list(map(lambda s, p: bbox_nms(p, s, overlap_threshold=nms_overlap_threshold, + score_threshold=nms_score_threshold), scores, boxes)) if do_nms else [ + torch.arange(len(p)) for p in boxes] + + else: + nms = torch.arange(len(class_labels)).unsqueeze(0).unsqueeze(-1).expand(len(image_file_name), len(class_labels), + 1) + scores = torch.zeros(len(image_file_name), len(class_labels), len(class_labels)) + + imgsetpath = os.path.join(VOC_DEVKIT_VOCYEAR, 'ImageSets', 'Main', SUBSET + '.txt') + detrespath = package_submission(out_dir, image_file_name, class_labels, VOCYEAR, SUBSET, 'comp4_det', tar=tar, + scores=scores, proposals=boxes, nms=nms) + + if octave: + imgsetpath_fix = os.path.join(out_dir, detection_mean_ap.__name__ + '.txt') + with open(imgsetpath_fix, 'w') as f: + f.writelines([line[:-1] + ' -1\n' for line in open(imgsetpath)]) + procs = [subprocess.Popen(cmd.split() + [ + "oldpwd = pwd; cd('{}/..'); addpath(fullfile(pwd, 'VOCcode')); VOCinit; cd(oldpwd); VOCopts.testset = '{}'; VOCopts.detrespath = '{}'; VOCopts.imgsetpath = '{}'; classlabel = '{}'; warning('off', 'Octave:possible-matlab-short-circuit-operator'); warning('off', 'Octave:num-to-str'); [rec, prec, ap] = VOCevaldet(VOCopts, 'comp4', classlabel, false); dlmwrite(sprintf(VOCopts.detrespath, 'resu4', classlabel), ap); quit;".format( + VOC_DEVKIT_VOCYEAR, SUBSET, detrespath, imgsetpath_fix, class_label)], stdout=stdout_stderr, + stderr=stdout_stderr, env=env) for class_label in class_labels] + res = list(map(lambda class_label, proc: proc.wait() or float(open(detrespath % ('resu4', class_label)).read()), + class_labels, procs)) + + else: + res = [voc_eval(detrespath.replace('%s', '{}').format('comp4', '{}'), + os.path.join(VOC_DEVKIT_VOCYEAR, 'Annotations', '{}.xml'), imgsetpath, class_label, + cachedir=os.path.join(out_dir, 'cache_detection_mean_ap_' + SUBSET), use_07_metric=True)[-1] for + class_label in class_labels] + + return torch.tensor(res).mean(), res \ No newline at end of file diff --git a/datasets/coco.py b/datasets/coco.py new file mode 100644 index 0000000..2415895 --- /dev/null +++ b/datasets/coco.py @@ -0,0 +1,163 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ +from pathlib import Path +import torch +import torch.utils.data +from pycocotools import mask as coco_mask +from .torchvision_datasets import CocoDetection as TvCocoDetection +from util.misc import get_local_rank, get_local_size +import datasets.transforms as T + + +class CocoDetection(TvCocoDetection): + def __init__(self, img_folder, ann_file, transforms, return_masks, cache_mode=False, local_rank=0, local_size=1): + super(CocoDetection, self).__init__(img_folder, ann_file, + cache_mode=cache_mode, local_rank=local_rank, local_size=local_size) + self._transforms = transforms + self.prepare = ConvertCocoPolysToMask(return_masks) + + def __getitem__(self, idx): + img, target = super(CocoDetection, self).__getitem__(idx) + image_id = self.ids[idx] + target = {'image_id': image_id, 'annotations': target} + img, target = self.prepare(img, target) + if self._transforms is not None: + img, target = self._transforms(img, target) + return img, target + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +class ConvertCocoPolysToMask(object): + def __init__(self, return_masks=False): + self.return_masks = return_masks + + def __call__(self, image, target): + w, h = image.size + + image_id = target["image_id"] + image_id = torch.tensor([image_id]) + + anno = target["annotations"] + + anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + classes = [obj["category_id"] for obj in anno] + classes = torch.tensor(classes, dtype=torch.int64) + + if self.return_masks: + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + keypoints = None + if anno and "keypoints" in anno[0]: + keypoints = [obj["keypoints"] for obj in anno] + keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + if self.return_masks: + masks = masks[keep] + if keypoints is not None: + keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + target["labels"] = classes + if self.return_masks: + target["masks"] = masks + target["image_id"] = image_id + if keypoints is not None: + target["keypoints"] = keypoints + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) + target["area"] = area[keep] + target["iscrowd"] = iscrowd[keep] + + target["orig_size"] = torch.as_tensor([int(h), int(w)]) + target["size"] = torch.as_tensor([int(h), int(w)]) + + return image, target + + +def make_coco_transforms(image_set): + + normalize = T.Compose([ + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] + + if image_set == 'train': + return T.Compose([ + T.RandomHorizontalFlip(), + T.RandomSelect( + T.RandomResize(scales, max_size=1333), + T.Compose([ + T.RandomResize([400, 500, 600]), + T.RandomSizeCrop(384, 600), + T.RandomResize(scales, max_size=1333), + ]) + ), + normalize, + ]) + + if image_set == 'val': + return T.Compose([ + T.RandomResize([800], max_size=1333), + normalize, + ]) + + raise ValueError(f'unknown {image_set}') + + +def build(image_set, cfg): + root = Path(cfg.DATASET.COCO_PATH) + assert root.exists(), f'provided COCO path {root} does not exist' + mode = 'instances' + PATHS = { + "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'), + "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'), + } + + img_folder, ann_file = PATHS[image_set] + dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=cfg.MODEL.MASKS, + cache_mode=cfg.CACHE_MODE, local_rank=get_local_rank(), local_size=get_local_size()) + return dataset diff --git a/datasets/coco_eval.py b/datasets/coco_eval.py new file mode 100644 index 0000000..011ee53 --- /dev/null +++ b/datasets/coco_eval.py @@ -0,0 +1,267 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +COCO evaluator that works in distributed mode. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py +The difference is that there is less copy-pasting from pycocotools +in the end of the file, as python3 can suppress prints with contextlib +""" +import os +import contextlib +import copy +import numpy as np +import torch + +from pycocotools.cocoeval import COCOeval +from pycocotools.coco import COCO +import pycocotools.mask as mask_util + +from util.misc import all_gather + + +class CocoEvaluator(object): + def __init__(self, coco_gt, iou_types): + assert isinstance(iou_types, (list, tuple)) + coco_gt = copy.deepcopy(coco_gt) + self.coco_gt = coco_gt + + self.iou_types = iou_types + self.coco_eval = {} + for iou_type in iou_types: + self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) + + self.img_ids = [] + self.eval_imgs = {k: [] for k in iou_types} + + def update(self, predictions): + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + + # suppress pycocotools prints + with open(os.devnull, 'w') as devnull: + with contextlib.redirect_stdout(devnull): + coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() + coco_eval = self.coco_eval[iou_type] + + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + img_ids, eval_imgs = evaluate(coco_eval) + + self.eval_imgs[iou_type].append(eval_imgs) + + def synchronize_between_processes(self): + for iou_type in self.iou_types: + self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) + create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) + + def accumulate(self): + for coco_eval in self.coco_eval.values(): + coco_eval.accumulate() + + def summarize(self): + for iou_type, coco_eval in self.coco_eval.items(): + print("IoU metric: {}".format(iou_type)) + coco_eval.summarize() + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_coco_segmentation(predictions) + elif iou_type == "keypoints": + return self.prepare_for_coco_keypoint(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_coco_detection(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + + def prepare_for_coco_segmentation(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"] + labels = prediction["labels"] + masks = prediction["masks"] + + masks = masks > 0.5 + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return coco_results + + def prepare_for_coco_keypoint(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + keypoints = prediction["keypoints"] + keypoints = keypoints.flatten(start_dim=1).tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + 'keypoints': keypoint, + "score": scores[k], + } + for k, keypoint in enumerate(keypoints) + ] + ) + return coco_results + + +def convert_to_xywh(boxes): + xmin, ymin, xmax, ymax = boxes.unbind(1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) + + +def merge(img_ids, eval_imgs): + all_img_ids = all_gather(img_ids) + all_eval_imgs = all_gather(eval_imgs) + + merged_img_ids = [] + for p in all_img_ids: + merged_img_ids.extend(p) + + merged_eval_imgs = [] + for p in all_eval_imgs: + merged_eval_imgs.append(p) + + merged_img_ids = np.array(merged_img_ids) + merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) + + # keep only unique (and in sorted order) images + merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) + merged_eval_imgs = merged_eval_imgs[..., idx] + + return merged_img_ids, merged_eval_imgs + + +def create_common_coco_eval(coco_eval, img_ids, eval_imgs): + img_ids, eval_imgs = merge(img_ids, eval_imgs) + img_ids = list(img_ids) + eval_imgs = list(eval_imgs.flatten()) + + coco_eval.evalImgs = eval_imgs + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + + +################################################################# +# From pycocotools, just removed the prints and fixed +# a Python3 bug about unicode not defined +################################################################# + + +def evaluate(self): + ''' + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + :return: None + ''' + # tic = time.time() + # print('Running per image evaluation...') + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = 'segm' if p.useSegm == 1 else 'bbox' + print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) + # print('Evaluate annotation type *{}*'.format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == 'segm' or p.iouType == 'bbox': + computeIoU = self.computeIoU + elif p.iouType == 'keypoints': + computeIoU = self.computeOks + self.ious = { + (imgId, catId): computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds} + + evaluateImg = self.evaluateImg + maxDet = p.maxDets[-1] + evalImgs = [ + evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + # this is NOT in the pycocotools code, but could be done outside + evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) + self._paramsEval = copy.deepcopy(self.params) + # toc = time.time() + # print('DONE (t={:0.2f}s).'.format(toc-tic)) + return p.imgIds, evalImgs + +################################################################# +# end of straight copy from pycocotools, just removing the prints +################################################################# diff --git a/datasets/coco_panoptic.py b/datasets/coco_panoptic.py new file mode 100644 index 0000000..176fcf5 --- /dev/null +++ b/datasets/coco_panoptic.py @@ -0,0 +1,109 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import json +from pathlib import Path + +import numpy as np +import torch +from PIL import Image + +from panopticapi.utils import rgb2id +from util.box_ops import masks_to_boxes + +from .coco import make_coco_transforms + + +class CocoPanoptic: + def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True): + with open(ann_file, 'r') as f: + self.coco = json.load(f) + + # sort 'images' field so that they are aligned with 'annotations' + # i.e., in alphabetical order + self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id']) + # sanity check + if "annotations" in self.coco: + for img, ann in zip(self.coco['images'], self.coco['annotations']): + assert img['file_name'][:-4] == ann['file_name'][:-4] + + self.img_folder = img_folder + self.ann_folder = ann_folder + self.ann_file = ann_file + self.transforms = transforms + self.return_masks = return_masks + + def __getitem__(self, idx): + ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx] + img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg') + ann_path = Path(self.ann_folder) / ann_info['file_name'] + + img = Image.open(img_path).convert('RGB') + w, h = img.size + if "segments_info" in ann_info: + masks = np.asarray(Image.open(ann_path), dtype=np.uint32) + masks = rgb2id(masks) + + ids = np.array([ann['id'] for ann in ann_info['segments_info']]) + masks = masks == ids[:, None, None] + + masks = torch.as_tensor(masks, dtype=torch.uint8) + labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64) + + target = {} + target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]]) + if self.return_masks: + target['masks'] = masks + target['labels'] = labels + + target["boxes"] = masks_to_boxes(masks) + + target['size'] = torch.as_tensor([int(h), int(w)]) + target['orig_size'] = torch.as_tensor([int(h), int(w)]) + if "segments_info" in ann_info: + for name in ['iscrowd', 'area']: + target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']]) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self): + return len(self.coco['images']) + + def get_height_and_width(self, idx): + img_info = self.coco['images'][idx] + height = img_info['height'] + width = img_info['width'] + return height, width + + +def build(image_set, cfg): + img_folder_root = Path(cfg.DATASET.COCO_PATH) + ann_folder_root = Path(cfg.COCO_PANOPTIC_PATH) + assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist' + assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist' + mode = 'panoptic' + PATHS = { + "train": ("train2017", Path("annotations") / f'{mode}_train2017.json'), + "val": ("val2017", Path("annotations") / f'{mode}_val2017.json'), + } + + img_folder, ann_file = PATHS[image_set] + img_folder_path = img_folder_root / img_folder + ann_folder = ann_folder_root / f'{mode}_{img_folder}' + ann_file = ann_folder_root / ann_file + + dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file, + transforms=make_coco_transforms(image_set), return_masks=cfg.MODEL.MASKS) + + return dataset diff --git a/datasets/data_prefetcher.py b/datasets/data_prefetcher.py new file mode 100644 index 0000000..fbe12ff --- /dev/null +++ b/datasets/data_prefetcher.py @@ -0,0 +1,71 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +import torch + +def to_cuda(samples, targets, device): + samples = samples.to(device, non_blocking=True) + targets = [{k: v.to(device, non_blocking=True) for k, v in t.items()} for t in targets] + return samples, targets + +class data_prefetcher(): + def __init__(self, loader, device, prefetch=True): + self.loader = iter(loader) + self.prefetch = prefetch + self.device = device + if prefetch: + self.stream = torch.cuda.Stream() + self.preload() + + def preload(self): + try: + self.next_samples, self.next_targets = next(self.loader) + except StopIteration: + self.next_samples = None + self.next_targets = None + return + # if record_stream() doesn't work, another option is to make sure device inputs are created + # on the main stream. + # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda') + # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda') + # Need to make sure the memory allocated for next_* is not still in use by the main stream + # at the time we start copying to next_*: + # self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + self.next_samples, self.next_targets = to_cuda(self.next_samples, self.next_targets, self.device) + # more code for the alternative if record_stream() doesn't work: + # copy_ will record the use of the pinned source tensor in this side stream. + # self.next_input_gpu.copy_(self.next_input, non_blocking=True) + # self.next_target_gpu.copy_(self.next_target, non_blocking=True) + # self.next_input = self.next_input_gpu + # self.next_target = self.next_target_gpu + + # With Amp, it isn't necessary to manually convert data to half. + # if args.fp16: + # self.next_input = self.next_input.half() + # else: + + def next(self): + if self.prefetch: + torch.cuda.current_stream().wait_stream(self.stream) + samples = self.next_samples + targets = self.next_targets + if samples is not None: + samples.record_stream(torch.cuda.current_stream()) + if targets is not None: + for t in targets: + for k, v in t.items(): + v.record_stream(torch.cuda.current_stream()) + self.preload() + else: + try: + samples, targets = next(self.loader) + samples, targets = to_cuda(samples, targets, self.device) + except StopIteration: + samples = None + targets = None + return samples, targets diff --git a/datasets/panoptic_eval.py b/datasets/panoptic_eval.py new file mode 100644 index 0000000..d63adc9 --- /dev/null +++ b/datasets/panoptic_eval.py @@ -0,0 +1,54 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import json +import os + +import util.misc as utils + +try: + from panopticapi.evaluation import pq_compute +except ImportError: + pass + + +class PanopticEvaluator(object): + def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"): + self.gt_json = ann_file + self.gt_folder = ann_folder + if utils.is_main_process(): + if not os.path.exists(output_dir): + os.mkdir(output_dir) + self.output_dir = output_dir + self.predictions = [] + + def update(self, predictions): + for p in predictions: + with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f: + f.write(p.pop("png_string")) + + self.predictions += predictions + + def synchronize_between_processes(self): + all_predictions = utils.all_gather(self.predictions) + merged_predictions = [] + for p in all_predictions: + merged_predictions += p + self.predictions = merged_predictions + + def summarize(self): + if utils.is_main_process(): + json_data = {"annotations": self.predictions} + predictions_json = os.path.join(self.output_dir, "predictions.json") + with open(predictions_json, "w") as f: + f.write(json.dumps(json_data)) + return pq_compute(self.gt_json, predictions_json, gt_folder=self.gt_folder, pred_folder=self.output_dir) + return None diff --git a/datasets/samplers.py b/datasets/samplers.py new file mode 100644 index 0000000..fe5ce38 --- /dev/null +++ b/datasets/samplers.py @@ -0,0 +1,141 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from codes in torch.utils.data.distributed +# ------------------------------------------------------------------------ + +import os +import math +import torch +import torch.distributed as dist +from torch.utils.data.sampler import Sampler + + +class DistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset : offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + + +class NodeDistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if local_rank is None: + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + if local_size is None: + local_size = int(os.environ.get('LOCAL_SIZE', 1)) + self.dataset = dataset + self.shuffle = shuffle + self.num_replicas = num_replicas + self.num_parts = local_size + self.rank = rank + self.local_rank = local_rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + indices = [i for i in indices if i % self.num_parts == self.local_rank] + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size_parts - len(indices))] + assert len(indices) == self.total_size_parts + + # subsample + indices = indices[self.rank // self.num_parts:self.total_size_parts:self.num_replicas // self.num_parts] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/datasets/torchvision_datasets/__init__.py b/datasets/torchvision_datasets/__init__.py new file mode 100644 index 0000000..47502cc --- /dev/null +++ b/datasets/torchvision_datasets/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from .coco import CocoDetection diff --git a/datasets/torchvision_datasets/coco.py b/datasets/torchvision_datasets/coco.py new file mode 100644 index 0000000..9bf70b9 --- /dev/null +++ b/datasets/torchvision_datasets/coco.py @@ -0,0 +1,86 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from torchvision +# ------------------------------------------------------------------------ + +""" +Copy-Paste from torchvision, but add utility of caching images on memory +""" +from torchvision.datasets.vision import VisionDataset +from PIL import Image +import os +import os.path +import tqdm +from io import BytesIO + + +class CocoDetection(VisionDataset): + """`MS Coco Detection `_ Dataset. + Args: + root (string): Root directory where images are downloaded to. + annFile (string): Path to json annotation file. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + transforms (callable, optional): A function/transform that takes input sample and its target as entry + and returns a transformed version. + """ + + def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None, + cache_mode=False, local_rank=0, local_size=1): + super(CocoDetection, self).__init__(root, transforms, transform, target_transform) + from pycocotools.coco import COCO + self.coco = COCO(annFile) + self.ids = list(sorted(self.coco.imgs.keys())) + self.cache_mode = cache_mode + self.local_rank = local_rank + self.local_size = local_size + if cache_mode: + self.cache = {} + self.cache_images() + + def cache_images(self): + self.cache = {} + for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids): + if index % self.local_size != self.local_rank: + continue + path = self.coco.loadImgs(img_id)[0]['file_name'] + with open(os.path.join(self.root, path), 'rb') as f: + self.cache[path] = f.read() + + def get_image(self, path): + if self.cache_mode: + if path not in self.cache.keys(): + with open(os.path.join(self.root, path), 'rb') as f: + self.cache[path] = f.read() + return Image.open(BytesIO(self.cache[path])).convert('RGB') + return Image.open(os.path.join(self.root, path)).convert('RGB') + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. + """ + coco = self.coco + img_id = self.ids[index] + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + + path = coco.loadImgs(img_id)[0]['file_name'] + + img = self.get_image(path) + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self): + return len(self.ids) diff --git a/datasets/transforms.py b/datasets/transforms.py new file mode 100644 index 0000000..50738ae --- /dev/null +++ b/datasets/transforms.py @@ -0,0 +1,286 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Transforms and data augmentation for both image + bbox. +""" +import random + +import PIL +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F + +from util.box_ops import box_xyxy_to_cxcywh +from util.misc import interpolate + + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + + fields = ["labels", "area", "iscrowd"] + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target["boxes"] = cropped_boxes.reshape(-1, 4) + target["area"] = area + fields.append("boxes") + + if "masks" in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append("masks") + + # remove elements for which the boxes or masks that have zero area + if "boxes" in target or "masks" in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if "boxes" in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) + target["boxes"] = boxes + + if "masks" in target: + target['masks'] = target['masks'].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + if "masks" in target: + target['masks'] = interpolate( + target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 + + return rescaled_image, target + + +def pad(image, target, padding): + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + target["size"] = torch.tensor(padded_image[::-1]) + if "masks" in target: + target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) + return padded_image, target + + +class RandomCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + region = T.RandomCrop.get_params(img, self.size) + return crop(img, target, region) + + +class RandomSizeCrop(object): + def __init__(self, min_size: int, max_size: int): + self.min_size = min_size + self.max_size = max_size + + def __call__(self, img: PIL.Image.Image, target: dict): + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + region = T.RandomCrop.get_params(img, [h, w]) + return crop(img, target, region) + + +class CenterCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) + + +class RandomHorizontalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RandomResize(object): + def __init__(self, sizes, max_size=None): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) + + +class RandomPad(object): + def __init__(self, max_pad): + self.max_pad = max_pad + + def __call__(self, img, target): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + return pad(img, target, (pad_x, pad_y)) + + +class RandomSelect(object): + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + def __init__(self, transforms1, transforms2, p=0.5): + self.transforms1 = transforms1 + self.transforms2 = transforms2 + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return self.transforms1(img, target) + return self.transforms2(img, target) + + +class ToTensor(object): + def __call__(self, img, target): + return F.to_tensor(img), target + + +class RandomErasing(object): + + def __init__(self, *args, **kwargs): + self.eraser = T.RandomErasing(*args, **kwargs) + + def __call__(self, img, target): + return self.eraser(img), target + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if "boxes" in target: + boxes = target["boxes"] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["boxes"] = boxes + return image, target + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string diff --git a/engine.py b/engine.py new file mode 100644 index 0000000..b41fd43 --- /dev/null +++ b/engine.py @@ -0,0 +1,168 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Train and eval functions used in main.py +""" +import math +import os +import sys +from typing import Iterable + +import torch +import util.misc as utils +from datasets.coco_eval import CocoEvaluator +from datasets.panoptic_eval import PanopticEvaluator +from datasets.data_prefetcher import data_prefetcher + + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, max_norm: float = 0, logger=None): + model.train() + criterion.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + header = 'Epoch: [{}]'.format(epoch) + print_freq = 10 + + prefetcher = data_prefetcher(data_loader, device, prefetch=True) + samples, targets = prefetcher.next() + + # for samples, targets in metric_logger.log_every(data_loader, print_freq, header): + for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header,logger=logger): + outputs = model(samples) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_unscaled = {f'{k}_unscaled': v + for k, v in loss_dict_reduced.items()} + loss_dict_reduced_scaled = {k: v * weight_dict[k] + for k, v in loss_dict_reduced.items() if k in weight_dict} + losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) + + loss_value = losses_reduced_scaled.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + print(loss_dict_reduced) + sys.exit(1) + + optimizer.zero_grad() + losses.backward() + if max_norm > 0: + grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + else: + grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm) + optimizer.step() + + metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) + metric_logger.update(class_error=loss_dict_reduced['class_error']) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + metric_logger.update(grad_norm=grad_total_norm) + + samples, targets = prefetcher.next() + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir): + model.eval() + criterion.eval() + + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + header = 'Test:' + + iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) + coco_evaluator = CocoEvaluator(base_ds, iou_types) + # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] + + panoptic_evaluator = None + if 'panoptic' in postprocessors.keys(): + panoptic_evaluator = PanopticEvaluator( + data_loader.dataset.ann_file, + data_loader.dataset.ann_folder, + output_dir=os.path.join(output_dir, "panoptic_eval"), + ) + + for samples, targets in metric_logger.log_every(data_loader, 10, header,logger=logger): + samples = samples.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + outputs = model(samples) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_scaled = {k: v * weight_dict[k] + for k, v in loss_dict_reduced.items() if k in weight_dict} + loss_dict_reduced_unscaled = {f'{k}_unscaled': v + for k, v in loss_dict_reduced.items()} + metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), + **loss_dict_reduced_scaled, + **loss_dict_reduced_unscaled) + metric_logger.update(class_error=loss_dict_reduced['class_error']) + + orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + results = postprocessors['bbox'](outputs, orig_target_sizes) + if 'segm' in postprocessors.keys(): + target_sizes = torch.stack([t["size"] for t in targets], dim=0) + results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) + res = {target['image_id'].item(): output for target, output in zip(targets, results)} + if coco_evaluator is not None: + coco_evaluator.update(res) + + if panoptic_evaluator is not None: + res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes) + for i, target in enumerate(targets): + image_id = target["image_id"].item() + file_name = f"{image_id:012d}.png" + res_pano[i]["image_id"] = image_id + res_pano[i]["file_name"] = file_name + + panoptic_evaluator.update(res_pano) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + if coco_evaluator is not None: + coco_evaluator.synchronize_between_processes() + if panoptic_evaluator is not None: + panoptic_evaluator.synchronize_between_processes() + + # accumulate predictions from all images + if coco_evaluator is not None: + coco_evaluator.accumulate() + coco_evaluator.summarize() + panoptic_res = None + if panoptic_evaluator is not None: + panoptic_res = panoptic_evaluator.summarize() + stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + if coco_evaluator is not None: + if 'bbox' in postprocessors.keys(): + stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() + if 'segm' in postprocessors.keys(): + stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() + if panoptic_res is not None: + stats['PQ_all'] = panoptic_res["All"] + stats['PQ_th'] = panoptic_res["Things"] + stats['PQ_st'] = panoptic_res["Stuff"] + return stats, coco_evaluator diff --git a/engine_aood.py b/engine_aood.py new file mode 100644 index 0000000..613898c --- /dev/null +++ b/engine_aood.py @@ -0,0 +1,172 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Train and eval functions used in main.py +""" +import math +import os +import sys +from typing import Iterable + +import torch +import util.misc as utils +from datasets.aood_eval import AOODEvaluator +from datasets.panoptic_eval import PanopticEvaluator +from datasets.data_prefetcher import data_prefetcher + + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, max_norm: float = 0, logger=None): + model.train() + criterion.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + header = 'Epoch: [{}]'.format(epoch) + print_freq = 10 + + prefetcher = data_prefetcher(data_loader, device, prefetch=True) + samples, targets = prefetcher.next() + + # for samples, targets in metric_logger.log_every(data_loader, print_freq, header): + for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header, logger=logger): + outputs = model(samples) + # loss_dict = criterion(outputs, targets, epoch) + loss_dict = criterion(samples, outputs, targets, epoch) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_unscaled = {f'{k}_unscaled': v + for k, v in loss_dict_reduced.items()} + loss_dict_reduced_scaled = {k: v * weight_dict[k] + for k, v in loss_dict_reduced.items() if k in weight_dict} + losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) + + loss_value = losses_reduced_scaled.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + print(loss_dict_reduced) + sys.exit(1) + + optimizer.zero_grad() + losses.backward() + if max_norm > 0: + grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + else: + grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm) + optimizer.step() + + metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) + metric_logger.update(class_error=loss_dict_reduced['class_error']) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + metric_logger.update(grad_norm=grad_total_norm) + + samples, targets = prefetcher.next() + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir): + model.eval() + criterion.eval() + + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) + header = 'Test:' + + iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) + coco_evaluator = AOODEvaluator(base_ds, iou_types) + # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] + + panoptic_evaluator = None + if 'panoptic' in postprocessors.keys(): + panoptic_evaluator = PanopticEvaluator( + data_loader.dataset.ann_file, + data_loader.dataset.ann_folder, + output_dir=os.path.join(output_dir, "panoptic_eval"), + ) + + for samples, targets in metric_logger.log_every(data_loader, 10, header): + samples = samples.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + outputs = model(samples) + loss_dict = criterion(samples, outputs, targets,epoch=0) + weight_dict = criterion.weight_dict + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_scaled = {k: v * weight_dict[k] + for k, v in loss_dict_reduced.items() if k in weight_dict} + loss_dict_reduced_unscaled = {f'{k}_unscaled': v + for k, v in loss_dict_reduced.items()} + metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), + **loss_dict_reduced_scaled, + **loss_dict_reduced_unscaled) + metric_logger.update(class_error=loss_dict_reduced['class_error']) + + orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + + results = postprocessors['bbox'](outputs, orig_target_sizes, show_box=False) + + if 'segm' in postprocessors.keys(): + target_sizes = torch.stack([t["size"] for t in targets], dim=0) + results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) + res = {target['image_id'].item(): output for target, output in zip(targets, results)} + if coco_evaluator is not None: + coco_evaluator.update(res) + + if panoptic_evaluator is not None: + res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes) + for i, target in enumerate(targets): + image_id = target["image_id"].item() + file_name = f"{image_id:012d}.png" + res_pano[i]["image_id"] = image_id + res_pano[i]["file_name"] = file_name + + panoptic_evaluator.update(res_pano) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + if coco_evaluator is not None: + coco_evaluator.synchronize_between_processes() + if panoptic_evaluator is not None: + panoptic_evaluator.synchronize_between_processes() + + # accumulate predictions from all images + if coco_evaluator is not None: + coco_evaluator.accumulate() + coco_evaluator.summarize() + panoptic_res = None + if panoptic_evaluator is not None: + panoptic_res = panoptic_evaluator.summarize() + stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + if coco_evaluator is not None: + # if 'bbox' in postprocessors.keys(): + # stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() + if 'segm' in postprocessors.keys(): + stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() + if panoptic_res is not None: + stats['PQ_all'] = panoptic_res["All"] + stats['PQ_th'] = panoptic_res["Things"] + stats['PQ_st'] = panoptic_res["Stuff"] + # return stats, coco_evaluator + return coco_evaluator.coco_eval['bbox'].stats, coco_evaluator diff --git a/main.py b/main.py new file mode 100644 index 0000000..94aaf74 --- /dev/null +++ b/main.py @@ -0,0 +1,323 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + + +import os +import argparse +import datetime +import json +import random +import time +from pathlib import Path + +import numpy as np +import torch +from torch.utils.data import DataLoader +import datasets +import datasets.DAOD as DAOD +import util.misc as utils +import datasets.samplers as samplers +from datasets import build_dataset, get_coco_api_from_dataset +from models import build_model +from config import get_cfg_defaults + + +def setup(args): + cfg = get_cfg_defaults() + if args.config_file: + cfg.merge_from_file(args.config_file) + if args.opts: + cfg.merge_from_list(args.opts) + utils.init_distributed_mode(cfg) + cfg.freeze() + if cfg.OUTPUT_DIR: + Path(cfg.OUTPUT_DIR).mkdir(parents=True, exist_ok=True) + os.system(f'cp {args.config_file} {cfg.OUTPUT_DIR}') + ddetr_src = 'models/deformable_detr.py' + ddetr_des = Path(cfg.OUTPUT_DIR) / 'deformable_detr.py.backup' + dtrans_src = 'models/deformable_transformer.py' + dtrans_des = Path(cfg.OUTPUT_DIR) / 'deformable_transformer.py.backup' + main_src = 'main.py' + main_des = Path(cfg.OUTPUT_DIR) / 'main.py.backup' + os.system(f'cp {ddetr_src} {ddetr_des}') + os.system(f'cp {dtrans_src} {dtrans_des}') + os.system(f'cp {main_src} {main_des}') + return cfg + +def main(cfg): + # align = cfg.MODEL.BACKBONE_ALIGN or cfg.MODEL.SPACE_ALIGN or cfg.MODEL.CHANNEL_ALIGN or cfg.MODEL.INSTANCE_ALIGN + # assert align == (cfg.DATASET.DA_MODE == 'uda') + # print("git:\n {}\n".format(utils.get_sha())) + print(cfg) + if cfg.DATASET.DA_MODE == 'osda': + from engine_aood import evaluate, train_one_epoch + else: + from engine import evaluate, train_one_epoch + if cfg.MODEL.FROZEN_WEIGHTS is not None: + assert cfg.MODEL.MASKS, "Frozen training is meant for segmentation only" + + device = torch.device(cfg.DEVICE) + # fix the seed for reproducibility + seed = cfg.SEED + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + model, criterion, postprocessors = build_model(cfg) + model.to(device) + + model_without_ddp = model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('number of params:', n_parameters) + + dataset_train = build_dataset(image_set='train', cfg=cfg) + dataset_val = build_dataset(image_set='val', cfg=cfg) + + if cfg.DIST.DISTRIBUTED: + if cfg.CACHE_MODE: + sampler_train = samplers.NodeDistributedSampler(dataset_train) + sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False) + else: + sampler_train = samplers.DistributedSampler(dataset_train) + sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + if cfg.DATASET.DA_MODE == 'uda' or cfg.DATASET.DA_MODE == 'osda': + assert cfg.TRAIN.BATCH_SIZE % 2 == 0, f'cfg.TRAIN.BATCH_SIZE {cfg.TRAIN.BATCH_SIZE} should be a multiple of 2' + batch_sampler_train = torch.utils.data.BatchSampler( + sampler_train, cfg.TRAIN.BATCH_SIZE//2, drop_last=True) + data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, + collate_fn=DAOD.collate_fn, num_workers=cfg.NUM_WORKERS, + pin_memory=True) + else: + batch_sampler_train = torch.utils.data.BatchSampler( + sampler_train, cfg.TRAIN.BATCH_SIZE, drop_last=True) + data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, + collate_fn=utils.collate_fn, num_workers=cfg.NUM_WORKERS, + pin_memory=True) + data_loader_val = DataLoader(dataset_val, cfg.TRAIN.BATCH_SIZE, sampler=sampler_val, + drop_last=False, collate_fn=utils.collate_fn, num_workers=cfg.NUM_WORKERS, + pin_memory=True) + + # lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"] + def match_name_keywords(n, name_keywords): + out = False + for b in name_keywords: + if b in n: + out = True + break + return out + + param_dicts = [ + { + "params": + [p for n, p in model_without_ddp.named_parameters() + if not match_name_keywords(n, cfg.TRAIN.LR_BACKBONE_NAMES) and not match_name_keywords(n, cfg.TRAIN.LR_LINEAR_PROJ_NAMES) and p.requires_grad], + "lr": cfg.TRAIN.LR, + }, + { + "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, cfg.TRAIN.LR_BACKBONE_NAMES) and p.requires_grad], + "lr": cfg.TRAIN.LR_BACKBONE, + }, + { + "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, cfg.TRAIN.LR_LINEAR_PROJ_NAMES) and p.requires_grad], + "lr": cfg.TRAIN.LR * cfg.TRAIN.LR_LINEAR_PROJ_MULT, + } + ] + if cfg.TRAIN.SGD: + optimizer = torch.optim.SGD(param_dicts, lr=cfg.TRAIN.LR, momentum=0.9, + weight_decay=cfg.TRAIN.WEIGHT_DECAY) + else: + optimizer = torch.optim.AdamW(param_dicts, lr=cfg.TRAIN.LR, + weight_decay=cfg.TRAIN.WEIGHT_DECAY) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, cfg.TRAIN.LR_DROP) + + if cfg.DIST.DISTRIBUTED: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[cfg.DIST.GPU]) + model_without_ddp = model.module + + if cfg.DATASET.DATASET_FILE == "coco_panoptic": + # We also evaluate AP during panoptic training, on original coco DS + coco_val = datasets.coco.build("val", cfg) + base_ds = get_coco_api_from_dataset(coco_val) + else: + base_ds = get_coco_api_from_dataset(dataset_val) + + if cfg.MODEL.FROZEN_WEIGHTS is not None: + checkpoint = torch.load(cfg.MODEL.FROZEN_WEIGHTS, map_location='cpu') + model_without_ddp.detr.load_state_dict(checkpoint['model']) + + output_dir = Path(cfg.OUTPUT_DIR) + + import logging + LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" + logging.basicConfig( + filename=cfg.OUTPUT_DIR +'/_rank_{}_'.format(utils.get_rank())+str(__file__)[:-3] + '_' + time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + '.log', + level=logging.INFO, format=LOG_FORMAT, filemode='w') + console = logging.StreamHandler() + console.setLevel(logging.INFO) + logging.getLogger('').addHandler(console) + logger = logging.getLogger("main_da") + + if cfg.RESUME: # [BUG] write after freezing cfgs + if cfg.RESUME.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + cfg.RESUME, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(cfg.RESUME, map_location='cpu') + missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False) + unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))] + if len(missing_keys) > 0: + print('Missing Keys: {}'.format(missing_keys)) + if len(unexpected_keys) > 0: + print('Unexpected Keys: {}'.format(unexpected_keys)) + if not cfg.EVAL and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint and cfg.LOAD_OPTIMIZER: + import copy + p_groups = copy.deepcopy(optimizer.param_groups) + optimizer.load_state_dict(checkpoint['optimizer']) + for pg, pg_old in zip(optimizer.param_groups, p_groups): + pg['lr'] = pg_old['lr'] + pg['initial_lr'] = pg_old['initial_lr'] + # print(optimizer.param_groups) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + # todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance). + override_resumed_lr_drop = True + if override_resumed_lr_drop: + print('Warning: (hack) override_resumed_lr_drop is set to True, so cfg.TRAIN.LR_DROP would override lr_drop in resumed lr_scheduler.') + lr_scheduler.step_size = cfg.TRAIN.LR_DROP + lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) + lr_scheduler.step(lr_scheduler.last_epoch) + cfg.START_EPOCH = checkpoint['epoch'] + 1 + + + + # check the resumed model + # if not cfg.EVAL: + # test_stats, coco_evaluator = evaluate( + # model, criterion, postprocessors, data_loader_val, base_ds, device, cfg.OUTPUT_DIR + # ) + # if utils.is_main_process(): + # for key in test_stats.keys(): + # logger.info('{} {}'.format(key, test_stats[key])) + + if cfg.EVAL: + test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, + data_loader_val, base_ds, device, cfg.OUTPUT_DIR) + + if cfg.OUTPUT_DIR: + title = test_stats['title'] + '\n' + results = test_stats['ap_map_wi_aose_ar'] + results = 'Epoch : ' + results + '\n' + if utils.is_main_process(): + results_dir = output_dir /'eval_results.txt' + + with open(results_dir, 'a') as f: + f.write(title) + + + + + utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") + return + + print("Start training") + start_time = time.time() + + best_mAP = 0 + checkpoint_dir = 'check' + for epoch in range(cfg.START_EPOCH, cfg.TRAIN.EPOCHS): + if cfg.DIST.DISTRIBUTED: + sampler_train.set_epoch(epoch) + train_stats = train_one_epoch( + model, criterion, data_loader_train, optimizer, device, epoch, cfg.TRAIN.CLIP_MAX_NORM, logger=logger) + lr_scheduler.step() + if cfg.OUTPUT_DIR: + checkpoint_paths = [output_dir / 'checkpoint.pth'] + # extra checkpoint before LR drop and every 5 epochs + if (epoch + 1) % cfg.TRAIN.LR_DROP == 0: + checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') + + for checkpoint_path in checkpoint_paths: + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'cfg': cfg, + }, checkpoint_path) + + test_stats, coco_evaluator = evaluate( + model, criterion, postprocessors, data_loader_val, base_ds, device, cfg.OUTPUT_DIR + ) + mAP_tmp = test_stats['base_mAP: '] + if mAP_tmp > best_mAP and utils.is_main_process(): + + if os.path.exists(checkpoint_dir): + os.remove(checkpoint_dir) + + checkpoint_dir = output_dir / f'best_{epoch:02}_{round(mAP_tmp,3)}.pth' + + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'cfg': cfg, + }, checkpoint_dir) + best_mAP = mAP_tmp + + + if cfg.OUTPUT_DIR and utils.is_main_process(): + title = test_stats['title'] + '\n' + results = test_stats['ap_map_wi_aose_ar'] + results = 'Epoch {}: '.format(epoch) + results + '\n' + results_dir = output_dir /'eval_results.txt' + if utils.is_main_process(): + if epoch == 0: + with open(results_dir, 'a') as f: + f.write(title) + + with open(results_dir, 'a') as f: + f.write(results) + + + # if utils.is_main_process(): + for key in test_stats.keys(): + logger.info('{} {}'.format(key, test_stats[key])) + # for evaluation logs + if coco_evaluator is not None: + (output_dir / 'eval').mkdir(exist_ok=True) + if "bbox" in coco_evaluator.coco_eval: + filenames = ['latest.pth'] + if epoch % 50 == 0: + filenames.append(f'{epoch:03}.pth') + for name in filenames: + torch.save(coco_evaluator.coco_eval["bbox"].eval, + output_dir / "eval" / name) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('Deformable DETR Detector') + parser.add_argument('--config_file', default='', type=str) + parser.add_argument("--opts", default=None, nargs=argparse.REMAINDER) + args = parser.parse_args() + cfg = setup(args) + main(cfg) diff --git a/main_multi_eval.py b/main_multi_eval.py new file mode 100644 index 0000000..bce75de --- /dev/null +++ b/main_multi_eval.py @@ -0,0 +1,341 @@ +# ------------------------------------------------------------------------ +# Novel Scenes & Classes: Towards Adaptive Open-set Object Detection +# Modified by Wuyang Li +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ +import os +import argparse +import datetime +import json +import random +import time +from pathlib import Path + +import numpy as np +import torch +from torch.utils.data import DataLoader +import datasets +import datasets.DAOD as DAOD +import util.misc as utils +import datasets.samplers as samplers +from datasets import build_dataset, get_coco_api_from_dataset + +from models import build_model +from config import get_cfg_defaults +import logging + +def setup(args): + cfg = get_cfg_defaults() + if args.config_file: + cfg.merge_from_file(args.config_file) + if args.opts: + cfg.merge_from_list(args.opts) + + utils.init_distributed_mode(cfg) + cfg.freeze() + + if cfg.OUTPUT_DIR: + Path(cfg.OUTPUT_DIR).mkdir(parents=True, exist_ok=True) + os.system(f'cp {args.config_file} {cfg.OUTPUT_DIR}') + ddetr_src = 'models/motif_detr.py' + ddetr_des = Path(cfg.OUTPUT_DIR) / 'motif_detr.py.backup' + dtrans_src = 'models/deformable_transformer.py' + dtrans_des = Path(cfg.OUTPUT_DIR) / 'deformable_transformer.py.backup' + main_src = 'main.py' + main_des = Path(cfg.OUTPUT_DIR) / 'main.py.backup' + os.system(f'cp {ddetr_src} {ddetr_des}') + os.system(f'cp {dtrans_src} {dtrans_des}') + os.system(f'cp {main_src} {main_des}') + + return cfg + + +def main(cfg): + + # align = cfg.MODEL.BACKBONE_ALIGN or cfg.MODEL.SPACE_ALIGN or cfg.MODEL.CHANNEL_ALIGN or cfg.MODEL.INSTANCE_ALIGN + # assert align == (cfg.DATASET.DA_MODE == 'uda') + # print("git:\n {}\n".format(utils.get_sha())) + + print(cfg) + + if cfg.DATASET.DA_MODE == 'aood': + from engine_aood import evaluate, train_one_epoch + else: + from engine import evaluate, train_one_epoch + + if cfg.MODEL.FROZEN_WEIGHTS is not None: + assert cfg.MODEL.MASKS, "Frozen training is meant for segmentation only" + + device = torch.device(cfg.DEVICE) + # fix the seed for reproducibility + seed = cfg.SEED + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + model, criterion, postprocessors = build_model(cfg) + model.to(device) + + model_without_ddp = model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('number of params:', n_parameters) + + dataset_train = build_dataset(image_set='train', cfg=cfg) + + if not cfg.DATASET.DATASET_FILE == 'pascal_to_clipart': + num_eval_novel_classes = [3,4,5] # eval with 3/4/5 novel classes + else: + num_eval_novel_classes = [6,8,10] # eval with 6/8/10 novel classes + num_sub_tasks = len(num_eval_novel_classes) + + dataset_val_list = [] + sampler_val_list = [] + + for i in range(num_sub_tasks): + dataset_val_list.append(build_dataset(image_set='val', cfg=cfg, multi_task_eval_id=num_eval_novel_classes[i])) + + if cfg.DIST.DISTRIBUTED: + if cfg.CACHE_MODE: + sampler_train = samplers.NodeDistributedSampler(dataset_train) + # sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False) + for dataloader in dataset_val_list: + sampler_val_list.append(samplers.NodeDistributedSampler(dataloader, shuffle=False)) + else: + sampler_train = samplers.DistributedSampler(dataset_train) + # sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False) + for dataloader in dataset_val_list: + sampler_val_list.append(samplers.DistributedSampler(dataloader, shuffle=False)) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + # sampler_val = torch.utils.data.SequentialSampler(dataset_val) + for dataloader in dataset_val_list: + sampler_val_list.append(torch.utils.data.SequentialSampler(dataloader)) + + if cfg.DATASET.DA_MODE == 'uda' or cfg.DATASET.DA_MODE == 'aood': + assert cfg.TRAIN.BATCH_SIZE % 2 == 0, f'cfg.TRAIN.BATCH_SIZE {cfg.TRAIN.BATCH_SIZE} should be a multiple of 2' + batch_sampler_train = torch.utils.data.BatchSampler( + sampler_train, cfg.TRAIN.BATCH_SIZE//2, drop_last=True) + data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, + collate_fn=DAOD.collate_fn, num_workers=cfg.NUM_WORKERS, + pin_memory=True) + else: + batch_sampler_train = torch.utils.data.BatchSampler( + sampler_train, cfg.TRAIN.BATCH_SIZE, drop_last=True) + data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, + collate_fn=utils.collate_fn, num_workers=cfg.NUM_WORKERS, + pin_memory=True) + + data_loader_val_list = [] + for i in range(num_sub_tasks): + data_loader_val_list.append( + DataLoader(dataset_val_list[i], cfg.TRAIN.BATCH_SIZE, sampler=sampler_val_list[i], + drop_last=False, collate_fn=utils.collate_fn, num_workers=cfg.NUM_WORKERS, + pin_memory=True)) + + # data_loader_val = DataLoader(dataset_val, cfg.TRAIN.BATCH_SIZE, sampler=sampler_val, + # drop_last=False, collate_fn=utils.collate_fn, num_workers=cfg.NUM_WORKERS, + # pin_memory=True) + # lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"] + def match_name_keywords(n, name_keywords): + out = False + for b in name_keywords: + if b in n: + out = True + break + return out + + param_dicts = [ + { + "params": + [p for n, p in model_without_ddp.named_parameters() + if not match_name_keywords(n, cfg.TRAIN.LR_BACKBONE_NAMES) and not match_name_keywords(n, cfg.TRAIN.LR_LINEAR_PROJ_NAMES) and p.requires_grad], + "lr": cfg.TRAIN.LR, + }, + { + "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, cfg.TRAIN.LR_BACKBONE_NAMES) and p.requires_grad], + "lr": cfg.TRAIN.LR_BACKBONE, + }, + { + "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, cfg.TRAIN.LR_LINEAR_PROJ_NAMES) and p.requires_grad], + "lr": cfg.TRAIN.LR * cfg.TRAIN.LR_LINEAR_PROJ_MULT, + } + ] + if cfg.TRAIN.SGD: + optimizer = torch.optim.SGD(param_dicts, lr=cfg.TRAIN.LR, momentum=0.9, + weight_decay=cfg.TRAIN.WEIGHT_DECAY) + else: + optimizer = torch.optim.AdamW(param_dicts, lr=cfg.TRAIN.LR, + weight_decay=cfg.TRAIN.WEIGHT_DECAY) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, cfg.TRAIN.LR_DROP) + + if cfg.DIST.DISTRIBUTED: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[cfg.DIST.GPU]) + model_without_ddp = model.module + + if cfg.DATASET.DATASET_FILE == "coco_panoptic": + # We also evaluate AP during panoptic training, on original coco DS + coco_val = datasets.coco.build("val", cfg) + base_ds = get_coco_api_from_dataset(coco_val) + else: + # base_ds = get_coco_api_from_dataset(dataset_val) + base_ds_list =[] + for i in range(num_sub_tasks): + base_ds_list.append(get_coco_api_from_dataset(dataset_val_list[i])) + + if cfg.MODEL.FROZEN_WEIGHTS is not None: + checkpoint = torch.load(cfg.MODEL.FROZEN_WEIGHTS, map_location='cpu') + model_without_ddp.detr.load_state_dict(checkpoint['model']) + output_dir = Path(cfg.OUTPUT_DIR) + + LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" + logging.basicConfig( + filename=cfg.OUTPUT_DIR +'/_rank_{}_'.format(utils.get_rank())+str(__file__)[:-3] + '_' + time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + '.log', + level=logging.INFO, format=LOG_FORMAT, filemode='w') + console = logging.StreamHandler() + console.setLevel(logging.INFO) + logging.getLogger('').addHandler(console) + logger = logging.getLogger("main_da") + + if cfg.RESUME: # [BUG] write after freezing cfgs + if cfg.RESUME.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + cfg.RESUME, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(cfg.RESUME, map_location='cpu') + missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False) + unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))] + if len(missing_keys) > 0: + print('Missing Keys: {}'.format(missing_keys)) + if len(unexpected_keys) > 0: + print('Unexpected Keys: {}'.format(unexpected_keys)) + if not cfg.EVAL and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint and cfg.LOAD_OPTIMIZER: + import copy + p_groups = copy.deepcopy(optimizer.param_groups) + optimizer.load_state_dict(checkpoint['optimizer']) + for pg, pg_old in zip(optimizer.param_groups, p_groups): + pg['lr'] = pg_old['lr'] + pg['initial_lr'] = pg_old['initial_lr'] + # print(optimizer.param_groups) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + # todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance). + override_resumed_lr_drop = True + if override_resumed_lr_drop: + print('Warning: (hack) override_resumed_lr_drop is set to True, so cfg.TRAIN.LR_DROP would override lr_drop in resumed lr_scheduler.') + lr_scheduler.step_size = cfg.TRAIN.LR_DROP + lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) + lr_scheduler.step(lr_scheduler.last_epoch) + cfg.START_EPOCH = checkpoint['epoch'] + 1 + if cfg.EVAL: + epoch = 0 + per_task_results = [] + for i in range(num_sub_tasks): + test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, + data_loader_val_list[i], base_ds_list[i], device, cfg.OUTPUT_DIR) + title = test_stats['title'] + '\n' + + per_task_results += test_stats['report_results'] + # import ipdb; ipdb.set_trace() + results = 'Epoch {}: '.format(epoch) + test_stats['ap_map_wi_aose_ar'] + '\n' + results_dir = output_dir /'eval_results.txt' + if utils.is_main_process(): + if epoch == 0: + with open(results_dir, 'a') as f: + f.write(title) + + with open(results_dir, 'a') as f: + f.write(results) + if utils.is_main_process(): + report_results = 'Epoch {}: '.format(epoch) + ' & '.join(per_task_results) + '\n' + with open(results_dir, 'a') as f: + f.write(report_results) + f.write(''.join(100*['='])) + + utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") + return + + print("Start training") + start_time = time.time() + + best_mAP = 0 + checkpoint_dir = 'initial_dir' + + for epoch in range(cfg.START_EPOCH, cfg.TRAIN.EPOCHS): + if cfg.DIST.DISTRIBUTED: + sampler_train.set_epoch(epoch) + train_stats = train_one_epoch( + model, criterion, data_loader_train, optimizer, device, epoch, cfg.TRAIN.CLIP_MAX_NORM, logger=logger) + lr_scheduler.step() + + if epoch>cfg.EVAL_EPOCH: + per_task_results = [] + results_dir = output_dir /'eval_results.txt' + for i in range(num_sub_tasks): + test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, + data_loader_val_list[i], base_ds_list[i], device, cfg.OUTPUT_DIR) + if utils.is_main_process(): + title = test_stats['title'] + '\n' + per_task_results +=test_stats['report_results'] + results = '[Epoch {}] [Task {}]'.format(epoch,i) + test_stats['ap_map_wi_aose_ar'] + '\n' + + if epoch == 0 and i ==0: + with open(results_dir, 'a') as f: + f.write(title) + + with open(results_dir, 'a') as f: + f.write(results) + if utils.is_main_process(): + report_results = 'Epoch {}: '.format(epoch) + ' & '.join(per_task_results) + '\n' + with open(results_dir, 'a') as f: + f.write(report_results) + f.write(''.join(100*['=']) + '\n') + + mAP_tmp = test_stats['base_mAP: '] + if mAP_tmp > best_mAP: + + if os.path.exists(checkpoint_dir): + os.remove(checkpoint_dir) + + checkpoint_dir = output_dir / f'best_{epoch:02}_{round(mAP_tmp,3)}.pth' + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'cfg': cfg, + }, checkpoint_dir) + best_mAP = mAP_tmp + # saveing more checkpoints + if epoch > cfg.TRAIN.LR_DROP-6 and epoch % 2 == 0: + model_dir = output_dir / f'model_{epoch:02}.pth' + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'cfg': cfg, + }, model_dir) + + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + +if __name__ == '__main__': + parser = argparse.ArgumentParser('Deformable DETR Detector') + parser.add_argument('--config_file', default='', type=str) + parser.add_argument("--opts", default=None, nargs=argparse.REMAINDER) + args = parser.parse_args() + cfg = setup(args) + main(cfg) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..17bade6 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,25 @@ +# ------------------------------------------------------------------------ +# Modified by Wuyang LI +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ +def build_model(cfg): + if cfg.AOOD.OW_DETR_ON: + print('using ow detr!') + from .ow_detr import build + elif cfg.AOOD.MOTIF_ON: + print('using motif detr!') + from .motif_detr import build + else: + print('using def detr!') + from .deformable_detr import build + return build(cfg) + + diff --git a/models/backbone.py b/models/backbone.py new file mode 100644 index 0000000..b9d488e --- /dev/null +++ b/models/backbone.py @@ -0,0 +1,139 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Backbone modules. +""" +from collections import OrderedDict +from torchvision.models.resnet import resnet50 +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List +from util.misc import NestedTensor, is_main_process +from .position_encoding import build_position_encoding + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n, eps=1e-5): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + self.eps = eps + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = self.eps + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool): + super().__init__() + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + if return_interm_layers: + # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} + self.strides = [8, 16, 32] + self.num_channels = [512, 1024, 2048] + else: + return_layers = {'layer4': "0"} + self.strides = [32] + self.num_channels = [2048] + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + norm_layer = FrozenBatchNorm2d + # backbone = getattr(torchvision.models, name)( + # replace_stride_with_dilation=[False, False, dilation], + # pretrained=is_main_process(), norm_layer=norm_layer) + backbone = resnet50(pretrained=False, replace_stride_with_dilation=[False, False, dilation], + norm_layer=norm_layer) + state_dict = torch.load('./dino_resnet50_pretrain.pth') + backbone.load_state_dict(state_dict, strict=False) + assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded" + super().__init__(backbone, train_backbone, return_interm_layers) + if dilation: + self.strides[-1] = self.strides[-1] // 2 + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + self.strides = backbone.strides + self.num_channels = backbone.num_channels + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in sorted(xs.items()): + out.append(x) + + # position encoding + for x in out: + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + +def build_backbone(cfg): + position_embedding = build_position_encoding(cfg) + train_backbone = cfg.TRAIN.LR_BACKBONE > 0 + return_interm_layers = cfg.MODEL.MASKS or (cfg.MODEL.NUM_FEATURE_LEVELS > 1) + backbone = Backbone(cfg.MODEL.BACKBONE, train_backbone, return_interm_layers, cfg.MODEL.DILATION) + model = Joiner(backbone, position_embedding) + return model diff --git a/models/deformable_detr.py b/models/deformable_detr.py new file mode 100644 index 0000000..14d5e97 --- /dev/null +++ b/models/deformable_detr.py @@ -0,0 +1,580 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Deformable DETR model and criterion classes. +""" +import torch +import torch.nn.functional as F +from torch import nn +import math + +from util import box_ops +from util.misc import (NestedTensor, nested_tensor_from_tensor_list, + accuracy, get_world_size, interpolate, + is_dist_avail_and_initialized, inverse_sigmoid) + +from .backbone import build_backbone +from .matcher import build_matcher +from .segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm, + dice_loss, sigmoid_focal_loss) +from .deformable_transformer import build_deforamble_transformer +from .utils import GradientReversal +import copy + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DeformableDETR(nn.Module): + """ This is the Deformable DETR module that performs object detection """ + def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels, + aux_loss=True, with_box_refine=False, two_stage=False, + backbone_align=False, space_align=False, channel_align=False, instance_align=False): + """ Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + with_box_refine: iterative bounding box refinement + two_stage: two-stage Deformable DETR + """ + super().__init__() + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + self.class_embed = nn.Linear(hidden_dim, num_classes) + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + self.num_feature_levels = num_feature_levels + if not two_stage: + self.query_embed = nn.Embedding(num_queries, hidden_dim*2) + if num_feature_levels > 1: + num_backbone_outs = len(backbone.strides) + input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = backbone.num_channels[_] + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + )) + for _ in range(num_feature_levels - num_backbone_outs): + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(32, hidden_dim), + )) + in_channels = hidden_dim + self.input_proj = nn.ModuleList(input_proj_list) + else: + self.input_proj = nn.ModuleList([ + nn.Sequential( + nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + )]) + self.backbone = backbone + self.aux_loss = aux_loss + self.with_box_refine = with_box_refine + self.two_stage = two_stage + # self.uda = backbone_align or space_align or channel_align or instance_align + self.uda = True + self.backbone_align = backbone_align + self.space_align = space_align + self.channel_align = channel_align + self.instance_align = instance_align + + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(num_classes) * bias_value + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + for proj in self.input_proj: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + + # if two-stage, the last class_embed and bbox_embed is for region proposal generation + num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers + if with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) + # hack implementation for iterative bounding box refinement + self.transformer.decoder.bbox_embed = self.bbox_embed + else: + nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) + self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) + self.transformer.decoder.bbox_embed = None + if two_stage: + # hack implementation for two-stage + self.transformer.decoder.class_embed = self.class_embed + for box_embed in self.bbox_embed: + nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) + if backbone_align: + self.grl = GradientReversal() + self.backbone_D = MLP(hidden_dim, hidden_dim, 1, 3) + for layer in self.backbone_D.layers: + nn.init.xavier_uniform_(layer.weight, gain=1) + nn.init.constant_(layer.bias, 0) + if space_align: + self.space_D = MLP(hidden_dim, hidden_dim, 1, 3) + for layer in self.space_D.layers: + nn.init.xavier_uniform_(layer.weight, gain=1) + nn.init.constant_(layer.bias, 0) + if channel_align: + self.channel_D = MLP(hidden_dim, hidden_dim, 1, 3) + for layer in self.channel_D.layers: + nn.init.xavier_uniform_(layer.weight, gain=1) + nn.init.constant_(layer.bias, 0) + if instance_align: + self.instance_D = MLP(hidden_dim, hidden_dim, 1, 3) + for layer in self.instance_D.layers: + nn.init.xavier_uniform_(layer.weight, gain=1) + nn.init.constant_(layer.bias, 0) + + def forward(self, samples: NestedTensor): + """ The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + srcs = [] + masks = [] + for l, feat in enumerate(features): + src, mask = feat.decompose() + srcs.append(self.input_proj[l](src)) + masks.append(mask) + assert mask is not None + if self.num_feature_levels > len(srcs): + _len_srcs = len(srcs) + for l in range(_len_srcs, self.num_feature_levels): + if l == _len_srcs: + src = self.input_proj[l](features[-1].tensors) + else: + src = self.input_proj[l](srcs[-1]) + m = samples.mask + mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] + pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) + srcs.append(src) + masks.append(mask) + pos.append(pos_l) + + query_embeds = None + if not self.two_stage: + query_embeds = self.query_embed.weight + hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, da_output, _ = self.transformer(srcs, masks, pos, query_embeds) + + outputs_classes = [] + outputs_coords = [] + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.class_embed[lvl](hs[lvl]) + tmp = self.bbox_embed[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = tmp.sigmoid() + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + outputs_class = torch.stack(outputs_classes) + outputs_coord = torch.stack(outputs_coords) + + if self.training and self.uda: + B = outputs_class.shape[1] + outputs_class = outputs_class[:, :B//2] + outputs_coord = outputs_coord[:, :B//2] + if self.two_stage: + enc_outputs_class = enc_outputs_class[:B//2] + enc_outputs_coord_unact = enc_outputs_coord_unact[:B//2] + if self.backbone_align: + da_output['backbone'] = torch.cat([self.backbone_D(self.grl(src.flatten(2).transpose(1, 2))) for src in srcs], dim=1) + if self.space_align: + da_output['space_query'] = self.space_D(da_output['space_query']) + if self.channel_align: + da_output['channel_query'] = self.channel_D(da_output['channel_query']) + if self.instance_align: + da_output['instance_query'] = self.instance_D(da_output['instance_query']) + + out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} + if self.aux_loss: + out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) + + if self.two_stage: + enc_outputs_coord = enc_outputs_coord_unact.sigmoid() + out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord} + + if self.training and self.uda: + out['da_output'] = da_output + + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25, da_gamma=2,cfg=None): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + losses: list of all the losses to be applied. See get_loss for list of available losses. + focal_alpha: alpha in Focal Loss + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.focal_alpha = focal_alpha + self.da_gamma = da_gamma + self.placeholder = cfg.OPEN_SET.PLACEHOLDER + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], + dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) + target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) + + target_classes_onehot = target_classes_onehot[:,:,:-1] + + + if self.placeholder > 0: + obj_idx = target_classes_onehot.sum(-1) > 0 + tmp = target_classes_onehot[obj_idx] + tmp[:,-1] = self.placeholder + target_classes_onehot[obj_idx] = tmp + + loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + + losses = {} + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + + src_masks = outputs["pred_masks"] + + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose() + target_masks = target_masks.to(src_masks) + + src_masks = src_masks[src_idx] + # upsample predictions to the target size + src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], + mode="bilinear", align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks[tgt_idx].flatten(1) + + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + + def loss_da(self, outputs, use_focal=False): + B = outputs.shape[0] + assert B % 2 == 0 + + targets = torch.empty_like(outputs) + targets[:B//2] = 0 + targets[B//2:] = 1 + + loss = F.binary_cross_entropy_with_logits(outputs, targets, reduction='none') + + if use_focal: + prob = outputs.sigmoid() + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = loss * ((1 - p_t) ** self.da_gamma) + + return loss.mean() + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + 'masks': self.loss_masks + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, samples, outputs, targets, epoch=0): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + kwargs = {} + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs['log'] = False + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + if 'enc_outputs' in outputs: + enc_outputs = outputs['enc_outputs'] + bin_targets = copy.deepcopy(targets) + for bt in bin_targets: + bt['labels'] = torch.zeros_like(bt['labels']) + indices = self.matcher(enc_outputs, bin_targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs['log'] = False + l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_enc': v for k, v in l_dict.items()} + losses.update(l_dict) + + if 'da_output' in outputs: + for k, v in outputs['da_output'].items(): + losses[f'loss_{k}'] = self.loss_da(v, use_focal='query' in k) + + return losses + + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + + @torch.no_grad() + def forward(self, outputs, target_sizes): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] + + return results + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +def build(cfg): + device = torch.device(cfg.DEVICE) + + backbone = build_backbone(cfg) + + transformer = build_deforamble_transformer(cfg) + model = DeformableDETR( + backbone, + transformer, + num_classes=cfg.DATASET.NUM_CLASSES, + num_queries=cfg.MODEL.NUM_QUERIES, + num_feature_levels=cfg.MODEL.NUM_FEATURE_LEVELS, + aux_loss=cfg.LOSS.AUX_LOSS, + with_box_refine=cfg.MODEL.WITH_BOX_REFINE, + two_stage=cfg.MODEL.TWO_STAGE, + backbone_align=cfg.MODEL.BACKBONE_ALIGN, + space_align=cfg.MODEL.SPACE_ALIGN, + channel_align=cfg.MODEL.CHANNEL_ALIGN, + instance_align=cfg.MODEL.INSTANCE_ALIGN + ) + if cfg.MODEL.MASKS: + model = DETRsegm(model, freeze_detr=(cfg.MODEL.FROZEN_WEIGHTS is not None)) + matcher = build_matcher(cfg) + weight_dict = {'loss_ce': cfg.LOSS.CLS_LOSS_COEF, 'loss_bbox': cfg.LOSS.BBOX_LOSS_COEF} + weight_dict['loss_giou'] = cfg.LOSS.GIOU_LOSS_COEF + if cfg.MODEL.MASKS: + weight_dict["loss_mask"] = cfg.LOSS.MASK_LOSS_COEF + weight_dict["loss_dice"] = cfg.LOSS.DICE_LOSS_COEF + # TODO this is a hack + if cfg.LOSS.AUX_LOSS: + aux_weight_dict = {} + for i in range(cfg.MODEL.DEC_LAYERS - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) + aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + weight_dict['loss_backbone'] = cfg.LOSS.BACKBONE_LOSS_COEF + weight_dict['loss_space_query'] = cfg.LOSS.SPACE_QUERY_LOSS_COEF + weight_dict['loss_channel_query'] = cfg.LOSS.CHANNEL_QUERY_LOSS_COEF + weight_dict['loss_instance_query'] = cfg.LOSS.INSTANCE_QUERY_LOSS_COEF + + # losses = ['labels', 'boxes', 'cardinality'] + losses = ['labels', 'boxes'] + if cfg.MODEL.MASKS: + losses += ["masks"] + # num_classes, matcher, weight_dict, losses, focal_alpha=0.25 + criterion = SetCriterion(cfg.DATASET.NUM_CLASSES, matcher, weight_dict, losses, focal_alpha=cfg.LOSS.FOCAL_ALPHA, da_gamma=cfg.LOSS.DA_GAMMA,cfg=cfg) + criterion.to(device) + postprocessors = {'bbox': PostProcess()} + if cfg.MODEL.MASKS: + postprocessors['segm'] = PostProcessSegm() + if cfg.DATASET.DATASET_FILE == "coco_panoptic": + is_thing_map = {i: i <= 90 for i in range(201)} + postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85) + + return model, criterion, postprocessors \ No newline at end of file diff --git a/models/deformable_transformer.py b/models/deformable_transformer.py new file mode 100644 index 0000000..0f57d4f --- /dev/null +++ b/models/deformable_transformer.py @@ -0,0 +1,476 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import copy +from typing import Optional, List +import math + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ + +from util.misc import inverse_sigmoid +from models.ops.modules import MSDeformAttn +from models.utils import DomainAttention, GradientReversal, remove_mask_and_warp + +class DeformableTransformer(nn.Module): + def __init__(self, d_model=256, nhead=8, + num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1, + activation="relu", return_intermediate_dec=False, + num_feature_levels=4, dec_n_points=4, enc_n_points=4, + two_stage=False, two_stage_num_proposals=300, + space_align=False, channel_align=False, instance_align=False): + super().__init__() + + self.d_model = d_model + self.nhead = nhead + self.two_stage = two_stage + self.two_stage_num_proposals = two_stage_num_proposals + + self.space_align = space_align + self.channel_align = channel_align + self.instance_align = instance_align + + encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, + dropout, activation, + num_feature_levels, nhead, enc_n_points, space_align, channel_align) + self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) + + decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward, + dropout, activation, + num_feature_levels, nhead, dec_n_points, instance_align) + self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec) + + self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + + if two_stage: + self.enc_output = nn.Linear(d_model, d_model) + self.enc_output_norm = nn.LayerNorm(d_model) + self.pos_trans = nn.Linear(d_model * 2, d_model * 2) + self.pos_trans_norm = nn.LayerNorm(d_model * 2) + else: + self.reference_points = nn.Linear(d_model, 2) + + if space_align: + self.space_query = nn.Parameter(torch.empty(1, 1, d_model)) + if channel_align: + # self.channel_query is actually an embedding layer for channel query + # We keep the name for consistency + self.channel_query = nn.Linear(d_model, 1) + self.grl = GradientReversal() + if instance_align: + self.instance_query = nn.Parameter(torch.empty(1, 1, d_model)) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + if not self.two_stage: + xavier_uniform_(self.reference_points.weight.data, gain=1.0) + constant_(self.reference_points.bias.data, 0.) + normal_(self.level_embed) + + def get_proposal_pos_embed(self, proposals): + num_pos_feats = 128 + temperature = 10000 + scale = 2 * math.pi + + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) + return pos + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + N_, S_, C_ = memory.shape + base_scale = 4.0 + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), + torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += (H_ * W_) + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) + + output_memory = memory + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def forward(self, srcs, masks, pos_embeds, query_embed=None): + assert self.two_stage or query_embed is not None + + # prepare input for encoder + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + src = src.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + space_query, channel_query, instance_query = None, None, None + if self.training: + if self.space_align: + space_query = self.space_query.expand(src_flatten.shape[0], -1, -1) + if self.channel_align: + src_warped, pos_warped = remove_mask_and_warp( + src_flatten, lvl_pos_embed_flatten, mask_flatten, level_start_index, spatial_shapes + ) + channel_query = self.channel_query(self.grl(src_warped+pos_warped)).flatten(0, 1).transpose(1, 2) + + # encoder + memory, space_query, channel_query = self.encoder( + src_flatten, space_query, channel_query, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten + ) + + da_output = {} + if self.training: + if self.space_align: + da_output['space_query'] = torch.cat(space_query, dim=1) + if self.channel_align: + da_output['channel_query'] = torch.cat(channel_query, dim=1) + + # prepare input for decoder + bs, _, c = memory.shape + if self.two_stage: + output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) + + # hack implementation for two-stage Deformable DETR + enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) + enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) + query_embed, tgt = torch.split(pos_trans_out, c, dim=2) + else: + query_embed, tgt = torch.split(query_embed, c, dim=1) + query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) + tgt = tgt.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_embed).sigmoid() + init_reference_out = reference_points + + if self.training and self.instance_align: + instance_query = self.instance_query.expand(tgt.shape[0], -1, -1) + + # decoder + hs, inter_references, instance_query = self.decoder( + tgt, instance_query, reference_points, memory, spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten + ) + + if self.training and self.instance_align: + da_output['instance_query'] = instance_query + + inter_references_out = inter_references + if self.two_stage: + return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact, da_output, memory + return hs, init_reference_out, inter_references_out, None, None, da_output, memory + + +class DeformableTransformerEncoderLayer(nn.Module): + def __init__(self, + d_model=256, d_ffn=1024, + dropout=0.1, activation="relu", + n_levels=4, n_heads=8, n_points=4, + space_align=False, channel_align=False): + super().__init__() + + self.space_align = space_align + self.channel_align = channel_align + if space_align: + self.space_attn = DomainAttention(d_model, n_heads, dropout) + if channel_align: + self.channel_attn = DomainAttention(d_model, n_heads, dropout) + + # self attention + self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def forward(self, src, space_query, channel_query, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): + # self attention + src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) + src = src + self.dropout1(src2) + src = self.norm1(src) + + if self.training: + if self.space_align: + space_query = self.space_attn(space_query, src, pos, padding_mask) + if self.channel_align: + src_warped, pos_warped = remove_mask_and_warp(src, pos, padding_mask, level_start_index, spatial_shapes) + channel_query = self.channel_attn( + channel_query, # bsz * num_feature_levels, 1, H*W + src_warped.flatten(0, 1).transpose(1, 2), # bsz * num_feature_levels, C, H*W + pos_warped.flatten(0, 1).transpose(1, 2) + ) + + # ffn + src = self.forward_ffn(src) + + return src, space_query, channel_query + + +class DeformableTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + + ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward(self, src, space_query, channel_query, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): + output = src + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) + space_querys = [] + channel_querys = [] + for _, layer in enumerate(self.layers): + output, space_query, channel_query = layer( + output, space_query, channel_query, pos, reference_points, spatial_shapes, level_start_index, padding_mask + ) + space_querys.append(space_query) + channel_querys.append(channel_query) + + return output, space_querys, channel_querys + + +class DeformableTransformerDecoderLayer(nn.Module): + def __init__(self, d_model=256, d_ffn=1024, + dropout=0.1, activation="relu", + n_levels=4, n_heads=8, n_points=4, + instance_align=False): + super().__init__() + + self.instance_align = instance_align + if instance_align: + self.instance_attn = DomainAttention(d_model, n_heads, dropout) + + # cross attention + self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward(self, tgt, instance_query, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None): + # self attention + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # cross attention + tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), + reference_points, + src, src_spatial_shapes, level_start_index, src_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + if self.training and self.instance_align: + instance_query = self.instance_attn(instance_query, tgt, query_pos) + + # ffn + tgt = self.forward_ffn(tgt) + + return tgt, instance_query + + +class DeformableTransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + + def forward(self, tgt, instance_query, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios, + query_pos=None, src_padding_mask=None): + output = tgt + + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = reference_points[:, :, None] \ + * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] + output, instance_query = layer( + output, instance_query, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask + ) + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack(intermediate_reference_points), instance_query + + # https://github.com/fundamentalvision/Deformable-DETR/issues/43 + return [output], [reference_points], instance_query + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +def build_deforamble_transformer(cfg): + return DeformableTransformer( + d_model=cfg.MODEL.HIDDEN_DIM, + nhead=cfg.MODEL.NHEADS, + num_encoder_layers=cfg.MODEL.ENC_LAYERS, + num_decoder_layers=cfg.MODEL.DEC_LAYERS, + dim_feedforward=cfg.MODEL.DIM_FEEDFORWARD, + dropout=cfg.MODEL.DROPOUT, + activation="relu", + return_intermediate_dec=True, + num_feature_levels=cfg.MODEL.NUM_FEATURE_LEVELS, + dec_n_points=cfg.MODEL.DEC_N_POINTS, + enc_n_points=cfg.MODEL.ENC_N_POINTS, + two_stage=cfg.MODEL.TWO_STAGE, + two_stage_num_proposals=cfg.MODEL.NUM_QUERIES, + space_align=cfg.MODEL.SPACE_ALIGN, + channel_align=cfg.MODEL.CHANNEL_ALIGN, + instance_align=cfg.MODEL.INSTANCE_ALIGN) + + diff --git a/models/matcher.py b/models/matcher.py new file mode 100644 index 0000000..9fed154 --- /dev/null +++ b/models/matcher.py @@ -0,0 +1,104 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn + +from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, + cost_class: float = 1, + cost_bbox: float = 1, + cost_giou: float = 1): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" + + def forward(self, outputs, targets): + """ Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + with torch.no_grad(): + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. + alpha = 0.25 + gamma = 2.0 + neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) + pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), + box_cxcywh_to_xyxy(tgt_bbox)) + + # Final cost matrix + C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +def build_matcher(cfg): + return HungarianMatcher(cost_class=cfg.LOSS.SET_COST_CLASS, + cost_bbox=cfg.LOSS.SET_COST_BBOX, + cost_giou=cfg.LOSS.SET_COST_GIOU) diff --git a/models/motif_detr.py b/models/motif_detr.py new file mode 100644 index 0000000..9a205fd --- /dev/null +++ b/models/motif_detr.py @@ -0,0 +1,854 @@ +# ------------------------------------------------------------------------ +# Novel Scenes & Classes: Towards Adaptive Open-set Object Detection +# Modified by Wuyang Li +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Deformable DETR model and criterion classes. +""" +import torch +import torch.nn.functional as F +from torch import nn +import math + +from util import box_ops +from util.misc import (NestedTensor, nested_tensor_from_tensor_list, + accuracy, get_world_size, interpolate, + is_dist_avail_and_initialized, inverse_sigmoid) + +from .backbone import build_backbone +from .matcher import build_matcher +from .segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm, + dice_loss, sigmoid_focal_loss) +from .deformable_transformer import build_deforamble_transformer +from .utils import GradientReversal +import copy +from copy import deepcopy + +def sim_matrix(a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.clamp(a_n, min=eps) + b_norm = b / torch.clamp(b_n, min=eps) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DeformableDETR(nn.Module): + """ This is the Deformable DETR module that performs object detection """ + def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels, + aux_loss=True, with_box_refine=False, two_stage=False, from_cfg=None): + """ Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + with_box_refine: iterative bounding box refinement + two_stage: two-stage Deformable DETR + """ + super().__init__() + + self.from_cfg = from_cfg + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + self.class_embed = nn.Linear(hidden_dim, num_classes) + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + self.num_feature_levels = num_feature_levels + if not two_stage: + self.query_embed = nn.Embedding(num_queries, hidden_dim*2) + if num_feature_levels > 1: + num_backbone_outs = len(backbone.strides) + input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = backbone.num_channels[_] + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + )) + for _ in range(num_feature_levels - num_backbone_outs): + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(32, hidden_dim), + )) + in_channels = hidden_dim + self.input_proj = nn.ModuleList(input_proj_list) + else: + self.input_proj = nn.ModuleList([ + nn.Sequential( + nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + )]) + self.backbone = backbone + self.aux_loss = aux_loss + self.with_box_refine = with_box_refine + self.two_stage = two_stage + + self.da = self.from_cfg['da'] + self.backbone_align = self.from_cfg['backbone_align'] + self.space_align = self.from_cfg['space_align'] + self.channel_align = self.from_cfg['channel_align'] + self.instance_align = self.from_cfg['instance_align'] + + self.register_buffer('cls_means', torch.zeros(num_classes, 256)) + self.register_buffer('cls_stds', torch.zeros(num_classes, 256)) + + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(num_classes) * bias_value + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + for proj in self.input_proj: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + + # if two-stage, the last class_embed and bbox_embed is for region proposal generation + num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers + + if with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) + # hack implementation for iterative bounding box refinement + self.transformer.decoder.bbox_embed = self.bbox_embed + else: + nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) + self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) + self.transformer.decoder.bbox_embed = None + if two_stage: + # hack implementation for two-stage + self.transformer.decoder.class_embed = self.class_embed + for box_embed in self.bbox_embed: + nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) + if self.backbone_align: + self.grl = GradientReversal(lambda_=self.from_cfg['backbone_adv_lambda']) + self.backbone_D = MLP(hidden_dim, hidden_dim, 1, 3) + for layer in self.backbone_D.layers: + nn.init.xavier_uniform_(layer.weight, gain=1) + nn.init.constant_(layer.bias, 0) + if self.space_align: + self.space_D = MLP(hidden_dim, hidden_dim, 1, 3) + for layer in self.space_D.layers: + nn.init.xavier_uniform_(layer.weight, gain=1) + nn.init.constant_(layer.bias, 0) + if self.channel_align: + self.channel_D = MLP(hidden_dim, hidden_dim, 1, 3) + for layer in self.channel_D.layers: + nn.init.xavier_uniform_(layer.weight, gain=1) + nn.init.constant_(layer.bias, 0) + if self.instance_align: + self.instance_D = MLP(hidden_dim, hidden_dim, 1, 3) + for layer in self.instance_D.layers: + nn.init.xavier_uniform_(layer.weight, gain=1) + nn.init.constant_(layer.bias, 0) + + def forward(self, samples: NestedTensor, targets=None): + """ The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + out = {} + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + srcs = [] + masks = [] + for l, feat in enumerate(features): + src, mask = feat.decompose() + srcs.append(self.input_proj[l](src)) + masks.append(mask) + assert mask is not None + if self.num_feature_levels > len(srcs): + _len_srcs = len(srcs) + for l in range(_len_srcs, self.num_feature_levels): + if l == _len_srcs: + src = self.input_proj[l](features[-1].tensors) + else: + src = self.input_proj[l](srcs[-1]) + m = samples.mask + mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] + pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) + srcs.append(src) + masks.append(mask) + pos.append(pos_l) + + query_embeds = None + if not self.two_stage: + query_embeds = self.query_embed.weight + + # send to def-transformer + hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, da_output, memory = self.transformer(srcs, masks, pos, query_embeds) + # hs: lvl, bs, 100, 256 + + outputs_classes = [] + outputs_coords = [] + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.class_embed[lvl](hs[lvl]) + tmp = self.bbox_embed[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = tmp.sigmoid() + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + outputs_class = torch.stack(outputs_classes) + outputs_coord = torch.stack(outputs_coords) + + out['pred_logits_both'] = outputs_class[-1] + out['is_training'] = self.training + + out['cls_means'] = self.cls_means + out['cls_stds'] = self.cls_stds + out['final_classifier'] = self.class_embed[-1] + out['first_classifier'] = self.class_embed[0] + + if self.training and self.da: + B = outputs_class.shape[1] + outputs_class = outputs_class[:, :B//2] + outputs_coord = outputs_coord[:, :B//2] + if self.two_stage: + enc_outputs_class = enc_outputs_class[:B//2] + enc_outputs_coord_unact = enc_outputs_coord_unact[:B//2] + if self.backbone_align: + da_output['backbone'] = torch.cat([self.backbone_D(self.grl(src.flatten(2).transpose(1, 2))) for src in srcs], dim=1) + if self.space_align: + da_output['space_query'] = self.space_D(da_output['space_query']) + if self.channel_align: + da_output['channel_query'] = self.channel_D(da_output['channel_query']) + if self.instance_align: + da_output['instance_query'] = self.instance_D(da_output['instance_query']) + + + out.update({'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], 'object_embedding': hs[-1], 'first_embedding': hs[0]}) + if self.aux_loss: + out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) + + if self.two_stage: + enc_outputs_coord = enc_outputs_coord_unact.sigmoid() + out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord} + if self.training and self.da: + out['da_output'] = da_output + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25, da_gamma=2, from_cfg = None): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + losses: list of all the losses to be applied. See get_loss for list of available losses. + focal_alpha: alpha in Focal Loss + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.focal_alpha = focal_alpha + self.da_gamma = da_gamma + self.from_cfg = from_cfg + self.unk_prob = from_cfg['unk_prob'] + self.bce_loss = nn.BCELoss() + self.pretrain_th = from_cfg['pretrain_th'] + self.std_scaling = from_cfg['std_scaling'] + self.alpha = from_cfg['alpha'] + self.with_openset = from_cfg['with_openset'] + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], + dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) + target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) + target_classes_onehot = target_classes_onehot[:,:,:-1] + + if self.unk_prob > 0 and self.with_openset: + obj_idx = target_classes_onehot.sum(-1) > 0 + tmp = target_classes_onehot[obj_idx] + tmp[:,-1] = self.unk_prob + target_classes_onehot[obj_idx] = tmp + + loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + + losses = {} + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + + src_masks = outputs["pred_masks"] + + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose() + target_masks = target_masks.to(src_masks) + + src_masks = src_masks[src_idx] + # upsample predictions to the target size + src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], + mode="bilinear", align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks[tgt_idx].flatten(1) + + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + + def loss_da(self, outputs, use_focal=False): + + B = outputs.shape[0] + assert B % 2 == 0 + + targets = torch.empty_like(outputs) + targets[:B//2] = 0 + targets[B//2:] = 1 + + loss = F.binary_cross_entropy_with_logits(outputs, targets, reduction='none') + if use_focal: + prob = outputs.sigmoid() + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = loss * ((1 - p_t) ** self.da_gamma) + + return loss.mean() + + def loss_openset(self, outputs, indices, targets): + + ctrs = outputs['cls_means'][:-1] + obj_emb = outputs['object_embedding'] # bs, 100, 256 + + ctrs_labels = torch.arange(self.num_classes-1).to(ctrs.device) + # mtch_idx = self._get_src_permutation_idx(indices) # bs, idx + unmtch_idx = self._get_src_unmatched_permutation_idx(indices, num_query=100) + unmtch_emb = obj_emb[unmtch_idx] + + pair_dis = self.eu_dis(ctrs, ctrs) + top_k_idx = torch.sort(pair_dis, descending=True, dim=-1)[1][:,0] # k far nei + + ctrs_1 = ctrs + ctrs_2 = ctrs_1[top_k_idx] + + ctrs_1_labels = ctrs_labels + ctrs_2_labels = ctrs_labels[top_k_idx] + + motif_embeds_list = [] + angle_list = [] + calss_1 = [] + calss_2 = [] + + for i in range(len(unmtch_emb)): + + vct1 = (ctrs_1 - unmtch_emb[i]) + vct2 = (ctrs_2 - unmtch_emb[i]) + + dis1 = (ctrs_1 - unmtch_emb[i]).norm(dim=-1) + dis2 = (ctrs_2 - unmtch_emb[i]).norm(dim=-1) + delta_dis = (dis1-dis2).abs() + dis_base = (ctrs_1 - ctrs_2).norm(dim=-1) + + angle = torch.nn.functional.cosine_similarity(vct1,vct2) + delta_dis/dis_base + + motif_idx = angle.argmin() + angle_list.append(angle.min().unsqueeze(dim=0)) + + motif_emb = torch.stack([ctrs_1[motif_idx], ctrs_2[motif_idx], unmtch_emb[i]], dim=0) + calss_1.append(ctrs_1_labels[motif_idx].unsqueeze(0)) + calss_2.append(ctrs_2_labels[motif_idx].unsqueeze(0)) + + motif_embeds_list.append(motif_emb.mean(dim=0)[None,:]) + + motif_embeds = torch.cat(motif_embeds_list,dim=0) + neg_angles = -torch.cat(angle_list) + + assert motif_embeds.size(0) == neg_angles.size(0) + + select_idx = neg_angles.topk(self.from_cfg['os_KNN'])[1] + + calss_1 = torch.cat(calss_1)[select_idx].unsqueeze(-1) + calss_2 = torch.cat(calss_2)[select_idx].unsqueeze(-1) + + motif_embeds_topk = motif_embeds[select_idx] + + classifier = outputs['final_classifier'] + motif_prob = classifier(motif_embeds_topk).sigmoid() + + target = torch.full_like(motif_prob, 0.0).detach() + target[:,-1]=1.0 + + loss = self.bce_loss(motif_prob, target) + + # update memory bank + with torch.no_grad(): + ctrs = outputs['cls_means'] + stds = outputs['cls_stds'] + + ema = self.alpha + avg_emb_base = motif_embeds_topk.mean(0) + ctrs[-1] = (1. - ema) * ctrs[-1] + ema * avg_emb_base + + std_emb_base = motif_embeds_topk.std(0) + stds[-1] = (1. - ema) * stds[-1] + ema * std_emb_base + + outputs['cls_means'] = ctrs + outputs['cls_stds'] = stds + + return loss + + def loss_crossdomain(self, outputs, targets, indices): + q_embs = outputs['object_embedding'] # bs, 100, 256 # source: 0: bs//2; target: bs//2: + + B = q_embs.shape[0] + assert B % 2 == 0 + + q_tg_pred = outputs['pred_logits_both'][B//2:] + q_tg_scores = q_tg_pred.view(-1, q_tg_pred.size(-1)).sigmoid() + + ctrs = outputs['cls_means'] + stds = outputs['cls_stds'] + + ctrs_labels = torch.arange(self.num_classes).to(ctrs.device) + + scaling_factor = ctrs.new_ones(ctrs.size(0)) * self.std_scaling + scaling_factor = scaling_factor[:,None] + + a = ctrs + scaling_factor * stds + b = ctrs - scaling_factor * stds + + # add centers + ctrs_1 = torch.cat([a, b], dim=0) + ctrs_1_labels = torch.cat([ctrs_labels, ctrs_labels]) + ctrs_2 = torch.cat([b, a], dim=0) + + q_tg_raw = q_embs[B//2:].view(-1, q_embs.size(-1)) + # score_mask = q_tg_scores.max(-1)[0] > self.pretrain_th + score_mask = q_tg_scores.sum(-1) > self.pretrain_th + q_tg = q_tg_raw[score_mask] + + if len(q_tg)< self.from_cfg['da_KNN']: + return q_tg_scores.sum()*0 + + sr_label = [] + motif_embeds_list =[] + for i in range(len(q_tg)): + vct1 = (ctrs_1 - q_tg[i]) + vct2 = (ctrs_2 - q_tg[i]) + angle = torch.nn.functional.cosine_similarity(vct1,vct2) + motif_idx = angle.argmin(-1) + + ctr_1 = ctrs_1[motif_idx] + ctr_2 = ctrs_2[motif_idx] + + motif_emb = torch.stack([ctr_1, q_tg[i], ctr_2], dim=0) + sr_label.append(ctrs_1_labels[motif_idx].unsqueeze(dim=0)) + motif_embeds_list.append(motif_emb.mean(dim=0)[None,:]) + + motif_embeds = torch.cat(motif_embeds_list,dim=0) + + sr_label = torch.cat(sr_label) + tg_label = self.eu_dis(q_tg, ctrs).argmin(-1) + + prob = outputs['final_classifier'](motif_embeds) + target_motif = torch.zeros(prob.size()).to(prob.device) + + prob_tmp = 0.5 + tg = torch.full_like(sr_label[:,None].float(), prob_tmp) + target_motif.scatter_(1,sr_label[:,None], tg) + target_motif.scatter_(1,tg_label[:,None], tg) + target_motif[target_motif.sum(-1) == 0.5] *=2 + + # loss = sigmoid_focal_loss(prob, target_motif, prob.size(0), alpha=self.focal_alpha, gamma=2) + loss = self.bce_loss(prob.sigmoid(), target_motif.detach()) + return loss + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_src_unmatched_permutation_idx(self, indices, num_query=100): + # permute predictions following indices + bs = len(indices) + queries = torch.arange(num_query) + batch_idx = [] + src_idx = [] + for i, (src, _) in enumerate(indices): + combined = torch.cat( + (queries, src)) + uniques, counts = combined.unique(return_counts=True) + unmatched_box = uniques[counts == 1] + batch_idx.append(torch.full_like(unmatched_box, i)) + src_idx.append(unmatched_box) + batch_idx = torch.cat(batch_idx) + src_idx = torch.cat(src_idx) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def eu_dis(self, a,b,p=2): + return torch.norm(a[:,None]-b,dim=2,p=p) + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + 'masks': self.loss_masks + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + @torch.no_grad() + def update_class_centers(self, outputs, targets, indices, ema= 0.01): + + ctrs = outputs['cls_means'] + stds = outputs['cls_stds'] + q_embs = outputs['object_embedding'] # bs, 100, 256 + + matched_idx = self._get_src_permutation_idx(indices) # bs, idx + matched_q = q_embs[matched_idx] + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + + for i in target_classes_o.unique(): + per_cls_q = matched_q[target_classes_o==i] + avg_emb = per_cls_q.mean(dim=0) + ctrs[i] = (1. - ema) * ctrs[i] + ema * avg_emb.detach() + + if per_cls_q.size(0) > 2: + std_emb = per_cls_q.std(dim=0) + stds[i] = (1. - ema) * stds[i] + ema * std_emb.detach() + + avg_emb_base = ctrs[:-1].mean(0) + ctrs[-1] = (1. - ema) * ctrs[-1] + ema * avg_emb_base + + + std_emb_base = stds[:-1].mean(0) + stds[-1] = (1. - ema) * stds[-1] + ema * std_emb_base + + outputs['cls_means'] = ctrs + outputs['cls_stds'] = stds + return outputs + + def forward(self, samples, outputs, targets, epoch=0): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'} + # Compute all the requested losses + losses = {} + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + if self.training: + outputs = self.update_class_centers(outputs, targets, indices, ema=self.alpha) + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + + for loss in self.losses: + kwargs = {} + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)) + + if self.training and self.from_cfg['with_openset'] and epoch > self.from_cfg['warm_up_epoch']: + losses['loss_openset'] = self.loss_openset(outputs, indices, targets) + + if self.training and self.from_cfg['with_crossdomain'] and epoch > self.from_cfg['warm_up_epoch']: + losses['loss_crossdomain'] = self.loss_crossdomain(outputs, targets, indices) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs['log'] = False + + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + if 'enc_outputs' in outputs: + enc_outputs = outputs['enc_outputs'] + bin_targets = copy.deepcopy(targets) + for bt in bin_targets: + bt['labels'] = torch.zeros_like(bt['labels']) + indices = self.matcher(enc_outputs, bin_targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs['log'] = False + l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_enc': v for k, v in l_dict.items()} + losses.update(l_dict) + + if 'da_output' in outputs: + for k, v in outputs['da_output'].items(): + losses[f'loss_{k}'] = self.loss_da(v, use_focal='query' in k) + + return losses + + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + + @torch.no_grad() + def forward(self, outputs, target_sizes, show_box=False): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + if show_box: + # for qualitative visualization to surpress unk preds + #TODO may be different from the old implementation, need to check + bs, num_q, num_class = prob.size() + unk_mask = prob.argmax(-1) != num_class - 1 + prob[unk_mask] = 0.0 + + topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] + return results + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +def build(cfg): + + device = torch.device(cfg.DEVICE) + backbone = build_backbone(cfg) + transformer = build_deforamble_transformer(cfg) + + from_cfg = dict( + backbone_align=cfg.MODEL.BACKBONE_ALIGN, + space_align=cfg.MODEL.SPACE_ALIGN, + channel_align=cfg.MODEL.CHANNEL_ALIGN, + instance_align=cfg.MODEL.INSTANCE_ALIGN, + da=cfg.DATASET.DA_MODE == 'uda' or cfg.DATASET.DA_MODE == 'aood', + batch_size=cfg.TRAIN.BATCH_SIZE, + with_openset=cfg.AOOD.OPEN_SET.MOTIF_ON, + os_KNN=cfg.AOOD.OPEN_SET.KNN, + pretrain_th=cfg.AOOD.OPEN_SET.TH, + with_crossdomain=cfg.AOOD.CROSS_DOMAIN.MOTIF_ON, + da_KNN=cfg.AOOD.CROSS_DOMAIN.KNN, + unk_prob=cfg.AOOD.OPEN_SET.UNK_PROB, + backbone_adv_lambda=cfg.AOOD.CROSS_DOMAIN.BACKBONE_LAMBDA, + warm_up_epoch=cfg.AOOD.OPEN_SET.WARM_UP, + std_scaling=cfg.AOOD.CROSS_DOMAIN.BETA, + motif_update=cfg.AOOD.OPEN_SET.MOTIF_UPDATE, + alpha=cfg.AOOD.OPEN_SET.ALPHA, + ) + print(from_cfg) + + model = DeformableDETR( + backbone, + transformer, + num_classes=cfg.DATASET.NUM_CLASSES, + num_queries=cfg.MODEL.NUM_QUERIES, + num_feature_levels=cfg.MODEL.NUM_FEATURE_LEVELS, + aux_loss=cfg.LOSS.AUX_LOSS, + with_box_refine=cfg.MODEL.WITH_BOX_REFINE, + two_stage=cfg.MODEL.TWO_STAGE, + from_cfg = from_cfg, + ) + if cfg.MODEL.MASKS: + model = DETRsegm(model, freeze_detr=(cfg.MODEL.FROZEN_WEIGHTS is not None)) + matcher = build_matcher(cfg) + weight_dict = {'loss_ce': cfg.LOSS.CLS_LOSS_COEF, 'loss_bbox': cfg.LOSS.BBOX_LOSS_COEF} + weight_dict['loss_giou'] = cfg.LOSS.GIOU_LOSS_COEF + if cfg.MODEL.MASKS: + weight_dict["loss_mask"] = cfg.LOSS.MASK_LOSS_COEF + weight_dict["loss_dice"] = cfg.LOSS.DICE_LOSS_COEF + # TODO this is a hack + if cfg.LOSS.AUX_LOSS: + aux_weight_dict = {} + for i in range(cfg.MODEL.DEC_LAYERS - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) + aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + weight_dict['loss_backbone'] = cfg.LOSS.BACKBONE_LOSS_COEF + weight_dict['loss_space_query'] = cfg.LOSS.SPACE_QUERY_LOSS_COEF + weight_dict['loss_channel_query'] = cfg.LOSS.CHANNEL_QUERY_LOSS_COEF + weight_dict['loss_instance_query'] = cfg.LOSS.INSTANCE_QUERY_LOSS_COEF + + weight_dict['loss_crossdomain'] = cfg.AOOD.CROSS_DOMAIN.MOTIF_LOSS_COEF + weight_dict['loss_openset'] = cfg.AOOD.OPEN_SET.MOTIF_LOSS_COEF + + losses = ['labels', 'boxes'] + if cfg.MODEL.MASKS: + losses += ["masks"] + # num_classes, matcher, weight_dict, losses, focal_alpha=0.25 + criterion = SetCriterion( + cfg.DATASET.NUM_CLASSES, + matcher, + weight_dict, + losses, + focal_alpha=cfg.LOSS.FOCAL_ALPHA, + da_gamma=cfg.LOSS.DA_GAMMA, + from_cfg=from_cfg, + + ) + criterion.to(device) + postprocessors = {'bbox': PostProcess()} + if cfg.MODEL.MASKS: + postprocessors['segm'] = PostProcessSegm() + if cfg.DATASET.DATASET_FILE == "coco_panoptic": + is_thing_map = {i: i <= 90 for i in range(201)} + postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85) + + return model, criterion, postprocessors diff --git a/models/ops/functions/__init__.py b/models/ops/functions/__init__.py new file mode 100644 index 0000000..8a2197b --- /dev/null +++ b/models/ops/functions/__init__.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn_func import MSDeformAttnFunction + diff --git a/models/ops/functions/ms_deform_attn_func.py b/models/ops/functions/ms_deform_attn_func.py new file mode 100644 index 0000000..8c5df8c --- /dev/null +++ b/models/ops/functions/ms_deform_attn_func.py @@ -0,0 +1,61 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +import MultiScaleDeformableAttention as MSDA + + +class MSDeformAttnFunction(Function): + @staticmethod + def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): + ctx.im2col_step = im2col_step + output = MSDA.ms_deform_attn_forward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) + ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = \ + MSDA.ms_deform_attn_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, + mode='bilinear', padding_mode='zeros', align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) + return output.transpose(1, 2).contiguous() diff --git a/models/ops/make.sh b/models/ops/make.sh new file mode 100644 index 0000000..106b685 --- /dev/null +++ b/models/ops/make.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +python setup.py build install diff --git a/models/ops/modules/__init__.py b/models/ops/modules/__init__.py new file mode 100644 index 0000000..f82cb1a --- /dev/null +++ b/models/ops/modules/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn import MSDeformAttn diff --git a/models/ops/modules/ms_deform_attn.py b/models/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000..663d64a --- /dev/null +++ b/models/ops/modules/ms_deform_attn.py @@ -0,0 +1,115 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import math + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ + +from ..functions import MSDeformAttnFunction + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n-1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation.") + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.) + constant_(self.attention_weights.bias.data, 0.) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + else: + raise ValueError( + 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) + output = MSDeformAttnFunction.apply( + value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) + output = self.output_proj(output) + return output diff --git a/models/ops/setup.py b/models/ops/setup.py new file mode 100644 index 0000000..a0131bc --- /dev/null +++ b/models/ops/setup.py @@ -0,0 +1,71 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +import os +import glob + +import torch + +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +from setuptools import find_packages +from setuptools import setup + +requirements = ["torch", "torchvision"] + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": []} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + raise NotImplementedError('Cuda is not availabel') + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "MultiScaleDeformableAttention", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + +setup( + name="MultiScaleDeformableAttention", + version="1.0", + author="Weijie Su", + url="https://github.com/fundamentalvision/Deformable-DETR", + description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", + packages=find_packages(exclude=("configs", "tests",)), + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/models/ops/src/cpu/ms_deform_attn_cpu.cpp b/models/ops/src/cpu/ms_deform_attn_cpu.cpp new file mode 100644 index 0000000..e1bf854 --- /dev/null +++ b/models/ops/src/cpu/ms_deform_attn_cpu.cpp @@ -0,0 +1,41 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + diff --git a/models/ops/src/cpu/ms_deform_attn_cpu.h b/models/ops/src/cpu/ms_deform_attn_cpu.h new file mode 100644 index 0000000..81b7b58 --- /dev/null +++ b/models/ops/src/cpu/ms_deform_attn_cpu.h @@ -0,0 +1,33 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + + diff --git a/models/ops/src/cuda/ms_deform_attn_cuda.cu b/models/ops/src/cuda/ms_deform_attn_cuda.cu new file mode 100644 index 0000000..d6d5836 --- /dev/null +++ b/models/ops/src/cuda/ms_deform_attn_cuda.cu @@ -0,0 +1,153 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include "cuda/ms_deform_im2col_cuda.cuh" + +#include +#include +#include +#include + + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} \ No newline at end of file diff --git a/models/ops/src/cuda/ms_deform_attn_cuda.h b/models/ops/src/cuda/ms_deform_attn_cuda.h new file mode 100644 index 0000000..c7ae53f --- /dev/null +++ b/models/ops/src/cuda/ms_deform_attn_cuda.h @@ -0,0 +1,30 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + diff --git a/models/ops/src/cuda/ms_deform_im2col_cuda.cuh b/models/ops/src/cuda/ms_deform_im2col_cuda.cuh new file mode 100644 index 0000000..6bc2acb --- /dev/null +++ b/models/ops/src/cuda/ms_deform_im2col_cuda.cuh @@ -0,0 +1,1327 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} \ No newline at end of file diff --git a/models/ops/src/ms_deform_attn.h b/models/ops/src/ms_deform_attn.h new file mode 100644 index 0000000..ac0ef2e --- /dev/null +++ b/models/ops/src/ms_deform_attn.h @@ -0,0 +1,62 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once + +#include "cpu/ms_deform_attn_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/ms_deform_attn_cuda.h" +#endif + + +at::Tensor +ms_deform_attn_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +ms_deform_attn_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + diff --git a/models/ops/src/vision.cpp b/models/ops/src/vision.cpp new file mode 100644 index 0000000..2201f63 --- /dev/null +++ b/models/ops/src/vision.cpp @@ -0,0 +1,16 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "ms_deform_attn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); +} diff --git a/models/ops/test.py b/models/ops/test.py new file mode 100644 index 0000000..8dbf6d5 --- /dev/null +++ b/models/ops/test.py @@ -0,0 +1,89 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import time +import torch +import torch.nn as nn +from torch.autograd import gradcheck + +from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch + + +N, M, D = 1, 2, 2 +Lq, L, P = 2, 2, 2 +shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() +level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) +S = sum([(H*W).item() for H, W in shapes]) + + +torch.manual_seed(3) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_double(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +@torch.no_grad() +def check_forward_equal_with_pytorch_float(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): + + value = torch.rand(N, S, M, channels).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + func = MSDeformAttnFunction.apply + + value.requires_grad = grad_value + sampling_locations.requires_grad = grad_sampling_loc + attention_weights.requires_grad = grad_attn_weight + + gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) + + print(f'* {gradok} check_gradient_numerical(D={channels})') + + +if __name__ == '__main__': + check_forward_equal_with_pytorch_double() + check_forward_equal_with_pytorch_float() + + for channels in [30, 32, 64, 71, 1025, 2048, 3096]: + check_gradient_numerical(channels, True, True, True) + + + diff --git a/models/ow_detr.py b/models/ow_detr.py new file mode 100644 index 0000000..dd4be54 --- /dev/null +++ b/models/ow_detr.py @@ -0,0 +1,763 @@ +# ------------------------------------------------------------------------ +# Novel Scenes & Classes: Towards Adaptive Open-set Object Detection +# Modified by Wuyang Li +# ------------------------------------------------------------------------ +# OW-DETR: Open-world Detection Transformer +# Akshita Gupta^, Sanath Narayan^, K J Joseph, Salman Khan, Fahad Shahbaz Khan, Mubarak Shah +# https://arxiv.org/pdf/2112.01513.pdf +# ------------------------------------------------------------------------ +# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +Deformable DETR model and criterion classes. +""" +import torch +import torch.nn.functional as F +from torch import nn +import math +import pickle +from util import box_ops +from util.misc import (NestedTensor, nested_tensor_from_tensor_list, + accuracy, get_world_size, interpolate, + is_dist_avail_and_initialized, inverse_sigmoid) +from .backbone import build_backbone +from .matcher import build_matcher +from .segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm, + dice_loss, sigmoid_focal_loss) #, sigmoid_focal_loss_CA) +from .deformable_transformer import build_deforamble_transformer +import copy +import heapq +import operator +import os +from copy import deepcopy +from .utils import GradientReversal +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DeformableDETR(nn.Module): + """ This is the Deformable DETR module that performs object detection """ + def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels, + aux_loss=True, with_box_refine=False, two_stage=False, + unmatched_boxes=False, novelty_cls=False, featdim=1024): + """ Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + with_box_refine: iterative bounding box refinement + two_stage: two-stage Deformable DETR + """ + super().__init__() + + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + + self.class_embed = nn.Linear(hidden_dim, num_classes) + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + self.num_feature_levels = num_feature_levels + + ### + self.featdim = featdim + self.unmatched_boxes = unmatched_boxes + self.novelty_cls = novelty_cls + if self.novelty_cls: + self.nc_class_embed = nn.Linear(hidden_dim, 1) + + if not two_stage: + self.query_embed = nn.Embedding(num_queries, hidden_dim*2) + if num_feature_levels > 1: + num_backbone_outs = len(backbone.strides) + input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = backbone.num_channels[_] + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + )) + for _ in range(num_feature_levels - num_backbone_outs): + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(32, hidden_dim), + )) + in_channels = hidden_dim + self.input_proj = nn.ModuleList(input_proj_list) + else: + self.input_proj = nn.ModuleList([ + nn.Sequential( + nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + )]) + self.backbone = backbone + self.aux_loss = aux_loss + self.with_box_refine = with_box_refine + self.two_stage = two_stage + + self.grl = GradientReversal(lambda_=1.0) + self.backbone_D = MLP(hidden_dim, hidden_dim, 1, 3) + for layer in self.backbone_D.layers: + nn.init.xavier_uniform_(layer.weight, gain=1) + nn.init.constant_(layer.bias, 0) + + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(num_classes) * bias_value + if self.novelty_cls: + self.nc_class_embed.bias.data = torch.ones(1) * bias_value + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + for proj in self.input_proj: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + + # if two-stage, the last class_embed and bbox_embed is for region proposal generation + num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers + if with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) + # hack implementation for iterative bounding box refinement + self.transformer.decoder.bbox_embed = self.bbox_embed + else: + nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) + self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) + + if self.novelty_cls: + self.nc_class_embed = nn.ModuleList([self.nc_class_embed for _ in range(num_pred)]) + + self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) + self.transformer.decoder.bbox_embed = None + if two_stage: + # hack implementation for two-stage + self.transformer.decoder.class_embed = self.class_embed + for box_embed in self.bbox_embed: + nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) + + def forward(self, samples: NestedTensor): + """ The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + srcs = [] + masks = [] + if self.featdim == 512: + dim_index = 0 + elif self.featdim == 1024: + dim_index = 1 + else: + dim_index = 2 + + for l, feat in enumerate(features): + src, mask = feat.decompose() + ## [Info] extracting the resnet features which are used for selecting unmatched queries + if self.unmatched_boxes: + if l == dim_index: + resnet_1024_feature = src.clone() # 2X1024X61X67 + else: + resnet_1024_feature = None + srcs.append(self.input_proj[l](src)) + masks.append(mask) + assert mask is not None + if self.num_feature_levels > len(srcs): + _len_srcs = len(srcs) + for l in range(_len_srcs, self.num_feature_levels): + if l == _len_srcs: + src = self.input_proj[l](features[-1].tensors) + else: + src = self.input_proj[l](srcs[-1]) + m = samples.mask + mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] + pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) + srcs.append(src) + masks.append(mask) + pos.append(pos_l) + + query_embeds = None + if not self.two_stage: + query_embeds = self.query_embed.weight + + # hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = self.transformer(srcs, masks, pos, query_embeds) + + hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, da_output, memory = self.transformer(srcs, masks, pos, query_embeds) + + outputs_classes = [] + outputs_coords = [] + outputs_classes_nc = [] + + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + + reference = inverse_sigmoid(reference) + outputs_class = self.class_embed[lvl](hs[lvl]) + + ## novelty classification + if self.novelty_cls: + outputs_class_nc = self.nc_class_embed[lvl](hs[lvl]) + + tmp = self.bbox_embed[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = tmp.sigmoid() + outputs_classes.append(outputs_class) + if self.novelty_cls: + outputs_classes_nc.append(outputs_class_nc) + outputs_coords.append(outputs_coord) + + outputs_class = torch.stack(outputs_classes) + outputs_coord = torch.stack(outputs_coords) + if self.novelty_cls: + output_class_nc = torch.stack(outputs_classes_nc) + out= {} + if self.training: + B = outputs_class.shape[1] + outputs_class = outputs_class[:, :B//2] + outputs_coord = outputs_coord[:, :B//2] + output_class_nc = output_class_nc[:, :B//2] + if self.two_stage: + enc_outputs_class = enc_outputs_class[:B//2] + enc_outputs_coord_unact = enc_outputs_coord_unact[:B//2] + + out['da_backbone'] = torch.cat([self.backbone_D(self.grl(src.flatten(2).transpose(1, 2))) for src in srcs], dim=1) + + out.update({'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], 'resnet_1024_feat': resnet_1024_feature}) + + if self.novelty_cls: + out['pred_nc_logits'] = output_class_nc[-1] + + if self.aux_loss: + out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, output_class_nc=None) + if self.novelty_cls: + out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, output_class_nc=output_class_nc) + if self.two_stage: + enc_outputs_coord = enc_outputs_coord_unact.sigmoid() + out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord} + + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord, output_class_nc=None): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + # import pdb;pdb.set_trace() + if output_class_nc is not None: + xx = [{'pred_logits': a, 'pred_nc_logits': c, 'pred_boxes': b} + for a, c, b in zip(outputs_class[:-1], output_class_nc[:-1], outputs_coord[:-1])] + else: + xx = [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + return xx + + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, args, num_classes, matcher, weight_dict, losses, invalid_cls_logits, focal_alpha=0.25): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + losses: list of all the losses to be applied. See get_loss for list of available losses. + focal_alpha: alpha in Focal Loss + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.focal_alpha = focal_alpha + self.nc_epoch = -1 + self.invalid_cls_logits = invalid_cls_logits + self.unmatched_boxes = True + self.top_unk = 5 + self.num_seen_classes = 4 + + + def loss_NC_labels(self, outputs, targets, indices, num_boxes, current_epoch, owod_targets, owod_indices, log=True): + """Novelty classification loss + target labels will contain class as 1 + owod_indices -> indices combining matched indices + psuedo labeled indices + owod_targets -> targets combining GT targets + psuedo labeled unknown targets + target_classes_o -> contains all 1's + """ + assert 'pred_nc_logits' in outputs + src_logits = outputs['pred_nc_logits'] + + idx = self._get_src_permutation_idx(owod_indices) + target_classes_o = torch.cat([torch.full_like(t["labels"][J], 0) for t, (_, J) in zip(owod_targets, owod_indices)]) + target_classes = torch.full(src_logits.shape[:2], 1, dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) + target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) + + target_classes_onehot = target_classes_onehot[:,:,:-1] + loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] + + losses = {'loss_NC': loss_ce} + return losses + + def loss_labels(self, outputs, targets, indices, num_boxes, current_epoch, owod_targets, owod_indices, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + ## comment lines from 317-320 when running for oracle settings + temp_src_logits = outputs['pred_logits'].clone() + temp_src_logits[:,:, self.invalid_cls_logits] = -10e10 + src_logits = temp_src_logits + + if self.unmatched_boxes: + idx = self._get_src_permutation_idx(owod_indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(owod_targets, owod_indices)]) + else: + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + + target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], + dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) + target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) + target_classes_onehot = target_classes_onehot[:,:,:-1] + loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] + + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes, current_epoch, owod_targets, owod_indices): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + temp_pred_logits = outputs['pred_logits'].clone() + temp_pred_logits[:,:, self.invalid_cls_logits] = -10e10 + pred_logits = temp_pred_logits + + device = pred_logits.device + if self.unmatched_boxes: + tgt_lengths = torch.as_tensor([len(v["labels"]) for v in owod_targets], device=device) + else: + tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes, current_epoch, owod_targets, owod_indices): + # def loss_boxes(self, outputs, targets, indices, num_boxes, current_epoch, owod_targets, owod_indices, ca_owod_targets, ca_owod_indices): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. + """ + + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + + losses = {} + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes, current_epoch, owod_targets, owod_indices): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + + src_masks = outputs["pred_masks"] + + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose() + target_masks = target_masks.to(src_masks) + + src_masks = src_masks[src_idx] + # upsample predictions to the target size + src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], + mode="bilinear", align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks[tgt_idx].flatten(1) + + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + + def save_dict(self, di_, filename_): + with open(filename_, 'wb') as f: + pickle.dump(di_, f) + + def load_dict(self, filename_): + with open(filename_, 'rb') as f: + ret_dict = pickle.load(f) + return ret_dict + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_src_single_permutation_idx(self, indices, index): + ## Only need the src query index selection from this function for attention feature selection + batch_idx = [torch.full_like(src, i) for i, src in enumerate(indices)][0] + src_idx = indices[0] + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, epoch, owod_targets, owod_indices, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'NC_labels': self.loss_NC_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + 'masks': self.loss_masks + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, epoch, owod_targets, owod_indices, **kwargs) + + def loss_da(self, outputs, use_focal=False): + + B = outputs.shape[0] + assert B % 2 == 0 + + targets = torch.empty_like(outputs) + targets[:B//2] = 0 + targets[B//2:] = 1 + + loss = F.binary_cross_entropy_with_logits(outputs, targets, reduction='none') + + if use_focal: + prob = outputs.sigmoid() + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = loss * ((1 - p_t) ** self.da_gamma) + + return loss.mean() + def forward(self, samples, outputs, targets, epoch): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + # if self.nc_epoch > 0: + # loss_epoch = 9 + # else: + # loss_epoch = 0 + loss_epoch = -1 + + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'} + indices = self.matcher(outputs_without_aux, targets) + + owod_targets = deepcopy(targets) + owod_indices = deepcopy(indices) + + + owod_outputs = outputs_without_aux.copy() + owod_device = owod_outputs["pred_boxes"].device + + if self.unmatched_boxes and epoch >= loss_epoch and self.training: + ## get pseudo unmatched boxes from this section + res_feat = torch.mean(outputs['resnet_1024_feat'], 1) + queries = torch.arange(outputs['pred_logits'].shape[1]) + for i in range(len(indices)): + combined = torch.cat((queries, self._get_src_single_permutation_idx(indices[i], i)[-1])) ## need to fix the indexing + uniques, counts = combined.unique(return_counts=True) + unmatched_indices = uniques[counts == 1] + boxes = outputs_without_aux['pred_boxes'][i] #[unmatched_indices,:] + img = samples.tensors[i].cpu().permute(1,2,0).numpy() + h, w = img.shape[:-1] + img_w = torch.tensor(w, device=owod_device) + img_h = torch.tensor(h, device=owod_device) + unmatched_boxes = box_ops.box_cxcywh_to_xyxy(boxes) + unmatched_boxes = unmatched_boxes * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(owod_device) + means_bb = torch.zeros(queries.shape[0]).to(unmatched_boxes) + bb = unmatched_boxes + for j, _ in enumerate(means_bb): + if j in unmatched_indices: + upsaple = nn.Upsample(size=(img_h,img_w), mode='bilinear') + img_feat = upsaple(res_feat[i].unsqueeze(0).unsqueeze(0)) + img_feat = img_feat.squeeze(0).squeeze(0) + xmin = bb[j,:][0].long() + ymin = bb[j,:][1].long() + xmax = bb[j,:][2].long() + ymax = bb[j,:][3].long() + means_bb[j] = torch.mean(img_feat[ymin:ymax,xmin:xmax]) + if torch.isnan(means_bb[j]): + means_bb[j] = -10e10 + else: + means_bb[j] = -10e10 + + _, topk_inds = torch.topk(means_bb, self.top_unk) + topk_inds = torch.as_tensor(topk_inds) + + topk_inds = topk_inds.cpu() + + unk_label = torch.as_tensor([self.num_classes-1], device=owod_device) + owod_targets[i]['labels'] = torch.cat((owod_targets[i]['labels'], unk_label.repeat_interleave(self.top_unk))) + owod_indices[i] = (torch.cat((owod_indices[i][0], topk_inds)), torch.cat((owod_indices[i][1], (owod_targets[i]['labels'] == unk_label).nonzero(as_tuple=True)[0].cpu()))) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + kwargs = {} + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, epoch, owod_targets, owod_indices, **kwargs)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets) + + owod_targets = deepcopy(targets) + owod_indices = deepcopy(indices) + + aux_owod_outputs = aux_outputs.copy() + owod_device = aux_owod_outputs["pred_boxes"].device + + if self.unmatched_boxes and epoch >= loss_epoch and self.training: + ## get pseudo unmatched boxes from this section + res_feat = torch.mean(outputs['resnet_1024_feat'], 1) #2 X 67 X 50 + queries = torch.arange(aux_owod_outputs['pred_logits'].shape[1]) + for i in range(len(indices)): + combined = torch.cat((queries, self._get_src_single_permutation_idx(indices[i], i)[-1])) ## need to fix the indexing + uniques, counts = combined.unique(return_counts=True) + unmatched_indices = uniques[counts == 1] + boxes = aux_owod_outputs['pred_boxes'][i] #[unmatched_indices,:] + img = samples.tensors[i].cpu().permute(1,2,0).numpy() + h, w = img.shape[:-1] + img_w = torch.tensor(w, device=owod_device) + img_h = torch.tensor(h, device=owod_device) + unmatched_boxes = box_ops.box_cxcywh_to_xyxy(boxes) + unmatched_boxes = unmatched_boxes * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(owod_device) + means_bb = torch.zeros(queries.shape[0]).to(unmatched_boxes) #torch.zeros(unmatched_boxes.shape[0]) + bb = unmatched_boxes + ## [INFO]: iterating over the full list of boxes and then selecting the unmatched ones + for j, _ in enumerate(means_bb): + if j in unmatched_indices: + upsaple = nn.Upsample(size=(img_h,img_w), mode='bilinear') + img_feat = upsaple(res_feat[i].unsqueeze(0).unsqueeze(0)) + img_feat = img_feat.squeeze(0).squeeze(0) + xmin = bb[j,:][0].long() + ymin = bb[j,:][1].long() + xmax = bb[j,:][2].long() + ymax = bb[j,:][3].long() + means_bb[j] = torch.mean(img_feat[ymin:ymax,xmin:xmax]) + if torch.isnan(means_bb[j]): + means_bb[j] = -10e10 + else: + means_bb[j] = -10e10 + + _, topk_inds = torch.topk(means_bb, self.top_unk) + topk_inds = torch.as_tensor(topk_inds) + + topk_inds = topk_inds.cpu() + unk_label = torch.as_tensor([self.num_classes-1], device=owod_device) + owod_targets[i]['labels'] = torch.cat((owod_targets[i]['labels'], unk_label.repeat_interleave(self.top_unk))) + owod_indices[i] = (torch.cat((owod_indices[i][0], topk_inds)), torch.cat((owod_indices[i][1], (owod_targets[i]['labels'] == unk_label).nonzero(as_tuple=True)[0].cpu()))) + + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs['log'] = False + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, epoch, owod_targets, owod_indices, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + if 'enc_outputs' in outputs: + enc_outputs = outputs['enc_outputs'] + bin_targets = copy.deepcopy(targets) + for bt in bin_targets: + bt['labels'] = torch.zeros_like(bt['labels']) + indices = self.matcher(enc_outputs, bin_targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs['log'] = False + l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_enc': v for k, v in l_dict.items()} + losses.update(l_dict) + + # import ipdb; ipdb.set_trace() + if 'da_backbone' in outputs: + # for k, v in outputs['da_output'].items(): + losses['loss_backbone'] = self.loss_da(outputs['da_backbone'], use_focal=False) + # print(losses) + return losses + + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + + @torch.no_grad() + def forward(self, outputs, target_sizes): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] + + return results + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +# def build_ow_detr(cfg): +def build(cfg): + num_classes = 11 + # num_classes = 4 + print(num_classes) + # if args.dataset == "coco_panoptic": + # num_classes = 250 + device = torch.device(cfg.DEVICE) + + backbone = build_backbone(cfg) + + transformer = build_deforamble_transformer(cfg) + + prev_intro_cls = 0 + curr_intro_cls = 10 + # curr_intro_cls = 3 + seen_classes = prev_intro_cls + curr_intro_cls + invalid_cls_logits = list(range(seen_classes, num_classes-1)) #unknown class indx will not be included in the invalid class range + print("Invalid class rangw: " + str(invalid_cls_logits)) + + model = DeformableDETR( + backbone, + transformer, + num_classes=cfg.DATASET.NUM_CLASSES, + num_queries=cfg.MODEL.NUM_QUERIES, + num_feature_levels=cfg.MODEL.NUM_FEATURE_LEVELS, + aux_loss=cfg.LOSS.AUX_LOSS, + with_box_refine=cfg.MODEL.WITH_BOX_REFINE, + two_stage=cfg.MODEL.TWO_STAGE, + unmatched_boxes=True, + novelty_cls=True, + featdim=1024, + ) + if cfg.MODEL.MASKS: + model = DETRsegm(model, freeze_detr=(cfg.MODEL.FROZEN_WEIGHTS is not None)) + matcher = build_matcher(cfg) + weight_dict = {'loss_ce': cfg.LOSS.CLS_LOSS_COEF, 'loss_bbox': cfg.LOSS.BBOX_LOSS_COEF} + weight_dict['loss_giou'] = cfg.LOSS.GIOU_LOSS_COEF + if cfg.MODEL.MASKS: + weight_dict["loss_mask"] = cfg.LOSS.MASK_LOSS_COEF + weight_dict["loss_dice"] = cfg.LOSS.DICE_LOSS_COEF + + weight_dict['loss_backbone'] = cfg.LOSS.BACKBONE_LOSS_COEF + weight_dict['loss_NC'] = 0.1 + + # TODO this is a hack + if cfg.LOSS.AUX_LOSS: + aux_weight_dict = {} + for i in range(cfg.MODEL.DEC_LAYERS - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) + aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ['labels', 'NC_labels', 'boxes', 'cardinality'] + # if args.masks: + # losses += ["masks"] + criterion = SetCriterion(cfg, num_classes, matcher, weight_dict, losses, invalid_cls_logits, focal_alpha=cfg.LOSS.FOCAL_ALPHA) + criterion.to(device) + postprocessors = {'bbox': PostProcess()} + return model, criterion, postprocessors \ No newline at end of file diff --git a/models/position_encoding.py b/models/position_encoding.py new file mode 100644 index 0000000..fa09903 --- /dev/null +++ b/models/position_encoding.py @@ -0,0 +1,99 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from util.misc import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(cfg): + N_steps = cfg.MODEL.HIDDEN_DIM // 2 + if cfg.MODEL.POSITION_EMBEDDING in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif cfg.MODEL.POSITION_EMBEDDING in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {cfg.MODEL.POSITION_EMBEDDING}") + + return position_embedding diff --git a/models/segmentation.py b/models/segmentation.py new file mode 100644 index 0000000..b1460f0 --- /dev/null +++ b/models/segmentation.py @@ -0,0 +1,371 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +This file provides the definition of the convolutional heads used to predict masks, as well as the losses +""" +import io +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + +import util.box_ops as box_ops +from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list + +try: + from panopticapi.utils import id2rgb, rgb2id +except ImportError: + pass + + +class DETRsegm(nn.Module): + def __init__(self, detr, freeze_detr=False): + super().__init__() + self.detr = detr + + if freeze_detr: + for p in self.parameters(): + p.requires_grad_(False) + + hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead + self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0) + self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim) + + def forward(self, samples: NestedTensor): + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.detr.backbone(samples) + + bs = features[-1].tensors.shape[0] + + src, mask = features[-1].decompose() + src_proj = self.detr.input_proj(src) + hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1]) + + outputs_class = self.detr.class_embed(hs) + outputs_coord = self.detr.bbox_embed(hs).sigmoid() + out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} + if self.detr.aux_loss: + out["aux_outputs"] = [ + {"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) + ] + + # FIXME h_boxes takes the last one computed, keep this in mind + bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) + + seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors]) + outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) + + out["pred_masks"] = outputs_seg_masks + return out + + +class MaskHeadSmallConv(nn.Module): + """ + Simple convolutional head, using group norm. + Upsampling is done using a FPN approach + """ + + def __init__(self, dim, fpn_dims, context_dim): + super().__init__() + + inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] + self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) + self.gn1 = torch.nn.GroupNorm(8, dim) + self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) + self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) + self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) + self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) + self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) + self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) + self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) + self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) + self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) + + self.dim = dim + + self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) + self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) + self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.constant_(m.bias, 0) + + def forward(self, x, bbox_mask, fpns): + def expand(tensor, length): + return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) + + x = torch.cat([expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) + + x = self.lay1(x) + x = self.gn1(x) + x = F.relu(x) + x = self.lay2(x) + x = self.gn2(x) + x = F.relu(x) + + cur_fpn = self.adapter1(fpns[0]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay3(x) + x = self.gn3(x) + x = F.relu(x) + + cur_fpn = self.adapter2(fpns[1]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay4(x) + x = self.gn4(x) + x = F.relu(x) + + cur_fpn = self.adapter3(fpns[2]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay5(x) + x = self.gn5(x) + x = F.relu(x) + + x = self.out_lay(x) + return x + + +class MHAttentionMap(nn.Module): + """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" + + def __init__(self, query_dim, hidden_dim, num_heads, dropout=0, bias=True): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.dropout = nn.Dropout(dropout) + + self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + + nn.init.zeros_(self.k_linear.bias) + nn.init.zeros_(self.q_linear.bias) + nn.init.xavier_uniform_(self.k_linear.weight) + nn.init.xavier_uniform_(self.q_linear.weight) + self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 + + def forward(self, q, k, mask=None): + q = self.q_linear(q) + k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) + qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) + kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) + weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) + + if mask is not None: + weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) + weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights) + weights = self.dropout(weights) + return weights + + +def dice_loss(inputs, targets, num_boxes): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_boxes + + +def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_boxes + + +class PostProcessSegm(nn.Module): + def __init__(self, threshold=0.5): + super().__init__() + self.threshold = threshold + + @torch.no_grad() + def forward(self, results, outputs, orig_target_sizes, max_target_sizes): + assert len(orig_target_sizes) == len(max_target_sizes) + max_h, max_w = max_target_sizes.max(0)[0].tolist() + outputs_masks = outputs["pred_masks"].squeeze(2) + outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) + outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() + + for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): + img_h, img_w = t[0], t[1] + results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) + results[i]["masks"] = F.interpolate( + results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" + ).byte() + + return results + + +class PostProcessPanoptic(nn.Module): + """This class converts the output of the model to the final panoptic result, in the format expected by the + coco panoptic API """ + + def __init__(self, is_thing_map, threshold=0.85): + """ + Parameters: + is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether + the class is a thing (True) or a stuff (False) class + threshold: confidence threshold: segments with confidence lower than this will be deleted + """ + super().__init__() + self.threshold = threshold + self.is_thing_map = is_thing_map + + def forward(self, outputs, processed_sizes, target_sizes=None): + """ This function computes the panoptic prediction from the model's predictions. + Parameters: + outputs: This is a dict coming directly from the model. See the model doc for the content. + processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the + model, ie the size after data augmentation but before batching. + target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size + of each prediction. If left to None, it will default to the processed_sizes + """ + if target_sizes is None: + target_sizes = processed_sizes + assert len(processed_sizes) == len(target_sizes) + out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"] + assert len(out_logits) == len(raw_masks) == len(target_sizes) + preds = [] + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.cpu().tolist()) + + for cur_logits, cur_masks, cur_boxes, size, target_size in zip( + out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes + ): + # we filter empty queries and detection below threshold + scores, labels = cur_logits.softmax(-1).max(-1) + keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold) + cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) + cur_scores = cur_scores[keep] + cur_classes = cur_classes[keep] + cur_masks = cur_masks[keep] + cur_masks = interpolate(cur_masks[None], to_tuple(size), mode="bilinear").squeeze(0) + cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) + + h, w = cur_masks.shape[-2:] + assert len(cur_boxes) == len(cur_classes) + + # It may be that we have several predicted masks for the same stuff class. + # In the following, we track the list of masks ids for each stuff class (they are merged later on) + cur_masks = cur_masks.flatten(1) + stuff_equiv_classes = defaultdict(lambda: []) + for k, label in enumerate(cur_classes): + if not self.is_thing_map[label.item()]: + stuff_equiv_classes[label.item()].append(k) + + def get_ids_area(masks, scores, dedup=False): + # This helper function creates the final panoptic segmentation image + # It also returns the area of the masks that appears on the image + + m_id = masks.transpose(0, 1).softmax(-1) + + if m_id.shape[-1] == 0: + # We didn't detect any mask :( + m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) + else: + m_id = m_id.argmax(-1).view(h, w) + + if dedup: + # Merge the masks corresponding to the same stuff class + for equiv in stuff_equiv_classes.values(): + if len(equiv) > 1: + for eq_id in equiv: + m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) + + final_h, final_w = to_tuple(target_size) + + seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) + seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) + + np_seg_img = ( + torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy() + ) + m_id = torch.from_numpy(rgb2id(np_seg_img)) + + area = [] + for i in range(len(scores)): + area.append(m_id.eq(i).sum().item()) + return area, seg_img + + area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) + if cur_classes.numel() > 0: + # We know filter empty masks as long as we find some + while True: + filtered_small = torch.as_tensor( + [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device + ) + if filtered_small.any().item(): + cur_scores = cur_scores[~filtered_small] + cur_classes = cur_classes[~filtered_small] + cur_masks = cur_masks[~filtered_small] + area, seg_img = get_ids_area(cur_masks, cur_scores) + else: + break + + else: + cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) + + segments_info = [] + for i, a in enumerate(area): + cat = cur_classes[i].item() + segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a}) + del cur_classes + + with io.BytesIO() as out: + seg_img.save(out, format="PNG") + predictions = {"png_string": out.getvalue(), "segments_info": segments_info} + preds.append(predictions) + return preds diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000..c298d89 --- /dev/null +++ b/models/utils.py @@ -0,0 +1,124 @@ +# ---------------------------------------------- +# Created by Wei-Jie Huang +# ---------------------------------------------- + + +import torch +from torch import nn +import torch.nn.functional as F + + +def remove_mask_and_warp(src, pos, padding_mask, level_start_index, spatial_shapes): + """ Removes padding mask in sequence and warps each level of tokens into fixed-sized sequences. + + Args: + src, pos (batch_size, sequence_length, d_model): patch tokens and position encodings + padding_mask (batch_size, sequence_length): key padding mask + level_start_index (num_feature_levels): start index of each feature level + spatial_shapes (num_feature_levels, 2): spatial shape (H, W) of each feature level + + Returns: + src_warped, pos_warped (batch_size, num_feature_levels, C, C): warped patch tokens and + position encodings. The last two dimensions indicate sequence length (i.e., H*W) and model + dimension, respectively. + """ + + B, _, C = src.shape + sqrt_C = int(C ** 0.5) + src_warped = [] + pos_warped = [] + for start, shape in zip(level_start_index, spatial_shapes): + H, W = shape + s = src[:, start:start+H*W].view(B, H, W, C).permute(0, 3, 1, 2) + p = pos[:, start:start+H*W].view(B, H, W, C).permute(0, 3, 1, 2) + m = padding_mask[:, start:start+H*W].view(B, H, W) + + not_m = ~m + real_H = not_m.sum(1).max(1).values + real_W = not_m.sum(2).max(1).values + + src_warped.append(torch.stack([F.adaptive_avg_pool2d(s_i[:, :real_H[i], :real_W[i]], sqrt_C) for i, s_i in enumerate(s)])) + pos_warped.append(torch.stack([F.adaptive_avg_pool2d(p_i[:, :real_H[i], :real_W[i]], sqrt_C) for i, p_i in enumerate(p)])) + + src_warped = torch.stack(src_warped, dim=1).flatten(-2).transpose(-2, -1) + pos_warped = torch.stack(pos_warped, dim=1).flatten(-2).transpose(-2, -1) + return src_warped, pos_warped + + +class DomainAttention(nn.Module): + """ Wraps domain-adapting cross attention and MLP into a module. + The operations are similar to those in Transformer, including normalization + layers and dropout layers, while MLP is simplified as a linear layer. + + Args: + d_model: total dimension of the model. + n_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. + """ + + def __init__(self, d_model, n_heads, dropout): + super(DomainAttention, self).__init__() + self.grl = GradientReversal() + self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + self.linear = nn.Linear(d_model, d_model) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward(self, query, src, pos=None, padding_mask=None): + """ Args: + query (batch_size, num_queries, d_model): discriminator query + src, pos (batch_size, sequence_length, d_model): patch tokens and position encodings + padding_mask (batch_size, sequence_length): key padding mask + """ + r_query, _ = self.cross_attn( + query=query.transpose(0, 1), + key=self.grl(self.with_pos_embed(src, pos)).transpose(0, 1), + value=self.grl(src).transpose(0, 1), + key_padding_mask=padding_mask, + ) + query = query + self.dropout1(r_query.transpose(0, 1)) + query = self.norm1(query) + query = query + self.dropout2(self.linear(query)) + query = self.norm2(query) + return query + + +# ------------------------------------------------------------------------------------------------------------------------------ +# Copy-paste from https://github.com/jvanvugt/pytorch-domain-adaptation/blob/35ac3a5a04b5e1cf5b2145b6c442c2d678362eef/utils.py +# ------------------------------------------------------------------------------------------------------------------------------ + + +class GradientReversalFunction(torch.autograd.Function): + """ + Gradient Reversal Layer from: + Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) + Forward pass is the identity function. In the backward pass, + the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) + """ + + @staticmethod + def forward(ctx, x, lambda_): + ctx.lambda_ = lambda_ + return x.clone() + + @staticmethod + def backward(ctx, grads): + lambda_ = ctx.lambda_ + lambda_ = grads.new_tensor(lambda_) + dx = -lambda_ * grads + return dx, None + + +class GradientReversal(nn.Module): + def __init__(self, lambda_=1): + super(GradientReversal, self).__init__() + self.lambda_ = lambda_ + + def forward(self, x): + return GradientReversalFunction.apply(x, self.lambda_) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..96e9e35 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +pycocotools +tqdm +cython +scipy +yacs diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..91f37be --- /dev/null +++ b/run.sh @@ -0,0 +1,29 @@ +# Cityscapes to Foggy Cityscapes +GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_foggy_r50.yaml \ +--opts DATASET.AOOD_SETTING 1 OUTPUT_DIR experiments/city_to_foggy/setting1 + +GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_foggy_r50.yaml \ +--opts DATASET.AOOD_SETTING 2 OUTPUT_DIR experiments/city_to_foggy/setting2 + +GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_foggy_r50.yaml \ +--opts DATASET.AOOD_SETTING 3 OUTPUT_DIR experiments/city_to_foggy/setting3 + +GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_foggy_r50.yaml \ +--opts DATASET.AOOD_SETTING 4 OUTPUT_DIR experiments/city_to_foggy/setting4 + +# Pascal to CLipart +GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_pascal_to_clipart_r50.yaml \ +--opts OUTPUT_DIR experiments/pascal_to_clipart + +# Cityscapes to BDD00k_daytime +GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_bdd100k_r50.yaml \ +--opts DATASET.AOOD_SETTING 1 OUTPUT_DIR experiments/city_to_bdd100k/setting1 + +GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_bdd100k_r50.yaml \ +--opts DATASET.AOOD_SETTING 2 OUTPUT_DIR experiments/city_to_bdd100k/setting2 + +GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_bdd100k_r50.yaml \ +--opts DATASET.AOOD_SETTING 3 OUTPUT_DIR experiments/city_to_bdd100k/setting3 + +GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_bdd100k_r50.yaml \ +--opts DATASET.AOOD_SETTING 4 OUTPUT_DIR experiments/city_to_bdd100k/setting4 \ No newline at end of file diff --git a/tools/launch.py b/tools/launch.py new file mode 100644 index 0000000..2b3ceaa --- /dev/null +++ b/tools/launch.py @@ -0,0 +1,192 @@ +# -------------------------------------------------------------------------------------------------------------------------- +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# -------------------------------------------------------------------------------------------------------------------------- +# Modified from https://github.com/pytorch/pytorch/blob/173f224570017b4b1a3a1a13d0bff280a54d9cd9/torch/distributed/launch.py +# -------------------------------------------------------------------------------------------------------------------------- + +r""" +`torch.distributed.launch` is a module that spawns up multiple distributed +training processes on each of the training nodes. +The utility can be used for single-node distributed training, in which one or +more processes per node will be spawned. The utility can be used for either +CPU training or GPU training. If the utility is used for GPU training, +each distributed process will be operating on a single GPU. This can achieve +well-improved single-node training performance. It can also be used in +multi-node distributed training, by spawning up multiple processes on each node +for well-improved multi-node distributed training performance as well. +This will especially be benefitial for systems with multiple Infiniband +interfaces that have direct-GPU support, since all of them can be utilized for +aggregated communication bandwidth. +In both cases of single-node distributed training or multi-node distributed +training, this utility will launch the given number of processes per node +(``--nproc_per_node``). If used for GPU training, this number needs to be less +or euqal to the number of GPUs on the current system (``nproc_per_node``), +and each process will be operating on a single GPU from *GPU 0 to +GPU (nproc_per_node - 1)*. +**How to use this module:** +1. Single-Node multi-process distributed training +:: + >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE + YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other + arguments of your training script) +2. Multi-Node multi-process distributed training: (e.g. two nodes) +Node 1: *(IP: 192.168.1.1, and has a free port: 1234)* +:: + >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" + --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) +Node 2: +:: + >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node_rank=1 --master_addr="192.168.1.1" + --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) +3. To look up what optional arguments this module offers: +:: + >>> python -m torch.distributed.launch --help +**Important Notices:** +1. This utilty and multi-process distributed (single-node or +multi-node) GPU training currently only achieves the best performance using +the NCCL distributed backend. Thus NCCL backend is the recommended backend to +use for GPU training. +2. In your training program, you must parse the command-line argument: +``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by this module. +If your training program uses GPUs, you should ensure that your code only +runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by: +Parsing the local_rank argument +:: + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> parser.add_argument("--local_rank", type=int) + >>> args = parser.parse_args() +Set your device to local rank using either +:: + >>> torch.cuda.set_device(arg.local_rank) # before your code runs +or +:: + >>> with torch.cuda.device(arg.local_rank): + >>> # your code to run +3. In your training program, you are supposed to call the following function +at the beginning to start the distributed backend. You need to make sure that +the init_method uses ``env://``, which is the only supported ``init_method`` +by this module. +:: + torch.distributed.init_process_group(backend='YOUR BACKEND', + init_method='env://') +4. In your training program, you can either use regular distributed functions +or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your +training program uses GPUs for training and you would like to use +:func:`torch.nn.parallel.DistributedDataParallel` module, +here is how to configure it. +:: + model = torch.nn.parallel.DistributedDataParallel(model, + device_ids=[arg.local_rank], + output_device=arg.local_rank) +Please ensure that ``device_ids`` argument is set to be the only GPU device id +that your code will be operating on. This is generally the local rank of the +process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``, +and ``output_device`` needs to be ``args.local_rank`` in order to use this +utility +5. Another way to pass ``local_rank`` to the subprocesses via environment variable +``LOCAL_RANK``. This behavior is enabled when you launch the script with +``--use_env=True``. You must adjust the subprocess example above to replace +``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher +will not pass ``--local_rank`` when you specify this flag. +.. warning:: + ``local_rank`` is NOT globally unique: it is only unique per process + on a machine. Thus, don't use it to decide if you should, e.g., + write to a networked filesystem. See + https://github.com/pytorch/pytorch/issues/12042 for an example of + how things can go wrong if you don't do this correctly. +""" + + +import sys +import subprocess +import os +import socket +from argparse import ArgumentParser, REMAINDER + +import torch + + +def parse_args(): + """ + Helper function parsing the command line options + @retval ArgumentParser + """ + parser = ArgumentParser(description="PyTorch distributed training launch " + "helper utilty that will spawn up " + "multiple distributed processes") + + # Optional arguments for the launch helper + parser.add_argument("--nnodes", type=int, default=1, + help="The number of nodes to use for distributed " + "training") + parser.add_argument("--node_rank", type=int, default=0, + help="The rank of the node for multi-node distributed " + "training") + parser.add_argument("--nproc_per_node", type=int, default=1, + help="The number of processes to launch on each node, " + "for GPU training, this is recommended to be set " + "to the number of GPUs in your system so that " + "each process can be bound to a single GPU.") + parser.add_argument("--master_addr", default="127.0.0.1", type=str, + help="Master node (rank 0)'s address, should be either " + "the IP address or the hostname of node 0, for " + "single node multi-proc training, the " + "--master_addr can simply be 127.0.0.1") + parser.add_argument("--master_port", default=29500, type=int, + help="Master node (rank 0)'s free port that needs to " + "be used for communciation during distributed " + "training") + + # positional + parser.add_argument("training_script", type=str, + help="The full path to the single GPU training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script") + + # rest from the training program + parser.add_argument('training_script_args', nargs=REMAINDER) + return parser.parse_args() + + +def main(): + args = parse_args() + + # world size in terms of number of processes + dist_world_size = args.nproc_per_node * args.nnodes + + # set PyTorch distributed related environmental variables + current_env = os.environ.copy() + current_env["MASTER_ADDR"] = args.master_addr + current_env["MASTER_PORT"] = str(args.master_port) + current_env["WORLD_SIZE"] = str(dist_world_size) + + processes = [] + + for local_rank in range(0, args.nproc_per_node): + # each process's rank + dist_rank = args.nproc_per_node * args.node_rank + local_rank + current_env["RANK"] = str(dist_rank) + current_env["LOCAL_RANK"] = str(local_rank) + + cmd = [args.training_script] + args.training_script_args + + process = subprocess.Popen(cmd, env=current_env) + processes.append(process) + + for process in processes: + process.wait() + if process.returncode != 0: + raise subprocess.CalledProcessError(returncode=process.returncode, + cmd=process.args) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tools/run_dist_launch.sh b/tools/run_dist_launch.sh new file mode 100644 index 0000000..f6f6c4f --- /dev/null +++ b/tools/run_dist_launch.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +set -x + +GPUS=$1 +RUN_COMMAND=${@:2} +if [ $GPUS -lt 8 ]; then + GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS} +else + GPUS_PER_NODE=${GPUS_PER_NODE:-8} +fi +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +MASTER_PORT=${MASTER_PORT:-"29500"} +NODE_RANK=${NODE_RANK:-0} + +let "NNODES=GPUS/GPUS_PER_NODE" + +python ./tools/launch.py \ + --nnodes ${NNODES} \ + --node_rank ${NODE_RANK} \ + --master_addr ${MASTER_ADDR} \ + --master_port ${MASTER_PORT} \ + --nproc_per_node ${GPUS_PER_NODE} \ + ${RUN_COMMAND} \ No newline at end of file diff --git a/tools/run_dist_slurm.sh b/tools/run_dist_slurm.sh new file mode 100644 index 0000000..bd73d0b --- /dev/null +++ b/tools/run_dist_slurm.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +# -------------------------------------------------------------------------------------------------------------------------- +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# -------------------------------------------------------------------------------------------------------------------------- +# Modified from https://github.com/open-mmlab/mmdetection/blob/3b53fe15d87860c6941f3dda63c0f27422da6266/tools/slurm_train.sh +# -------------------------------------------------------------------------------------------------------------------------- + +set -x + +PARTITION=$1 +JOB_NAME=$2 +GPUS=$3 +RUN_COMMAND=${@:4} +if [ $GPUS -lt 8 ]; then + GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS} +else + GPUS_PER_NODE=${GPUS_PER_NODE:-8} +fi +CPUS_PER_TASK=${CPUS_PER_TASK:-4} +SRUN_ARGS=${SRUN_ARGS:-""} + +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + ${RUN_COMMAND} + diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..01a98e4 --- /dev/null +++ b/util/__init__.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ diff --git a/util/anno_convert.py b/util/anno_convert.py new file mode 100644 index 0000000..c3bb09c --- /dev/null +++ b/util/anno_convert.py @@ -0,0 +1,422 @@ +# ---------------------------------------------- +# Created by Wei-Jie Huang +# A collection of annotation conversion scripts +# We provide this for reference only +# ---------------------------------------------- + + +import json +from pathlib import Path +from tqdm import tqdm + + +r""" +coco_anno_dict is like { + "images": list of image_info's + "annotations": list of annotation_info's + "categories": list of categorie_info's +} +where img_info is like: { + "id": ..., # 0-indexed + "width": ..., + "height": ..., + "file_name": ..., +}, annotation_info is like: { + "id": ..., # 0-indexed + "image_id": ..., + "category_id": ..., + "segmentation": ..., + "iscrowd": ..., + "area": ..., + "bbox": ..., # (x, y, w, h) +}, and category_info is like: { + "id": ..., # 1-indexed + "name": ..., +} +""" + + +def sim10k_to_coco( + src_path: str = "VOC2012/Annotations", + des_path: str = "annotations/sim10k_caronly.json", + categories: tuple = ("car",) + ) -> None: + + r""" Convert Sim10k (in VOC format) into COCO format. + Args: + src_path: path of the directory containing VOC-format annotations + des_path: destination of the converted COCO-fomat annotation + categories: only category ``car`` is considered by default + """ + + from xml.etree import ElementTree + + src_path = Path(src_path) + des_path = Path(des_path) + assert src_path.exists(), "Annotation directory does not exist" + if des_path.exists(): + print(f"{des_path} exists. Override? (y/n)", end=" ") + if input() != "y": + print("Abort") + return + else: + des_path.parent.mkdir(parents=True, exist_ok=True) + + # Initialization + coco_anno_dict = { + "images": [], + "categories": [], + "annotations": [], + } + num_images = 0 + num_categories = 0 + num_annotations = 0 + + # Categories + category_to_id = {} + for category in categories: + coco_anno_dict["categories"].append({ + "id": num_categories + 1, + "name": category + }) + category_to_id[category] = num_categories + 1 + num_categories += 1 + + # Start Conversion + for anno_file in tqdm(list(src_path.glob("*.xml"))): + et_root = ElementTree.parse(anno_file).getroot() + + ##### Images ##### + img_info = { + "id": num_images, + "file_name": anno_file.stem + ".jpg", + } + num_images += 1 + + # Image Size + size = et_root.find("size") + img_info["width"] = int(size.find("width").text) + img_info["height"] = int(size.find("height").text) + + coco_anno_dict["images"].append(img_info) + + ##### Annotations ##### + for anno_object in et_root.findall("object"): + category = anno_object.find("name").text + if category not in categories: + continue + anno_info = { + "id": num_annotations, + "image_id": img_info["id"], + "category_id": category_to_id[category], + "segmentation": [], + "iscrowd": 0 + } + num_annotations += 1 + + # Bounding box + bndbox = anno_object.find("bndbox") + xmin = float(bndbox.find("xmin").text) + ymin = float(bndbox.find("ymin").text) + xmax = float(bndbox.find("xmax").text) + ymax = float(bndbox.find("ymax").text) + # COCO format expects (x, y, w, h) + anno_info["bbox"] = [xmin, ymin, round(xmax - xmin, 2), round(ymax - ymin, 2)] + anno_info["area"] = round(anno_info["bbox"][2] * anno_info["bbox"][3], 2) + + coco_anno_dict["annotations"].append(anno_info) + + print("# of images:", num_images) + print("# of categories:", num_categories) + print("# of annotations:", num_annotations) + + with open(des_path, 'w') as f: + f.write(json.dumps(coco_anno_dict, indent=4)) + print(f"Convert successfully to {des_path}") + + +def bdd100k_daytime_to_coco( + src_path: str = "labels/bdd100k_labels_images_train.json", + des_path: str = "annotations/bdd_daytime_train.json", + categories: tuple = ( + "person", "rider", "car", "truck", "bus", "train", "motor", "bike") + ) -> None: + + r""" Extract ``daytime`` subset from BDD100k dataset and convert into COCO format. + Args: + src_path: source of the annotation json file + des_path: destination of the converted COCO-fomat annotation + categories: categories used + """ + + src_path = Path(src_path) + des_path = Path(des_path) + assert src_path.exists(), "Source annotation file does not exist" + if des_path.exists(): + print(f"{des_path} exists. Override? (y/n)", end=" ") + if input() != "y": + print("Abort") + return + else: + des_path.parent.mkdir(parents=True, exist_ok=True) + + # Initialization + coco_anno_dict = { + "images": [], + "categories": [], + "annotations": [], + } + num_images = 0 + num_categories = 0 + num_annotations = 0 + + # Categories + category_to_id = {} + for category in categories: + coco_anno_dict["categories"].append({ + "id": num_categories + 1, + "name": category + }) + category_to_id[category] = num_categories + 1 + num_categories += 1 + + with open(src_path, 'r') as f: + raw_img_annos = json.load(f) + # Start Conversion + for raw_img_anno in tqdm(raw_img_annos): + if raw_img_anno["attributes"]["timeofday"] != "daytime": + continue + + ##### Images ##### + img_info = { + "id": num_images, + "file_name": raw_img_anno["name"], + "height": 720, + "width": 1280 + } + coco_anno_dict["images"].append(img_info) + num_images += 1 + + ##### Annotations ##### + for label in raw_img_anno["labels"]: + if label["category"] not in category_to_id or "box2d" not in label: + continue + anno_info = { + "id": num_annotations, + "image_id": img_info["id"], + "category_id": category_to_id[label["category"]], + "segmentation": [], + "iscrowd": 0, + } + num_annotations += 1 + + # Bbox + x1 = label["box2d"]["x1"] + y1 = label["box2d"]["y1"] + x2 = label["box2d"]["x2"] + y2 = label["box2d"]["y2"] + anno_info["bbox"] = [x1, y1, x2 - x1, y2 - y1] + anno_info["area"] = float((x2 - x1) * (y2 - y1)) + coco_anno_dict["annotations"].append(anno_info) + + print("# of images:", num_images) + print("# of categories:", num_categories) + print("# of annotations:", num_annotations) + + with open(des_path, 'w') as f: + f.write(json.dumps(coco_anno_dict, indent=4)) + print(f"Convert successfully to {des_path}") + + +def cityscapes_to_coco( + src_path: str = "gtFine/train", + des_path: str = "annotations/cityscapes_train.json", + car_only: bool = False, + foggy: bool = False, + categories: tuple = ( + "person", "rider", "car", "truck", "bus", "train", "motor", "bike") + ) -> None: + + r"""Convert Cityscapes into COCO format. + Ref: https://github.com/facebookresearch/Detectron/blob/7aa91aa/tools/convert_cityscapes_to_coco.py + Args: + src_path: path of the directory containing Cityscapes annotations + des_path: destination of the converted COCO-fomat annotation + car_only: whether extract category ``car`` only. used in Syn-to-real adaptation + foggy: whether extract from foggy cityscapes. used in weather adaptation + categories: categories used + """ + + def get_instances_with_polygons(imageFileName): + r""" Ref: https://github.com/facebookresearch/Detectron/issues/111#issuecomment-363425465""" + import os + import sys + import cv2 + import numpy as np + from PIL import Image + from cityscapesscripts.evaluation.instance import Instance + from cityscapesscripts.helpers.csHelpers import labels, id2label + + # Load image + img = Image.open(imageFileName) + + # Image as numpy array + imgNp = np.array(img) + + # Initialize label categories + instances = {} + for label in labels: + instances[label.name] = [] + + # Loop through all instance ids in instance image + for instanceId in np.unique(imgNp): + if instanceId < 1000: + continue + + instanceObj = Instance(imgNp, instanceId) + instanceObj_dict = instanceObj.toDict() + + if id2label[instanceObj.labelID].hasInstances: + mask = (imgNp == instanceId).astype(np.uint8) + contour, hier = cv2.findContours( + mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + + polygons = [c.reshape(-1).tolist() for c in contour] + instanceObj_dict["contours"] = polygons + + instances[id2label[instanceObj.labelID].name].append( + instanceObj_dict) + return instances + + def polygon_to_bbox(polygon: list) -> list: + """Convert polygon into COCO-format bounding box.""" + + # https://github.com/facebookresearch/maskrcnn-benchmark/issues/288#issuecomment-449098063 + TO_REMOVE = 1 + + x0 = min(min(p[::2]) for p in polygon) + x1 = max(max(p[::2]) for p in polygon) + y0 = min(min(p[1::2]) for p in polygon) + y1 = max(max(p[1::2]) for p in polygon) + + bbox = [x0, y0, x1 -x0 + TO_REMOVE, y1 - y0 + TO_REMOVE] + return bbox + + src_path = Path(src_path) + des_path = Path(des_path) + assert src_path.exists(), "Source annotation file does not exist" + if des_path.exists(): + print(f"{des_path} exists. Override? (y/n)", end=" ") + if input() != "y": + print("Abort") + return + else: + des_path.parent.mkdir(parents=True, exist_ok=True) + + # Initialization + coco_anno_dict = { + "images": [], + "categories": [], + "annotations": [], + } + num_images = 0 + num_categories = 0 + num_annotations = 0 + + # Categories + if car_only: + categories = ("car",) + category_to_id = {} + for category in categories: + coco_anno_dict["categories"].append({ + "id": num_categories + 1, + "name": category + }) + category_to_id[category] = num_categories + 1 + num_categories += 1 + + # Start Conversion + for file in tqdm(list(src_path.rglob("*instanceIds.png"))): + ##### Images ##### + img_info = {"id": num_images} + num_images += 1 + img_info["file_name"] = \ + str(file.name).split("_", maxsplit=1)[0] + "/" + \ + str(file.name).replace("gtFine", "leftImg8bit").replace("_instanceIds", "") + if foggy: + img_info["file_name"] = \ + img_info["file_name"].replace("leftImg8bit", "leftImg8bit_foggy_beta_0.02") + with open(str(file).replace("instanceIds.png", "polygons.json"), "r") as f: + polygon_info = json.load(f) + img_info["width"] = polygon_info["imgWidth"] + img_info["height"] = polygon_info["imgHeight"] + coco_anno_dict["images"].append(img_info) + + ##### Annotations ##### + instances = get_instances_with_polygons(str(file.absolute())) + for category in instances.keys(): + if category not in categories: + continue + for instance in instances[category]: + anno_info = { + "id": num_annotations, + "image_id": img_info["id"], + "category_id": category_to_id[category], + "segmentation": [], + "iscrowd": 0, + "area": instance["pixelCount"], + "bbox": polygon_to_bbox(instance["contours"]), + } + num_annotations += 1 + coco_anno_dict["annotations"].append(anno_info) + + print("# of images:", num_images) + print("# of categories:", num_categories) + print("# of annotations:", num_annotations) + + with open(des_path, 'w') as f: + f.write(json.dumps(coco_anno_dict, indent=4)) + print(f"Convert successfully to {des_path}") + + +if __name__ == "__main__": + sim10k_to_coco( + src_path="VOC2012/Annotations", + des_path="annotations/sim10k_caronly.json" + ) + bdd100k_daytime_to_coco( + src_path="labels/bdd100k_labels_images_train.json", + des_path="annotations/bdd_daytime_train.json" + ) + bdd100k_daytime_to_coco( + src_path="labels/bdd100k_labels_images_val.json", + des_path="annotations/bdd_daytime_val.json" + ) + cityscapes_to_coco( + src_path="gtFine/train", + des_path="annotations/cityscapes_train.json", + ) + cityscapes_to_coco( + src_path="gtFine/val", + des_path="annotations/cityscapes_val.json", + ) + cityscapes_to_coco( + src_path="gtFine/train", + des_path="annotations/cityscapes_caronly_train.json", + car_only=True, + ) + cityscapes_to_coco( + src_path="gtFine/val", + des_path="annotations/cityscapes_caronly_val.json", + car_only=True, + ) + cityscapes_to_coco( + src_path="gtFine/train", + des_path="annotations/foggy_cityscapes_train.json", + foggy=True, + ) + cityscapes_to_coco( + src_path="gtFine/val", + des_path="annotations/foggy_cityscapes_val.json", + foggy=True, + ) diff --git a/util/box_ops.py b/util/box_ops.py new file mode 100644 index 0000000..3b87b2b --- /dev/null +++ b/util/box_ops.py @@ -0,0 +1,98 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = (masks * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = (masks * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/util/misc.py b/util/misc.py new file mode 100644 index 0000000..31f0584 --- /dev/null +++ b/util/misc.py @@ -0,0 +1,534 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from typing import Optional, List + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +if float(torchvision.__version__[:3]) < 0.5: + import math + from torchvision.ops.misc import _NewEmptyTensorOp + def _check_size_scale_factor(dim, size, scale_factor): + # type: (int, Optional[List[int]], Optional[float]) -> None + if size is None and scale_factor is None: + raise ValueError("either size or scale_factor should be defined") + if size is not None and scale_factor is not None: + raise ValueError("only one of size or scale_factor should be defined") + if not (scale_factor is not None and len(scale_factor) != dim): + raise ValueError( + "scale_factor shape must match input shape. " + "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor)) + ) + def _output_size(dim, input, size, scale_factor): + # type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int] + assert dim == 2 + _check_size_scale_factor(dim, size, scale_factor) + if size is not None: + return size + # if dim is not 2 or scale_factor is iterable use _ntuple instead of concat + assert scale_factor is not None and isinstance(scale_factor, (int, float)) + scale_factors = [scale_factor, scale_factor] + # math.floor might return float in py2.7 + return [ + int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim) + ] +elif float(torchvision.__version__[:3]) < 0.7: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None,logger=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + if logger is not None and is_main_process(): + logger.info(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + if logger is not None and is_main_process(): + logger.info('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + else: + + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device, non_blocking=False): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device, non_blocking=non_blocking) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device, non_blocking=non_blocking) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def record_stream(self, *args, **kwargs): + self.tensors.record_stream(*args, **kwargs) + if self.mask is not None: + self.mask.record_stream(*args, **kwargs) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_local_size(): + if not is_dist_avail_and_initialized(): + return 1 + return int(os.environ['LOCAL_SIZE']) + + +def get_local_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return int(os.environ['LOCAL_RANK']) + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(cfg): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + cfg.DIST.RANK = int(os.environ["RANK"]) + cfg.DIST.WORLD_SIZE = int(os.environ['WORLD_SIZE']) + cfg.DIST.GPU = int(os.environ['LOCAL_RANK']) + cfg.DIST.DIST_URL = 'env://' + os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) + elif 'SLURM_PROCID' in os.environ: + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + addr = subprocess.getoutput( + 'scontrol show hostname {} | head -n1'.format(node_list)) + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['RANK'] = str(proc_id) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['LOCAL_SIZE'] = str(num_gpus) + cfg.DIST.DIST_URL = 'env://' + cfg.DIST.WORLD_SIZE = ntasks + cfg.DIST.RANK = proc_id + cfg.DIST.GPU = proc_id % num_gpus + else: + print('Not using distributed mode') + cfg.DIST.DISTRIBUTED = False + return + + cfg.DIST.DISTRIBUTED = True + + torch.cuda.set_device(cfg.DIST.GPU) + cfg.DIST.DIST_BACKEND = 'nccl' + print('| distributed init (rank {}): {}'.format( + cfg.DIST.RANK, cfg.DIST.DIST_URL), flush=True) + torch.distributed.init_process_group(backend=cfg.DIST.DIST_BACKEND, init_method=cfg.DIST.DIST_URL, + world_size=cfg.DIST.WORLD_SIZE, rank=cfg.DIST.RANK) + torch.distributed.barrier() + setup_for_distributed(cfg.DIST.RANK == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__[:3]) < 0.7: + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + if float(torchvision.__version__[:3]) < 0.5: + return _NewEmptyTensorOp.apply(input, output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) + + +def get_total_grad_norm(parameters, norm_type=2): + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + device = parameters[0].grad.device + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), + norm_type) + return total_norm + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1/x2) + diff --git a/util/plot_utils.py b/util/plot_utils.py new file mode 100644 index 0000000..f6ff051 --- /dev/null +++ b/util/plot_utils.py @@ -0,0 +1,113 @@ +# ------------------------------------------------------------------------ +# Modified by Wei-Jie Huang +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Plotting utilities to visualize training logs. +""" +import torch +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt + +from pathlib import Path, PurePath + + +def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): + ''' + Function to plot specific fields from training log(s). Plots both training and test results. + + :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file + - fields = which results to plot from each log file - plots both training and test for each field. + - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots + - log_name = optional, name of log file if different than default 'log.txt'. + + :: Outputs - matplotlib plots of results in fields, color coded for each log file. + - solid lines are training results, dashed lines are test results. + + ''' + func_name = "plot_utils.py::plot_logs" + + # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, + # convert single Path to list to avoid 'not iterable' error + + if not isinstance(logs, list): + if isinstance(logs, PurePath): + logs = [logs] + print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") + else: + raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ + Expect list[Path] or single Path obj, received {type(logs)}") + + # verify valid dir(s) and that every item in list is Path object + for i, dir in enumerate(logs): + if not isinstance(dir, PurePath): + raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") + if dir.exists(): + continue + raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") + + # load log file(s) and plot + dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] + + fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) + + for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): + for j, field in enumerate(fields): + if field == 'mAP': + coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean() + axs[j].plot(coco_eval, c=color) + else: + df.interpolate().ewm(com=ewm_col).mean().plot( + y=[f'train_{field}', f'test_{field}'], + ax=axs[j], + color=[color] * 2, + style=['-', '--'] + ) + for ax, field in zip(axs, fields): + ax.legend([Path(p).name for p in logs]) + ax.set_title(field) + + +def plot_precision_recall(files, naming_scheme='iter'): + if naming_scheme == 'exp_id': + # name becomes exp_id + names = [f.parts[-3] for f in files] + elif naming_scheme == 'iter': + names = [f.stem for f in files] + else: + raise ValueError(f'not supported {naming_scheme}') + fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) + for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): + data = torch.load(f) + # precision is n_iou, n_points, n_cat, n_area, max_det + precision = data['precision'] + recall = data['params'].recThrs + scores = data['scores'] + # take precision for all classes, all areas and 100 detections + precision = precision[0, :, :, 0, -1].mean(1) + scores = scores[0, :, :, 0, -1].mean(1) + prec = precision.mean() + rec = data['recall'][0, :, 0, -1].mean() + print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + + f'score={scores.mean():0.3f}, ' + + f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' + ) + axs[0].plot(recall, precision, c=color) + axs[1].plot(recall, scores, c=color) + + axs[0].set_title('Precision / Recall') + axs[0].legend(names) + axs[1].set_title('Scores / Recall') + axs[1].legend(names) + return fig, axs + + +