diff --git a/README.md b/README.md
index 5c23997..27792f6 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,177 @@
-# SOMA
-[ICCV' 23] Novel Scenes & Classes: Towards Adaptive Open-set Object Detection
+# [Novel Scenes & Classes: Towards Adaptive Open-set Object Detection (ICCV-23 ORAL)](assets/paper.pdf)
+
+By [Wuyang Li](https://wymancv.github.io/wuyang.github.io/)
+
+Paper link will be updated after the CVF open access.
+
+
+
+
+
+Domain Adaptive Object Detection (DAOD) strongly assumes a shared class space between the two domains.
+
+This work breaks the assumption and formulates Adaptive Open-set Object Detection (AOOD), by allowing the target domain with novel-class objects.
+
+The object detector uses the base-class labels in the source domain for training, and aims to detect base-class objects and identify novel-class objects as unknown in the target domain.
+
+If you have any ideas and problems hope to discuss, you can reach me out via [E-mail](mailto:wuyangli2-c@my.cityu.edu.hk).
+
+# 💡 Preparation
+
+## Setp 1: Clone and Install the Project
+
+### Clone the repository
+
+```bash
+git clone https://github.com/CityU-AIM-Group/SOMA.git
+```
+
+### Install the project following [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR)
+
+Note that the following is in line with our experimental environments, which is silightly different from the official one.
+
+```
+# Linux, CUDA>=9.2, GCC>=5.4
+# (ours) CUDA=10.2, GCC=8.4, NVIDIA V100
+# Establish the conda environment
+
+conda create -n aood python=3.7 pip
+conda activate aood
+conda install pytorch=1.5.1 torchvision=0.6.1 cudatoolkit=10.2 -c pytorch
+pip install -r requirements.txt
+
+# Compile the project
+cd ./models/ops
+sh ./make.sh
+
+# unit test (should see all checking is True)
+python test.py
+
+# NOTE: If you meet the permission denied issue when starting the training
+cd ../../
+chmod -R 777 ./
+```
+
+## Setp 2: Download Necessary Resources
+
+### Download pre-processed datasets (VOC format) from the following links
+
+| | (Foggy) Cityscapes | Pascal VOC | Clipart | BDD100K |
+| :------------: | :------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------: |
+| Official Links | [Imgs](https://www.cityscapes-dataset.com/login/) | [Imgs+Labels](https://pjreddie.com/projects/pascal-voc-dataset-mirror/) | - | - |
+| Our Links | [Labels](https://portland-my.sharepoint.com/:u:/g/personal/wuyangli2-c_my_cityu_edu_hk/EVNAjK2JkG9ChREzzqdqJkYBLoZ_VOqkMdhWasN_BETGWw?e=fP9Ae4) | - | [Imgs+Labels](https://portland-my.sharepoint.com/:u:/g/personal/wuyangli2-c_my_cityu_edu_hk/Edz2YcXHuStIqwM_NA7k8FMBGLeyAGQcSjdSR-vYaVx_vw?e=es6KDW) | [Imgs+Labels](https://portland-my.sharepoint.com/:u:/g/personal/wuyangli2-c_my_cityu_edu_hk/EeiO6O36QgZKnTcUZMInACIB0dfWEg4OFyoEZnZCkibKHA?e=6byqBX) |
+
+### Download DINO-pretrained ResNet-50 from this [link](https://portland-my.sharepoint.com/:u:/g/personal/wuyangli2-c_my_cityu_edu_hk/EVnK9IPi91ZPuNmwpeSWGHABqhSFQK52I7xGzroXKeuyzA?e=EnlwgO)
+
+## Setp 3: Change the Path
+
+### Change the data path as follows.
+
+```
+[DATASET_PATH]
+└─ Cityscapes
+ └─ AOOD_Annotations
+ └─ AOOD_Main
+ └─ train_source.txt
+ └─ train_target.txt
+ └─ val_source.txt
+ └─ val_target.txt
+ └─ leftImg8bit
+ └─ train
+ └─ val
+ └─ leftImg8bit_foggy
+ └─ train
+ └─ val
+└─ bdd_daytime
+ └─ Annotations
+ └─ ImageSets
+ └─ JPEGImages
+└─ clipart
+ └─ Annotations
+ └─ ImageSets
+ └─ JPEGImages
+└─ VOCdevkit
+ └─ VOC2007
+ └─ VOC2012
+```
+
+### Change the data root folder in config files
+
+Replace the DATASET.COCO_PATH in all yaml files in [config](configs) by your data root $DATASET_PATH, e.g., Line 22 of [soma_aood_city_to_foggy_r50.yaml](configs/soma_aood_city_to_foggy_r50.yaml)
+
+### Change the path of DINO-pretrained backbone
+
+Replace the backbone loading path at Line 107 of [backbone.py](models/backbone.py).
+
+# 🔥 Start Training
+
+We use two GPUs for training with 2 source images and 2 target images as input.
+
+```bash
+GPUS_PER_NODE=2
+./tools/run_dist_launch.sh 2 python main.py --config_file {CONFIG_FILE} --opts DATASET.AOOD_SETTING 1
+```
+
+We provide some scripts in our experiments in [run.sh](./run.sh). After "--opts", the settings will overwrite the default config file as the maskrcnn-benchmark framework.
+
+# 📦 Well-trained models
+
+Will be provided later
+
+
+
+
+# 💬 Notification
+
+- The core idea is to select informative motifs (which can be trated as the mix-up of object queries) for self-training.
+- You can try the DA version of [OW-DETR](https://github.com/akshitac8/OW-DETR) in this repository by setting:
+```
+-opts AOOD.OW_DETR_ON True
+```
+- Adopting SAM to address AOOD may be a good direction.
+- To visualize unknown boxes, post-processing is needed in Line736 of [PostProcess](models/motif_detr.py).
+
+# 📝 Citation
+
+If you think this work is helpful for your project, please give it a star and citation. We sincerely appreciate your acknowledgment.
+
+```BibTeX
+@InProceedings{li2023novel,
+ title={Novel Scenes & Classes: Towards Adaptive Open-set Object Detection},
+ author={Li, Wuyang and Guo, Xiaoqing and Yuan, Yixuan},
+ booktitle={ICCV},
+ year={2023}
+}
+```
+
+Relevant project:
+
+Exploring the similar issue for the classifictaion task. [[link]](https://openaccess.thecvf.com/content/CVPR2023/html/Li_Adjustment_and_Alignment_for_Unbiased_Open_Set_Domain_Adaptation_CVPR_2023_paper.html)
+
+```BibTeX
+@InProceedings{Li_2023_CVPR,
+ author = {Li, Wuyang and Liu, Jie and Han, Bo and Yuan, Yixuan},
+ title = {Adjustment and Alignment for Unbiased Open Set Domain Adaptation},
+ booktitle = {CVPR},
+ year = {2023},
+}
+```
+
+# 🤞 Acknowledgements
+
+We greatly appreciate the tremendous effort for the following works.
+
+- This work is based on DAOD framework [AQT](https://github.com/weii41392/AQT).
+- Our work is highly inspired by [OW-DETR](https://github.com/akshitac8/OW-DETR) and [OpenDet](https://github.com/csuhan/opendet2).
+- The implementation of the basic detector is based on [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR).
+
+# 📒 Abstract
+
+Domain Adaptive Object Detection (DAOD) transfers an object detector to a novel domain free of labels. However, in the real world, besides encountering novel scenes, novel domains always contain novel-class objects de facto, which are ignored in existing research. Thus, we formulate and study a more practical setting, Adaptive Open-set Object Detection (AOOD), considering both novel scenes and classes. Directly combing off-the-shelled cross-domain and open-set approaches is sub-optimal since their low-order dependence, such as the confidence score, is insufficient for the AOOD with two dimensions of novel information. To address this, we propose a novel Structured Motif Matching (SOMA) framework for AOOD, which models the high-order relation with motifs, \ie, statistically significant subgraphs, and formulates AOOD solution as motif matching to learn with high-order patterns. In a nutshell, SOMA consists of Structure-aware Novel-class Learning (SNL) and Structure-aware Transfer Learning (STL). As for SNL, we establish an instance-oriented graph to capture the class-independent object feature hidden in different base classes. Then, a high-order metric is proposed to match the most significant motif as high-order patterns, serving for motif-guided novel-class learning. In STL, we set up a semantic-oriented graph to model the class-dependent relation across domains, and match unlabelled objects with high-order motifs to align the cross-domain distribution with structural awareness. Extensive experiments demonstrate that the proposed SOMA achieves state-of-the-art performance.
+
+![image](./assets/overall.png)
diff --git a/assets/mot.png b/assets/mot.png
new file mode 100644
index 0000000..72a5962
Binary files /dev/null and b/assets/mot.png differ
diff --git a/assets/overall.png b/assets/overall.png
new file mode 100644
index 0000000..a0cd42b
Binary files /dev/null and b/assets/overall.png differ
diff --git a/assets/paper.pdf b/assets/paper.pdf
new file mode 100644
index 0000000..1f1f4ba
Binary files /dev/null and b/assets/paper.pdf differ
diff --git a/benchmark.py b/benchmark.py
new file mode 100644
index 0000000..93702da
--- /dev/null
+++ b/benchmark.py
@@ -0,0 +1,69 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+
+"""
+Benchmark inference speed of Deformable DETR.
+"""
+import os
+import time
+import argparse
+
+import torch
+
+from main import get_args_parser as get_main_args_parser
+from models import build_model
+from datasets import build_dataset
+from util.misc import nested_tensor_from_tensor_list
+
+
+def get_benckmark_arg_parser():
+ parser = argparse.ArgumentParser('Benchmark inference speed of Deformable DETR.')
+ parser.add_argument('--num_iters', type=int, default=300, help='total iters to benchmark speed')
+ parser.add_argument('--warm_iters', type=int, default=5, help='ignore first several iters that are very slow')
+ parser.add_argument('--batch_size', type=int, default=1, help='batch size in inference')
+ parser.add_argument('--resume', type=str, help='load the pre-trained checkpoint')
+ return parser
+
+
+@torch.no_grad()
+def measure_average_inference_time(model, inputs, num_iters=100, warm_iters=5):
+ ts = []
+ for iter_ in range(num_iters):
+ torch.cuda.synchronize()
+ t_ = time.perf_counter()
+ model(inputs)
+ torch.cuda.synchronize()
+ t = time.perf_counter() - t_
+ if iter_ >= warm_iters:
+ ts.append(t)
+ print(ts)
+ return sum(ts) / len(ts)
+
+
+def benchmark():
+ args, _ = get_benckmark_arg_parser().parse_known_args()
+ main_args = get_main_args_parser().parse_args(_)
+ assert args.warm_iters < args.num_iters and args.num_iters > 0 and args.warm_iters >= 0
+ assert args.batch_size > 0
+ assert args.resume is None or os.path.exists(args.resume)
+ dataset = build_dataset('val', main_args)
+ model, _, _ = build_model(main_args)
+ model.cuda()
+ model.eval()
+ if args.resume is not None:
+ ckpt = torch.load(args.resume, map_location=lambda storage, loc: storage)
+ model.load_state_dict(ckpt['model'])
+ inputs = nested_tensor_from_tensor_list([dataset.__getitem__(0)[0].cuda() for _ in range(args.batch_size)])
+ t = measure_average_inference_time(model, inputs, args.num_iters, args.warm_iters)
+ return 1.0 / t * args.batch_size
+
+
+if __name__ == '__main__':
+ fps = benchmark()
+ print(f'Inference Speed: {fps:.1f} FPS')
+
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..93f574c
--- /dev/null
+++ b/config.py
@@ -0,0 +1,167 @@
+from yacs.config import CfgNode as CN
+from numpy import pi
+
+_C = CN()
+
+# ------------------------------------------------------------------------
+# Training
+# ------------------------------------------------------------------------
+_C.TRAIN = CN()
+_C.TRAIN.LR = 2e-4
+_C.TRAIN.LR_BACKBONE_NAMES = ["backbone.0"]
+_C.TRAIN.LR_BACKBONE = 2e-5
+_C.TRAIN.LR_LINEAR_PROJ_NAMES = ['reference_points', 'sampling_offsets']
+_C.TRAIN.LR_LINEAR_PROJ_MULT = 0.1
+_C.TRAIN.BATCH_SIZE = 2
+_C.TRAIN.WEIGHT_DECAY = 1e-4
+_C.TRAIN.EPOCHS = 50
+_C.TRAIN.LR_DROP = 40
+_C.TRAIN.LR_DROP_EPOCHS = None
+_C.TRAIN.CLIP_MAX_NORM = 0.1 # gradient clipping max norm
+_C.TRAIN.SGD = False # AdamW is used when setting this false
+
+
+# ------------------------------------------------------------------------
+# Model
+# ------------------------------------------------------------------------
+_C.MODEL = CN()
+
+# Variants of Deformable DETR
+_C.MODEL.WITH_BOX_REFINE = False
+_C.MODEL.TWO_STAGE = False
+
+# Model parameters
+_C.MODEL.FROZEN_WEIGHTS = None # Path to the pretrained model. If set, only the mask head will be trained
+
+# * Backbone
+_C.MODEL.BACKBONE = 'resnet50' # Name of the convolutional backbone to use
+_C.MODEL.DILATION = False # If true, we replace stride with dilation in the last convolutional block (DC5)
+_C.MODEL.POSITION_EMBEDDING = 'sine' # ('sine', 'learned') Type of positional embedding to use on top of the image features
+_C.MODEL.POSITION_EMBEDDING_SCALE = 2 * pi # position / size * scale
+_C.MODEL.NUM_FEATURE_LEVELS = 4 # number of feature levels
+
+# * Transformer
+_C.MODEL.ENC_LAYERS = 6 # Number of encoding layers in the transformer
+_C.MODEL.DEC_LAYERS = 6 # Number of decoding layers in the transformer
+_C.MODEL.DIM_FEEDFORWARD = 1024 # Intermediate size of the feedforward layers in the transformer blocks
+_C.MODEL.HIDDEN_DIM = 256 # Size of the embeddings (dimension of the transformer)
+_C.MODEL.DROPOUT = 0.1 # Dropout applied in the transformer
+_C.MODEL.NHEADS = 8 # Number of attention heads inside the transformer's attentions
+_C.MODEL.NUM_QUERIES = 300 # Number of query slots
+_C.MODEL.DEC_N_POINTS = 4
+_C.MODEL.ENC_N_POINTS = 4
+
+# * Segmentation
+_C.MODEL.MASKS = False # Train segmentation head if the flag is provided
+
+# * Domain Adaptation
+_C.MODEL.BACKBONE_ALIGN = False
+_C.MODEL.SPACE_ALIGN = False
+_C.MODEL.CHANNEL_ALIGN = False
+_C.MODEL.INSTANCE_ALIGN = False
+
+# ------------------------------------------------------------------------
+# Deformable DETR baseline Loss
+# ------------------------------------------------------------------------
+_C.LOSS = CN()
+_C.LOSS.AUX_LOSS = True # auxiliary decoding losses (loss at each layer)
+
+# * Matcher
+_C.LOSS.SET_COST_CLASS = 2. # Class coefficient in the matching cost
+_C.LOSS.SET_COST_BBOX = 5. # L1 box coefficient in the matching cost
+_C.LOSS.SET_COST_GIOU = 2. # giou box coefficient in the matching cost
+
+# * Loss coefficients
+_C.LOSS.MASK_LOSS_COEF = 1.
+_C.LOSS.DICE_LOSS_COEF = 1.
+_C.LOSS.CLS_LOSS_COEF = 2.
+_C.LOSS.BBOX_LOSS_COEF = 5.
+_C.LOSS.GIOU_LOSS_COEF = 2.
+
+_C.LOSS.SPACE_QUERY_LOSS_COEF = 0.1
+_C.LOSS.CHANNEL_QUERY_LOSS_COEF = 0.1
+_C.LOSS.INSTANCE_QUERY_LOSS_COEF = 0.1
+_C.LOSS.FOCAL_ALPHA = 0.25
+_C.LOSS.DA_GAMMA = 0
+
+
+# ------------------------------------------------------------------------
+# dataset parameters
+# ------------------------------------------------------------------------
+_C.DATASET = CN()
+_C.DATASET.DA_MODE = 'source_only' # ('source_only', 'uda', 'oracle')
+_C.DATASET.NUM_CLASSES = 9 # This should be set as max_class_id + 1
+_C.DATASET.DATASET_FILE = 'cityscapes_to_foggy_cityscapes'
+_C.DATASET.COCO_PATH = '../datasets'
+_C.DATASET.COCO_PANOPTIC_PATH = None
+_C.DATASET.REMOVE_DIFFICULT = False
+
+
+# ------------------------------------------------------------------------
+# Distributed
+# ------------------------------------------------------------------------
+_C.DIST = CN()
+_C.DIST.DISTRIBUTED = False
+_C.DIST.RANK = None
+_C.DIST.WORLD_SIZE = None
+_C.DIST.GPU = None
+_C.DIST.DIST_URL = None
+_C.DIST.DIST_BACKEND = None
+
+# ------------------------------------------------------------------------
+# Miscellaneous
+# ------------------------------------------------------------------------
+_C.OUTPUT_DIR = '' # path where to save, empty for no saving
+_C.DEVICE = 'cuda' # device to use for training / testing
+_C.RESUME = '' # resume from checkpoint
+_C.START_EPOCH = 0 # start epoch
+_C.EVAL = False
+_C.NUM_WORKERS = 2
+_C.CACHE_MODE = False # whether to cache images on memory
+_C.SEED = 42 # Note this this cannot strictly control the same results. I don not know why
+
+# ------------------------------------------------------------------------
+# Adaptive Open-set Object Detection (AOOD)
+# ------------------------------------------------------------------------
+_C.AOOD = CN()
+_C.AOOD.OPEN_SET = CN()
+_C.AOOD.CROSS_DOMAIN = CN()
+
+_C.AOOD.OW_DETR_ON = False
+_C.AOOD.MOTIF_ON = False
+_C.AOOD.OPENDET_DETR_ON = False
+_C.DATASET.AOOD_SETTING = 1
+_C.DATASET.AOOD_TASK = 4
+_C.DATASET.AOOD_SCENE = 'cityscapes'
+# _C.DATASET.AOOD_SCENE = 'pascal'
+
+# global alignment + def-detr baseline
+_C.AOOD.CROSS_DOMAIN.BACKBONE_LAMBDA = 1.0
+_C.LOSS.BACKBONE_LOSS_COEF = 0.1
+_C.LOAD_OPTIMIZER = True
+_C.EVAL_EPOCH = 29
+
+# For novel-class
+_C.AOOD.OPEN_SET.MOTIF_ON = False
+_C.AOOD.OPEN_SET.KNN = 5
+_C.AOOD.OPEN_SET.TH = 0.5
+_C.AOOD.OPEN_SET.MOTIF_LOSS_COEF = 1.0
+_C.AOOD.OPEN_SET.WITH_SELF_LABELING = False
+_C.AOOD.OPEN_SET.UNK_PROB = 0.0
+_C.AOOD.OPEN_SET.WARM_UP = -1 # -1 indicates no warm-up
+_C.AOOD.OPEN_SET.MOTIF_UPDATE = True
+_C.AOOD.OPEN_SET.ALPHA = 0.01
+
+# For novel-scene
+_C.AOOD.CROSS_DOMAIN.WARM_UP = -1
+_C.AOOD.CROSS_DOMAIN.MOTIF_ON = False
+_C.AOOD.CROSS_DOMAIN.MOTIF_LOSS_COEF = 0.01
+_C.AOOD.CROSS_DOMAIN.KNN = 5
+_C.AOOD.CROSS_DOMAIN.BETA = 1.0
+
+
+def get_cfg_defaults():
+ """Get a yacs CfgNode object with default values for my_project."""
+ # Return a clone so that the defaults will not be altered
+ # This is for the "local variable" use pattern
+ return _C.clone()
diff --git a/configs/soma_aood_city_to_bdd100k_r50.yaml b/configs/soma_aood_city_to_bdd100k_r50.yaml
new file mode 100644
index 0000000..e9ee55a
--- /dev/null
+++ b/configs/soma_aood_city_to_bdd100k_r50.yaml
@@ -0,0 +1,93 @@
+CACHE_MODE: False
+AOOD:
+ MOTIF_ON: True
+ # novel-class
+ OPEN_SET:
+ WARM_UP: 8
+ TH: 0.5
+ MOTIF_ON: True
+ ALPHA: 0.01
+ KNN: 5
+ UNK_PROB: 0.5
+ MOTIF_LOSS_COEF: 0.05
+ # novel-scene
+ CROSS_DOMAIN:
+ MOTIF_ON: True
+ BETA: 1.0 # std scaling
+DATASET:
+ DA_MODE: aood
+ AOOD_SETTING: 1 # different class splittings 1-4
+ AOOD_SCENE: 'cityscapes'
+ COCO_PANOPTIC_PATH: None
+ COCO_PATH: /home/wuyangli2/data/
+ DATASET_FILE: cityscapes_to_bdd_daytime
+ NUM_CLASSES: 4 # 3 known + 1
+ REMOVE_DIFFICULT: False
+DEVICE: cuda
+DIST:
+ DISTRIBUTED: False
+ DIST_BACKEND: nccl
+ DIST_URL: env://
+ GPU: 0
+ RANK: 0
+ WORLD_SIZE: 4
+EVAL: False
+LOSS:
+ AUX_LOSS: True
+ BACKBONE_LOSS_COEF: 0.1
+ BBOX_LOSS_COEF: 5.0
+ CHANNEL_QUERY_LOSS_COEF: 0.1
+ CLS_LOSS_COEF: 2.0
+ DA_GAMMA: 0
+ DICE_LOSS_COEF: 1.0
+ FOCAL_ALPHA: 0.25
+ GIOU_LOSS_COEF: 2.0
+ INSTANCE_QUERY_LOSS_COEF: 0.1
+ MASK_LOSS_COEF: 1.0
+ SET_COST_BBOX: 5.0
+ SET_COST_CLASS: 2.0
+ SET_COST_GIOU: 2.0
+ SPACE_QUERY_LOSS_COEF: 0.1
+MODEL:
+ BACKBONE: resnet50
+ BACKBONE_ALIGN: True
+ CHANNEL_ALIGN: False
+ DEC_LAYERS: 3
+ DEC_N_POINTS: 4
+ DILATION: False
+ DIM_FEEDFORWARD: 1024
+ DROPOUT: 0.1
+ ENC_LAYERS: 3
+ ENC_N_POINTS: 4
+ FROZEN_WEIGHTS: None
+ HIDDEN_DIM: 256
+ INSTANCE_ALIGN: False
+ MASKS: False
+ NHEADS: 8
+ NUM_FEATURE_LEVELS: 4
+ NUM_QUERIES: 100
+ POSITION_EMBEDDING: sine
+ POSITION_EMBEDDING_SCALE: 6.283185307179586
+ SPACE_ALIGN: False
+ TWO_STAGE: False
+ WITH_BOX_REFINE: False
+NUM_WORKERS: 2
+OUTPUT_DIR: exps/r50_aood_c2b
+RESUME:
+LOAD_OPTIMIZER: False
+SEED: 42 # This cannot strictly control the same results. Sorry, I do not know why.
+START_EPOCH: 0
+EVAL_EPOCH: 9
+TRAIN:
+ BATCH_SIZE: 4 # each gpu, 2 for source and 2 for target
+ CLIP_MAX_NORM: 0.1
+ EPOCHS: 20
+ LR: 0.0002
+ LR_BACKBONE: 2e-05
+ LR_BACKBONE_NAMES: ['backbone.0']
+ LR_DROP: 15
+ LR_DROP_EPOCHS: None
+ LR_LINEAR_PROJ_MULT: 0.1
+ LR_LINEAR_PROJ_NAMES: ['reference_points', 'sampling_offsets']
+ SGD: False
+ WEIGHT_DECAY: 0.0001
\ No newline at end of file
diff --git a/configs/soma_aood_city_to_foggy_r50.yaml b/configs/soma_aood_city_to_foggy_r50.yaml
new file mode 100644
index 0000000..4af00ff
--- /dev/null
+++ b/configs/soma_aood_city_to_foggy_r50.yaml
@@ -0,0 +1,93 @@
+CACHE_MODE: False
+AOOD:
+ MOTIF_ON: True
+ # novel-class
+ OPEN_SET:
+ WARM_UP: 9
+ TH: 0.5
+ MOTIF_ON: True
+ ALPHA: 0.01
+ KNN: 5
+ UNK_PROB: 0.5
+ MOTIF_LOSS_COEF: 0.1
+ # novel-scene
+ CROSS_DOMAIN:
+ MOTIF_ON: True
+ BETA: 1.0 # std scaling
+DATASET:
+ DA_MODE: aood
+ AOOD_SETTING: 1 # different class splittings 1-4
+ AOOD_SCENE: 'cityscapes'
+ COCO_PANOPTIC_PATH: None
+ COCO_PATH: /home/wuyangli2/data/
+ DATASET_FILE: cityscapes_to_foggy_cityscapes
+ NUM_CLASSES: 4 # 3 known + 1
+ REMOVE_DIFFICULT: False
+DEVICE: cuda
+DIST:
+ DISTRIBUTED: False
+ DIST_BACKEND: nccl
+ DIST_URL: env://
+ GPU: 0
+ RANK: 0
+ WORLD_SIZE: 4
+EVAL: False
+LOSS:
+ AUX_LOSS: True
+ BACKBONE_LOSS_COEF: 0.1
+ BBOX_LOSS_COEF: 5.0
+ CHANNEL_QUERY_LOSS_COEF: 0.1
+ CLS_LOSS_COEF: 2.0
+ DA_GAMMA: 0
+ DICE_LOSS_COEF: 1.0
+ FOCAL_ALPHA: 0.25
+ GIOU_LOSS_COEF: 2.0
+ INSTANCE_QUERY_LOSS_COEF: 0.1
+ MASK_LOSS_COEF: 1.0
+ SET_COST_BBOX: 5.0
+ SET_COST_CLASS: 2.0
+ SET_COST_GIOU: 2.0
+ SPACE_QUERY_LOSS_COEF: 0.1
+MODEL:
+ BACKBONE: resnet50
+ BACKBONE_ALIGN: True
+ CHANNEL_ALIGN: False
+ DEC_LAYERS: 3
+ DEC_N_POINTS: 4
+ DILATION: False
+ DIM_FEEDFORWARD: 1024
+ DROPOUT: 0.1
+ ENC_LAYERS: 3
+ ENC_N_POINTS: 4
+ FROZEN_WEIGHTS: None
+ HIDDEN_DIM: 256
+ INSTANCE_ALIGN: False
+ MASKS: False
+ NHEADS: 8
+ NUM_FEATURE_LEVELS: 4
+ NUM_QUERIES: 100
+ POSITION_EMBEDDING: sine
+ POSITION_EMBEDDING_SCALE: 6.283185307179586
+ SPACE_ALIGN: False
+ TWO_STAGE: False
+ WITH_BOX_REFINE: False
+NUM_WORKERS: 2
+OUTPUT_DIR: exps/r50_aood_c2f
+RESUME:
+LOAD_OPTIMIZER: False
+SEED: 42 # This cannot strictly control the same results. Sorry, I do not know why.
+START_EPOCH: 0
+EVAL_EPOCH: 29
+TRAIN:
+ BATCH_SIZE: 4 # each gpu, 2 for source and 2 for target
+ CLIP_MAX_NORM: 0.1
+ EPOCHS: 65
+ LR: 0.0002
+ LR_BACKBONE: 2e-05
+ LR_BACKBONE_NAMES: ['backbone.0']
+ LR_DROP: 40
+ LR_DROP_EPOCHS: None
+ LR_LINEAR_PROJ_MULT: 0.1
+ LR_LINEAR_PROJ_NAMES: ['reference_points', 'sampling_offsets']
+ SGD: False
+ WEIGHT_DECAY: 0.0001
\ No newline at end of file
diff --git a/configs/soma_aood_pascal_to_clipart_r50.yaml b/configs/soma_aood_pascal_to_clipart_r50.yaml
new file mode 100644
index 0000000..7599bce
--- /dev/null
+++ b/configs/soma_aood_pascal_to_clipart_r50.yaml
@@ -0,0 +1,93 @@
+EVAL_EPOCH: 1
+CACHE_MODE: False
+AOOD:
+ MOTIF_ON: True
+ # novel-class
+ OPEN_SET:
+ WARM_UP: 9
+ TH: 1.0
+ MOTIF_ON: True
+ ALPHA: 0.01
+ KNN: 5
+ UNK_PROB: 0.5
+ MOTIF_LOSS_COEF: 0.01
+ # novel-scene
+ CROSS_DOMAIN:
+ MOTIF_ON: True
+ BETA: 2.0 # std scaling
+DATASET:
+ DA_MODE: aood
+ AOOD_SETTING: 1
+ AOOD_SCENE: 'pascal'
+ COCO_PANOPTIC_PATH: None
+ COCO_PATH: /home/wuyangli2/data/
+ DATASET_FILE: pascal_to_clipart
+ NUM_CLASSES: 11 # 10 known + 1
+ REMOVE_DIFFICULT: False
+DEVICE: cuda
+DIST:
+ DISTRIBUTED: False
+ DIST_BACKEND: nccl
+ DIST_URL: env://
+ GPU: 0
+ RANK: 0
+ WORLD_SIZE: 4
+EVAL: False
+LOSS:
+ AUX_LOSS: True
+ BACKBONE_LOSS_COEF: 0.1
+ BBOX_LOSS_COEF: 5.0
+ CHANNEL_QUERY_LOSS_COEF: 0.1
+ CLS_LOSS_COEF: 2.0
+ DA_GAMMA: 0
+ DICE_LOSS_COEF: 1.0
+ FOCAL_ALPHA: 0.25
+ GIOU_LOSS_COEF: 2.0
+ INSTANCE_QUERY_LOSS_COEF: 0.1
+ MASK_LOSS_COEF: 1.0
+ SET_COST_BBOX: 5.0
+ SET_COST_CLASS: 2.0
+ SET_COST_GIOU: 2.0
+ SPACE_QUERY_LOSS_COEF: 0.1
+MODEL:
+ BACKBONE: resnet50
+ BACKBONE_ALIGN: True
+ CHANNEL_ALIGN: False
+ DEC_LAYERS: 3
+ DEC_N_POINTS: 4
+ DILATION: False
+ DIM_FEEDFORWARD: 1024
+ DROPOUT: 0.1
+ ENC_LAYERS: 3
+ ENC_N_POINTS: 4
+ FROZEN_WEIGHTS: None
+ HIDDEN_DIM: 256
+ INSTANCE_ALIGN: False
+ MASKS: False
+ NHEADS: 8
+ NUM_FEATURE_LEVELS: 4
+ NUM_QUERIES: 100
+ POSITION_EMBEDDING: sine
+ POSITION_EMBEDDING_SCALE: 6.283185307179586
+ SPACE_ALIGN: False
+ TWO_STAGE: False
+ WITH_BOX_REFINE: False
+NUM_WORKERS: 2
+OUTPUT_DIR: exps/r50_aood_p2c
+RESUME:
+LOAD_OPTIMIZER: False
+SEED: 42 # This cannot strictly control the same results. Sorry, I do not know why.
+START_EPOCH: 0
+TRAIN:
+ BATCH_SIZE: 4 # each gpu, 2 for source and 2 for target
+ CLIP_MAX_NORM: 0.1
+ EPOCHS: 30
+ LR: 0.0002
+ LR_BACKBONE: 2e-05
+ LR_BACKBONE_NAMES: ['backbone.0']
+ LR_DROP: 25
+ LR_DROP_EPOCHS: None
+ LR_LINEAR_PROJ_MULT: 0.1
+ LR_LINEAR_PROJ_NAMES: ['reference_points', 'sampling_offsets']
+ SGD: False
+ WEIGHT_DECAY: 0.0001
\ No newline at end of file
diff --git a/datasets/DAOD.py b/datasets/DAOD.py
new file mode 100644
index 0000000..337b1c4
--- /dev/null
+++ b/datasets/DAOD.py
@@ -0,0 +1,226 @@
+# ------------------------------------------------------------------------
+# Novel Scenes & Classes: Towards Adaptive Open-set Object Detection
+# Modified by Wuyang Li
+# ----------------------------------------------
+# Created by Wei-Jie Huang
+# ----------------------------------------------
+from pathlib import Path
+from torch.utils.data import Dataset
+from datasets.coco import CocoDetection, make_coco_transforms
+from datasets.aood import AOODDetection
+from util.misc import get_local_rank, get_local_size, nested_tensor_from_tensor_list
+
+def get_paths(root):
+ root = Path(root)
+ return {
+ 'cityscapes': {
+ 'train_img': root / 'Cityscapes/leftImg8bit/train',
+ 'val_img': root / 'Cityscapes/leftImg8bit/val',
+ 'train_anno': root / 'Cityscapes/cocoAnnotations/cityscapes_train_cocostyle.json',
+ 'val_img': root / 'Cityscapes/leftImg8bit/val',
+ 'val_anno': root / 'Cityscapes/cocoAnnotations/cityscapes_foggy_val_cocostyle.json',
+
+ 'train_xml': root / 'Cityscapes/AOOD_Annotations',
+ 'val_xml': root / 'Cityscapes/AOOD_Annotations',
+ 'train_data_list': root / 'Cityscapes/AOOD_Main/train_source.txt',
+ 'val_data_list': root / 'Cityscapes/AOOD_Main/val_source.txt',
+ },
+ 'cityscapes_caronly': {
+ 'train_img': root / 'Cityscapes/leftImg8bit/train',
+ 'train_anno': root / 'Cityscapes/annotations/cityscapes_caronly_train.json',
+ 'val_img': root / 'Cityscapes/leftImg8bit/val',
+ 'val_anno': root / 'Cityscapes/annotations/cityscapes_caronly_val.json',
+ },
+ 'foggy_cityscapes': {
+ 'train_img': root / 'Cityscapes/leftImg8bit_foggy/train',
+ 'train_anno': root / 'Cityscapes/cocoAnnotations/cityscapes_foggy_train_cocostyle.json',
+ 'val_img': root / 'Cityscapes/leftImg8bit_foggy/val',
+ 'val_anno': root / 'Cityscapes/cocoAnnotations/cityscapes_foggy_val_cocostyle.json',
+
+ 'train_xml': root / 'Cityscapes/AOOD_Annotations',
+ 'train_data_list': root / 'Cityscapes/AOOD_Main/train_target.txt',
+
+ 'val_xml': root / 'Cityscapes/AOOD_Annotations',
+ 'val_data_list': root / 'Cityscapes/AOOD_Main/val_target.txt',
+
+ },
+ 'sim10k': {
+ 'train_img': root / 'sim10k/VOC2012/JPEGImages',
+ 'train_anno': root / 'sim10k/annotations/sim10k_caronly.json',
+ },
+ 'bdd_daytime': {
+ 'train_img': root / 'bdd_daytime/JPEGImages',
+ 'val_img': root / 'bdd_daytime/JPEGImages',
+ 'train_xml': root / 'bdd_daytime/Annotations',
+ 'train_data_list': root / 'bdd_daytime/ImageSets/Main/train.txt',
+ 'val_xml': root / 'bdd_daytime/Annotations',
+ 'val_data_list': root / 'bdd_daytime/ImageSets/Main/val.txt',
+
+ },
+ 'pascal': {
+ 'train_img': root / 'VOCdevkit/VOC2012/JPEGImages',
+ 'train_xml': root / 'VOCdevkit/VOC2012/Annotations',
+ 'train_data_list': root / 'VOCdevkit/VOC2012/ImageSets/Main/trainval.txt',
+ 'val_img': root / 'VOCdevkit/VOC2012/JPEGImages',
+ 'val_xml': root / 'VOCdevkit/VOC2012/Annotations',
+ 'val_data_list': root / 'VOCdevkit/VOC2012/ImageSets/Main/trainval.txt',
+ },
+ 'clipart': {
+ 'train_img': root / 'clipart/JPEGImages',
+ 'train_xml': root / 'clipart/Annotations',
+ 'train_data_list': root / 'clipart/ImageSets/Main/all.txt',
+ 'val_img': root / 'clipart/JPEGImages',
+ 'val_xml': root / 'clipart/Annotations',
+ 'val_data_list': root / 'clipart/ImageSets/Main/all.txt',
+ },
+ }
+
+class AOODDataset(Dataset):
+ def __init__(self, source_img_folder, source_ann_folder, source_data_list, target_img_folder, target_ann_folder, target_data_list,
+ transforms, setting, scene):
+ self.source = AOODDetection(
+ img_folder=source_img_folder,
+ ann_folder=source_ann_folder,
+ data_list = source_data_list,
+ remove_unk = True,
+ transforms=transforms,
+ setting=setting,
+ scene = scene[0],
+ )
+
+ self.target = AOODDetection(
+ img_folder=target_img_folder,
+ ann_folder=target_ann_folder,
+ data_list=target_data_list,
+ transforms=transforms,
+ remove_unk=False,
+ setting=setting,
+ scene = scene[1],
+
+ )
+
+ def __len__(self):
+ return max(len(self.source), len(self.target))
+ # return min(len(self.source), len(self.target))
+
+ def __getitem__(self, idx):
+ source_img, source_target = self.source[idx % len(self.source)]
+ target_img, _ = self.target[idx % len(self.target)]
+ return source_img, target_img, source_target
+
+class DADataset(Dataset):
+ def __init__(self, source_img_folder, source_ann_file, target_img_folder, target_ann_file,
+ transforms, return_masks, cache_mode=False, local_rank=0, local_size=1):
+ self.source = CocoDetection(
+ img_folder=source_img_folder,
+ ann_file=source_ann_file,
+ transforms=transforms,
+ return_masks=return_masks,
+ cache_mode=cache_mode,
+ local_rank=local_rank,
+ local_size=local_size
+ )
+
+ self.target = CocoDetection(
+ img_folder=target_img_folder,
+ ann_file=target_ann_file,
+ transforms=transforms,
+ return_masks=return_masks,
+ cache_mode=cache_mode,
+ local_rank=local_rank,
+ local_size=local_size
+ )
+
+ def __len__(self):
+ return max(len(self.source), len(self.target))
+
+ def __getitem__(self, idx):
+ source_img, source_target = self.source[idx % len(self.source)]
+ target_img, _ = self.target[idx % len(self.target)]
+ return source_img, target_img, source_target
+
+
+def collate_fn(batch):
+ source_imgs, target_imgs, source_targets = list(zip(*batch))
+ samples = nested_tensor_from_tensor_list(source_imgs + target_imgs)
+ return samples, source_targets
+
+
+def build(image_set, cfg, multi_task_eval_id=4):
+ paths = get_paths(cfg.DATASET.COCO_PATH)
+ source_domain, target_domain = cfg.DATASET.DATASET_FILE.split('_to_')
+ if image_set == 'val':
+ if cfg.DATASET.DA_MODE == 'aood':
+ return AOODDetection(
+ img_folder=paths[target_domain]['val_img'],
+ ann_folder=paths[target_domain]['val_xml'],
+ data_list=paths[target_domain]['val_data_list'],
+ transforms=make_coco_transforms(image_set),
+ remove_unk=False,
+ setting= cfg.DATASET.AOOD_SETTING,
+ scene = target_domain,
+ multi_task_eval_id = multi_task_eval_id,
+ is_eval =True,
+
+ )
+ else:
+ return CocoDetection(
+ img_folder=paths[target_domain]['val_img'],
+ ann_file=paths[target_domain]['val_anno'],
+ transforms=make_coco_transforms(image_set),
+ return_masks=cfg.MODEL.MASKS,
+ cache_mode=cfg.CACHE_MODE,
+ local_rank=get_local_rank(),
+ local_size=get_local_size()
+ )
+ elif image_set == 'train':
+ if cfg.DATASET.DA_MODE == 'source_only':
+ return CocoDetection(
+ img_folder=paths[source_domain]['train_img'],
+ ann_file=paths[source_domain]['train_anno'],
+ transforms=make_coco_transforms(image_set),
+ return_masks=cfg.MODEL.MASKS,
+ cache_mode=cfg.CACHE_MODE,
+ local_rank=get_local_rank(),
+ local_size=get_local_size(),
+ )
+ elif cfg.DATASET.DA_MODE == 'oracle':
+ return CocoDetection(
+ img_folder=paths[target_domain]['train_img'],
+ ann_file=paths[target_domain]['train_anno'],
+ transforms=make_coco_transforms(image_set),
+ return_masks=cfg.MODEL.MASKS,
+ cache_mode=cfg.CACHE_MODE,
+ local_rank=get_local_rank(),
+ local_size=get_local_size()
+ )
+ elif cfg.DATASET.DA_MODE == 'uda':
+ return DADataset(
+ source_img_folder=paths[source_domain]['train_img'],
+ source_ann_file=paths[source_domain]['train_anno'],
+ target_img_folder=paths[target_domain]['train_img'],
+ target_ann_file=paths[target_domain]['train_anno'],
+ transforms=make_coco_transforms(image_set),
+ return_masks=cfg.MODEL.MASKS,
+ cache_mode=cfg.CACHE_MODE,
+ local_rank=get_local_rank(),
+ local_size=get_local_size()
+ )
+
+ elif cfg.DATASET.DA_MODE == 'aood':
+ return AOODDataset(
+ source_img_folder=paths[source_domain]['train_img'],
+ source_ann_folder=paths[source_domain]['train_xml'],
+ source_data_list=paths[source_domain]['train_data_list'],
+
+ target_img_folder=paths[target_domain]['train_img'],
+ target_ann_folder=paths[target_domain]['train_xml'],
+ target_data_list=paths[target_domain]['train_data_list'],
+
+ transforms=make_coco_transforms(image_set),
+ setting=cfg.DATASET.AOOD_SETTING,
+ scene = [source_domain, target_domain]
+ )
+ else:
+ raise ValueError(f'Unknown argument cfg.DATASET.DA_MODE {cfg.DATASET.DA_MODE}')
+ raise ValueError(f'unknown image set {image_set}')
diff --git a/datasets/__init__.py b/datasets/__init__.py
new file mode 100644
index 0000000..1b4787b
--- /dev/null
+++ b/datasets/__init__.py
@@ -0,0 +1,46 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+import torch.utils.data
+from .torchvision_datasets import CocoDetection
+
+from .coco import build as build_coco
+
+
+def get_coco_api_from_dataset(dataset):
+ for _ in range(10):
+ # if isinstance(dataset, torchvision.datasets.CocoDetection):
+ # break
+ if isinstance(dataset, torch.utils.data.Subset):
+ dataset = dataset.dataset
+ if isinstance(dataset, CocoDetection):
+ return dataset.coco
+ else:
+ return dataset
+
+
+def build_dataset(image_set, cfg,multi_task_eval_id=4):
+ if cfg.DATASET.DATASET_FILE == 'coco':
+ return build_coco(image_set, cfg)
+ if cfg.DATASET.DATASET_FILE == 'coco_panoptic':
+ # to avoid making panopticapi required for coco
+ from .coco_panoptic import build as build_coco_panoptic
+ return build_coco_panoptic(image_set, cfg)
+ DAOD_dataset = [
+ 'cityscapes_to_foggy_cityscapes',
+ 'sim10k_to_cityscapes_caronly',
+ 'cityscapes_to_bdd_daytime',
+ 'pascal_to_clipart',
+ ]
+ if cfg.DATASET.DATASET_FILE in DAOD_dataset:
+ from .DAOD import build
+ return build(image_set, cfg, multi_task_eval_id=multi_task_eval_id)
+ raise ValueError(f'dataset {cfg.DATASET.DATASET_FILE} not supported')
diff --git a/datasets/aood.py b/datasets/aood.py
new file mode 100644
index 0000000..5a8f707
--- /dev/null
+++ b/datasets/aood.py
@@ -0,0 +1,339 @@
+# ------------------------------------------------------------------------
+# Novel Scenes & Classes: Towards Adaptive Open-set Object Detection
+# Modified by Wuyang Li
+# ------------------------------------------------------------------------
+import functools
+import torch
+import os
+import tarfile
+import collections
+import logging
+import copy
+from torchvision.datasets import VisionDataset
+import itertools
+import util.misc as utils
+import xml.etree.ElementTree as ET
+from PIL import Image
+import datasets.transforms as T
+
+class AOODDetection(VisionDataset):
+
+ def get_aood_settings_cityscapes(self, setting=2):
+
+ NMES = ['person', 'car', 'train', 'rider', 'truck', 'motorcycle', 'bicycle', 'bus']
+ UNK = ["unknown"]
+
+ if setting == 1: # different semantics
+ BASE_CLASSES = ['car', 'truck', 'bus']
+ NOVEL_CLASSES = ['person','motorcycle','train', 'bicycle' , 'rider']
+ elif setting == 2: # similar semantics
+ BASE_CLASSES = ['person', 'bicycle', 'bus']
+ NOVEL_CLASSES = ['car', 'truck','train', 'motorcycle', 'rider' ]
+ elif setting == 3: # frequency down
+ BASE_CLASSES = ['person', 'car', 'rider']
+ NOVEL_CLASSES = [ 'bicycle', 'train', 'truck', 'motorcycle', 'bus']
+ elif setting == 4: # frequency top
+ BASE_CLASSES = [ 'motorcycle', 'truck', 'bus']
+ NOVEL_CLASSES = ['person', 'train', 'car','bicycle', 'rider']
+
+ ALL_CLASSES= tuple(itertools.chain(BASE_CLASSES, NOVEL_CLASSES))
+ CLASS_NAMES= tuple(itertools.chain(BASE_CLASSES, UNK))
+
+ return BASE_CLASSES, NOVEL_CLASSES, ALL_CLASSES, CLASS_NAMES
+
+
+ def get_aood_settings_pascal_voc(self, setting=1):
+ PASCAL_CLASSES = [
+ "aeroplane",
+ "bicycle",
+ "bird",
+ "boat",
+ "bottle",
+ "bus",
+ "car",
+ "cat",
+ "chair",
+ "cow",
+ "diningtable",
+ "dog",
+ "horse",
+ "motorbike",
+ "person",
+ "pottedplant",
+ "sheep",
+ "sofa",
+ "train",
+ "tvmonitor"]
+ UNK = ["unknown"]
+
+ # if setting == 1:
+ # BASE_CLASSES = ALL_CLASSES[:19]
+ # NOVEL_CLASSES = ALL_CLASSES[19:]
+ # elif setting == 2:
+ # BASE_CLASSES = ALL_CLASSES[:15]
+ # NOVEL_CLASSES = ALL_CLASSES[15:]
+ # elif setting == 3:
+ # BASE_CLASSES = ALL_CLASSES[:10]
+ # NOVEL_CLASSES = ALL_CLASSES[10:]
+ # elif setting == 4:
+ # BASE_CLASSES = ALL_CLASSES[:5]
+ # NOVEL_CLASSES = ALL_CLASSES[5:]
+
+ BASE_CLASSES = PASCAL_CLASSES[:10]
+ NOVEL_CLASSES = PASCAL_CLASSES[10:]
+
+ ALL_CLASSES= tuple(itertools.chain(PASCAL_CLASSES))
+ BASE_CLASSES= tuple(itertools.chain(BASE_CLASSES))
+ NOVEL_CLASSES= tuple(itertools.chain(NOVEL_CLASSES))
+
+ CLASS_NAMES = tuple(itertools.chain(BASE_CLASSES, UNK))
+ return BASE_CLASSES, NOVEL_CLASSES, ALL_CLASSES, CLASS_NAMES
+
+
+ def __init__(self,
+ # args,
+ img_folder,
+ ann_folder,
+ data_list,
+ transforms=None,
+ remove_unk=False,
+ setting=1,
+ scene='pascal',
+ multi_task_eval_id = 4,
+ is_eval=False
+ ):
+ super(AOODDetection, self).__init__(img_folder)
+
+ self.images = []
+ self.annotations = []
+ self.imgids = []
+ self.imgid2annotations = {}
+ self.image_set = []
+ self.transforms = transforms
+ self.remove_unk = remove_unk
+ self.is_eval = is_eval
+
+ self.id_to_task_nme = {
+ 1: 'het-sem',
+ 2: 'hom-sem',
+ 3: 'freq-dec',
+ 4: 'freq-inc'
+ }
+
+ self.scene = scene
+ if self.scene == 'cityscapes' or self.scene == 'bdd_daytime' or self.scene == 'foggy_cityscapes':
+ self.BASE_CLASSES, self.NOVEL_CLASSES, self.ALL_CLASSES, self.CLASS_NAMES = self.get_aood_settings_cityscapes(setting)
+ elif self.scene == 'pascal' or self.scene == 'clipart':
+ self.BASE_CLASSES, self.NOVEL_CLASSES, self.ALL_CLASSES, self.CLASS_NAMES = self.get_aood_settings_pascal_voc(setting)
+ else:
+ raise KeyError('undefined aood scenes')
+
+ self.num_classes = len(self.CLASS_NAMES) # K+1 for model training
+ self.num_base = len(self.BASE_CLASSES) # K
+ self.unk_id = self.num_classes - 1
+
+ self.NOVEL_CLASSES_PER_TASK = self.NOVEL_CLASSES[:multi_task_eval_id]
+
+ num_novel_per_task = len(self.NOVEL_CLASSES_PER_TASK )
+
+ all_classes_id = range(len(self.ALL_CLASSES))
+
+ self.bdd2city={
+ 'bike':'bicycle',
+ 'motor': 'motorcycle',
+ }
+
+ self.base_id = all_classes_id[:self.num_base] # 0-k
+ self.novel_id = all_classes_id[self.num_base:] # k- all
+ self.per_task_novel_id = all_classes_id[self.num_base:self.num_base + num_novel_per_task] # k- sel
+
+ with open(data_list, "r") as f:
+ file_names = [x.strip() for x in f.readlines()]
+
+ if self.remove_unk: # remove images without base-class objects
+ if utils.is_main_process():
+ print(''.join(80*['=']))
+ print('source domain training set:')
+ if self.scene != 'pascal':
+ print('AOOD Task: {}'.format(self.id_to_task_nme[setting]))
+ print("BASE_CLASSES: {}".format(self.BASE_CLASSES))
+ print("REALLOCATED CLASSES: {}".format(self.CLASS_NAMES))
+ file_names = self.filter_imgs_without_base_objects(ann_folder, file_names)
+ elif is_eval: # inference: remove images without base-class objects
+ if utils.is_main_process():
+ print(''.join(80*['=']))
+ print('target domain test set (task {}):'.format(multi_task_eval_id))
+ print("BASE_CLASSES: {}".format(self.BASE_CLASSES))
+ print("NOVEL_CLASSES: {}".format(self.NOVEL_CLASSES_PER_TASK))
+ file_names = self.filter_imgs_without_base_novel_objects(ann_folder, file_names)
+ else: # target domain training set: preserve all images
+ if utils.is_main_process():
+ print(''.join(80*['=']))
+ print('target domain training set:')
+ print('num images: {}'.format(len(file_names)))
+ print("BASE_CLASSES: {}".format(self.BASE_CLASSES))
+ print("NOVEL_CLASSES: {}".format(self.NOVEL_CLASSES))
+
+ self.image_set.extend(file_names)
+
+ suffix = ".png" if self.scene == 'cityscapes' or self.scene == 'foggy_cityscapes' else '.jpg'
+ self.images.extend([os.path.join(img_folder, x + suffix) for x in file_names])
+ self.annotations.extend([os.path.join(ann_folder, x + ".xml") for x in file_names])
+
+ self.imgids = list(range(len(file_names)))
+ self.imgids2img = dict(zip(self.imgids, file_names))
+ self.imgid2annotations.update(dict(zip(self.imgids, self.annotations)))
+
+ assert (len(self.images) == len(self.annotations) == len(self.imgids))
+
+ def filter_imgs_without_base_novel_objects(self, ann_folder, file_names):
+ new_file_names = []
+ for x in file_names:
+ anno = os.path.join(ann_folder, x + ".xml")
+ tree = ET.parse(anno)
+ target = self.parse_voc_xml(tree.getroot())
+ flag=True
+ for obj in target['annotation']['object']:
+ cls = obj["name"]
+ # if cls in self.bdd2city.keys():
+ if cls in self.bdd2city.keys() and self.scene == 'bdd_daytime':
+ cls = self.bdd2city[cls]
+ if cls not in self.BASE_CLASSES and cls not in self.NOVEL_CLASSES_PER_TASK:
+ flag=False
+ break
+ if flag:
+ new_file_names.append(x)
+
+ print('original images: {}, after removing images without base and novel objects: {}.'.format(len(file_names), len(new_file_names)))
+ return new_file_names
+
+ def filter_imgs_without_base_objects(self, ann_folder, file_names):
+ new_file_names = []
+ for x in file_names:
+ anno = os.path.join(ann_folder, x + ".xml")
+ tree = ET.parse(anno)
+ target = self.parse_voc_xml(tree.getroot())
+
+ for obj in target['annotation']['object']:
+ cls = obj["name"]
+ if cls in self.bdd2city.keys():
+ cls = self.bdd2city[cls]
+
+ if cls in self.BASE_CLASSES:
+ new_file_names.append(x)
+ break
+ print('original images: {}, after removing images without base objects: {}.'.format(len(file_names), len(new_file_names)))
+ return new_file_names
+
+ @functools.lru_cache(maxsize=None)
+ def load_instances(self, img_id):
+ tree = ET.parse(self.imgid2annotations[img_id])
+ target = self.parse_voc_xml(tree.getroot())
+ instances = []
+ for obj in target['annotation']['object']:
+ cls = obj["name"]
+
+ if cls in self.bdd2city.keys():
+ cls = self.bdd2city[cls]
+ bbox = obj["bndbox"]
+ bbox = [float(bbox[x]) for x in ["xmin", "ymin", "xmax", "ymax"]]
+ bbox[0] -= 1.0
+ bbox[1] -= 1.0
+ instance = dict(
+ category_id=self.ALL_CLASSES.index(cls),
+ bbox=bbox,
+ area=(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]),
+ image_id=img_id
+ )
+ instances.append(instance)
+ return target, instances
+
+ def remove_novel_instances(self, target):
+ # for the labelled training
+ entry = copy.copy(target)
+ for annotation in copy.copy(entry):
+ if annotation["category_id"] not in self.base_id:
+ entry.remove(annotation)
+ return entry
+
+ def label_all_novel_instances_as_unk(self, target):
+ # for the unlabelled training
+ entry = copy.copy(target)
+ for annotation in copy.copy(entry):
+ # for annotation in entry:
+ if annotation["category_id"] not in self.base_id:
+ annotation["category_id"] = self.unk_id
+
+ return entry
+
+ def label_per_task_novel_instances_as_unk(self, target):
+ # for the unlabelled training
+ entry = copy.copy(target)
+ for annotation in copy.copy(entry):
+ # for annotation in entry:
+ if annotation["category_id"] in self.base_id:
+ continue
+ elif annotation["category_id"] in self.per_task_novel_id:
+ annotation["category_id"] = self.unk_id
+ else:
+ entry.remove(annotation)
+ return entry
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is a dictionary of the XML tree.
+ """
+
+ img = Image.open(self.images[index]).convert('RGB')
+ target, instances = self.load_instances(self.imgids[index])
+
+ if self.remove_unk:
+ instances = self.remove_novel_instances(instances)
+ elif self.is_eval:
+ instances = self.label_per_task_novel_instances_as_unk(instances)
+ else:
+ instances = self.label_all_novel_instances_as_unk(instances)
+
+ w, h = map(target['annotation']['size'].get, ['width', 'height'])
+ target = dict(
+ image_id=torch.tensor([self.imgids[index]], dtype=torch.int64),
+ labels=torch.tensor([i['category_id'] for i in instances], dtype=torch.int64),
+ area=torch.tensor([i['area'] for i in instances], dtype=torch.float32),
+ boxes=torch.as_tensor([i['bbox'] for i in instances], dtype=torch.float32),
+ orig_size=torch.as_tensor([int(h), int(w)]),
+ size=torch.as_tensor([int(h), int(w)]),
+ iscrowd=torch.zeros(len(instances), dtype=torch.uint8)
+ )
+
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+ return img, target
+
+ def __len__(self):
+ return len(self.images)
+
+ def parse_voc_xml(self, node):
+ voc_dict = {}
+ children = list(node)
+ if children:
+ def_dic = collections.defaultdict(list)
+ for dc in map(self.parse_voc_xml, children):
+ for ind, v in dc.items():
+ def_dic[ind].append(v)
+ if node.tag == 'annotation':
+ def_dic['object'] = [def_dic['object']]
+ voc_dict = {
+ node.tag:
+ {ind: v[0] if len(v) == 1 else v
+ for ind, v in def_dic.items()}
+ }
+ if node.text:
+ text = node.text.strip()
+ if not children:
+ voc_dict[node.tag] = text
+ return voc_dict
+
diff --git a/datasets/aood_eval.py b/datasets/aood_eval.py
new file mode 100644
index 0000000..8417c3f
--- /dev/null
+++ b/datasets/aood_eval.py
@@ -0,0 +1,633 @@
+# ------------------------------------------------------------------------
+# Novel Scenes & Classes: Towards Adaptive Open-set Object Detection
+# Modified by Wuyang Li
+# ------------------------------------------------------------------------
+import os
+import shutil
+import datetime
+import functools
+import subprocess
+import xml.etree.ElementTree as ET
+import numpy as np
+import torch
+import logging
+from util.misc import all_gather
+from numpy import *
+from collections import OrderedDict, defaultdict
+
+class AOODEvaluator:
+ def __init__(self, voc_gt, iou_types, args=None, use_07_metric=True, ovthresh=list(range(50, 100, 5))):
+ assert tuple(iou_types) == ('bbox',)
+ self.use_07_metric = use_07_metric
+ self.ovthresh = ovthresh
+ self.voc_gt = voc_gt
+ self.eps = torch.finfo(torch.float64).eps
+ self.num_classes = len(self.voc_gt.CLASS_NAMES) # base1, base2 ,.. base n, unk
+ self._class_names = self.voc_gt.CLASS_NAMES
+ self.AP = torch.zeros(self.num_classes, 1)
+ self.all_recs = defaultdict(list)
+ self.all_precs = defaultdict(list)
+ self.recs = defaultdict(list)
+ self.precs = defaultdict(list)
+ self.num_unks = defaultdict(list)
+ self.unk_det_as_knowns = defaultdict(list)
+ self.tp_plus_fp_cs = defaultdict(list)
+ self.fp_os = defaultdict(list)
+ self.coco_eval = dict(bbox=lambda: None)
+ # self.coco_eval['bbox'].stats = torch.tensor([])
+ self.coco_eval['bbox'].stats = dict()
+ self.coco_eval['bbox'].eval = dict()
+
+ self.img_ids = []
+ self.lines = []
+ self.lines_cls = []
+
+ self.total_num_class = self.voc_gt.num_classes
+ self.unknown_class_index = self.voc_gt.unk_id
+ self.num_seen_classes = self.voc_gt.num_base
+ self.known_classes = self.voc_gt.BASE_CLASSES
+ self.unknown_classes = self.voc_gt.NOVEL_CLASSES_PER_TASK
+
+ def update(self, predictions):
+ for img_id, pred in predictions.items():
+ pred_boxes, pred_labels, pred_scores = [pred[k].cpu() for k in ['boxes', 'labels', 'scores']]
+ # print(img_id)
+ # image_id = self.voc_gt.convert_image_id(int(img_id), to_string=True)
+ img_name = self.voc_gt.imgids2img[int(img_id)]
+ self.img_ids.append(img_id)
+ classes = pred_labels.tolist()
+ for (xmin, ymin, xmax, ymax), cls, score in zip(pred_boxes.tolist(), classes , pred_scores.tolist()):
+ xmin += 1
+ ymin += 1
+ self.lines.append(f"{img_name} {score:.3f} {xmin:.1f} {ymin:.1f} {xmax:.1f} {ymax:.1f}")
+ self.lines_cls.append(cls)
+
+ def compute_avg_precision_at_many_recall_level_for_unk(self, precisions, recalls):
+ precs = {}
+ for r in range(1, 10):
+ r = r/10
+ p = self.compute_avg_precision_at_a_recall_level_for_unk(precisions, recalls, recall_level=r)
+ precs[r] = p
+ return precs
+
+ def compute_avg_precision_at_a_recall_level_for_unk(self, precisions, recalls, recall_level=0.5):
+ precs = {}
+ for iou, recall in recalls.items():
+ prec = []
+ for cls_id, rec in enumerate(recall):
+ if cls_id == self.unknown_class_index and len(rec)>0:
+ p = precisions[iou][cls_id][min(range(len(rec)), key=lambda i: abs(rec[i] - recall_level))]
+ prec.append(p)
+ if len(prec) > 0:
+ precs[iou] = np.mean(prec)
+ else:
+ precs[iou] = 0
+ return precs
+
+ def compute_WI_at_many_recall_level(self, recalls, tp_plus_fp_cs, fp_os):
+ wi_at_recall = {}
+ for r in range(1, 10):
+ r = r/10
+ wi = self.compute_WI_at_a_recall_level(recalls, tp_plus_fp_cs, fp_os, recall_level=r)
+ wi_at_recall[r] = wi
+ return wi_at_recall
+
+ def compute_WI_at_a_recall_level(self, recalls, tp_plus_fp_cs, fp_os, recall_level=0.5):
+ wi_at_iou = {}
+ for iou, recall in recalls.items():
+ tp_plus_fps = []
+ fps = []
+ for cls_id, rec in enumerate(recall):
+ if cls_id in range(self.num_seen_classes) and len(rec) > 0:
+ index = min(range(len(rec)), key=lambda i: abs(rec[i] - recall_level))
+ tp_plus_fp = tp_plus_fp_cs[iou][cls_id][index]
+ tp_plus_fps.append(tp_plus_fp)
+ fp = fp_os[iou][cls_id][index]
+ fps.append(fp)
+ if len(tp_plus_fps) > 0:
+ wi_at_iou[iou] = np.mean(fps) / np.mean(tp_plus_fps)
+ else:
+ wi_at_iou[iou] = 0
+ return wi_at_iou
+
+ def synchronize_between_processes(self):
+ self.img_ids = torch.tensor(self.img_ids, dtype=torch.int64)
+ self.lines_cls = torch.tensor(self.lines_cls, dtype=torch.int64)
+ self.img_ids, self.lines, self.lines_cls = self.merge(self.img_ids, self.lines, self.lines_cls)
+
+ def merge(self, img_ids, lines, lines_cls):
+ flatten = lambda ls: [s for l in ls for s in l]
+ all_img_ids = torch.cat(all_gather(img_ids))
+ all_lines_cls = torch.cat(all_gather(lines_cls))
+ all_lines = flatten(all_gather(lines))
+ return all_img_ids, all_lines, all_lines_cls
+
+ def accumulate(self):
+ for class_label_ind, class_label in enumerate(self.voc_gt.CLASS_NAMES):
+
+ # lines_by_class = [l + '\n' for l, c in zip(self.lines, self.lines_cls.tolist()) if c == class_label_ind+1]
+ lines_by_class = [l + '\n' for l, c in zip(self.lines, self.lines_cls.tolist()) if c == class_label_ind] # coco start from 1
+ if len(lines_by_class) == 0:
+ lines_by_class = []
+ print(class_label + " has " + str(len(lines_by_class)) + " predictions.")
+ ovthresh = 50
+ ovthresh_ind, _ = map(self.ovthresh.index, [50, 75])
+ self.rec, self.prec, self.AP[class_label_ind, ovthresh_ind], self.unk_det_as_known, \
+ self.num_unk, self.tp_plus_fp_closed_set, self.fp_open_set = voc_eval(
+ lines_by_class,
+ self.voc_gt.annotations,
+ self.voc_gt.image_set,
+ class_label,
+ ovthresh=ovthresh / 100.0,
+ use_07_metric=self.use_07_metric,
+ known_classes=self.known_classes,
+ unknown_classes=self.unknown_classes) #[-1]
+
+ self.AP[class_label_ind, ovthresh_ind] = self.AP[class_label_ind, ovthresh_ind] * 100
+ self.all_recs[ovthresh].append(self.rec)
+ self.all_precs[ovthresh].append(self.prec)
+ self.num_unks[ovthresh].append(self.num_unk)
+ self.unk_det_as_knowns[ovthresh].append(self.unk_det_as_known)
+ self.tp_plus_fp_cs[ovthresh].append(self.tp_plus_fp_closed_set)
+ self.fp_os[ovthresh].append(self.fp_open_set)
+ try:
+ self.recs[ovthresh].append(self.rec[-1] * 100)
+ self.precs[ovthresh].append(self.prec[-1] * 100)
+ except:
+ self.recs[ovthresh].append(0.)
+ self.precs[ovthresh].append(0.)
+
+ def summarize(self, fmt='{:.06f}'):
+ o50, _ = map(self.ovthresh.index, [50, 75])
+ mAP = float(self.AP.mean())
+ mAP50 = float(self.AP[:, o50].mean())
+ # print('detection mAP50:', fmt.format(mAP50))
+ # print('detection mAP:', fmt.format(mAP))
+ # print('---AP50---')
+ wi = self.compute_WI_at_many_recall_level(self.all_recs, self.tp_plus_fp_cs, self.fp_os)
+ # print('Wilderness Impact: ' + str(wi))
+ avg_precision_unk = self.compute_avg_precision_at_many_recall_level_for_unk(self.all_precs, self.all_recs)
+ # print('avg_precision: ' + str(avg_precision_unk))
+ total_num_unk_det_as_known = {iou: np.sum(x) for iou, x in self.unk_det_as_knowns.items()} #torch.sum(self.unk_det_as_knowns[:, o50]) #[np.sum(x) for x in self.unk_det_as_knowns[:, o50]]
+ total_num_unk = self.num_unks[50][0]
+ # print('Absolute OSE (total_num_unk_det_as_known): ' + str(total_num_unk_det_as_known))
+ # print('total_num_unk ' + str(total_num_unk))
+ # print("AP50: " + str(['%.1f' % x for x in self.AP[:, o50]]))
+ # print("Precisions50: " + str(['%.1f' % x for x in self.precs[50]]))
+ # print("Recall50: " + str(['%.1f' % x for x in self.recs[50]]))
+ # print("Unknown AP50: " + str(self.AP[:, o50][-1]))
+ # print("Unknown Precisions50: " + str(self.precs[50][-1]))
+ # print("Unknown Recall50: " + str(self.recs[50][-1]))
+
+ # import ipdb; ipdb.set_trace()
+
+
+ AP50 = [round(x.item(),2) for x in self.AP[:, o50]]
+
+ base_AP50 = AP50[:-1]
+ # base_mAP = mean(base_AP50)
+ base_mAP = round(mean(base_AP50), 2)
+ novel_AP50 = round(AP50[-1], 5)
+
+ novel_recall = round(self.recs[50][-1], 2)
+ wi_08 = round(wi[0.8][50] *100, 3)
+
+ AOSE = total_num_unk_det_as_known[50]
+
+
+ class_name = list(self.voc_gt.CLASS_NAMES)
+ base_name = class_name[:-1]
+
+ title = base_name + ['mAP_base','WI_08', 'AOSE', 'Recall_novel', 'nodel_AP50' ]
+ ap_map_wi_aose_ar = base_AP50 + [base_mAP]+[wi_08]+[AOSE]+[novel_recall]+[novel_AP50]
+ report_results = [str(base_mAP)]+[str(novel_recall)] + [str(wi_08)]+[str(round(AOSE))]
+
+ ap_map_wi_aose_ar = [str(i) for i in ap_map_wi_aose_ar]
+ ap_map_wi_aose_ar=' & '.join(ap_map_wi_aose_ar)
+
+ title = ' & '.join(title)
+ # "Wilderness Impact: ": str(wi),
+ # "avg_unk_precision: ": str(avg_precision_unk),
+ # for class_name, ap in zip(self.voc_gt.CLASS_NAMES, self.AP[:, o50].cpu().tolist()):
+ # print(class_name, fmt.format(ap))
+ # for class_name, ap in zip(self.voc_gt.CLASS_NAMES, self.AP[:, o50].cpu().tolist()):
+ # print(class_name, fmt.format(ap))
+ self.coco_eval['bbox'].stats = {
+ "detection_mAP50: ": float(self.AP[:, o50].mean()),
+ "Absolute OSE (total_num_unk_det_as_known): ": str(total_num_unk_det_as_known),
+ "total_num_unk: ": str(total_num_unk),
+ "base_mAP: ": base_mAP,
+ "AP50: ": str(['%.1f' % x for x in self.AP[:, o50]]),
+ "Recall50: ": str(['%.1f' % x for x in self.recs[50]]),
+ "Unknown AP50: ": str(self.AP[:, o50][-1]),
+ "Unknown Recall50: ": str(self.recs[50][-1]),
+ "ap_map_wi_aose_ar": ap_map_wi_aose_ar,
+ "report_results": report_results,
+ "title": title,
+
+ }
+ print(self.coco_eval['bbox'].stats )
+
+def voc_ap(rec, prec, use_07_metric=False):
+ """ ap = voc_ap(rec, prec, [use_07_metric])
+ Compute VOC AP given precision and recall.
+ If use_07_metric is true, uses the
+ VOC 07 11 point method (default:False).
+ """
+ if use_07_metric:
+ # 11 point metric
+ ap = 0.
+ for t in np.arange(0., 1.1, 0.1):
+ if np.sum(rec >= t) == 0:
+ p = 0
+ else:
+ p = np.max(prec[rec >= t])
+ ap = ap + p / 11.
+ else:
+ # correct AP calculation
+ # first append sentinel values at the end
+ mrec = np.concatenate(([0.], rec, [1.]))
+ mpre = np.concatenate(([0.], prec, [0.]))
+
+ # compute the precision envelope
+ for i in range(mpre.size - 1, 0, -1):
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
+
+ # to calculate area under PR curve, look for points
+ # where X axis (recall) changes value
+ i = np.where(mrec[1:] != mrec[:-1])[0]
+
+ # and sum (\Delta recall) * prec
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+ return ap
+
+
+@functools.lru_cache(maxsize=None)
+def parse_rec(filename, known_classes, unknown_classes):
+ """ Parse a PASCAL VOC xml file """
+ VOC_CLASS_NAMES_COCOFIED = [
+ "bike", "motor"
+ ]
+ BASE_VOC_CLASS_NAMES = [
+ "bicycle", "motorcycle"
+ ]
+
+ # VOC_CLASS_NAMES_COCOFIED = [
+ # "airplane", "dining table", "motorcycle",
+ # "potted plant", "couch", "tv"
+ # ]
+ # BASE_VOC_CLASS_NAMES = [
+ # "aeroplane", "diningtable", "motorbike",
+ # "pottedplant", "sofa", "tvmonitor"
+ # ]
+
+ tree = ET.parse(filename)
+
+ objects = []
+ for obj in tree.findall('object'):
+ obj_struct = {}
+ cls_name = obj.find('name').text
+
+ if cls_name in VOC_CLASS_NAMES_COCOFIED:
+ cls_name = BASE_VOC_CLASS_NAMES[VOC_CLASS_NAMES_COCOFIED.index(cls_name)]
+
+ # if cls_name not in known_classes:
+ if cls_name in known_classes:
+ cls_name = cls_name
+ elif cls_name in unknown_classes:
+ cls_name = 'unknown'
+ else:
+ continue
+ obj_struct['name'] = cls_name
+ obj_struct['difficult'] = int(obj.find('difficult').text)
+ bbox = obj.find('bndbox')
+ obj_struct['bbox'] = [int(bbox.find('xmin').text),
+ int(bbox.find('ymin').text),
+ int(bbox.find('xmax').text),
+ int(bbox.find('ymax').text)]
+ objects.append(obj_struct)
+
+ return objects
+
+
+def voc_eval(detpath,
+ annopath,
+ imagesetfile,
+ classname,
+ ovthresh=0.5,
+ use_07_metric=False,
+ known_classes=None,
+ unknown_classes= None):
+ # --------------------------------------------------------
+ # Fast/er R-CNN
+ # Licensed under The MIT License [see LICENSE for details]
+ # Written by Bharath Hariharan
+ # --------------------------------------------------------
+
+ """rec, prec, ap = voc_eval(detpath,
+ annopath,
+ imagesetfile,
+ classname,
+ [ovthresh],
+ [use_07_metric])
+ Top level function that does the PASCAL VOC evaluation.
+ detpath: Path to detections
+ detpath.format(classname) should produce the detection results file.
+ annopath: Path to annotations
+ annopath.format(imagename) should be the xml annotations file.
+ imagesetfile: Text file containing the list of images, one image per line.
+ classname: Category name (duh)
+ cachedir: Directory for caching the annotations
+ [ovthresh]: Overlap threshold (default = 0.5)
+ [use_07_metric]: Whether to use VOC07's 11 point AP computation
+ (default False)
+ """
+
+
+ def iou(BBGT, bb):
+ ixmin = np.maximum(BBGT[:, 0], bb[0])
+ iymin = np.maximum(BBGT[:, 1], bb[1])
+ ixmax = np.minimum(BBGT[:, 2], bb[2])
+ iymax = np.minimum(BBGT[:, 3], bb[3])
+ iw = np.maximum(ixmax - ixmin + 1., 0.)
+ ih = np.maximum(iymax - iymin + 1., 0.)
+ inters = iw * ih
+
+ # union
+ uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
+ (BBGT[:, 2] - BBGT[:, 0] + 1.) *
+ (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
+
+ overlaps = inters / uni
+ ovmax = np.max(overlaps)
+ jmax = np.argmax(overlaps)
+ return ovmax, jmax
+
+ # assumes detections are in detpath.format(classname)
+ # assumes annotations are in annopath.format(imagename)
+ # assumes imagesetfile is a text file with each line an image name
+ # cachedir caches the annotations in a pickle file
+
+ # read list of images
+ if isinstance(imagesetfile, list):
+ lines = imagesetfile
+ else:
+ with open(imagesetfile, 'r') as f:
+ lines = f.readlines()
+ # imagenames = [x.strip().split('/')[-1] for x in lines]
+ # imagenames = [x.strip() for x in lines]
+ imagenames = [x.split('/')[-1] for x in lines]
+
+ # load annots
+ recs = {}
+ if isinstance(annopath, list):
+ for a in annopath:
+ imagename = os.path.splitext(os.path.basename(a))[0]
+ recs[imagename] = parse_rec(a, tuple(known_classes), tuple(unknown_classes))
+ else:
+ for i, imagename in enumerate(imagenames):
+ recs[imagename] = parse_rec(annopath.format(imagename), tuple(known_classes), tuple(unknown_classes))
+
+ # extract gt objects for this class
+ class_recs = {}
+ npos = 0
+
+ for imagename in imagenames:
+ R = [obj for obj in recs[imagename] if obj['name'] == classname]
+ bbox = np.array([x['bbox'] for x in R])
+ difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
+ det = [False] * len(R)
+ npos = npos + sum(~difficult)
+ class_recs[imagename] = {'bbox': bbox,
+ 'difficult': difficult,
+ 'det': det}
+
+ # read dets
+ if isinstance(detpath, list):
+ lines = detpath
+ else:
+ detfile = detpath.format(classname)
+ with open(detfile, 'r') as f:
+ lines = f.readlines()
+
+
+ splitlines = [x.strip().split(' ') for x in lines]
+ image_ids = [x[0] for x in splitlines]
+ confidence = np.array([float(x[1]) for x in splitlines])
+ if len(splitlines) == 0:
+ BB = np.array([[float(z) for z in x[2:]] for x in splitlines]).reshape(-1, 4)
+ else:
+ BB = np.array([[float(z) for z in x[2:]] for x in splitlines])#.reshape(-1, 4)
+
+ # if BB.size == 0:
+ # return 0, 0, 0
+
+ # sort by confidence
+ sorted_ind = np.argsort(-confidence)
+ BB = BB[sorted_ind, :]
+
+ image_ids = [image_ids[x].split('/')[-1] for x in sorted_ind]
+
+ # go down dets and mark TPs and FPs
+ nd = len(image_ids)
+ tp = np.zeros(nd)
+ fp = np.zeros(nd)
+ for d in range(nd):
+ R = class_recs[image_ids[d]]
+ bb = BB[d, :].astype(float)
+ ovmax = -np.inf
+ BBGT = R['bbox'].astype(float)
+
+ if BBGT.size > 0:
+ ovmax, jmax = iou(BBGT, bb)
+
+ if ovmax > ovthresh:
+ if not R['difficult'][jmax]:
+ if not R['det'][jmax]:
+ tp[d] = 1.
+ R['det'][jmax] = 1
+ else:
+ fp[d] = 1.
+ else:
+ fp[d] = 1.
+
+ # compute precision recall
+ fp = np.cumsum(fp)
+ tp = np.cumsum(tp)
+ rec = tp / float(npos)
+ # avoid divide by zero in case the first detection matches a difficult
+ # ground truth
+ prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
+ ap = voc_ap(rec, prec, use_07_metric)
+
+ '''
+ Computing Absolute Open-Set Error (A-OSE) and Wilderness Impact (WI)
+ ===========
+ Absolute OSE = # of unknown objects classified as known objects of class 'classname'
+ WI = FP_openset / (TP_closed_set + FP_closed_set)
+ '''
+ # logger = logging.getLogger(__name__)
+
+ # Finding GT of unknown objects
+ unknown_class_recs = {}
+ n_unk = 0
+ for imagename in imagenames:
+ R = [obj for obj in recs[imagename] if obj["name"] == 'unknown']
+ bbox = np.array([x["bbox"] for x in R])
+ difficult = np.array([x["difficult"] for x in R]).astype(np.bool)
+ det = [False] * len(R)
+ n_unk = n_unk + sum(~difficult)
+ unknown_class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det}
+
+ if classname == 'unknown':
+ return rec, prec, ap, 0., n_unk, None, None
+
+ # Go down each detection and see if it has an overlap with an unknown object.
+ # If so, it is an unknown object that was classified as known.
+ is_unk = np.zeros(nd)
+ for d in range(nd):
+ R = unknown_class_recs[image_ids[d]]
+ bb = BB[d, :].astype(float)
+ ovmax = -np.inf
+ BBGT = R["bbox"].astype(float)
+
+ if BBGT.size > 0:
+ # compute overlaps
+ # intersection
+ ixmin = np.maximum(BBGT[:, 0], bb[0])
+ iymin = np.maximum(BBGT[:, 1], bb[1])
+ ixmax = np.minimum(BBGT[:, 2], bb[2])
+ iymax = np.minimum(BBGT[:, 3], bb[3])
+ iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
+ ih = np.maximum(iymax - iymin + 1.0, 0.0)
+ inters = iw * ih
+
+ # union
+ uni = (
+ (bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0)
+ + (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
+ - inters
+ )
+
+ overlaps = inters / uni
+ ovmax = np.max(overlaps)
+ jmax = np.argmax(overlaps)
+
+ if ovmax > ovthresh:
+ is_unk[d] = 1.0
+
+ is_unk_sum = np.sum(is_unk)
+ tp_plus_fp_closed_set = tp+fp
+ fp_open_set = np.cumsum(is_unk)
+
+
+ return rec, prec, ap, is_unk_sum, n_unk, tp_plus_fp_closed_set, fp_open_set
+
+
+def bbox_nms(boxes, scores, overlap_threshold=0.4, score_threshold=0.0, mask=False):
+ def overlap(box1, box2=None, rectint=False, eps=1e-6):
+ area = lambda boxes=None, x1=None, y1=None, x2=None, y2=None: (boxes[..., 2] - boxes[..., 0]) * (
+ boxes[..., 3] - boxes[..., 1]) if boxes is not None else (x2 - x1).clamp(min=0) * (y2 - y1).clamp(
+ min=0)
+
+ if box2 is None and not isinstance(box1, list) and box1.dim() == 3:
+ return torch.stack(list(map(overlap, box1)))
+ b1, b2 = [(b if b.dim() == 2 else b.unsqueeze(0)).t().contiguous() for b in
+ [box1, (box2 if box2 is not None else box1)]]
+
+ xx1 = torch.max(b1[0].unsqueeze(1), b2[0].unsqueeze(0))
+ yy1 = torch.max(b1[1].unsqueeze(1), b2[1].unsqueeze(0))
+ xx2 = torch.min(b1[2].unsqueeze(1), b2[2].unsqueeze(0))
+ yy2 = torch.min(b1[3].unsqueeze(1), b2[3].unsqueeze(0))
+
+ inter = area(x1=xx1, y1=yy1, x2=xx2, y2=yy2)
+ return inter / (area(b1.t()).unsqueeze(1) + area(b2.t()).unsqueeze(0) - inter + eps) if not rectint else inter
+
+ O = overlap(boxes)
+ I = scores.sort(0)[1]
+ M = scores.gather(0, I).ge(score_threshold)
+ M = M if M.any() else M.fill_(1)
+ pick = []
+
+ for i, m in zip(I.t(), M.t()):
+ p = []
+ i = i[m]
+ while len(i) > 1:
+ p.append(i[-1])
+ m = O[:, i[-1]][i].lt(overlap_threshold)
+ m[-1] = 0
+ i = i[m]
+ pick.append(torch.tensor(p + i.tolist(), dtype=torch.int64))
+
+ return pick if not mask else torch.stack(
+ [torch.zeros(len(scores), dtype=torch.bool).scatter_(0, p, 1) for p in pick])
+
+
+def package_submission(out_dir, image_file_name, class_labels, VOCYEAR, SUBSET, TASK, tar=True, **kwargs):
+ def cls(file_path, class_label_ind, scores):
+ with open(file_path, 'w') as f:
+ f.writelines(map('{} {}\n'.format, image_file_name, scores[:, class_label_ind].tolist()))
+
+ def det(file_path, class_label_ind, scores, proposals, keep):
+ zipped = []
+ for example_idx, basename in enumerate(image_file_name):
+ I = keep[example_idx][class_label_ind]
+ zipped.extend((basename, s) + tuple(p) for s, p in zip(scores[example_idx][I, class_label_ind].tolist(),
+ proposals[example_idx][I, :4].add(1).tolist()))
+ with open(file_path, 'w') as f:
+ f.writelines(map('{} {} {:.0f} {:.0f} {:.0f} {:.0f} \n'.format, *zip(*zipped)))
+
+ task_a, task_b = TASK.split('_')
+ resdir = os.path.join(out_dir, 'results')
+ respath = os.path.join(resdir, VOCYEAR, 'Main', '%s_{}_{}_%s.txt'.format(task_b, SUBSET))
+
+ if os.path.exists(resdir):
+ shutil.rmtree(resdir)
+ os.makedirs(os.path.join(resdir, VOCYEAR, 'Main'))
+
+ for class_label_ind, class_label in enumerate(class_labels):
+ dict(det=det, cls=cls)[task_b](respath.replace('%s', '{}').format(task_a, class_label), class_label_ind,
+ **kwargs)
+
+ if tar:
+ subprocess.check_call(['tar', '-czf', 'results-{}-{}-{}.tar.gz'.format(VOCYEAR, TASK, SUBSET), 'results'],
+ cwd=out_dir)
+
+ return respath
+
+
+def detection_mean_ap(out_dir, image_file_name, class_labels, VOCYEAR, SUBSET, VOC_DEVKIT_VOCYEAR, scores=None,
+ boxes=None, nms_score_threshold=1e-4, nms_overlap_threshold=0.4, tar=False, octave=False,
+ cmd='octave --eval', env=None, stdout_stderr=open(os.devnull, 'wb'), do_nms=True):
+ if scores is not None:
+ nms = list(map(lambda s, p: bbox_nms(p, s, overlap_threshold=nms_overlap_threshold,
+ score_threshold=nms_score_threshold), scores, boxes)) if do_nms else [
+ torch.arange(len(p)) for p in boxes]
+
+ else:
+ nms = torch.arange(len(class_labels)).unsqueeze(0).unsqueeze(-1).expand(len(image_file_name), len(class_labels),
+ 1)
+ scores = torch.zeros(len(image_file_name), len(class_labels), len(class_labels))
+
+ imgsetpath = os.path.join(VOC_DEVKIT_VOCYEAR, 'ImageSets', 'Main', SUBSET + '.txt')
+ detrespath = package_submission(out_dir, image_file_name, class_labels, VOCYEAR, SUBSET, 'comp4_det', tar=tar,
+ scores=scores, proposals=boxes, nms=nms)
+
+ if octave:
+ imgsetpath_fix = os.path.join(out_dir, detection_mean_ap.__name__ + '.txt')
+ with open(imgsetpath_fix, 'w') as f:
+ f.writelines([line[:-1] + ' -1\n' for line in open(imgsetpath)])
+ procs = [subprocess.Popen(cmd.split() + [
+ "oldpwd = pwd; cd('{}/..'); addpath(fullfile(pwd, 'VOCcode')); VOCinit; cd(oldpwd); VOCopts.testset = '{}'; VOCopts.detrespath = '{}'; VOCopts.imgsetpath = '{}'; classlabel = '{}'; warning('off', 'Octave:possible-matlab-short-circuit-operator'); warning('off', 'Octave:num-to-str'); [rec, prec, ap] = VOCevaldet(VOCopts, 'comp4', classlabel, false); dlmwrite(sprintf(VOCopts.detrespath, 'resu4', classlabel), ap); quit;".format(
+ VOC_DEVKIT_VOCYEAR, SUBSET, detrespath, imgsetpath_fix, class_label)], stdout=stdout_stderr,
+ stderr=stdout_stderr, env=env) for class_label in class_labels]
+ res = list(map(lambda class_label, proc: proc.wait() or float(open(detrespath % ('resu4', class_label)).read()),
+ class_labels, procs))
+
+ else:
+ res = [voc_eval(detrespath.replace('%s', '{}').format('comp4', '{}'),
+ os.path.join(VOC_DEVKIT_VOCYEAR, 'Annotations', '{}.xml'), imgsetpath, class_label,
+ cachedir=os.path.join(out_dir, 'cache_detection_mean_ap_' + SUBSET), use_07_metric=True)[-1] for
+ class_label in class_labels]
+
+ return torch.tensor(res).mean(), res
\ No newline at end of file
diff --git a/datasets/coco.py b/datasets/coco.py
new file mode 100644
index 0000000..2415895
--- /dev/null
+++ b/datasets/coco.py
@@ -0,0 +1,163 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+from pathlib import Path
+import torch
+import torch.utils.data
+from pycocotools import mask as coco_mask
+from .torchvision_datasets import CocoDetection as TvCocoDetection
+from util.misc import get_local_rank, get_local_size
+import datasets.transforms as T
+
+
+class CocoDetection(TvCocoDetection):
+ def __init__(self, img_folder, ann_file, transforms, return_masks, cache_mode=False, local_rank=0, local_size=1):
+ super(CocoDetection, self).__init__(img_folder, ann_file,
+ cache_mode=cache_mode, local_rank=local_rank, local_size=local_size)
+ self._transforms = transforms
+ self.prepare = ConvertCocoPolysToMask(return_masks)
+
+ def __getitem__(self, idx):
+ img, target = super(CocoDetection, self).__getitem__(idx)
+ image_id = self.ids[idx]
+ target = {'image_id': image_id, 'annotations': target}
+ img, target = self.prepare(img, target)
+ if self._transforms is not None:
+ img, target = self._transforms(img, target)
+ return img, target
+
+
+def convert_coco_poly_to_mask(segmentations, height, width):
+ masks = []
+ for polygons in segmentations:
+ rles = coco_mask.frPyObjects(polygons, height, width)
+ mask = coco_mask.decode(rles)
+ if len(mask.shape) < 3:
+ mask = mask[..., None]
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
+ mask = mask.any(dim=2)
+ masks.append(mask)
+ if masks:
+ masks = torch.stack(masks, dim=0)
+ else:
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
+ return masks
+
+
+class ConvertCocoPolysToMask(object):
+ def __init__(self, return_masks=False):
+ self.return_masks = return_masks
+
+ def __call__(self, image, target):
+ w, h = image.size
+
+ image_id = target["image_id"]
+ image_id = torch.tensor([image_id])
+
+ anno = target["annotations"]
+
+ anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
+
+ boxes = [obj["bbox"] for obj in anno]
+ # guard against no boxes via resizing
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
+ boxes[:, 2:] += boxes[:, :2]
+ boxes[:, 0::2].clamp_(min=0, max=w)
+ boxes[:, 1::2].clamp_(min=0, max=h)
+
+ classes = [obj["category_id"] for obj in anno]
+ classes = torch.tensor(classes, dtype=torch.int64)
+
+ if self.return_masks:
+ segmentations = [obj["segmentation"] for obj in anno]
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
+
+ keypoints = None
+ if anno and "keypoints" in anno[0]:
+ keypoints = [obj["keypoints"] for obj in anno]
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
+ num_keypoints = keypoints.shape[0]
+ if num_keypoints:
+ keypoints = keypoints.view(num_keypoints, -1, 3)
+
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+ boxes = boxes[keep]
+ classes = classes[keep]
+ if self.return_masks:
+ masks = masks[keep]
+ if keypoints is not None:
+ keypoints = keypoints[keep]
+
+ target = {}
+ target["boxes"] = boxes
+ target["labels"] = classes
+ if self.return_masks:
+ target["masks"] = masks
+ target["image_id"] = image_id
+ if keypoints is not None:
+ target["keypoints"] = keypoints
+
+ # for conversion to coco api
+ area = torch.tensor([obj["area"] for obj in anno])
+ iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
+ target["area"] = area[keep]
+ target["iscrowd"] = iscrowd[keep]
+
+ target["orig_size"] = torch.as_tensor([int(h), int(w)])
+ target["size"] = torch.as_tensor([int(h), int(w)])
+
+ return image, target
+
+
+def make_coco_transforms(image_set):
+
+ normalize = T.Compose([
+ T.ToTensor(),
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+ ])
+
+ scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
+
+ if image_set == 'train':
+ return T.Compose([
+ T.RandomHorizontalFlip(),
+ T.RandomSelect(
+ T.RandomResize(scales, max_size=1333),
+ T.Compose([
+ T.RandomResize([400, 500, 600]),
+ T.RandomSizeCrop(384, 600),
+ T.RandomResize(scales, max_size=1333),
+ ])
+ ),
+ normalize,
+ ])
+
+ if image_set == 'val':
+ return T.Compose([
+ T.RandomResize([800], max_size=1333),
+ normalize,
+ ])
+
+ raise ValueError(f'unknown {image_set}')
+
+
+def build(image_set, cfg):
+ root = Path(cfg.DATASET.COCO_PATH)
+ assert root.exists(), f'provided COCO path {root} does not exist'
+ mode = 'instances'
+ PATHS = {
+ "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
+ "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
+ }
+
+ img_folder, ann_file = PATHS[image_set]
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=cfg.MODEL.MASKS,
+ cache_mode=cfg.CACHE_MODE, local_rank=get_local_rank(), local_size=get_local_size())
+ return dataset
diff --git a/datasets/coco_eval.py b/datasets/coco_eval.py
new file mode 100644
index 0000000..011ee53
--- /dev/null
+++ b/datasets/coco_eval.py
@@ -0,0 +1,267 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+COCO evaluator that works in distributed mode.
+
+Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
+The difference is that there is less copy-pasting from pycocotools
+in the end of the file, as python3 can suppress prints with contextlib
+"""
+import os
+import contextlib
+import copy
+import numpy as np
+import torch
+
+from pycocotools.cocoeval import COCOeval
+from pycocotools.coco import COCO
+import pycocotools.mask as mask_util
+
+from util.misc import all_gather
+
+
+class CocoEvaluator(object):
+ def __init__(self, coco_gt, iou_types):
+ assert isinstance(iou_types, (list, tuple))
+ coco_gt = copy.deepcopy(coco_gt)
+ self.coco_gt = coco_gt
+
+ self.iou_types = iou_types
+ self.coco_eval = {}
+ for iou_type in iou_types:
+ self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
+
+ self.img_ids = []
+ self.eval_imgs = {k: [] for k in iou_types}
+
+ def update(self, predictions):
+ img_ids = list(np.unique(list(predictions.keys())))
+ self.img_ids.extend(img_ids)
+
+ for iou_type in self.iou_types:
+ results = self.prepare(predictions, iou_type)
+
+ # suppress pycocotools prints
+ with open(os.devnull, 'w') as devnull:
+ with contextlib.redirect_stdout(devnull):
+ coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
+ coco_eval = self.coco_eval[iou_type]
+
+ coco_eval.cocoDt = coco_dt
+ coco_eval.params.imgIds = list(img_ids)
+ img_ids, eval_imgs = evaluate(coco_eval)
+
+ self.eval_imgs[iou_type].append(eval_imgs)
+
+ def synchronize_between_processes(self):
+ for iou_type in self.iou_types:
+ self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
+ create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
+
+ def accumulate(self):
+ for coco_eval in self.coco_eval.values():
+ coco_eval.accumulate()
+
+ def summarize(self):
+ for iou_type, coco_eval in self.coco_eval.items():
+ print("IoU metric: {}".format(iou_type))
+ coco_eval.summarize()
+
+ def prepare(self, predictions, iou_type):
+ if iou_type == "bbox":
+ return self.prepare_for_coco_detection(predictions)
+ elif iou_type == "segm":
+ return self.prepare_for_coco_segmentation(predictions)
+ elif iou_type == "keypoints":
+ return self.prepare_for_coco_keypoint(predictions)
+ else:
+ raise ValueError("Unknown iou type {}".format(iou_type))
+
+ def prepare_for_coco_detection(self, predictions):
+ coco_results = []
+ for original_id, prediction in predictions.items():
+ if len(prediction) == 0:
+ continue
+
+ boxes = prediction["boxes"]
+ boxes = convert_to_xywh(boxes).tolist()
+ scores = prediction["scores"].tolist()
+ labels = prediction["labels"].tolist()
+
+ coco_results.extend(
+ [
+ {
+ "image_id": original_id,
+ "category_id": labels[k],
+ "bbox": box,
+ "score": scores[k],
+ }
+ for k, box in enumerate(boxes)
+ ]
+ )
+ return coco_results
+
+ def prepare_for_coco_segmentation(self, predictions):
+ coco_results = []
+ for original_id, prediction in predictions.items():
+ if len(prediction) == 0:
+ continue
+
+ scores = prediction["scores"]
+ labels = prediction["labels"]
+ masks = prediction["masks"]
+
+ masks = masks > 0.5
+
+ scores = prediction["scores"].tolist()
+ labels = prediction["labels"].tolist()
+
+ rles = [
+ mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
+ for mask in masks
+ ]
+ for rle in rles:
+ rle["counts"] = rle["counts"].decode("utf-8")
+
+ coco_results.extend(
+ [
+ {
+ "image_id": original_id,
+ "category_id": labels[k],
+ "segmentation": rle,
+ "score": scores[k],
+ }
+ for k, rle in enumerate(rles)
+ ]
+ )
+ return coco_results
+
+ def prepare_for_coco_keypoint(self, predictions):
+ coco_results = []
+ for original_id, prediction in predictions.items():
+ if len(prediction) == 0:
+ continue
+
+ boxes = prediction["boxes"]
+ boxes = convert_to_xywh(boxes).tolist()
+ scores = prediction["scores"].tolist()
+ labels = prediction["labels"].tolist()
+ keypoints = prediction["keypoints"]
+ keypoints = keypoints.flatten(start_dim=1).tolist()
+
+ coco_results.extend(
+ [
+ {
+ "image_id": original_id,
+ "category_id": labels[k],
+ 'keypoints': keypoint,
+ "score": scores[k],
+ }
+ for k, keypoint in enumerate(keypoints)
+ ]
+ )
+ return coco_results
+
+
+def convert_to_xywh(boxes):
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
+ return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
+
+
+def merge(img_ids, eval_imgs):
+ all_img_ids = all_gather(img_ids)
+ all_eval_imgs = all_gather(eval_imgs)
+
+ merged_img_ids = []
+ for p in all_img_ids:
+ merged_img_ids.extend(p)
+
+ merged_eval_imgs = []
+ for p in all_eval_imgs:
+ merged_eval_imgs.append(p)
+
+ merged_img_ids = np.array(merged_img_ids)
+ merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
+
+ # keep only unique (and in sorted order) images
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
+ merged_eval_imgs = merged_eval_imgs[..., idx]
+
+ return merged_img_ids, merged_eval_imgs
+
+
+def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
+ img_ids, eval_imgs = merge(img_ids, eval_imgs)
+ img_ids = list(img_ids)
+ eval_imgs = list(eval_imgs.flatten())
+
+ coco_eval.evalImgs = eval_imgs
+ coco_eval.params.imgIds = img_ids
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
+
+
+#################################################################
+# From pycocotools, just removed the prints and fixed
+# a Python3 bug about unicode not defined
+#################################################################
+
+
+def evaluate(self):
+ '''
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
+ :return: None
+ '''
+ # tic = time.time()
+ # print('Running per image evaluation...')
+ p = self.params
+ # add backward compatibility if useSegm is specified in params
+ if p.useSegm is not None:
+ p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
+ print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
+ # print('Evaluate annotation type *{}*'.format(p.iouType))
+ p.imgIds = list(np.unique(p.imgIds))
+ if p.useCats:
+ p.catIds = list(np.unique(p.catIds))
+ p.maxDets = sorted(p.maxDets)
+ self.params = p
+
+ self._prepare()
+ # loop through images, area range, max detection number
+ catIds = p.catIds if p.useCats else [-1]
+
+ if p.iouType == 'segm' or p.iouType == 'bbox':
+ computeIoU = self.computeIoU
+ elif p.iouType == 'keypoints':
+ computeIoU = self.computeOks
+ self.ious = {
+ (imgId, catId): computeIoU(imgId, catId)
+ for imgId in p.imgIds
+ for catId in catIds}
+
+ evaluateImg = self.evaluateImg
+ maxDet = p.maxDets[-1]
+ evalImgs = [
+ evaluateImg(imgId, catId, areaRng, maxDet)
+ for catId in catIds
+ for areaRng in p.areaRng
+ for imgId in p.imgIds
+ ]
+ # this is NOT in the pycocotools code, but could be done outside
+ evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
+ self._paramsEval = copy.deepcopy(self.params)
+ # toc = time.time()
+ # print('DONE (t={:0.2f}s).'.format(toc-tic))
+ return p.imgIds, evalImgs
+
+#################################################################
+# end of straight copy from pycocotools, just removing the prints
+#################################################################
diff --git a/datasets/coco_panoptic.py b/datasets/coco_panoptic.py
new file mode 100644
index 0000000..176fcf5
--- /dev/null
+++ b/datasets/coco_panoptic.py
@@ -0,0 +1,109 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+import json
+from pathlib import Path
+
+import numpy as np
+import torch
+from PIL import Image
+
+from panopticapi.utils import rgb2id
+from util.box_ops import masks_to_boxes
+
+from .coco import make_coco_transforms
+
+
+class CocoPanoptic:
+ def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True):
+ with open(ann_file, 'r') as f:
+ self.coco = json.load(f)
+
+ # sort 'images' field so that they are aligned with 'annotations'
+ # i.e., in alphabetical order
+ self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id'])
+ # sanity check
+ if "annotations" in self.coco:
+ for img, ann in zip(self.coco['images'], self.coco['annotations']):
+ assert img['file_name'][:-4] == ann['file_name'][:-4]
+
+ self.img_folder = img_folder
+ self.ann_folder = ann_folder
+ self.ann_file = ann_file
+ self.transforms = transforms
+ self.return_masks = return_masks
+
+ def __getitem__(self, idx):
+ ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx]
+ img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg')
+ ann_path = Path(self.ann_folder) / ann_info['file_name']
+
+ img = Image.open(img_path).convert('RGB')
+ w, h = img.size
+ if "segments_info" in ann_info:
+ masks = np.asarray(Image.open(ann_path), dtype=np.uint32)
+ masks = rgb2id(masks)
+
+ ids = np.array([ann['id'] for ann in ann_info['segments_info']])
+ masks = masks == ids[:, None, None]
+
+ masks = torch.as_tensor(masks, dtype=torch.uint8)
+ labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64)
+
+ target = {}
+ target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]])
+ if self.return_masks:
+ target['masks'] = masks
+ target['labels'] = labels
+
+ target["boxes"] = masks_to_boxes(masks)
+
+ target['size'] = torch.as_tensor([int(h), int(w)])
+ target['orig_size'] = torch.as_tensor([int(h), int(w)])
+ if "segments_info" in ann_info:
+ for name in ['iscrowd', 'area']:
+ target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']])
+
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.coco['images'])
+
+ def get_height_and_width(self, idx):
+ img_info = self.coco['images'][idx]
+ height = img_info['height']
+ width = img_info['width']
+ return height, width
+
+
+def build(image_set, cfg):
+ img_folder_root = Path(cfg.DATASET.COCO_PATH)
+ ann_folder_root = Path(cfg.COCO_PANOPTIC_PATH)
+ assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist'
+ assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist'
+ mode = 'panoptic'
+ PATHS = {
+ "train": ("train2017", Path("annotations") / f'{mode}_train2017.json'),
+ "val": ("val2017", Path("annotations") / f'{mode}_val2017.json'),
+ }
+
+ img_folder, ann_file = PATHS[image_set]
+ img_folder_path = img_folder_root / img_folder
+ ann_folder = ann_folder_root / f'{mode}_{img_folder}'
+ ann_file = ann_folder_root / ann_file
+
+ dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file,
+ transforms=make_coco_transforms(image_set), return_masks=cfg.MODEL.MASKS)
+
+ return dataset
diff --git a/datasets/data_prefetcher.py b/datasets/data_prefetcher.py
new file mode 100644
index 0000000..fbe12ff
--- /dev/null
+++ b/datasets/data_prefetcher.py
@@ -0,0 +1,71 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+import torch
+
+def to_cuda(samples, targets, device):
+ samples = samples.to(device, non_blocking=True)
+ targets = [{k: v.to(device, non_blocking=True) for k, v in t.items()} for t in targets]
+ return samples, targets
+
+class data_prefetcher():
+ def __init__(self, loader, device, prefetch=True):
+ self.loader = iter(loader)
+ self.prefetch = prefetch
+ self.device = device
+ if prefetch:
+ self.stream = torch.cuda.Stream()
+ self.preload()
+
+ def preload(self):
+ try:
+ self.next_samples, self.next_targets = next(self.loader)
+ except StopIteration:
+ self.next_samples = None
+ self.next_targets = None
+ return
+ # if record_stream() doesn't work, another option is to make sure device inputs are created
+ # on the main stream.
+ # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
+ # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
+ # Need to make sure the memory allocated for next_* is not still in use by the main stream
+ # at the time we start copying to next_*:
+ # self.stream.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(self.stream):
+ self.next_samples, self.next_targets = to_cuda(self.next_samples, self.next_targets, self.device)
+ # more code for the alternative if record_stream() doesn't work:
+ # copy_ will record the use of the pinned source tensor in this side stream.
+ # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
+ # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
+ # self.next_input = self.next_input_gpu
+ # self.next_target = self.next_target_gpu
+
+ # With Amp, it isn't necessary to manually convert data to half.
+ # if args.fp16:
+ # self.next_input = self.next_input.half()
+ # else:
+
+ def next(self):
+ if self.prefetch:
+ torch.cuda.current_stream().wait_stream(self.stream)
+ samples = self.next_samples
+ targets = self.next_targets
+ if samples is not None:
+ samples.record_stream(torch.cuda.current_stream())
+ if targets is not None:
+ for t in targets:
+ for k, v in t.items():
+ v.record_stream(torch.cuda.current_stream())
+ self.preload()
+ else:
+ try:
+ samples, targets = next(self.loader)
+ samples, targets = to_cuda(samples, targets, self.device)
+ except StopIteration:
+ samples = None
+ targets = None
+ return samples, targets
diff --git a/datasets/panoptic_eval.py b/datasets/panoptic_eval.py
new file mode 100644
index 0000000..d63adc9
--- /dev/null
+++ b/datasets/panoptic_eval.py
@@ -0,0 +1,54 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+import json
+import os
+
+import util.misc as utils
+
+try:
+ from panopticapi.evaluation import pq_compute
+except ImportError:
+ pass
+
+
+class PanopticEvaluator(object):
+ def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"):
+ self.gt_json = ann_file
+ self.gt_folder = ann_folder
+ if utils.is_main_process():
+ if not os.path.exists(output_dir):
+ os.mkdir(output_dir)
+ self.output_dir = output_dir
+ self.predictions = []
+
+ def update(self, predictions):
+ for p in predictions:
+ with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f:
+ f.write(p.pop("png_string"))
+
+ self.predictions += predictions
+
+ def synchronize_between_processes(self):
+ all_predictions = utils.all_gather(self.predictions)
+ merged_predictions = []
+ for p in all_predictions:
+ merged_predictions += p
+ self.predictions = merged_predictions
+
+ def summarize(self):
+ if utils.is_main_process():
+ json_data = {"annotations": self.predictions}
+ predictions_json = os.path.join(self.output_dir, "predictions.json")
+ with open(predictions_json, "w") as f:
+ f.write(json.dumps(json_data))
+ return pq_compute(self.gt_json, predictions_json, gt_folder=self.gt_folder, pred_folder=self.output_dir)
+ return None
diff --git a/datasets/samplers.py b/datasets/samplers.py
new file mode 100644
index 0000000..fe5ce38
--- /dev/null
+++ b/datasets/samplers.py
@@ -0,0 +1,141 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from codes in torch.utils.data.distributed
+# ------------------------------------------------------------------------
+
+import os
+import math
+import torch
+import torch.distributed as dist
+from torch.utils.data.sampler import Sampler
+
+
+class DistributedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+ .. note::
+ Dataset is assumed to be of constant size.
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+ self.shuffle = shuffle
+
+ def __iter__(self):
+ if self.shuffle:
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ indices += indices[: (self.total_size - len(indices))]
+ assert len(indices) == self.total_size
+
+ # subsample
+ offset = self.num_samples * self.rank
+ indices = indices[offset : offset + self.num_samples]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+
+class NodeDistributedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+ .. note::
+ Dataset is assumed to be of constant size.
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ if local_rank is None:
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
+ if local_size is None:
+ local_size = int(os.environ.get('LOCAL_SIZE', 1))
+ self.dataset = dataset
+ self.shuffle = shuffle
+ self.num_replicas = num_replicas
+ self.num_parts = local_size
+ self.rank = rank
+ self.local_rank = local_rank
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+
+ self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts
+
+ def __iter__(self):
+ if self.shuffle:
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+ indices = [i for i in indices if i % self.num_parts == self.local_rank]
+
+ # add extra samples to make it evenly divisible
+ indices += indices[:(self.total_size_parts - len(indices))]
+ assert len(indices) == self.total_size_parts
+
+ # subsample
+ indices = indices[self.rank // self.num_parts:self.total_size_parts:self.num_replicas // self.num_parts]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/datasets/torchvision_datasets/__init__.py b/datasets/torchvision_datasets/__init__.py
new file mode 100644
index 0000000..47502cc
--- /dev/null
+++ b/datasets/torchvision_datasets/__init__.py
@@ -0,0 +1,9 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+
+from .coco import CocoDetection
diff --git a/datasets/torchvision_datasets/coco.py b/datasets/torchvision_datasets/coco.py
new file mode 100644
index 0000000..9bf70b9
--- /dev/null
+++ b/datasets/torchvision_datasets/coco.py
@@ -0,0 +1,86 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from torchvision
+# ------------------------------------------------------------------------
+
+"""
+Copy-Paste from torchvision, but add utility of caching images on memory
+"""
+from torchvision.datasets.vision import VisionDataset
+from PIL import Image
+import os
+import os.path
+import tqdm
+from io import BytesIO
+
+
+class CocoDetection(VisionDataset):
+ """`MS Coco Detection `_ Dataset.
+ Args:
+ root (string): Root directory where images are downloaded to.
+ annFile (string): Path to json annotation file.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.ToTensor``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
+ and returns a transformed version.
+ """
+
+ def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None,
+ cache_mode=False, local_rank=0, local_size=1):
+ super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
+ from pycocotools.coco import COCO
+ self.coco = COCO(annFile)
+ self.ids = list(sorted(self.coco.imgs.keys()))
+ self.cache_mode = cache_mode
+ self.local_rank = local_rank
+ self.local_size = local_size
+ if cache_mode:
+ self.cache = {}
+ self.cache_images()
+
+ def cache_images(self):
+ self.cache = {}
+ for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids):
+ if index % self.local_size != self.local_rank:
+ continue
+ path = self.coco.loadImgs(img_id)[0]['file_name']
+ with open(os.path.join(self.root, path), 'rb') as f:
+ self.cache[path] = f.read()
+
+ def get_image(self, path):
+ if self.cache_mode:
+ if path not in self.cache.keys():
+ with open(os.path.join(self.root, path), 'rb') as f:
+ self.cache[path] = f.read()
+ return Image.open(BytesIO(self.cache[path])).convert('RGB')
+ return Image.open(os.path.join(self.root, path)).convert('RGB')
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
+ """
+ coco = self.coco
+ img_id = self.ids[index]
+ ann_ids = coco.getAnnIds(imgIds=img_id)
+ target = coco.loadAnns(ann_ids)
+
+ path = coco.loadImgs(img_id)[0]['file_name']
+
+ img = self.get_image(path)
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.ids)
diff --git a/datasets/transforms.py b/datasets/transforms.py
new file mode 100644
index 0000000..50738ae
--- /dev/null
+++ b/datasets/transforms.py
@@ -0,0 +1,286 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Transforms and data augmentation for both image + bbox.
+"""
+import random
+
+import PIL
+import torch
+import torchvision.transforms as T
+import torchvision.transforms.functional as F
+
+from util.box_ops import box_xyxy_to_cxcywh
+from util.misc import interpolate
+
+
+def crop(image, target, region):
+ cropped_image = F.crop(image, *region)
+
+ target = target.copy()
+ i, j, h, w = region
+
+ # should we do something wrt the original size?
+ target["size"] = torch.tensor([h, w])
+
+ fields = ["labels", "area", "iscrowd"]
+
+ if "boxes" in target:
+ boxes = target["boxes"]
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
+ cropped_boxes = cropped_boxes.clamp(min=0)
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
+ target["area"] = area
+ fields.append("boxes")
+
+ if "masks" in target:
+ # FIXME should we update the area here if there are no boxes?
+ target['masks'] = target['masks'][:, i:i + h, j:j + w]
+ fields.append("masks")
+
+ # remove elements for which the boxes or masks that have zero area
+ if "boxes" in target or "masks" in target:
+ # favor boxes selection when defining which elements to keep
+ # this is compatible with previous implementation
+ if "boxes" in target:
+ cropped_boxes = target['boxes'].reshape(-1, 2, 2)
+ keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
+ else:
+ keep = target['masks'].flatten(1).any(1)
+
+ for field in fields:
+ target[field] = target[field][keep]
+
+ return cropped_image, target
+
+
+def hflip(image, target):
+ flipped_image = F.hflip(image)
+
+ w, h = image.size
+
+ target = target.copy()
+ if "boxes" in target:
+ boxes = target["boxes"]
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
+ target["boxes"] = boxes
+
+ if "masks" in target:
+ target['masks'] = target['masks'].flip(-1)
+
+ return flipped_image, target
+
+
+def resize(image, target, size, max_size=None):
+ # size can be min_size (scalar) or (w, h) tuple
+
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
+ w, h = image_size
+ if max_size is not None:
+ min_original_size = float(min((w, h)))
+ max_original_size = float(max((w, h)))
+ if max_original_size / min_original_size * size > max_size:
+ size = int(round(max_size * min_original_size / max_original_size))
+
+ if (w <= h and w == size) or (h <= w and h == size):
+ return (h, w)
+
+ if w < h:
+ ow = size
+ oh = int(size * h / w)
+ else:
+ oh = size
+ ow = int(size * w / h)
+
+ return (oh, ow)
+
+ def get_size(image_size, size, max_size=None):
+ if isinstance(size, (list, tuple)):
+ return size[::-1]
+ else:
+ return get_size_with_aspect_ratio(image_size, size, max_size)
+
+ size = get_size(image.size, size, max_size)
+ rescaled_image = F.resize(image, size)
+
+ if target is None:
+ return rescaled_image, None
+
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
+ ratio_width, ratio_height = ratios
+
+ target = target.copy()
+ if "boxes" in target:
+ boxes = target["boxes"]
+ scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
+ target["boxes"] = scaled_boxes
+
+ if "area" in target:
+ area = target["area"]
+ scaled_area = area * (ratio_width * ratio_height)
+ target["area"] = scaled_area
+
+ h, w = size
+ target["size"] = torch.tensor([h, w])
+
+ if "masks" in target:
+ target['masks'] = interpolate(
+ target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
+
+ return rescaled_image, target
+
+
+def pad(image, target, padding):
+ # assumes that we only pad on the bottom right corners
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
+ if target is None:
+ return padded_image, None
+ target = target.copy()
+ # should we do something wrt the original size?
+ target["size"] = torch.tensor(padded_image[::-1])
+ if "masks" in target:
+ target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
+ return padded_image, target
+
+
+class RandomCrop(object):
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, img, target):
+ region = T.RandomCrop.get_params(img, self.size)
+ return crop(img, target, region)
+
+
+class RandomSizeCrop(object):
+ def __init__(self, min_size: int, max_size: int):
+ self.min_size = min_size
+ self.max_size = max_size
+
+ def __call__(self, img: PIL.Image.Image, target: dict):
+ w = random.randint(self.min_size, min(img.width, self.max_size))
+ h = random.randint(self.min_size, min(img.height, self.max_size))
+ region = T.RandomCrop.get_params(img, [h, w])
+ return crop(img, target, region)
+
+
+class CenterCrop(object):
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, img, target):
+ image_width, image_height = img.size
+ crop_height, crop_width = self.size
+ crop_top = int(round((image_height - crop_height) / 2.))
+ crop_left = int(round((image_width - crop_width) / 2.))
+ return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
+
+
+class RandomHorizontalFlip(object):
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, img, target):
+ if random.random() < self.p:
+ return hflip(img, target)
+ return img, target
+
+
+class RandomResize(object):
+ def __init__(self, sizes, max_size=None):
+ assert isinstance(sizes, (list, tuple))
+ self.sizes = sizes
+ self.max_size = max_size
+
+ def __call__(self, img, target=None):
+ size = random.choice(self.sizes)
+ return resize(img, target, size, self.max_size)
+
+
+class RandomPad(object):
+ def __init__(self, max_pad):
+ self.max_pad = max_pad
+
+ def __call__(self, img, target):
+ pad_x = random.randint(0, self.max_pad)
+ pad_y = random.randint(0, self.max_pad)
+ return pad(img, target, (pad_x, pad_y))
+
+
+class RandomSelect(object):
+ """
+ Randomly selects between transforms1 and transforms2,
+ with probability p for transforms1 and (1 - p) for transforms2
+ """
+ def __init__(self, transforms1, transforms2, p=0.5):
+ self.transforms1 = transforms1
+ self.transforms2 = transforms2
+ self.p = p
+
+ def __call__(self, img, target):
+ if random.random() < self.p:
+ return self.transforms1(img, target)
+ return self.transforms2(img, target)
+
+
+class ToTensor(object):
+ def __call__(self, img, target):
+ return F.to_tensor(img), target
+
+
+class RandomErasing(object):
+
+ def __init__(self, *args, **kwargs):
+ self.eraser = T.RandomErasing(*args, **kwargs)
+
+ def __call__(self, img, target):
+ return self.eraser(img), target
+
+
+class Normalize(object):
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, image, target=None):
+ image = F.normalize(image, mean=self.mean, std=self.std)
+ if target is None:
+ return image, None
+ target = target.copy()
+ h, w = image.shape[-2:]
+ if "boxes" in target:
+ boxes = target["boxes"]
+ boxes = box_xyxy_to_cxcywh(boxes)
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
+ target["boxes"] = boxes
+ return image, target
+
+
+class Compose(object):
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, image, target):
+ for t in self.transforms:
+ image, target = t(image, target)
+ return image, target
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + "("
+ for t in self.transforms:
+ format_string += "\n"
+ format_string += " {0}".format(t)
+ format_string += "\n)"
+ return format_string
diff --git a/engine.py b/engine.py
new file mode 100644
index 0000000..b41fd43
--- /dev/null
+++ b/engine.py
@@ -0,0 +1,168 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Train and eval functions used in main.py
+"""
+import math
+import os
+import sys
+from typing import Iterable
+
+import torch
+import util.misc as utils
+from datasets.coco_eval import CocoEvaluator
+from datasets.panoptic_eval import PanopticEvaluator
+from datasets.data_prefetcher import data_prefetcher
+
+
+def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
+ device: torch.device, epoch: int, max_norm: float = 0, logger=None):
+ model.train()
+ criterion.train()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
+ metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+ print_freq = 10
+
+ prefetcher = data_prefetcher(data_loader, device, prefetch=True)
+ samples, targets = prefetcher.next()
+
+ # for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
+ for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header,logger=logger):
+ outputs = model(samples)
+ loss_dict = criterion(outputs, targets)
+ weight_dict = criterion.weight_dict
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
+
+ # reduce losses over all GPUs for logging purposes
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
+ loss_dict_reduced_unscaled = {f'{k}_unscaled': v
+ for k, v in loss_dict_reduced.items()}
+ loss_dict_reduced_scaled = {k: v * weight_dict[k]
+ for k, v in loss_dict_reduced.items() if k in weight_dict}
+ losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
+
+ loss_value = losses_reduced_scaled.item()
+
+ if not math.isfinite(loss_value):
+ print("Loss is {}, stopping training".format(loss_value))
+ print(loss_dict_reduced)
+ sys.exit(1)
+
+ optimizer.zero_grad()
+ losses.backward()
+ if max_norm > 0:
+ grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
+ else:
+ grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm)
+ optimizer.step()
+
+ metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
+ metric_logger.update(class_error=loss_dict_reduced['class_error'])
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+ metric_logger.update(grad_norm=grad_total_norm)
+
+ samples, targets = prefetcher.next()
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
+ model.eval()
+ criterion.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
+ header = 'Test:'
+
+ iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
+ coco_evaluator = CocoEvaluator(base_ds, iou_types)
+ # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
+
+ panoptic_evaluator = None
+ if 'panoptic' in postprocessors.keys():
+ panoptic_evaluator = PanopticEvaluator(
+ data_loader.dataset.ann_file,
+ data_loader.dataset.ann_folder,
+ output_dir=os.path.join(output_dir, "panoptic_eval"),
+ )
+
+ for samples, targets in metric_logger.log_every(data_loader, 10, header,logger=logger):
+ samples = samples.to(device)
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
+
+ outputs = model(samples)
+ loss_dict = criterion(outputs, targets)
+ weight_dict = criterion.weight_dict
+
+ # reduce losses over all GPUs for logging purposes
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
+ loss_dict_reduced_scaled = {k: v * weight_dict[k]
+ for k, v in loss_dict_reduced.items() if k in weight_dict}
+ loss_dict_reduced_unscaled = {f'{k}_unscaled': v
+ for k, v in loss_dict_reduced.items()}
+ metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
+ **loss_dict_reduced_scaled,
+ **loss_dict_reduced_unscaled)
+ metric_logger.update(class_error=loss_dict_reduced['class_error'])
+
+ orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
+ results = postprocessors['bbox'](outputs, orig_target_sizes)
+ if 'segm' in postprocessors.keys():
+ target_sizes = torch.stack([t["size"] for t in targets], dim=0)
+ results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
+ res = {target['image_id'].item(): output for target, output in zip(targets, results)}
+ if coco_evaluator is not None:
+ coco_evaluator.update(res)
+
+ if panoptic_evaluator is not None:
+ res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
+ for i, target in enumerate(targets):
+ image_id = target["image_id"].item()
+ file_name = f"{image_id:012d}.png"
+ res_pano[i]["image_id"] = image_id
+ res_pano[i]["file_name"] = file_name
+
+ panoptic_evaluator.update(res_pano)
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+ if coco_evaluator is not None:
+ coco_evaluator.synchronize_between_processes()
+ if panoptic_evaluator is not None:
+ panoptic_evaluator.synchronize_between_processes()
+
+ # accumulate predictions from all images
+ if coco_evaluator is not None:
+ coco_evaluator.accumulate()
+ coco_evaluator.summarize()
+ panoptic_res = None
+ if panoptic_evaluator is not None:
+ panoptic_res = panoptic_evaluator.summarize()
+ stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+ if coco_evaluator is not None:
+ if 'bbox' in postprocessors.keys():
+ stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
+ if 'segm' in postprocessors.keys():
+ stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
+ if panoptic_res is not None:
+ stats['PQ_all'] = panoptic_res["All"]
+ stats['PQ_th'] = panoptic_res["Things"]
+ stats['PQ_st'] = panoptic_res["Stuff"]
+ return stats, coco_evaluator
diff --git a/engine_aood.py b/engine_aood.py
new file mode 100644
index 0000000..613898c
--- /dev/null
+++ b/engine_aood.py
@@ -0,0 +1,172 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Train and eval functions used in main.py
+"""
+import math
+import os
+import sys
+from typing import Iterable
+
+import torch
+import util.misc as utils
+from datasets.aood_eval import AOODEvaluator
+from datasets.panoptic_eval import PanopticEvaluator
+from datasets.data_prefetcher import data_prefetcher
+
+
+def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
+ device: torch.device, epoch: int, max_norm: float = 0, logger=None):
+ model.train()
+ criterion.train()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
+ metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+ print_freq = 10
+
+ prefetcher = data_prefetcher(data_loader, device, prefetch=True)
+ samples, targets = prefetcher.next()
+
+ # for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
+ for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header, logger=logger):
+ outputs = model(samples)
+ # loss_dict = criterion(outputs, targets, epoch)
+ loss_dict = criterion(samples, outputs, targets, epoch)
+ weight_dict = criterion.weight_dict
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
+
+ # reduce losses over all GPUs for logging purposes
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
+ loss_dict_reduced_unscaled = {f'{k}_unscaled': v
+ for k, v in loss_dict_reduced.items()}
+ loss_dict_reduced_scaled = {k: v * weight_dict[k]
+ for k, v in loss_dict_reduced.items() if k in weight_dict}
+ losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
+
+ loss_value = losses_reduced_scaled.item()
+
+ if not math.isfinite(loss_value):
+ print("Loss is {}, stopping training".format(loss_value))
+ print(loss_dict_reduced)
+ sys.exit(1)
+
+ optimizer.zero_grad()
+ losses.backward()
+ if max_norm > 0:
+ grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
+ else:
+ grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm)
+ optimizer.step()
+
+ metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
+ metric_logger.update(class_error=loss_dict_reduced['class_error'])
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+ metric_logger.update(grad_norm=grad_total_norm)
+
+ samples, targets = prefetcher.next()
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
+ model.eval()
+ criterion.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
+ header = 'Test:'
+
+ iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
+ coco_evaluator = AOODEvaluator(base_ds, iou_types)
+ # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
+
+ panoptic_evaluator = None
+ if 'panoptic' in postprocessors.keys():
+ panoptic_evaluator = PanopticEvaluator(
+ data_loader.dataset.ann_file,
+ data_loader.dataset.ann_folder,
+ output_dir=os.path.join(output_dir, "panoptic_eval"),
+ )
+
+ for samples, targets in metric_logger.log_every(data_loader, 10, header):
+ samples = samples.to(device)
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
+
+ outputs = model(samples)
+ loss_dict = criterion(samples, outputs, targets,epoch=0)
+ weight_dict = criterion.weight_dict
+
+ # reduce losses over all GPUs for logging purposes
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
+ loss_dict_reduced_scaled = {k: v * weight_dict[k]
+ for k, v in loss_dict_reduced.items() if k in weight_dict}
+ loss_dict_reduced_unscaled = {f'{k}_unscaled': v
+ for k, v in loss_dict_reduced.items()}
+ metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
+ **loss_dict_reduced_scaled,
+ **loss_dict_reduced_unscaled)
+ metric_logger.update(class_error=loss_dict_reduced['class_error'])
+
+ orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
+
+ results = postprocessors['bbox'](outputs, orig_target_sizes, show_box=False)
+
+ if 'segm' in postprocessors.keys():
+ target_sizes = torch.stack([t["size"] for t in targets], dim=0)
+ results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
+ res = {target['image_id'].item(): output for target, output in zip(targets, results)}
+ if coco_evaluator is not None:
+ coco_evaluator.update(res)
+
+ if panoptic_evaluator is not None:
+ res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
+ for i, target in enumerate(targets):
+ image_id = target["image_id"].item()
+ file_name = f"{image_id:012d}.png"
+ res_pano[i]["image_id"] = image_id
+ res_pano[i]["file_name"] = file_name
+
+ panoptic_evaluator.update(res_pano)
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+ if coco_evaluator is not None:
+ coco_evaluator.synchronize_between_processes()
+ if panoptic_evaluator is not None:
+ panoptic_evaluator.synchronize_between_processes()
+
+ # accumulate predictions from all images
+ if coco_evaluator is not None:
+ coco_evaluator.accumulate()
+ coco_evaluator.summarize()
+ panoptic_res = None
+ if panoptic_evaluator is not None:
+ panoptic_res = panoptic_evaluator.summarize()
+ stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+ if coco_evaluator is not None:
+ # if 'bbox' in postprocessors.keys():
+ # stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
+ if 'segm' in postprocessors.keys():
+ stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
+ if panoptic_res is not None:
+ stats['PQ_all'] = panoptic_res["All"]
+ stats['PQ_th'] = panoptic_res["Things"]
+ stats['PQ_st'] = panoptic_res["Stuff"]
+ # return stats, coco_evaluator
+ return coco_evaluator.coco_eval['bbox'].stats, coco_evaluator
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..94aaf74
--- /dev/null
+++ b/main.py
@@ -0,0 +1,323 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+
+import os
+import argparse
+import datetime
+import json
+import random
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+import datasets
+import datasets.DAOD as DAOD
+import util.misc as utils
+import datasets.samplers as samplers
+from datasets import build_dataset, get_coco_api_from_dataset
+from models import build_model
+from config import get_cfg_defaults
+
+
+def setup(args):
+ cfg = get_cfg_defaults()
+ if args.config_file:
+ cfg.merge_from_file(args.config_file)
+ if args.opts:
+ cfg.merge_from_list(args.opts)
+ utils.init_distributed_mode(cfg)
+ cfg.freeze()
+ if cfg.OUTPUT_DIR:
+ Path(cfg.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
+ os.system(f'cp {args.config_file} {cfg.OUTPUT_DIR}')
+ ddetr_src = 'models/deformable_detr.py'
+ ddetr_des = Path(cfg.OUTPUT_DIR) / 'deformable_detr.py.backup'
+ dtrans_src = 'models/deformable_transformer.py'
+ dtrans_des = Path(cfg.OUTPUT_DIR) / 'deformable_transformer.py.backup'
+ main_src = 'main.py'
+ main_des = Path(cfg.OUTPUT_DIR) / 'main.py.backup'
+ os.system(f'cp {ddetr_src} {ddetr_des}')
+ os.system(f'cp {dtrans_src} {dtrans_des}')
+ os.system(f'cp {main_src} {main_des}')
+ return cfg
+
+def main(cfg):
+ # align = cfg.MODEL.BACKBONE_ALIGN or cfg.MODEL.SPACE_ALIGN or cfg.MODEL.CHANNEL_ALIGN or cfg.MODEL.INSTANCE_ALIGN
+ # assert align == (cfg.DATASET.DA_MODE == 'uda')
+ # print("git:\n {}\n".format(utils.get_sha()))
+ print(cfg)
+ if cfg.DATASET.DA_MODE == 'osda':
+ from engine_aood import evaluate, train_one_epoch
+ else:
+ from engine import evaluate, train_one_epoch
+ if cfg.MODEL.FROZEN_WEIGHTS is not None:
+ assert cfg.MODEL.MASKS, "Frozen training is meant for segmentation only"
+
+ device = torch.device(cfg.DEVICE)
+ # fix the seed for reproducibility
+ seed = cfg.SEED + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.deterministic = True
+
+ model, criterion, postprocessors = build_model(cfg)
+ model.to(device)
+
+ model_without_ddp = model
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print('number of params:', n_parameters)
+
+ dataset_train = build_dataset(image_set='train', cfg=cfg)
+ dataset_val = build_dataset(image_set='val', cfg=cfg)
+
+ if cfg.DIST.DISTRIBUTED:
+ if cfg.CACHE_MODE:
+ sampler_train = samplers.NodeDistributedSampler(dataset_train)
+ sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False)
+ else:
+ sampler_train = samplers.DistributedSampler(dataset_train)
+ sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False)
+ else:
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+
+ if cfg.DATASET.DA_MODE == 'uda' or cfg.DATASET.DA_MODE == 'osda':
+ assert cfg.TRAIN.BATCH_SIZE % 2 == 0, f'cfg.TRAIN.BATCH_SIZE {cfg.TRAIN.BATCH_SIZE} should be a multiple of 2'
+ batch_sampler_train = torch.utils.data.BatchSampler(
+ sampler_train, cfg.TRAIN.BATCH_SIZE//2, drop_last=True)
+ data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
+ collate_fn=DAOD.collate_fn, num_workers=cfg.NUM_WORKERS,
+ pin_memory=True)
+ else:
+ batch_sampler_train = torch.utils.data.BatchSampler(
+ sampler_train, cfg.TRAIN.BATCH_SIZE, drop_last=True)
+ data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
+ collate_fn=utils.collate_fn, num_workers=cfg.NUM_WORKERS,
+ pin_memory=True)
+ data_loader_val = DataLoader(dataset_val, cfg.TRAIN.BATCH_SIZE, sampler=sampler_val,
+ drop_last=False, collate_fn=utils.collate_fn, num_workers=cfg.NUM_WORKERS,
+ pin_memory=True)
+
+ # lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"]
+ def match_name_keywords(n, name_keywords):
+ out = False
+ for b in name_keywords:
+ if b in n:
+ out = True
+ break
+ return out
+
+ param_dicts = [
+ {
+ "params":
+ [p for n, p in model_without_ddp.named_parameters()
+ if not match_name_keywords(n, cfg.TRAIN.LR_BACKBONE_NAMES) and not match_name_keywords(n, cfg.TRAIN.LR_LINEAR_PROJ_NAMES) and p.requires_grad],
+ "lr": cfg.TRAIN.LR,
+ },
+ {
+ "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, cfg.TRAIN.LR_BACKBONE_NAMES) and p.requires_grad],
+ "lr": cfg.TRAIN.LR_BACKBONE,
+ },
+ {
+ "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, cfg.TRAIN.LR_LINEAR_PROJ_NAMES) and p.requires_grad],
+ "lr": cfg.TRAIN.LR * cfg.TRAIN.LR_LINEAR_PROJ_MULT,
+ }
+ ]
+ if cfg.TRAIN.SGD:
+ optimizer = torch.optim.SGD(param_dicts, lr=cfg.TRAIN.LR, momentum=0.9,
+ weight_decay=cfg.TRAIN.WEIGHT_DECAY)
+ else:
+ optimizer = torch.optim.AdamW(param_dicts, lr=cfg.TRAIN.LR,
+ weight_decay=cfg.TRAIN.WEIGHT_DECAY)
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, cfg.TRAIN.LR_DROP)
+
+ if cfg.DIST.DISTRIBUTED:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[cfg.DIST.GPU])
+ model_without_ddp = model.module
+
+ if cfg.DATASET.DATASET_FILE == "coco_panoptic":
+ # We also evaluate AP during panoptic training, on original coco DS
+ coco_val = datasets.coco.build("val", cfg)
+ base_ds = get_coco_api_from_dataset(coco_val)
+ else:
+ base_ds = get_coco_api_from_dataset(dataset_val)
+
+ if cfg.MODEL.FROZEN_WEIGHTS is not None:
+ checkpoint = torch.load(cfg.MODEL.FROZEN_WEIGHTS, map_location='cpu')
+ model_without_ddp.detr.load_state_dict(checkpoint['model'])
+
+ output_dir = Path(cfg.OUTPUT_DIR)
+
+ import logging
+ LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
+ logging.basicConfig(
+ filename=cfg.OUTPUT_DIR +'/_rank_{}_'.format(utils.get_rank())+str(__file__)[:-3] + '_' + time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + '.log',
+ level=logging.INFO, format=LOG_FORMAT, filemode='w')
+ console = logging.StreamHandler()
+ console.setLevel(logging.INFO)
+ logging.getLogger('').addHandler(console)
+ logger = logging.getLogger("main_da")
+
+ if cfg.RESUME: # [BUG] write after freezing cfgs
+ if cfg.RESUME.startswith('https'):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ cfg.RESUME, map_location='cpu', check_hash=True)
+ else:
+ checkpoint = torch.load(cfg.RESUME, map_location='cpu')
+ missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
+ unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
+ if len(missing_keys) > 0:
+ print('Missing Keys: {}'.format(missing_keys))
+ if len(unexpected_keys) > 0:
+ print('Unexpected Keys: {}'.format(unexpected_keys))
+ if not cfg.EVAL and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint and cfg.LOAD_OPTIMIZER:
+ import copy
+ p_groups = copy.deepcopy(optimizer.param_groups)
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ for pg, pg_old in zip(optimizer.param_groups, p_groups):
+ pg['lr'] = pg_old['lr']
+ pg['initial_lr'] = pg_old['initial_lr']
+ # print(optimizer.param_groups)
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ # todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance).
+ override_resumed_lr_drop = True
+ if override_resumed_lr_drop:
+ print('Warning: (hack) override_resumed_lr_drop is set to True, so cfg.TRAIN.LR_DROP would override lr_drop in resumed lr_scheduler.')
+ lr_scheduler.step_size = cfg.TRAIN.LR_DROP
+ lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
+ lr_scheduler.step(lr_scheduler.last_epoch)
+ cfg.START_EPOCH = checkpoint['epoch'] + 1
+
+
+
+ # check the resumed model
+ # if not cfg.EVAL:
+ # test_stats, coco_evaluator = evaluate(
+ # model, criterion, postprocessors, data_loader_val, base_ds, device, cfg.OUTPUT_DIR
+ # )
+ # if utils.is_main_process():
+ # for key in test_stats.keys():
+ # logger.info('{} {}'.format(key, test_stats[key]))
+
+ if cfg.EVAL:
+ test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
+ data_loader_val, base_ds, device, cfg.OUTPUT_DIR)
+
+ if cfg.OUTPUT_DIR:
+ title = test_stats['title'] + '\n'
+ results = test_stats['ap_map_wi_aose_ar']
+ results = 'Epoch : ' + results + '\n'
+ if utils.is_main_process():
+ results_dir = output_dir /'eval_results.txt'
+
+ with open(results_dir, 'a') as f:
+ f.write(title)
+
+
+
+
+ utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
+ return
+
+ print("Start training")
+ start_time = time.time()
+
+ best_mAP = 0
+ checkpoint_dir = 'check'
+ for epoch in range(cfg.START_EPOCH, cfg.TRAIN.EPOCHS):
+ if cfg.DIST.DISTRIBUTED:
+ sampler_train.set_epoch(epoch)
+ train_stats = train_one_epoch(
+ model, criterion, data_loader_train, optimizer, device, epoch, cfg.TRAIN.CLIP_MAX_NORM, logger=logger)
+ lr_scheduler.step()
+ if cfg.OUTPUT_DIR:
+ checkpoint_paths = [output_dir / 'checkpoint.pth']
+ # extra checkpoint before LR drop and every 5 epochs
+ if (epoch + 1) % cfg.TRAIN.LR_DROP == 0:
+ checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
+
+ for checkpoint_path in checkpoint_paths:
+ utils.save_on_master({
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch,
+ 'cfg': cfg,
+ }, checkpoint_path)
+
+ test_stats, coco_evaluator = evaluate(
+ model, criterion, postprocessors, data_loader_val, base_ds, device, cfg.OUTPUT_DIR
+ )
+ mAP_tmp = test_stats['base_mAP: ']
+ if mAP_tmp > best_mAP and utils.is_main_process():
+
+ if os.path.exists(checkpoint_dir):
+ os.remove(checkpoint_dir)
+
+ checkpoint_dir = output_dir / f'best_{epoch:02}_{round(mAP_tmp,3)}.pth'
+
+ utils.save_on_master({
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch,
+ 'cfg': cfg,
+ }, checkpoint_dir)
+ best_mAP = mAP_tmp
+
+
+ if cfg.OUTPUT_DIR and utils.is_main_process():
+ title = test_stats['title'] + '\n'
+ results = test_stats['ap_map_wi_aose_ar']
+ results = 'Epoch {}: '.format(epoch) + results + '\n'
+ results_dir = output_dir /'eval_results.txt'
+ if utils.is_main_process():
+ if epoch == 0:
+ with open(results_dir, 'a') as f:
+ f.write(title)
+
+ with open(results_dir, 'a') as f:
+ f.write(results)
+
+
+ # if utils.is_main_process():
+ for key in test_stats.keys():
+ logger.info('{} {}'.format(key, test_stats[key]))
+ # for evaluation logs
+ if coco_evaluator is not None:
+ (output_dir / 'eval').mkdir(exist_ok=True)
+ if "bbox" in coco_evaluator.coco_eval:
+ filenames = ['latest.pth']
+ if epoch % 50 == 0:
+ filenames.append(f'{epoch:03}.pth')
+ for name in filenames:
+ torch.save(coco_evaluator.coco_eval["bbox"].eval,
+ output_dir / "eval" / name)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser('Deformable DETR Detector')
+ parser.add_argument('--config_file', default='', type=str)
+ parser.add_argument("--opts", default=None, nargs=argparse.REMAINDER)
+ args = parser.parse_args()
+ cfg = setup(args)
+ main(cfg)
diff --git a/main_multi_eval.py b/main_multi_eval.py
new file mode 100644
index 0000000..bce75de
--- /dev/null
+++ b/main_multi_eval.py
@@ -0,0 +1,341 @@
+# ------------------------------------------------------------------------
+# Novel Scenes & Classes: Towards Adaptive Open-set Object Detection
+# Modified by Wuyang Li
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+import os
+import argparse
+import datetime
+import json
+import random
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+import datasets
+import datasets.DAOD as DAOD
+import util.misc as utils
+import datasets.samplers as samplers
+from datasets import build_dataset, get_coco_api_from_dataset
+
+from models import build_model
+from config import get_cfg_defaults
+import logging
+
+def setup(args):
+ cfg = get_cfg_defaults()
+ if args.config_file:
+ cfg.merge_from_file(args.config_file)
+ if args.opts:
+ cfg.merge_from_list(args.opts)
+
+ utils.init_distributed_mode(cfg)
+ cfg.freeze()
+
+ if cfg.OUTPUT_DIR:
+ Path(cfg.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
+ os.system(f'cp {args.config_file} {cfg.OUTPUT_DIR}')
+ ddetr_src = 'models/motif_detr.py'
+ ddetr_des = Path(cfg.OUTPUT_DIR) / 'motif_detr.py.backup'
+ dtrans_src = 'models/deformable_transformer.py'
+ dtrans_des = Path(cfg.OUTPUT_DIR) / 'deformable_transformer.py.backup'
+ main_src = 'main.py'
+ main_des = Path(cfg.OUTPUT_DIR) / 'main.py.backup'
+ os.system(f'cp {ddetr_src} {ddetr_des}')
+ os.system(f'cp {dtrans_src} {dtrans_des}')
+ os.system(f'cp {main_src} {main_des}')
+
+ return cfg
+
+
+def main(cfg):
+
+ # align = cfg.MODEL.BACKBONE_ALIGN or cfg.MODEL.SPACE_ALIGN or cfg.MODEL.CHANNEL_ALIGN or cfg.MODEL.INSTANCE_ALIGN
+ # assert align == (cfg.DATASET.DA_MODE == 'uda')
+ # print("git:\n {}\n".format(utils.get_sha()))
+
+ print(cfg)
+
+ if cfg.DATASET.DA_MODE == 'aood':
+ from engine_aood import evaluate, train_one_epoch
+ else:
+ from engine import evaluate, train_one_epoch
+
+ if cfg.MODEL.FROZEN_WEIGHTS is not None:
+ assert cfg.MODEL.MASKS, "Frozen training is meant for segmentation only"
+
+ device = torch.device(cfg.DEVICE)
+ # fix the seed for reproducibility
+ seed = cfg.SEED + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.deterministic = True
+
+ model, criterion, postprocessors = build_model(cfg)
+ model.to(device)
+
+ model_without_ddp = model
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print('number of params:', n_parameters)
+
+ dataset_train = build_dataset(image_set='train', cfg=cfg)
+
+ if not cfg.DATASET.DATASET_FILE == 'pascal_to_clipart':
+ num_eval_novel_classes = [3,4,5] # eval with 3/4/5 novel classes
+ else:
+ num_eval_novel_classes = [6,8,10] # eval with 6/8/10 novel classes
+ num_sub_tasks = len(num_eval_novel_classes)
+
+ dataset_val_list = []
+ sampler_val_list = []
+
+ for i in range(num_sub_tasks):
+ dataset_val_list.append(build_dataset(image_set='val', cfg=cfg, multi_task_eval_id=num_eval_novel_classes[i]))
+
+ if cfg.DIST.DISTRIBUTED:
+ if cfg.CACHE_MODE:
+ sampler_train = samplers.NodeDistributedSampler(dataset_train)
+ # sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False)
+ for dataloader in dataset_val_list:
+ sampler_val_list.append(samplers.NodeDistributedSampler(dataloader, shuffle=False))
+ else:
+ sampler_train = samplers.DistributedSampler(dataset_train)
+ # sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False)
+ for dataloader in dataset_val_list:
+ sampler_val_list.append(samplers.DistributedSampler(dataloader, shuffle=False))
+ else:
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
+ # sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+ for dataloader in dataset_val_list:
+ sampler_val_list.append(torch.utils.data.SequentialSampler(dataloader))
+
+ if cfg.DATASET.DA_MODE == 'uda' or cfg.DATASET.DA_MODE == 'aood':
+ assert cfg.TRAIN.BATCH_SIZE % 2 == 0, f'cfg.TRAIN.BATCH_SIZE {cfg.TRAIN.BATCH_SIZE} should be a multiple of 2'
+ batch_sampler_train = torch.utils.data.BatchSampler(
+ sampler_train, cfg.TRAIN.BATCH_SIZE//2, drop_last=True)
+ data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
+ collate_fn=DAOD.collate_fn, num_workers=cfg.NUM_WORKERS,
+ pin_memory=True)
+ else:
+ batch_sampler_train = torch.utils.data.BatchSampler(
+ sampler_train, cfg.TRAIN.BATCH_SIZE, drop_last=True)
+ data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
+ collate_fn=utils.collate_fn, num_workers=cfg.NUM_WORKERS,
+ pin_memory=True)
+
+ data_loader_val_list = []
+ for i in range(num_sub_tasks):
+ data_loader_val_list.append(
+ DataLoader(dataset_val_list[i], cfg.TRAIN.BATCH_SIZE, sampler=sampler_val_list[i],
+ drop_last=False, collate_fn=utils.collate_fn, num_workers=cfg.NUM_WORKERS,
+ pin_memory=True))
+
+ # data_loader_val = DataLoader(dataset_val, cfg.TRAIN.BATCH_SIZE, sampler=sampler_val,
+ # drop_last=False, collate_fn=utils.collate_fn, num_workers=cfg.NUM_WORKERS,
+ # pin_memory=True)
+ # lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"]
+ def match_name_keywords(n, name_keywords):
+ out = False
+ for b in name_keywords:
+ if b in n:
+ out = True
+ break
+ return out
+
+ param_dicts = [
+ {
+ "params":
+ [p for n, p in model_without_ddp.named_parameters()
+ if not match_name_keywords(n, cfg.TRAIN.LR_BACKBONE_NAMES) and not match_name_keywords(n, cfg.TRAIN.LR_LINEAR_PROJ_NAMES) and p.requires_grad],
+ "lr": cfg.TRAIN.LR,
+ },
+ {
+ "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, cfg.TRAIN.LR_BACKBONE_NAMES) and p.requires_grad],
+ "lr": cfg.TRAIN.LR_BACKBONE,
+ },
+ {
+ "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, cfg.TRAIN.LR_LINEAR_PROJ_NAMES) and p.requires_grad],
+ "lr": cfg.TRAIN.LR * cfg.TRAIN.LR_LINEAR_PROJ_MULT,
+ }
+ ]
+ if cfg.TRAIN.SGD:
+ optimizer = torch.optim.SGD(param_dicts, lr=cfg.TRAIN.LR, momentum=0.9,
+ weight_decay=cfg.TRAIN.WEIGHT_DECAY)
+ else:
+ optimizer = torch.optim.AdamW(param_dicts, lr=cfg.TRAIN.LR,
+ weight_decay=cfg.TRAIN.WEIGHT_DECAY)
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, cfg.TRAIN.LR_DROP)
+
+ if cfg.DIST.DISTRIBUTED:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[cfg.DIST.GPU])
+ model_without_ddp = model.module
+
+ if cfg.DATASET.DATASET_FILE == "coco_panoptic":
+ # We also evaluate AP during panoptic training, on original coco DS
+ coco_val = datasets.coco.build("val", cfg)
+ base_ds = get_coco_api_from_dataset(coco_val)
+ else:
+ # base_ds = get_coco_api_from_dataset(dataset_val)
+ base_ds_list =[]
+ for i in range(num_sub_tasks):
+ base_ds_list.append(get_coco_api_from_dataset(dataset_val_list[i]))
+
+ if cfg.MODEL.FROZEN_WEIGHTS is not None:
+ checkpoint = torch.load(cfg.MODEL.FROZEN_WEIGHTS, map_location='cpu')
+ model_without_ddp.detr.load_state_dict(checkpoint['model'])
+ output_dir = Path(cfg.OUTPUT_DIR)
+
+ LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
+ logging.basicConfig(
+ filename=cfg.OUTPUT_DIR +'/_rank_{}_'.format(utils.get_rank())+str(__file__)[:-3] + '_' + time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + '.log',
+ level=logging.INFO, format=LOG_FORMAT, filemode='w')
+ console = logging.StreamHandler()
+ console.setLevel(logging.INFO)
+ logging.getLogger('').addHandler(console)
+ logger = logging.getLogger("main_da")
+
+ if cfg.RESUME: # [BUG] write after freezing cfgs
+ if cfg.RESUME.startswith('https'):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ cfg.RESUME, map_location='cpu', check_hash=True)
+ else:
+ checkpoint = torch.load(cfg.RESUME, map_location='cpu')
+ missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
+ unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
+ if len(missing_keys) > 0:
+ print('Missing Keys: {}'.format(missing_keys))
+ if len(unexpected_keys) > 0:
+ print('Unexpected Keys: {}'.format(unexpected_keys))
+ if not cfg.EVAL and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint and cfg.LOAD_OPTIMIZER:
+ import copy
+ p_groups = copy.deepcopy(optimizer.param_groups)
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ for pg, pg_old in zip(optimizer.param_groups, p_groups):
+ pg['lr'] = pg_old['lr']
+ pg['initial_lr'] = pg_old['initial_lr']
+ # print(optimizer.param_groups)
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ # todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance).
+ override_resumed_lr_drop = True
+ if override_resumed_lr_drop:
+ print('Warning: (hack) override_resumed_lr_drop is set to True, so cfg.TRAIN.LR_DROP would override lr_drop in resumed lr_scheduler.')
+ lr_scheduler.step_size = cfg.TRAIN.LR_DROP
+ lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
+ lr_scheduler.step(lr_scheduler.last_epoch)
+ cfg.START_EPOCH = checkpoint['epoch'] + 1
+ if cfg.EVAL:
+ epoch = 0
+ per_task_results = []
+ for i in range(num_sub_tasks):
+ test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
+ data_loader_val_list[i], base_ds_list[i], device, cfg.OUTPUT_DIR)
+ title = test_stats['title'] + '\n'
+
+ per_task_results += test_stats['report_results']
+ # import ipdb; ipdb.set_trace()
+ results = 'Epoch {}: '.format(epoch) + test_stats['ap_map_wi_aose_ar'] + '\n'
+ results_dir = output_dir /'eval_results.txt'
+ if utils.is_main_process():
+ if epoch == 0:
+ with open(results_dir, 'a') as f:
+ f.write(title)
+
+ with open(results_dir, 'a') as f:
+ f.write(results)
+ if utils.is_main_process():
+ report_results = 'Epoch {}: '.format(epoch) + ' & '.join(per_task_results) + '\n'
+ with open(results_dir, 'a') as f:
+ f.write(report_results)
+ f.write(''.join(100*['=']))
+
+ utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
+ return
+
+ print("Start training")
+ start_time = time.time()
+
+ best_mAP = 0
+ checkpoint_dir = 'initial_dir'
+
+ for epoch in range(cfg.START_EPOCH, cfg.TRAIN.EPOCHS):
+ if cfg.DIST.DISTRIBUTED:
+ sampler_train.set_epoch(epoch)
+ train_stats = train_one_epoch(
+ model, criterion, data_loader_train, optimizer, device, epoch, cfg.TRAIN.CLIP_MAX_NORM, logger=logger)
+ lr_scheduler.step()
+
+ if epoch>cfg.EVAL_EPOCH:
+ per_task_results = []
+ results_dir = output_dir /'eval_results.txt'
+ for i in range(num_sub_tasks):
+ test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
+ data_loader_val_list[i], base_ds_list[i], device, cfg.OUTPUT_DIR)
+ if utils.is_main_process():
+ title = test_stats['title'] + '\n'
+ per_task_results +=test_stats['report_results']
+ results = '[Epoch {}] [Task {}]'.format(epoch,i) + test_stats['ap_map_wi_aose_ar'] + '\n'
+
+ if epoch == 0 and i ==0:
+ with open(results_dir, 'a') as f:
+ f.write(title)
+
+ with open(results_dir, 'a') as f:
+ f.write(results)
+ if utils.is_main_process():
+ report_results = 'Epoch {}: '.format(epoch) + ' & '.join(per_task_results) + '\n'
+ with open(results_dir, 'a') as f:
+ f.write(report_results)
+ f.write(''.join(100*['=']) + '\n')
+
+ mAP_tmp = test_stats['base_mAP: ']
+ if mAP_tmp > best_mAP:
+
+ if os.path.exists(checkpoint_dir):
+ os.remove(checkpoint_dir)
+
+ checkpoint_dir = output_dir / f'best_{epoch:02}_{round(mAP_tmp,3)}.pth'
+ utils.save_on_master({
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch,
+ 'cfg': cfg,
+ }, checkpoint_dir)
+ best_mAP = mAP_tmp
+ # saveing more checkpoints
+ if epoch > cfg.TRAIN.LR_DROP-6 and epoch % 2 == 0:
+ model_dir = output_dir / f'model_{epoch:02}.pth'
+ utils.save_on_master({
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch,
+ 'cfg': cfg,
+ }, model_dir)
+
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser('Deformable DETR Detector')
+ parser.add_argument('--config_file', default='', type=str)
+ parser.add_argument("--opts", default=None, nargs=argparse.REMAINDER)
+ args = parser.parse_args()
+ cfg = setup(args)
+ main(cfg)
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..17bade6
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,25 @@
+# ------------------------------------------------------------------------
+# Modified by Wuyang LI
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+def build_model(cfg):
+ if cfg.AOOD.OW_DETR_ON:
+ print('using ow detr!')
+ from .ow_detr import build
+ elif cfg.AOOD.MOTIF_ON:
+ print('using motif detr!')
+ from .motif_detr import build
+ else:
+ print('using def detr!')
+ from .deformable_detr import build
+ return build(cfg)
+
+
diff --git a/models/backbone.py b/models/backbone.py
new file mode 100644
index 0000000..b9d488e
--- /dev/null
+++ b/models/backbone.py
@@ -0,0 +1,139 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Backbone modules.
+"""
+from collections import OrderedDict
+from torchvision.models.resnet import resnet50
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn
+from torchvision.models._utils import IntermediateLayerGetter
+from typing import Dict, List
+from util.misc import NestedTensor, is_main_process
+from .position_encoding import build_position_encoding
+
+
+class FrozenBatchNorm2d(torch.nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
+ without which any other models than torchvision.models.resnet[18,34,50,101]
+ produce nans.
+ """
+
+ def __init__(self, n, eps=1e-5):
+ super(FrozenBatchNorm2d, self).__init__()
+ self.register_buffer("weight", torch.ones(n))
+ self.register_buffer("bias", torch.zeros(n))
+ self.register_buffer("running_mean", torch.zeros(n))
+ self.register_buffer("running_var", torch.ones(n))
+ self.eps = eps
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
+ if num_batches_tracked_key in state_dict:
+ del state_dict[num_batches_tracked_key]
+
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs)
+
+ def forward(self, x):
+ # move reshapes to the beginning
+ # to make it fuser-friendly
+ w = self.weight.reshape(1, -1, 1, 1)
+ b = self.bias.reshape(1, -1, 1, 1)
+ rv = self.running_var.reshape(1, -1, 1, 1)
+ rm = self.running_mean.reshape(1, -1, 1, 1)
+ eps = self.eps
+ scale = w * (rv + eps).rsqrt()
+ bias = b - rm * scale
+ return x * scale + bias
+
+
+class BackboneBase(nn.Module):
+
+ def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool):
+ super().__init__()
+ for name, parameter in backbone.named_parameters():
+ if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+ parameter.requires_grad_(False)
+ if return_interm_layers:
+ # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
+ return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
+ self.strides = [8, 16, 32]
+ self.num_channels = [512, 1024, 2048]
+ else:
+ return_layers = {'layer4': "0"}
+ self.strides = [32]
+ self.num_channels = [2048]
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+
+ def forward(self, tensor_list: NestedTensor):
+ xs = self.body(tensor_list.tensors)
+ out: Dict[str, NestedTensor] = {}
+ for name, x in xs.items():
+ m = tensor_list.mask
+ assert m is not None
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
+ out[name] = NestedTensor(x, mask)
+ return out
+
+class Backbone(BackboneBase):
+ """ResNet backbone with frozen BatchNorm."""
+ def __init__(self, name: str,
+ train_backbone: bool,
+ return_interm_layers: bool,
+ dilation: bool):
+ norm_layer = FrozenBatchNorm2d
+ # backbone = getattr(torchvision.models, name)(
+ # replace_stride_with_dilation=[False, False, dilation],
+ # pretrained=is_main_process(), norm_layer=norm_layer)
+ backbone = resnet50(pretrained=False, replace_stride_with_dilation=[False, False, dilation],
+ norm_layer=norm_layer)
+ state_dict = torch.load('./dino_resnet50_pretrain.pth')
+ backbone.load_state_dict(state_dict, strict=False)
+ assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded"
+ super().__init__(backbone, train_backbone, return_interm_layers)
+ if dilation:
+ self.strides[-1] = self.strides[-1] // 2
+
+class Joiner(nn.Sequential):
+ def __init__(self, backbone, position_embedding):
+ super().__init__(backbone, position_embedding)
+ self.strides = backbone.strides
+ self.num_channels = backbone.num_channels
+
+ def forward(self, tensor_list: NestedTensor):
+ xs = self[0](tensor_list)
+ out: List[NestedTensor] = []
+ pos = []
+ for name, x in sorted(xs.items()):
+ out.append(x)
+
+ # position encoding
+ for x in out:
+ pos.append(self[1](x).to(x.tensors.dtype))
+
+ return out, pos
+
+def build_backbone(cfg):
+ position_embedding = build_position_encoding(cfg)
+ train_backbone = cfg.TRAIN.LR_BACKBONE > 0
+ return_interm_layers = cfg.MODEL.MASKS or (cfg.MODEL.NUM_FEATURE_LEVELS > 1)
+ backbone = Backbone(cfg.MODEL.BACKBONE, train_backbone, return_interm_layers, cfg.MODEL.DILATION)
+ model = Joiner(backbone, position_embedding)
+ return model
diff --git a/models/deformable_detr.py b/models/deformable_detr.py
new file mode 100644
index 0000000..14d5e97
--- /dev/null
+++ b/models/deformable_detr.py
@@ -0,0 +1,580 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Deformable DETR model and criterion classes.
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn
+import math
+
+from util import box_ops
+from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
+ accuracy, get_world_size, interpolate,
+ is_dist_avail_and_initialized, inverse_sigmoid)
+
+from .backbone import build_backbone
+from .matcher import build_matcher
+from .segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm,
+ dice_loss, sigmoid_focal_loss)
+from .deformable_transformer import build_deforamble_transformer
+from .utils import GradientReversal
+import copy
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+class DeformableDETR(nn.Module):
+ """ This is the Deformable DETR module that performs object detection """
+ def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels,
+ aux_loss=True, with_box_refine=False, two_stage=False,
+ backbone_align=False, space_align=False, channel_align=False, instance_align=False):
+ """ Initializes the model.
+ Parameters:
+ backbone: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ num_classes: number of object classes
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ with_box_refine: iterative bounding box refinement
+ two_stage: two-stage Deformable DETR
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.transformer = transformer
+ hidden_dim = transformer.d_model
+ self.class_embed = nn.Linear(hidden_dim, num_classes)
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
+ self.num_feature_levels = num_feature_levels
+ if not two_stage:
+ self.query_embed = nn.Embedding(num_queries, hidden_dim*2)
+ if num_feature_levels > 1:
+ num_backbone_outs = len(backbone.strides)
+ input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = backbone.num_channels[_]
+ input_proj_list.append(nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ ))
+ for _ in range(num_feature_levels - num_backbone_outs):
+ input_proj_list.append(nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
+ nn.GroupNorm(32, hidden_dim),
+ ))
+ in_channels = hidden_dim
+ self.input_proj = nn.ModuleList(input_proj_list)
+ else:
+ self.input_proj = nn.ModuleList([
+ nn.Sequential(
+ nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ )])
+ self.backbone = backbone
+ self.aux_loss = aux_loss
+ self.with_box_refine = with_box_refine
+ self.two_stage = two_stage
+ # self.uda = backbone_align or space_align or channel_align or instance_align
+ self.uda = True
+ self.backbone_align = backbone_align
+ self.space_align = space_align
+ self.channel_align = channel_align
+ self.instance_align = instance_align
+
+ prior_prob = 0.01
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
+ self.class_embed.bias.data = torch.ones(num_classes) * bias_value
+ nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
+ for proj in self.input_proj:
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
+ nn.init.constant_(proj[0].bias, 0)
+
+ # if two-stage, the last class_embed and bbox_embed is for region proposal generation
+ num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers
+ if with_box_refine:
+ self.class_embed = _get_clones(self.class_embed, num_pred)
+ self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
+ nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
+ # hack implementation for iterative bounding box refinement
+ self.transformer.decoder.bbox_embed = self.bbox_embed
+ else:
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
+ self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
+ self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
+ self.transformer.decoder.bbox_embed = None
+ if two_stage:
+ # hack implementation for two-stage
+ self.transformer.decoder.class_embed = self.class_embed
+ for box_embed in self.bbox_embed:
+ nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
+ if backbone_align:
+ self.grl = GradientReversal()
+ self.backbone_D = MLP(hidden_dim, hidden_dim, 1, 3)
+ for layer in self.backbone_D.layers:
+ nn.init.xavier_uniform_(layer.weight, gain=1)
+ nn.init.constant_(layer.bias, 0)
+ if space_align:
+ self.space_D = MLP(hidden_dim, hidden_dim, 1, 3)
+ for layer in self.space_D.layers:
+ nn.init.xavier_uniform_(layer.weight, gain=1)
+ nn.init.constant_(layer.bias, 0)
+ if channel_align:
+ self.channel_D = MLP(hidden_dim, hidden_dim, 1, 3)
+ for layer in self.channel_D.layers:
+ nn.init.xavier_uniform_(layer.weight, gain=1)
+ nn.init.constant_(layer.bias, 0)
+ if instance_align:
+ self.instance_D = MLP(hidden_dim, hidden_dim, 1, 3)
+ for layer in self.instance_D.layers:
+ nn.init.xavier_uniform_(layer.weight, gain=1)
+ nn.init.constant_(layer.bias, 0)
+
+ def forward(self, samples: NestedTensor):
+ """ The forward expects a NestedTensor, which consists of:
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
+
+ It returns a dict with the following elements:
+ - "pred_logits": the classification logits (including no-object) for all queries.
+ Shape= [batch_size x num_queries x (num_classes + 1)]
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
+ (center_x, center_y, height, width). These values are normalized in [0, 1],
+ relative to the size of each individual image (disregarding possible padding).
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
+ dictionnaries containing the two above keys for each decoder layer.
+ """
+ if not isinstance(samples, NestedTensor):
+ samples = nested_tensor_from_tensor_list(samples)
+ features, pos = self.backbone(samples)
+
+ srcs = []
+ masks = []
+ for l, feat in enumerate(features):
+ src, mask = feat.decompose()
+ srcs.append(self.input_proj[l](src))
+ masks.append(mask)
+ assert mask is not None
+ if self.num_feature_levels > len(srcs):
+ _len_srcs = len(srcs)
+ for l in range(_len_srcs, self.num_feature_levels):
+ if l == _len_srcs:
+ src = self.input_proj[l](features[-1].tensors)
+ else:
+ src = self.input_proj[l](srcs[-1])
+ m = samples.mask
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
+ pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
+ srcs.append(src)
+ masks.append(mask)
+ pos.append(pos_l)
+
+ query_embeds = None
+ if not self.two_stage:
+ query_embeds = self.query_embed.weight
+ hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, da_output, _ = self.transformer(srcs, masks, pos, query_embeds)
+
+ outputs_classes = []
+ outputs_coords = []
+ for lvl in range(hs.shape[0]):
+ if lvl == 0:
+ reference = init_reference
+ else:
+ reference = inter_references[lvl - 1]
+ reference = inverse_sigmoid(reference)
+ outputs_class = self.class_embed[lvl](hs[lvl])
+ tmp = self.bbox_embed[lvl](hs[lvl])
+ if reference.shape[-1] == 4:
+ tmp += reference
+ else:
+ assert reference.shape[-1] == 2
+ tmp[..., :2] += reference
+ outputs_coord = tmp.sigmoid()
+ outputs_classes.append(outputs_class)
+ outputs_coords.append(outputs_coord)
+ outputs_class = torch.stack(outputs_classes)
+ outputs_coord = torch.stack(outputs_coords)
+
+ if self.training and self.uda:
+ B = outputs_class.shape[1]
+ outputs_class = outputs_class[:, :B//2]
+ outputs_coord = outputs_coord[:, :B//2]
+ if self.two_stage:
+ enc_outputs_class = enc_outputs_class[:B//2]
+ enc_outputs_coord_unact = enc_outputs_coord_unact[:B//2]
+ if self.backbone_align:
+ da_output['backbone'] = torch.cat([self.backbone_D(self.grl(src.flatten(2).transpose(1, 2))) for src in srcs], dim=1)
+ if self.space_align:
+ da_output['space_query'] = self.space_D(da_output['space_query'])
+ if self.channel_align:
+ da_output['channel_query'] = self.channel_D(da_output['channel_query'])
+ if self.instance_align:
+ da_output['instance_query'] = self.instance_D(da_output['instance_query'])
+
+ out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
+ if self.aux_loss:
+ out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
+
+ if self.two_stage:
+ enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
+ out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord}
+
+ if self.training and self.uda:
+ out['da_output'] = da_output
+
+ return out
+
+ @torch.jit.unused
+ def _set_aux_loss(self, outputs_class, outputs_coord):
+ # this is a workaround to make torchscript happy, as torchscript
+ # doesn't support dictionary with non-homogeneous values, such
+ # as a dict having both a Tensor and a list.
+ return [{'pred_logits': a, 'pred_boxes': b}
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
+
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+ def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25, da_gamma=2,cfg=None):
+ """ Create the criterion.
+ Parameters:
+ num_classes: number of object categories, omitting the special no-object category
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ focal_alpha: alpha in Focal Loss
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.focal_alpha = focal_alpha
+ self.da_gamma = da_gamma
+ self.placeholder = cfg.OPEN_SET.PLACEHOLDER
+
+ def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
+ """Classification loss (NLL)
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+ """
+ assert 'pred_logits' in outputs
+ src_logits = outputs['pred_logits']
+
+ idx = self._get_src_permutation_idx(indices)
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
+ target_classes = torch.full(src_logits.shape[:2], self.num_classes,
+ dtype=torch.int64, device=src_logits.device)
+ target_classes[idx] = target_classes_o
+
+ target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
+ dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
+ target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
+
+ target_classes_onehot = target_classes_onehot[:,:,:-1]
+
+
+ if self.placeholder > 0:
+ obj_idx = target_classes_onehot.sum(-1) > 0
+ tmp = target_classes_onehot[obj_idx]
+ tmp[:,-1] = self.placeholder
+ target_classes_onehot[obj_idx] = tmp
+
+ loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]
+ losses = {'loss_ce': loss_ce}
+
+ if log:
+ # TODO this should probably be a separate loss, not hacked in this one here
+ losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
+ return losses
+
+ @torch.no_grad()
+ def loss_cardinality(self, outputs, targets, indices, num_boxes):
+ """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
+ This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
+ """
+ pred_logits = outputs['pred_logits']
+ device = pred_logits.device
+ tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
+ # Count the number of predictions that are NOT "no-object" (which is the last class)
+ card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
+ card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
+ losses = {'cardinality_error': card_err}
+ return losses
+
+ def loss_boxes(self, outputs, targets, indices, num_boxes):
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+ The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
+ """
+ assert 'pred_boxes' in outputs
+ idx = self._get_src_permutation_idx(indices)
+ src_boxes = outputs['pred_boxes'][idx]
+ target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
+
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
+
+ losses = {}
+ losses['loss_bbox'] = loss_bbox.sum() / num_boxes
+
+ loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
+ box_ops.box_cxcywh_to_xyxy(src_boxes),
+ box_ops.box_cxcywh_to_xyxy(target_boxes)))
+ losses['loss_giou'] = loss_giou.sum() / num_boxes
+ return losses
+
+ def loss_masks(self, outputs, targets, indices, num_boxes):
+ """Compute the losses related to the masks: the focal loss and the dice loss.
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
+ """
+ assert "pred_masks" in outputs
+
+ src_idx = self._get_src_permutation_idx(indices)
+ tgt_idx = self._get_tgt_permutation_idx(indices)
+
+ src_masks = outputs["pred_masks"]
+
+ # TODO use valid to mask invalid areas due to padding in loss
+ target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose()
+ target_masks = target_masks.to(src_masks)
+
+ src_masks = src_masks[src_idx]
+ # upsample predictions to the target size
+ src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
+ mode="bilinear", align_corners=False)
+ src_masks = src_masks[:, 0].flatten(1)
+
+ target_masks = target_masks[tgt_idx].flatten(1)
+
+ losses = {
+ "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
+ "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
+ }
+ return losses
+
+ def loss_da(self, outputs, use_focal=False):
+ B = outputs.shape[0]
+ assert B % 2 == 0
+
+ targets = torch.empty_like(outputs)
+ targets[:B//2] = 0
+ targets[B//2:] = 1
+
+ loss = F.binary_cross_entropy_with_logits(outputs, targets, reduction='none')
+
+ if use_focal:
+ prob = outputs.sigmoid()
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = loss * ((1 - p_t) ** self.da_gamma)
+
+ return loss.mean()
+
+ def _get_src_permutation_idx(self, indices):
+ # permute predictions following indices
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+ src_idx = torch.cat([src for (src, _) in indices])
+ return batch_idx, src_idx
+
+ def _get_tgt_permutation_idx(self, indices):
+ # permute targets following indices
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+ return batch_idx, tgt_idx
+
+ def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
+ loss_map = {
+ 'labels': self.loss_labels,
+ 'cardinality': self.loss_cardinality,
+ 'boxes': self.loss_boxes,
+ 'masks': self.loss_masks
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
+
+ def forward(self, samples, outputs, targets, epoch=0):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'}
+
+ # Retrieve the matching between the outputs of the last layer and the targets
+ indices = self.matcher(outputs_without_aux, targets)
+
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
+ num_boxes = sum(len(t["labels"]) for t in targets)
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
+ if is_dist_avail_and_initialized():
+ torch.distributed.all_reduce(num_boxes)
+ num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
+
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ kwargs = {}
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs))
+
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if 'aux_outputs' in outputs:
+ for i, aux_outputs in enumerate(outputs['aux_outputs']):
+ indices = self.matcher(aux_outputs, targets)
+ for loss in self.losses:
+ if loss == 'masks':
+ # Intermediate masks losses are too costly to compute, we ignore them.
+ continue
+ kwargs = {}
+ if loss == 'labels':
+ # Logging is enabled only for the last layer
+ kwargs['log'] = False
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
+ l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ if 'enc_outputs' in outputs:
+ enc_outputs = outputs['enc_outputs']
+ bin_targets = copy.deepcopy(targets)
+ for bt in bin_targets:
+ bt['labels'] = torch.zeros_like(bt['labels'])
+ indices = self.matcher(enc_outputs, bin_targets)
+ for loss in self.losses:
+ if loss == 'masks':
+ # Intermediate masks losses are too costly to compute, we ignore them.
+ continue
+ kwargs = {}
+ if loss == 'labels':
+ # Logging is enabled only for the last layer
+ kwargs['log'] = False
+ l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
+ l_dict = {k + f'_enc': v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ if 'da_output' in outputs:
+ for k, v in outputs['da_output'].items():
+ losses[f'loss_{k}'] = self.loss_da(v, use_focal='query' in k)
+
+ return losses
+
+
+class PostProcess(nn.Module):
+ """ This module converts the model's output into the format expected by the coco api"""
+
+ @torch.no_grad()
+ def forward(self, outputs, target_sizes):
+ """ Perform the computation
+ Parameters:
+ outputs: raw outputs of the model
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
+ For evaluation, this must be the original image size (before any data augmentation)
+ For visualization, this should be the image size after data augment, but before padding
+ """
+ out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
+
+ assert len(out_logits) == len(target_sizes)
+ assert target_sizes.shape[1] == 2
+
+ prob = out_logits.sigmoid()
+ topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
+ scores = topk_values
+ topk_boxes = topk_indexes // out_logits.shape[2]
+ labels = topk_indexes % out_logits.shape[2]
+ boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
+
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
+
+ return results
+
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+def build(cfg):
+ device = torch.device(cfg.DEVICE)
+
+ backbone = build_backbone(cfg)
+
+ transformer = build_deforamble_transformer(cfg)
+ model = DeformableDETR(
+ backbone,
+ transformer,
+ num_classes=cfg.DATASET.NUM_CLASSES,
+ num_queries=cfg.MODEL.NUM_QUERIES,
+ num_feature_levels=cfg.MODEL.NUM_FEATURE_LEVELS,
+ aux_loss=cfg.LOSS.AUX_LOSS,
+ with_box_refine=cfg.MODEL.WITH_BOX_REFINE,
+ two_stage=cfg.MODEL.TWO_STAGE,
+ backbone_align=cfg.MODEL.BACKBONE_ALIGN,
+ space_align=cfg.MODEL.SPACE_ALIGN,
+ channel_align=cfg.MODEL.CHANNEL_ALIGN,
+ instance_align=cfg.MODEL.INSTANCE_ALIGN
+ )
+ if cfg.MODEL.MASKS:
+ model = DETRsegm(model, freeze_detr=(cfg.MODEL.FROZEN_WEIGHTS is not None))
+ matcher = build_matcher(cfg)
+ weight_dict = {'loss_ce': cfg.LOSS.CLS_LOSS_COEF, 'loss_bbox': cfg.LOSS.BBOX_LOSS_COEF}
+ weight_dict['loss_giou'] = cfg.LOSS.GIOU_LOSS_COEF
+ if cfg.MODEL.MASKS:
+ weight_dict["loss_mask"] = cfg.LOSS.MASK_LOSS_COEF
+ weight_dict["loss_dice"] = cfg.LOSS.DICE_LOSS_COEF
+ # TODO this is a hack
+ if cfg.LOSS.AUX_LOSS:
+ aux_weight_dict = {}
+ for i in range(cfg.MODEL.DEC_LAYERS - 1):
+ aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
+ aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()})
+ weight_dict.update(aux_weight_dict)
+
+ weight_dict['loss_backbone'] = cfg.LOSS.BACKBONE_LOSS_COEF
+ weight_dict['loss_space_query'] = cfg.LOSS.SPACE_QUERY_LOSS_COEF
+ weight_dict['loss_channel_query'] = cfg.LOSS.CHANNEL_QUERY_LOSS_COEF
+ weight_dict['loss_instance_query'] = cfg.LOSS.INSTANCE_QUERY_LOSS_COEF
+
+ # losses = ['labels', 'boxes', 'cardinality']
+ losses = ['labels', 'boxes']
+ if cfg.MODEL.MASKS:
+ losses += ["masks"]
+ # num_classes, matcher, weight_dict, losses, focal_alpha=0.25
+ criterion = SetCriterion(cfg.DATASET.NUM_CLASSES, matcher, weight_dict, losses, focal_alpha=cfg.LOSS.FOCAL_ALPHA, da_gamma=cfg.LOSS.DA_GAMMA,cfg=cfg)
+ criterion.to(device)
+ postprocessors = {'bbox': PostProcess()}
+ if cfg.MODEL.MASKS:
+ postprocessors['segm'] = PostProcessSegm()
+ if cfg.DATASET.DATASET_FILE == "coco_panoptic":
+ is_thing_map = {i: i <= 90 for i in range(201)}
+ postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85)
+
+ return model, criterion, postprocessors
\ No newline at end of file
diff --git a/models/deformable_transformer.py b/models/deformable_transformer.py
new file mode 100644
index 0000000..0f57d4f
--- /dev/null
+++ b/models/deformable_transformer.py
@@ -0,0 +1,476 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+import copy
+from typing import Optional, List
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
+
+from util.misc import inverse_sigmoid
+from models.ops.modules import MSDeformAttn
+from models.utils import DomainAttention, GradientReversal, remove_mask_and_warp
+
+class DeformableTransformer(nn.Module):
+ def __init__(self, d_model=256, nhead=8,
+ num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
+ activation="relu", return_intermediate_dec=False,
+ num_feature_levels=4, dec_n_points=4, enc_n_points=4,
+ two_stage=False, two_stage_num_proposals=300,
+ space_align=False, channel_align=False, instance_align=False):
+ super().__init__()
+
+ self.d_model = d_model
+ self.nhead = nhead
+ self.two_stage = two_stage
+ self.two_stage_num_proposals = two_stage_num_proposals
+
+ self.space_align = space_align
+ self.channel_align = channel_align
+ self.instance_align = instance_align
+
+ encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
+ dropout, activation,
+ num_feature_levels, nhead, enc_n_points, space_align, channel_align)
+ self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
+
+ decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
+ dropout, activation,
+ num_feature_levels, nhead, dec_n_points, instance_align)
+ self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec)
+
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
+
+ if two_stage:
+ self.enc_output = nn.Linear(d_model, d_model)
+ self.enc_output_norm = nn.LayerNorm(d_model)
+ self.pos_trans = nn.Linear(d_model * 2, d_model * 2)
+ self.pos_trans_norm = nn.LayerNorm(d_model * 2)
+ else:
+ self.reference_points = nn.Linear(d_model, 2)
+
+ if space_align:
+ self.space_query = nn.Parameter(torch.empty(1, 1, d_model))
+ if channel_align:
+ # self.channel_query is actually an embedding layer for channel query
+ # We keep the name for consistency
+ self.channel_query = nn.Linear(d_model, 1)
+ self.grl = GradientReversal()
+ if instance_align:
+ self.instance_query = nn.Parameter(torch.empty(1, 1, d_model))
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MSDeformAttn):
+ m._reset_parameters()
+ if not self.two_stage:
+ xavier_uniform_(self.reference_points.weight.data, gain=1.0)
+ constant_(self.reference_points.bias.data, 0.)
+ normal_(self.level_embed)
+
+ def get_proposal_pos_embed(self, proposals):
+ num_pos_feats = 128
+ temperature = 10000
+ scale = 2 * math.pi
+
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
+ dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
+ # N, L, 4
+ proposals = proposals.sigmoid() * scale
+ # N, L, 4, 128
+ pos = proposals[:, :, :, None] / dim_t
+ # N, L, 4, 64, 2
+ pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
+ return pos
+
+ def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
+ N_, S_, C_ = memory.shape
+ base_scale = 4.0
+ proposals = []
+ _cur = 0
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+ mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
+
+ grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+
+ scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
+ wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
+ proposals.append(proposal)
+ _cur += (H_ * W_)
+ output_proposals = torch.cat(proposals, 1)
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
+ output_proposals = torch.log(output_proposals / (1 - output_proposals))
+ output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
+
+ output_memory = memory
+ output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
+ return output_memory, output_proposals
+
+ def get_valid_ratio(self, mask):
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def forward(self, srcs, masks, pos_embeds, query_embed=None):
+ assert self.two_stage or query_embed is not None
+
+ # prepare input for encoder
+ src_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes = []
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
+ bs, c, h, w = src.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+ src = src.flatten(2).transpose(1, 2)
+ mask = mask.flatten(1)
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ src_flatten.append(src)
+ mask_flatten.append(mask)
+ src_flatten = torch.cat(src_flatten, 1)
+ mask_flatten = torch.cat(mask_flatten, 1)
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
+
+ space_query, channel_query, instance_query = None, None, None
+ if self.training:
+ if self.space_align:
+ space_query = self.space_query.expand(src_flatten.shape[0], -1, -1)
+ if self.channel_align:
+ src_warped, pos_warped = remove_mask_and_warp(
+ src_flatten, lvl_pos_embed_flatten, mask_flatten, level_start_index, spatial_shapes
+ )
+ channel_query = self.channel_query(self.grl(src_warped+pos_warped)).flatten(0, 1).transpose(1, 2)
+
+ # encoder
+ memory, space_query, channel_query = self.encoder(
+ src_flatten, space_query, channel_query, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten
+ )
+
+ da_output = {}
+ if self.training:
+ if self.space_align:
+ da_output['space_query'] = torch.cat(space_query, dim=1)
+ if self.channel_align:
+ da_output['channel_query'] = torch.cat(channel_query, dim=1)
+
+ # prepare input for decoder
+ bs, _, c = memory.shape
+ if self.two_stage:
+ output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
+
+ # hack implementation for two-stage Deformable DETR
+ enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
+ enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals
+
+ topk = self.two_stage_num_proposals
+ topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
+ topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
+ topk_coords_unact = topk_coords_unact.detach()
+ reference_points = topk_coords_unact.sigmoid()
+ init_reference_out = reference_points
+ pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
+ query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
+ else:
+ query_embed, tgt = torch.split(query_embed, c, dim=1)
+ query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
+ tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
+ reference_points = self.reference_points(query_embed).sigmoid()
+ init_reference_out = reference_points
+
+ if self.training and self.instance_align:
+ instance_query = self.instance_query.expand(tgt.shape[0], -1, -1)
+
+ # decoder
+ hs, inter_references, instance_query = self.decoder(
+ tgt, instance_query, reference_points, memory, spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten
+ )
+
+ if self.training and self.instance_align:
+ da_output['instance_query'] = instance_query
+
+ inter_references_out = inter_references
+ if self.two_stage:
+ return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact, da_output, memory
+ return hs, init_reference_out, inter_references_out, None, None, da_output, memory
+
+
+class DeformableTransformerEncoderLayer(nn.Module):
+ def __init__(self,
+ d_model=256, d_ffn=1024,
+ dropout=0.1, activation="relu",
+ n_levels=4, n_heads=8, n_points=4,
+ space_align=False, channel_align=False):
+ super().__init__()
+
+ self.space_align = space_align
+ self.channel_align = channel_align
+ if space_align:
+ self.space_attn = DomainAttention(d_model, n_heads, dropout)
+ if channel_align:
+ self.channel_attn = DomainAttention(d_model, n_heads, dropout)
+
+ # self attention
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation)
+ self.dropout2 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout3 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, src):
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+ src = src + self.dropout3(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward(self, src, space_query, channel_query, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
+ # self attention
+ src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+
+ if self.training:
+ if self.space_align:
+ space_query = self.space_attn(space_query, src, pos, padding_mask)
+ if self.channel_align:
+ src_warped, pos_warped = remove_mask_and_warp(src, pos, padding_mask, level_start_index, spatial_shapes)
+ channel_query = self.channel_attn(
+ channel_query, # bsz * num_feature_levels, 1, H*W
+ src_warped.flatten(0, 1).transpose(1, 2), # bsz * num_feature_levels, C, H*W
+ pos_warped.flatten(0, 1).transpose(1, 2)
+ )
+
+ # ffn
+ src = self.forward_ffn(src)
+
+ return src, space_query, channel_query
+
+
+class DeformableTransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+
+ @staticmethod
+ def get_reference_points(spatial_shapes, valid_ratios, device):
+ reference_points_list = []
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+ def forward(self, src, space_query, channel_query, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
+ output = src
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
+ space_querys = []
+ channel_querys = []
+ for _, layer in enumerate(self.layers):
+ output, space_query, channel_query = layer(
+ output, space_query, channel_query, pos, reference_points, spatial_shapes, level_start_index, padding_mask
+ )
+ space_querys.append(space_query)
+ channel_querys.append(channel_query)
+
+ return output, space_querys, channel_querys
+
+
+class DeformableTransformerDecoderLayer(nn.Module):
+ def __init__(self, d_model=256, d_ffn=1024,
+ dropout=0.1, activation="relu",
+ n_levels=4, n_heads=8, n_points=4,
+ instance_align=False):
+ super().__init__()
+
+ self.instance_align = instance_align
+ if instance_align:
+ self.instance_attn = DomainAttention(d_model, n_heads, dropout)
+
+ # cross attention
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # self attention
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation)
+ self.dropout3 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout4 = nn.Dropout(dropout)
+ self.norm3 = nn.LayerNorm(d_model)
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, tgt):
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout4(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward(self, tgt, instance_query, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None):
+ # self attention
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+
+ # cross attention
+ tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
+ reference_points,
+ src, src_spatial_shapes, level_start_index, src_padding_mask)
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+
+ if self.training and self.instance_align:
+ instance_query = self.instance_attn(instance_query, tgt, query_pos)
+
+ # ffn
+ tgt = self.forward_ffn(tgt)
+
+ return tgt, instance_query
+
+
+class DeformableTransformerDecoder(nn.Module):
+ def __init__(self, decoder_layer, num_layers, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.return_intermediate = return_intermediate
+ # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
+ self.bbox_embed = None
+ self.class_embed = None
+
+ def forward(self, tgt, instance_query, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,
+ query_pos=None, src_padding_mask=None):
+ output = tgt
+
+ intermediate = []
+ intermediate_reference_points = []
+ for lid, layer in enumerate(self.layers):
+ if reference_points.shape[-1] == 4:
+ reference_points_input = reference_points[:, :, None] \
+ * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
+ else:
+ assert reference_points.shape[-1] == 2
+ reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
+ output, instance_query = layer(
+ output, instance_query, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask
+ )
+
+ # hack implementation for iterative bounding box refinement
+ if self.bbox_embed is not None:
+ tmp = self.bbox_embed[lid](output)
+ if reference_points.shape[-1] == 4:
+ new_reference_points = tmp + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ else:
+ assert reference_points.shape[-1] == 2
+ new_reference_points = tmp
+ new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ reference_points = new_reference_points.detach()
+
+ if self.return_intermediate:
+ intermediate.append(output)
+ intermediate_reference_points.append(reference_points)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate), torch.stack(intermediate_reference_points), instance_query
+
+ # https://github.com/fundamentalvision/Deformable-DETR/issues/43
+ return [output], [reference_points], instance_query
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+def build_deforamble_transformer(cfg):
+ return DeformableTransformer(
+ d_model=cfg.MODEL.HIDDEN_DIM,
+ nhead=cfg.MODEL.NHEADS,
+ num_encoder_layers=cfg.MODEL.ENC_LAYERS,
+ num_decoder_layers=cfg.MODEL.DEC_LAYERS,
+ dim_feedforward=cfg.MODEL.DIM_FEEDFORWARD,
+ dropout=cfg.MODEL.DROPOUT,
+ activation="relu",
+ return_intermediate_dec=True,
+ num_feature_levels=cfg.MODEL.NUM_FEATURE_LEVELS,
+ dec_n_points=cfg.MODEL.DEC_N_POINTS,
+ enc_n_points=cfg.MODEL.ENC_N_POINTS,
+ two_stage=cfg.MODEL.TWO_STAGE,
+ two_stage_num_proposals=cfg.MODEL.NUM_QUERIES,
+ space_align=cfg.MODEL.SPACE_ALIGN,
+ channel_align=cfg.MODEL.CHANNEL_ALIGN,
+ instance_align=cfg.MODEL.INSTANCE_ALIGN)
+
+
diff --git a/models/matcher.py b/models/matcher.py
new file mode 100644
index 0000000..9fed154
--- /dev/null
+++ b/models/matcher.py
@@ -0,0 +1,104 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Modules to compute the matching cost and solve the corresponding LSAP.
+"""
+import torch
+from scipy.optimize import linear_sum_assignment
+from torch import nn
+
+from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
+
+
+class HungarianMatcher(nn.Module):
+ """This class computes an assignment between the targets and the predictions of the network
+
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+ while the others are un-matched (and thus treated as non-objects).
+ """
+
+ def __init__(self,
+ cost_class: float = 1,
+ cost_bbox: float = 1,
+ cost_giou: float = 1):
+ """Creates the matcher
+
+ Params:
+ cost_class: This is the relative weight of the classification error in the matching cost
+ cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
+ cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
+ """
+ super().__init__()
+ self.cost_class = cost_class
+ self.cost_bbox = cost_bbox
+ self.cost_giou = cost_giou
+ assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
+
+ def forward(self, outputs, targets):
+ """ Performs the matching
+
+ Params:
+ outputs: This is a dict that contains at least these entries:
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+ "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
+
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+ objects in the target) containing the class labels
+ "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
+
+ Returns:
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
+ - index_i is the indices of the selected predictions (in order)
+ - index_j is the indices of the corresponding selected targets (in order)
+ For each batch element, it holds:
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+ """
+ with torch.no_grad():
+ bs, num_queries = outputs["pred_logits"].shape[:2]
+
+ # We flatten to compute the cost matrices in a batch
+ out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
+ out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
+
+ # Also concat the target labels and boxes
+ tgt_ids = torch.cat([v["labels"] for v in targets])
+ tgt_bbox = torch.cat([v["boxes"] for v in targets])
+
+ # Compute the classification cost.
+ alpha = 0.25
+ gamma = 2.0
+ neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
+ pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
+ cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
+
+ # Compute the L1 cost between boxes
+ cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
+
+ # Compute the giou cost betwen boxes
+ cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
+ box_cxcywh_to_xyxy(tgt_bbox))
+
+ # Final cost matrix
+ C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
+ C = C.view(bs, num_queries, -1).cpu()
+
+ sizes = [len(v["boxes"]) for v in targets]
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
+
+
+def build_matcher(cfg):
+ return HungarianMatcher(cost_class=cfg.LOSS.SET_COST_CLASS,
+ cost_bbox=cfg.LOSS.SET_COST_BBOX,
+ cost_giou=cfg.LOSS.SET_COST_GIOU)
diff --git a/models/motif_detr.py b/models/motif_detr.py
new file mode 100644
index 0000000..9a205fd
--- /dev/null
+++ b/models/motif_detr.py
@@ -0,0 +1,854 @@
+# ------------------------------------------------------------------------
+# Novel Scenes & Classes: Towards Adaptive Open-set Object Detection
+# Modified by Wuyang Li
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Deformable DETR model and criterion classes.
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn
+import math
+
+from util import box_ops
+from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
+ accuracy, get_world_size, interpolate,
+ is_dist_avail_and_initialized, inverse_sigmoid)
+
+from .backbone import build_backbone
+from .matcher import build_matcher
+from .segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm,
+ dice_loss, sigmoid_focal_loss)
+from .deformable_transformer import build_deforamble_transformer
+from .utils import GradientReversal
+import copy
+from copy import deepcopy
+
+def sim_matrix(a, b, eps=1e-8):
+ """
+ added eps for numerical stability
+ """
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
+ a_norm = a / torch.clamp(a_n, min=eps)
+ b_norm = b / torch.clamp(b_n, min=eps)
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
+ return sim_mt
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+class DeformableDETR(nn.Module):
+ """ This is the Deformable DETR module that performs object detection """
+ def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels,
+ aux_loss=True, with_box_refine=False, two_stage=False, from_cfg=None):
+ """ Initializes the model.
+ Parameters:
+ backbone: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ num_classes: number of object classes
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ with_box_refine: iterative bounding box refinement
+ two_stage: two-stage Deformable DETR
+ """
+ super().__init__()
+
+ self.from_cfg = from_cfg
+ self.num_queries = num_queries
+ self.transformer = transformer
+ hidden_dim = transformer.d_model
+ self.class_embed = nn.Linear(hidden_dim, num_classes)
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
+ self.num_feature_levels = num_feature_levels
+ if not two_stage:
+ self.query_embed = nn.Embedding(num_queries, hidden_dim*2)
+ if num_feature_levels > 1:
+ num_backbone_outs = len(backbone.strides)
+ input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = backbone.num_channels[_]
+ input_proj_list.append(nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ ))
+ for _ in range(num_feature_levels - num_backbone_outs):
+ input_proj_list.append(nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
+ nn.GroupNorm(32, hidden_dim),
+ ))
+ in_channels = hidden_dim
+ self.input_proj = nn.ModuleList(input_proj_list)
+ else:
+ self.input_proj = nn.ModuleList([
+ nn.Sequential(
+ nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ )])
+ self.backbone = backbone
+ self.aux_loss = aux_loss
+ self.with_box_refine = with_box_refine
+ self.two_stage = two_stage
+
+ self.da = self.from_cfg['da']
+ self.backbone_align = self.from_cfg['backbone_align']
+ self.space_align = self.from_cfg['space_align']
+ self.channel_align = self.from_cfg['channel_align']
+ self.instance_align = self.from_cfg['instance_align']
+
+ self.register_buffer('cls_means', torch.zeros(num_classes, 256))
+ self.register_buffer('cls_stds', torch.zeros(num_classes, 256))
+
+ prior_prob = 0.01
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
+ self.class_embed.bias.data = torch.ones(num_classes) * bias_value
+ nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
+ for proj in self.input_proj:
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
+ nn.init.constant_(proj[0].bias, 0)
+
+ # if two-stage, the last class_embed and bbox_embed is for region proposal generation
+ num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers
+
+ if with_box_refine:
+ self.class_embed = _get_clones(self.class_embed, num_pred)
+ self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
+ nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
+ # hack implementation for iterative bounding box refinement
+ self.transformer.decoder.bbox_embed = self.bbox_embed
+ else:
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
+ self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
+ self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
+ self.transformer.decoder.bbox_embed = None
+ if two_stage:
+ # hack implementation for two-stage
+ self.transformer.decoder.class_embed = self.class_embed
+ for box_embed in self.bbox_embed:
+ nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
+ if self.backbone_align:
+ self.grl = GradientReversal(lambda_=self.from_cfg['backbone_adv_lambda'])
+ self.backbone_D = MLP(hidden_dim, hidden_dim, 1, 3)
+ for layer in self.backbone_D.layers:
+ nn.init.xavier_uniform_(layer.weight, gain=1)
+ nn.init.constant_(layer.bias, 0)
+ if self.space_align:
+ self.space_D = MLP(hidden_dim, hidden_dim, 1, 3)
+ for layer in self.space_D.layers:
+ nn.init.xavier_uniform_(layer.weight, gain=1)
+ nn.init.constant_(layer.bias, 0)
+ if self.channel_align:
+ self.channel_D = MLP(hidden_dim, hidden_dim, 1, 3)
+ for layer in self.channel_D.layers:
+ nn.init.xavier_uniform_(layer.weight, gain=1)
+ nn.init.constant_(layer.bias, 0)
+ if self.instance_align:
+ self.instance_D = MLP(hidden_dim, hidden_dim, 1, 3)
+ for layer in self.instance_D.layers:
+ nn.init.xavier_uniform_(layer.weight, gain=1)
+ nn.init.constant_(layer.bias, 0)
+
+ def forward(self, samples: NestedTensor, targets=None):
+ """ The forward expects a NestedTensor, which consists of:
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
+
+ It returns a dict with the following elements:
+ - "pred_logits": the classification logits (including no-object) for all queries.
+ Shape= [batch_size x num_queries x (num_classes + 1)]
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
+ (center_x, center_y, height, width). These values are normalized in [0, 1],
+ relative to the size of each individual image (disregarding possible padding).
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
+ dictionnaries containing the two above keys for each decoder layer.
+ """
+ out = {}
+ if not isinstance(samples, NestedTensor):
+ samples = nested_tensor_from_tensor_list(samples)
+ features, pos = self.backbone(samples)
+
+ srcs = []
+ masks = []
+ for l, feat in enumerate(features):
+ src, mask = feat.decompose()
+ srcs.append(self.input_proj[l](src))
+ masks.append(mask)
+ assert mask is not None
+ if self.num_feature_levels > len(srcs):
+ _len_srcs = len(srcs)
+ for l in range(_len_srcs, self.num_feature_levels):
+ if l == _len_srcs:
+ src = self.input_proj[l](features[-1].tensors)
+ else:
+ src = self.input_proj[l](srcs[-1])
+ m = samples.mask
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
+ pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
+ srcs.append(src)
+ masks.append(mask)
+ pos.append(pos_l)
+
+ query_embeds = None
+ if not self.two_stage:
+ query_embeds = self.query_embed.weight
+
+ # send to def-transformer
+ hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, da_output, memory = self.transformer(srcs, masks, pos, query_embeds)
+ # hs: lvl, bs, 100, 256
+
+ outputs_classes = []
+ outputs_coords = []
+ for lvl in range(hs.shape[0]):
+ if lvl == 0:
+ reference = init_reference
+ else:
+ reference = inter_references[lvl - 1]
+ reference = inverse_sigmoid(reference)
+ outputs_class = self.class_embed[lvl](hs[lvl])
+ tmp = self.bbox_embed[lvl](hs[lvl])
+ if reference.shape[-1] == 4:
+ tmp += reference
+ else:
+ assert reference.shape[-1] == 2
+ tmp[..., :2] += reference
+ outputs_coord = tmp.sigmoid()
+ outputs_classes.append(outputs_class)
+ outputs_coords.append(outputs_coord)
+ outputs_class = torch.stack(outputs_classes)
+ outputs_coord = torch.stack(outputs_coords)
+
+ out['pred_logits_both'] = outputs_class[-1]
+ out['is_training'] = self.training
+
+ out['cls_means'] = self.cls_means
+ out['cls_stds'] = self.cls_stds
+ out['final_classifier'] = self.class_embed[-1]
+ out['first_classifier'] = self.class_embed[0]
+
+ if self.training and self.da:
+ B = outputs_class.shape[1]
+ outputs_class = outputs_class[:, :B//2]
+ outputs_coord = outputs_coord[:, :B//2]
+ if self.two_stage:
+ enc_outputs_class = enc_outputs_class[:B//2]
+ enc_outputs_coord_unact = enc_outputs_coord_unact[:B//2]
+ if self.backbone_align:
+ da_output['backbone'] = torch.cat([self.backbone_D(self.grl(src.flatten(2).transpose(1, 2))) for src in srcs], dim=1)
+ if self.space_align:
+ da_output['space_query'] = self.space_D(da_output['space_query'])
+ if self.channel_align:
+ da_output['channel_query'] = self.channel_D(da_output['channel_query'])
+ if self.instance_align:
+ da_output['instance_query'] = self.instance_D(da_output['instance_query'])
+
+
+ out.update({'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], 'object_embedding': hs[-1], 'first_embedding': hs[0]})
+ if self.aux_loss:
+ out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
+
+ if self.two_stage:
+ enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
+ out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord}
+ if self.training and self.da:
+ out['da_output'] = da_output
+ return out
+
+ @torch.jit.unused
+ def _set_aux_loss(self, outputs_class, outputs_coord):
+ # this is a workaround to make torchscript happy, as torchscript
+ # doesn't support dictionary with non-homogeneous values, such
+ # as a dict having both a Tensor and a list.
+ return [{'pred_logits': a, 'pred_boxes': b}
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
+
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+ def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25, da_gamma=2, from_cfg = None):
+ """ Create the criterion.
+ Parameters:
+ num_classes: number of object categories, omitting the special no-object category
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ focal_alpha: alpha in Focal Loss
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.focal_alpha = focal_alpha
+ self.da_gamma = da_gamma
+ self.from_cfg = from_cfg
+ self.unk_prob = from_cfg['unk_prob']
+ self.bce_loss = nn.BCELoss()
+ self.pretrain_th = from_cfg['pretrain_th']
+ self.std_scaling = from_cfg['std_scaling']
+ self.alpha = from_cfg['alpha']
+ self.with_openset = from_cfg['with_openset']
+
+ def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
+ """Classification loss (NLL)
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+ """
+ assert 'pred_logits' in outputs
+ src_logits = outputs['pred_logits']
+
+ idx = self._get_src_permutation_idx(indices)
+
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
+ target_classes = torch.full(src_logits.shape[:2], self.num_classes,
+ dtype=torch.int64, device=src_logits.device)
+ target_classes[idx] = target_classes_o
+
+ target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
+ dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
+ target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
+ target_classes_onehot = target_classes_onehot[:,:,:-1]
+
+ if self.unk_prob > 0 and self.with_openset:
+ obj_idx = target_classes_onehot.sum(-1) > 0
+ tmp = target_classes_onehot[obj_idx]
+ tmp[:,-1] = self.unk_prob
+ target_classes_onehot[obj_idx] = tmp
+
+ loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]
+ losses = {'loss_ce': loss_ce}
+
+ if log:
+ # TODO this should probably be a separate loss, not hacked in this one here
+ losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
+ return losses
+
+ @torch.no_grad()
+ def loss_cardinality(self, outputs, targets, indices, num_boxes):
+ """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
+ This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
+ """
+ pred_logits = outputs['pred_logits']
+ device = pred_logits.device
+ tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
+ # Count the number of predictions that are NOT "no-object" (which is the last class)
+ card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
+ card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
+ losses = {'cardinality_error': card_err}
+ return losses
+
+ def loss_boxes(self, outputs, targets, indices, num_boxes):
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+ The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
+ """
+ assert 'pred_boxes' in outputs
+ idx = self._get_src_permutation_idx(indices)
+ src_boxes = outputs['pred_boxes'][idx]
+ target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
+
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
+
+ losses = {}
+ losses['loss_bbox'] = loss_bbox.sum() / num_boxes
+
+ loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
+ box_ops.box_cxcywh_to_xyxy(src_boxes),
+ box_ops.box_cxcywh_to_xyxy(target_boxes)))
+ losses['loss_giou'] = loss_giou.sum() / num_boxes
+ return losses
+
+ def loss_masks(self, outputs, targets, indices, num_boxes):
+ """Compute the losses related to the masks: the focal loss and the dice loss.
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
+ """
+ assert "pred_masks" in outputs
+
+ src_idx = self._get_src_permutation_idx(indices)
+ tgt_idx = self._get_tgt_permutation_idx(indices)
+
+ src_masks = outputs["pred_masks"]
+
+ # TODO use valid to mask invalid areas due to padding in loss
+ target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose()
+ target_masks = target_masks.to(src_masks)
+
+ src_masks = src_masks[src_idx]
+ # upsample predictions to the target size
+ src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
+ mode="bilinear", align_corners=False)
+ src_masks = src_masks[:, 0].flatten(1)
+
+ target_masks = target_masks[tgt_idx].flatten(1)
+
+ losses = {
+ "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
+ "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
+ }
+ return losses
+
+ def loss_da(self, outputs, use_focal=False):
+
+ B = outputs.shape[0]
+ assert B % 2 == 0
+
+ targets = torch.empty_like(outputs)
+ targets[:B//2] = 0
+ targets[B//2:] = 1
+
+ loss = F.binary_cross_entropy_with_logits(outputs, targets, reduction='none')
+ if use_focal:
+ prob = outputs.sigmoid()
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = loss * ((1 - p_t) ** self.da_gamma)
+
+ return loss.mean()
+
+ def loss_openset(self, outputs, indices, targets):
+
+ ctrs = outputs['cls_means'][:-1]
+ obj_emb = outputs['object_embedding'] # bs, 100, 256
+
+ ctrs_labels = torch.arange(self.num_classes-1).to(ctrs.device)
+ # mtch_idx = self._get_src_permutation_idx(indices) # bs, idx
+ unmtch_idx = self._get_src_unmatched_permutation_idx(indices, num_query=100)
+ unmtch_emb = obj_emb[unmtch_idx]
+
+ pair_dis = self.eu_dis(ctrs, ctrs)
+ top_k_idx = torch.sort(pair_dis, descending=True, dim=-1)[1][:,0] # k far nei
+
+ ctrs_1 = ctrs
+ ctrs_2 = ctrs_1[top_k_idx]
+
+ ctrs_1_labels = ctrs_labels
+ ctrs_2_labels = ctrs_labels[top_k_idx]
+
+ motif_embeds_list = []
+ angle_list = []
+ calss_1 = []
+ calss_2 = []
+
+ for i in range(len(unmtch_emb)):
+
+ vct1 = (ctrs_1 - unmtch_emb[i])
+ vct2 = (ctrs_2 - unmtch_emb[i])
+
+ dis1 = (ctrs_1 - unmtch_emb[i]).norm(dim=-1)
+ dis2 = (ctrs_2 - unmtch_emb[i]).norm(dim=-1)
+ delta_dis = (dis1-dis2).abs()
+ dis_base = (ctrs_1 - ctrs_2).norm(dim=-1)
+
+ angle = torch.nn.functional.cosine_similarity(vct1,vct2) + delta_dis/dis_base
+
+ motif_idx = angle.argmin()
+ angle_list.append(angle.min().unsqueeze(dim=0))
+
+ motif_emb = torch.stack([ctrs_1[motif_idx], ctrs_2[motif_idx], unmtch_emb[i]], dim=0)
+ calss_1.append(ctrs_1_labels[motif_idx].unsqueeze(0))
+ calss_2.append(ctrs_2_labels[motif_idx].unsqueeze(0))
+
+ motif_embeds_list.append(motif_emb.mean(dim=0)[None,:])
+
+ motif_embeds = torch.cat(motif_embeds_list,dim=0)
+ neg_angles = -torch.cat(angle_list)
+
+ assert motif_embeds.size(0) == neg_angles.size(0)
+
+ select_idx = neg_angles.topk(self.from_cfg['os_KNN'])[1]
+
+ calss_1 = torch.cat(calss_1)[select_idx].unsqueeze(-1)
+ calss_2 = torch.cat(calss_2)[select_idx].unsqueeze(-1)
+
+ motif_embeds_topk = motif_embeds[select_idx]
+
+ classifier = outputs['final_classifier']
+ motif_prob = classifier(motif_embeds_topk).sigmoid()
+
+ target = torch.full_like(motif_prob, 0.0).detach()
+ target[:,-1]=1.0
+
+ loss = self.bce_loss(motif_prob, target)
+
+ # update memory bank
+ with torch.no_grad():
+ ctrs = outputs['cls_means']
+ stds = outputs['cls_stds']
+
+ ema = self.alpha
+ avg_emb_base = motif_embeds_topk.mean(0)
+ ctrs[-1] = (1. - ema) * ctrs[-1] + ema * avg_emb_base
+
+ std_emb_base = motif_embeds_topk.std(0)
+ stds[-1] = (1. - ema) * stds[-1] + ema * std_emb_base
+
+ outputs['cls_means'] = ctrs
+ outputs['cls_stds'] = stds
+
+ return loss
+
+ def loss_crossdomain(self, outputs, targets, indices):
+ q_embs = outputs['object_embedding'] # bs, 100, 256 # source: 0: bs//2; target: bs//2:
+
+ B = q_embs.shape[0]
+ assert B % 2 == 0
+
+ q_tg_pred = outputs['pred_logits_both'][B//2:]
+ q_tg_scores = q_tg_pred.view(-1, q_tg_pred.size(-1)).sigmoid()
+
+ ctrs = outputs['cls_means']
+ stds = outputs['cls_stds']
+
+ ctrs_labels = torch.arange(self.num_classes).to(ctrs.device)
+
+ scaling_factor = ctrs.new_ones(ctrs.size(0)) * self.std_scaling
+ scaling_factor = scaling_factor[:,None]
+
+ a = ctrs + scaling_factor * stds
+ b = ctrs - scaling_factor * stds
+
+ # add centers
+ ctrs_1 = torch.cat([a, b], dim=0)
+ ctrs_1_labels = torch.cat([ctrs_labels, ctrs_labels])
+ ctrs_2 = torch.cat([b, a], dim=0)
+
+ q_tg_raw = q_embs[B//2:].view(-1, q_embs.size(-1))
+ # score_mask = q_tg_scores.max(-1)[0] > self.pretrain_th
+ score_mask = q_tg_scores.sum(-1) > self.pretrain_th
+ q_tg = q_tg_raw[score_mask]
+
+ if len(q_tg)< self.from_cfg['da_KNN']:
+ return q_tg_scores.sum()*0
+
+ sr_label = []
+ motif_embeds_list =[]
+ for i in range(len(q_tg)):
+ vct1 = (ctrs_1 - q_tg[i])
+ vct2 = (ctrs_2 - q_tg[i])
+ angle = torch.nn.functional.cosine_similarity(vct1,vct2)
+ motif_idx = angle.argmin(-1)
+
+ ctr_1 = ctrs_1[motif_idx]
+ ctr_2 = ctrs_2[motif_idx]
+
+ motif_emb = torch.stack([ctr_1, q_tg[i], ctr_2], dim=0)
+ sr_label.append(ctrs_1_labels[motif_idx].unsqueeze(dim=0))
+ motif_embeds_list.append(motif_emb.mean(dim=0)[None,:])
+
+ motif_embeds = torch.cat(motif_embeds_list,dim=0)
+
+ sr_label = torch.cat(sr_label)
+ tg_label = self.eu_dis(q_tg, ctrs).argmin(-1)
+
+ prob = outputs['final_classifier'](motif_embeds)
+ target_motif = torch.zeros(prob.size()).to(prob.device)
+
+ prob_tmp = 0.5
+ tg = torch.full_like(sr_label[:,None].float(), prob_tmp)
+ target_motif.scatter_(1,sr_label[:,None], tg)
+ target_motif.scatter_(1,tg_label[:,None], tg)
+ target_motif[target_motif.sum(-1) == 0.5] *=2
+
+ # loss = sigmoid_focal_loss(prob, target_motif, prob.size(0), alpha=self.focal_alpha, gamma=2)
+ loss = self.bce_loss(prob.sigmoid(), target_motif.detach())
+ return loss
+
+ def _get_src_permutation_idx(self, indices):
+ # permute predictions following indices
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+ src_idx = torch.cat([src for (src, _) in indices])
+ return batch_idx, src_idx
+
+ def _get_src_unmatched_permutation_idx(self, indices, num_query=100):
+ # permute predictions following indices
+ bs = len(indices)
+ queries = torch.arange(num_query)
+ batch_idx = []
+ src_idx = []
+ for i, (src, _) in enumerate(indices):
+ combined = torch.cat(
+ (queries, src))
+ uniques, counts = combined.unique(return_counts=True)
+ unmatched_box = uniques[counts == 1]
+ batch_idx.append(torch.full_like(unmatched_box, i))
+ src_idx.append(unmatched_box)
+ batch_idx = torch.cat(batch_idx)
+ src_idx = torch.cat(src_idx)
+ return batch_idx, src_idx
+
+ def _get_tgt_permutation_idx(self, indices):
+ # permute targets following indices
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+ return batch_idx, tgt_idx
+
+ def eu_dis(self, a,b,p=2):
+ return torch.norm(a[:,None]-b,dim=2,p=p)
+
+ def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
+ loss_map = {
+ 'labels': self.loss_labels,
+ 'cardinality': self.loss_cardinality,
+ 'boxes': self.loss_boxes,
+ 'masks': self.loss_masks
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
+
+ @torch.no_grad()
+ def update_class_centers(self, outputs, targets, indices, ema= 0.01):
+
+ ctrs = outputs['cls_means']
+ stds = outputs['cls_stds']
+ q_embs = outputs['object_embedding'] # bs, 100, 256
+
+ matched_idx = self._get_src_permutation_idx(indices) # bs, idx
+ matched_q = q_embs[matched_idx]
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
+
+ for i in target_classes_o.unique():
+ per_cls_q = matched_q[target_classes_o==i]
+ avg_emb = per_cls_q.mean(dim=0)
+ ctrs[i] = (1. - ema) * ctrs[i] + ema * avg_emb.detach()
+
+ if per_cls_q.size(0) > 2:
+ std_emb = per_cls_q.std(dim=0)
+ stds[i] = (1. - ema) * stds[i] + ema * std_emb.detach()
+
+ avg_emb_base = ctrs[:-1].mean(0)
+ ctrs[-1] = (1. - ema) * ctrs[-1] + ema * avg_emb_base
+
+
+ std_emb_base = stds[:-1].mean(0)
+ stds[-1] = (1. - ema) * stds[-1] + ema * std_emb_base
+
+ outputs['cls_means'] = ctrs
+ outputs['cls_stds'] = stds
+ return outputs
+
+ def forward(self, samples, outputs, targets, epoch=0):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'}
+ # Compute all the requested losses
+ losses = {}
+ # Retrieve the matching between the outputs of the last layer and the targets
+ indices = self.matcher(outputs_without_aux, targets)
+ if self.training:
+ outputs = self.update_class_centers(outputs, targets, indices, ema=self.alpha)
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
+ num_boxes = sum(len(t["labels"]) for t in targets)
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
+
+ if is_dist_avail_and_initialized():
+ torch.distributed.all_reduce(num_boxes)
+ num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
+
+
+ for loss in self.losses:
+ kwargs = {}
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs))
+
+ if self.training and self.from_cfg['with_openset'] and epoch > self.from_cfg['warm_up_epoch']:
+ losses['loss_openset'] = self.loss_openset(outputs, indices, targets)
+
+ if self.training and self.from_cfg['with_crossdomain'] and epoch > self.from_cfg['warm_up_epoch']:
+ losses['loss_crossdomain'] = self.loss_crossdomain(outputs, targets, indices)
+
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if 'aux_outputs' in outputs:
+ for i, aux_outputs in enumerate(outputs['aux_outputs']):
+ indices = self.matcher(aux_outputs, targets)
+ for loss in self.losses:
+ if loss == 'masks':
+ # Intermediate masks losses are too costly to compute, we ignore them.
+ continue
+ kwargs = {}
+ if loss == 'labels':
+ # Logging is enabled only for the last layer
+ kwargs['log'] = False
+
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
+ l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ if 'enc_outputs' in outputs:
+ enc_outputs = outputs['enc_outputs']
+ bin_targets = copy.deepcopy(targets)
+ for bt in bin_targets:
+ bt['labels'] = torch.zeros_like(bt['labels'])
+ indices = self.matcher(enc_outputs, bin_targets)
+ for loss in self.losses:
+ if loss == 'masks':
+ # Intermediate masks losses are too costly to compute, we ignore them.
+ continue
+ kwargs = {}
+ if loss == 'labels':
+ # Logging is enabled only for the last layer
+ kwargs['log'] = False
+ l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
+ l_dict = {k + f'_enc': v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ if 'da_output' in outputs:
+ for k, v in outputs['da_output'].items():
+ losses[f'loss_{k}'] = self.loss_da(v, use_focal='query' in k)
+
+ return losses
+
+
+class PostProcess(nn.Module):
+ """ This module converts the model's output into the format expected by the coco api"""
+
+ @torch.no_grad()
+ def forward(self, outputs, target_sizes, show_box=False):
+ """ Perform the computation
+ Parameters:
+ outputs: raw outputs of the model
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
+ For evaluation, this must be the original image size (before any data augmentation)
+ For visualization, this should be the image size after data augment, but before padding
+ """
+ out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
+
+ assert len(out_logits) == len(target_sizes)
+ assert target_sizes.shape[1] == 2
+
+ prob = out_logits.sigmoid()
+ if show_box:
+ # for qualitative visualization to surpress unk preds
+ #TODO may be different from the old implementation, need to check
+ bs, num_q, num_class = prob.size()
+ unk_mask = prob.argmax(-1) != num_class - 1
+ prob[unk_mask] = 0.0
+
+ topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
+ scores = topk_values
+ topk_boxes = topk_indexes // out_logits.shape[2]
+ labels = topk_indexes % out_logits.shape[2]
+ boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
+
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
+ boxes = boxes * scale_fct[:, None, :]
+ results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
+ return results
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+def build(cfg):
+
+ device = torch.device(cfg.DEVICE)
+ backbone = build_backbone(cfg)
+ transformer = build_deforamble_transformer(cfg)
+
+ from_cfg = dict(
+ backbone_align=cfg.MODEL.BACKBONE_ALIGN,
+ space_align=cfg.MODEL.SPACE_ALIGN,
+ channel_align=cfg.MODEL.CHANNEL_ALIGN,
+ instance_align=cfg.MODEL.INSTANCE_ALIGN,
+ da=cfg.DATASET.DA_MODE == 'uda' or cfg.DATASET.DA_MODE == 'aood',
+ batch_size=cfg.TRAIN.BATCH_SIZE,
+ with_openset=cfg.AOOD.OPEN_SET.MOTIF_ON,
+ os_KNN=cfg.AOOD.OPEN_SET.KNN,
+ pretrain_th=cfg.AOOD.OPEN_SET.TH,
+ with_crossdomain=cfg.AOOD.CROSS_DOMAIN.MOTIF_ON,
+ da_KNN=cfg.AOOD.CROSS_DOMAIN.KNN,
+ unk_prob=cfg.AOOD.OPEN_SET.UNK_PROB,
+ backbone_adv_lambda=cfg.AOOD.CROSS_DOMAIN.BACKBONE_LAMBDA,
+ warm_up_epoch=cfg.AOOD.OPEN_SET.WARM_UP,
+ std_scaling=cfg.AOOD.CROSS_DOMAIN.BETA,
+ motif_update=cfg.AOOD.OPEN_SET.MOTIF_UPDATE,
+ alpha=cfg.AOOD.OPEN_SET.ALPHA,
+ )
+ print(from_cfg)
+
+ model = DeformableDETR(
+ backbone,
+ transformer,
+ num_classes=cfg.DATASET.NUM_CLASSES,
+ num_queries=cfg.MODEL.NUM_QUERIES,
+ num_feature_levels=cfg.MODEL.NUM_FEATURE_LEVELS,
+ aux_loss=cfg.LOSS.AUX_LOSS,
+ with_box_refine=cfg.MODEL.WITH_BOX_REFINE,
+ two_stage=cfg.MODEL.TWO_STAGE,
+ from_cfg = from_cfg,
+ )
+ if cfg.MODEL.MASKS:
+ model = DETRsegm(model, freeze_detr=(cfg.MODEL.FROZEN_WEIGHTS is not None))
+ matcher = build_matcher(cfg)
+ weight_dict = {'loss_ce': cfg.LOSS.CLS_LOSS_COEF, 'loss_bbox': cfg.LOSS.BBOX_LOSS_COEF}
+ weight_dict['loss_giou'] = cfg.LOSS.GIOU_LOSS_COEF
+ if cfg.MODEL.MASKS:
+ weight_dict["loss_mask"] = cfg.LOSS.MASK_LOSS_COEF
+ weight_dict["loss_dice"] = cfg.LOSS.DICE_LOSS_COEF
+ # TODO this is a hack
+ if cfg.LOSS.AUX_LOSS:
+ aux_weight_dict = {}
+ for i in range(cfg.MODEL.DEC_LAYERS - 1):
+ aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
+ aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()})
+ weight_dict.update(aux_weight_dict)
+
+ weight_dict['loss_backbone'] = cfg.LOSS.BACKBONE_LOSS_COEF
+ weight_dict['loss_space_query'] = cfg.LOSS.SPACE_QUERY_LOSS_COEF
+ weight_dict['loss_channel_query'] = cfg.LOSS.CHANNEL_QUERY_LOSS_COEF
+ weight_dict['loss_instance_query'] = cfg.LOSS.INSTANCE_QUERY_LOSS_COEF
+
+ weight_dict['loss_crossdomain'] = cfg.AOOD.CROSS_DOMAIN.MOTIF_LOSS_COEF
+ weight_dict['loss_openset'] = cfg.AOOD.OPEN_SET.MOTIF_LOSS_COEF
+
+ losses = ['labels', 'boxes']
+ if cfg.MODEL.MASKS:
+ losses += ["masks"]
+ # num_classes, matcher, weight_dict, losses, focal_alpha=0.25
+ criterion = SetCriterion(
+ cfg.DATASET.NUM_CLASSES,
+ matcher,
+ weight_dict,
+ losses,
+ focal_alpha=cfg.LOSS.FOCAL_ALPHA,
+ da_gamma=cfg.LOSS.DA_GAMMA,
+ from_cfg=from_cfg,
+
+ )
+ criterion.to(device)
+ postprocessors = {'bbox': PostProcess()}
+ if cfg.MODEL.MASKS:
+ postprocessors['segm'] = PostProcessSegm()
+ if cfg.DATASET.DATASET_FILE == "coco_panoptic":
+ is_thing_map = {i: i <= 90 for i in range(201)}
+ postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85)
+
+ return model, criterion, postprocessors
diff --git a/models/ops/functions/__init__.py b/models/ops/functions/__init__.py
new file mode 100644
index 0000000..8a2197b
--- /dev/null
+++ b/models/ops/functions/__init__.py
@@ -0,0 +1,10 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from .ms_deform_attn_func import MSDeformAttnFunction
+
diff --git a/models/ops/functions/ms_deform_attn_func.py b/models/ops/functions/ms_deform_attn_func.py
new file mode 100644
index 0000000..8c5df8c
--- /dev/null
+++ b/models/ops/functions/ms_deform_attn_func.py
@@ -0,0 +1,61 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+import MultiScaleDeformableAttention as MSDA
+
+
+class MSDeformAttnFunction(Function):
+ @staticmethod
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
+ ctx.im2col_step = im2col_step
+ output = MSDA.ms_deform_attn_forward(
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
+ ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
+ grad_value, grad_sampling_loc, grad_attn_weight = \
+ MSDA.ms_deform_attn_backward(
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
+
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
+
+
+def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
+ # for debug and test only,
+ # need to use cuda version instead
+ N_, S_, M_, D_ = value.shape
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
+ # N_*M_, D_, Lq_, P_
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
+ mode='bilinear', padding_mode='zeros', align_corners=False)
+ sampling_value_list.append(sampling_value_l_)
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
+ return output.transpose(1, 2).contiguous()
diff --git a/models/ops/make.sh b/models/ops/make.sh
new file mode 100644
index 0000000..106b685
--- /dev/null
+++ b/models/ops/make.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+python setup.py build install
diff --git a/models/ops/modules/__init__.py b/models/ops/modules/__init__.py
new file mode 100644
index 0000000..f82cb1a
--- /dev/null
+++ b/models/ops/modules/__init__.py
@@ -0,0 +1,9 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from .ms_deform_attn import MSDeformAttn
diff --git a/models/ops/modules/ms_deform_attn.py b/models/ops/modules/ms_deform_attn.py
new file mode 100644
index 0000000..663d64a
--- /dev/null
+++ b/models/ops/modules/ms_deform_attn.py
@@ -0,0 +1,115 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import warnings
+import math
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn.init import xavier_uniform_, constant_
+
+from ..functions import MSDeformAttnFunction
+
+
+def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
+ return (n & (n-1) == 0) and n != 0
+
+
+class MSDeformAttn(nn.Module):
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
+ """
+ Multi-Scale Deformable Attention Module
+ :param d_model hidden dimension
+ :param n_levels number of feature levels
+ :param n_heads number of attention heads
+ :param n_points number of sampling points per attention head per feature level
+ """
+ super().__init__()
+ if d_model % n_heads != 0:
+ raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
+ _d_per_head = d_model // n_heads
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
+ if not _is_power_of_2(_d_per_head):
+ warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
+ "which is more efficient in our CUDA implementation.")
+
+ self.im2col_step = 64
+
+ self.d_model = d_model
+ self.n_levels = n_levels
+ self.n_heads = n_heads
+ self.n_points = n_points
+
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
+ self.value_proj = nn.Linear(d_model, d_model)
+ self.output_proj = nn.Linear(d_model, d_model)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ constant_(self.sampling_offsets.weight.data, 0.)
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
+ for i in range(self.n_points):
+ grid_init[:, :, i, :] *= i + 1
+ with torch.no_grad():
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+ constant_(self.attention_weights.weight.data, 0.)
+ constant_(self.attention_weights.bias.data, 0.)
+ xavier_uniform_(self.value_proj.weight.data)
+ constant_(self.value_proj.bias.data, 0.)
+ xavier_uniform_(self.output_proj.weight.data)
+ constant_(self.output_proj.bias.data, 0.)
+
+ def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
+ """
+ :param query (N, Length_{query}, C)
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
+
+ :return output (N, Length_{query}, C)
+ """
+ N, Len_q, _ = query.shape
+ N, Len_in, _ = input_flatten.shape
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
+
+ value = self.value_proj(input_flatten)
+ if input_padding_mask is not None:
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
+ # N, Len_q, n_heads, n_levels, n_points, 2
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
+ sampling_locations = reference_points[:, :, None, :, None, :] \
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
+ else:
+ raise ValueError(
+ 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
+ output = MSDeformAttnFunction.apply(
+ value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
+ output = self.output_proj(output)
+ return output
diff --git a/models/ops/setup.py b/models/ops/setup.py
new file mode 100644
index 0000000..a0131bc
--- /dev/null
+++ b/models/ops/setup.py
@@ -0,0 +1,71 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+import os
+import glob
+
+import torch
+
+from torch.utils.cpp_extension import CUDA_HOME
+from torch.utils.cpp_extension import CppExtension
+from torch.utils.cpp_extension import CUDAExtension
+
+from setuptools import find_packages
+from setuptools import setup
+
+requirements = ["torch", "torchvision"]
+
+def get_extensions():
+ this_dir = os.path.dirname(os.path.abspath(__file__))
+ extensions_dir = os.path.join(this_dir, "src")
+
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
+
+ sources = main_file + source_cpu
+ extension = CppExtension
+ extra_compile_args = {"cxx": []}
+ define_macros = []
+
+ if torch.cuda.is_available() and CUDA_HOME is not None:
+ extension = CUDAExtension
+ sources += source_cuda
+ define_macros += [("WITH_CUDA", None)]
+ extra_compile_args["nvcc"] = [
+ "-DCUDA_HAS_FP16=1",
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ ]
+ else:
+ raise NotImplementedError('Cuda is not availabel')
+
+ sources = [os.path.join(extensions_dir, s) for s in sources]
+ include_dirs = [extensions_dir]
+ ext_modules = [
+ extension(
+ "MultiScaleDeformableAttention",
+ sources,
+ include_dirs=include_dirs,
+ define_macros=define_macros,
+ extra_compile_args=extra_compile_args,
+ )
+ ]
+ return ext_modules
+
+setup(
+ name="MultiScaleDeformableAttention",
+ version="1.0",
+ author="Weijie Su",
+ url="https://github.com/fundamentalvision/Deformable-DETR",
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
+ packages=find_packages(exclude=("configs", "tests",)),
+ ext_modules=get_extensions(),
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
+)
diff --git a/models/ops/src/cpu/ms_deform_attn_cpu.cpp b/models/ops/src/cpu/ms_deform_attn_cpu.cpp
new file mode 100644
index 0000000..e1bf854
--- /dev/null
+++ b/models/ops/src/cpu/ms_deform_attn_cpu.cpp
@@ -0,0 +1,41 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+
+#include
+#include
+
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
diff --git a/models/ops/src/cpu/ms_deform_attn_cpu.h b/models/ops/src/cpu/ms_deform_attn_cpu.h
new file mode 100644
index 0000000..81b7b58
--- /dev/null
+++ b/models/ops/src/cpu/ms_deform_attn_cpu.h
@@ -0,0 +1,33 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
+
diff --git a/models/ops/src/cuda/ms_deform_attn_cuda.cu b/models/ops/src/cuda/ms_deform_attn_cuda.cu
new file mode 100644
index 0000000..d6d5836
--- /dev/null
+++ b/models/ops/src/cuda/ms_deform_attn_cuda.cu
@@ -0,0 +1,153 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+#include "cuda/ms_deform_im2col_cuda.cuh"
+
+#include
+#include
+#include
+#include
+
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+ const int batch_n = im2col_step_;
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto columns = output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ columns.data());
+
+ }));
+ }
+
+ output = output.view({batch, num_query, num_heads*channels});
+
+ return output;
+}
+
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto grad_value = at::zeros_like(value);
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
+ auto grad_attn_weight = at::zeros_like(attn_weight);
+
+ const int batch_n = im2col_step_;
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto grad_output_g = grad_output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
+ grad_output_g.data(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ grad_value.data() + n * im2col_step_ * per_value_size,
+ grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
+
+ }));
+ }
+
+ return {
+ grad_value, grad_sampling_loc, grad_attn_weight
+ };
+}
\ No newline at end of file
diff --git a/models/ops/src/cuda/ms_deform_attn_cuda.h b/models/ops/src/cuda/ms_deform_attn_cuda.h
new file mode 100644
index 0000000..c7ae53f
--- /dev/null
+++ b/models/ops/src/cuda/ms_deform_attn_cuda.h
@@ -0,0 +1,30 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
diff --git a/models/ops/src/cuda/ms_deform_im2col_cuda.cuh b/models/ops/src/cuda/ms_deform_im2col_cuda.cuh
new file mode 100644
index 0000000..6bc2acb
--- /dev/null
+++ b/models/ops/src/cuda/ms_deform_im2col_cuda.cuh
@@ -0,0 +1,1327 @@
+/*!
+**************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************
+* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
+* Copyright (c) 2018 Microsoft
+**************************************************************************
+*/
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+ const scalar_t* data_value,
+ const int64_t* data_spatial_shapes,
+ const int64_t* data_level_start_index,
+ const scalar_t* data_sampling_loc,
+ const scalar_t* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* data_col)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+template
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+ const scalar_t* grad_col,
+ const scalar_t* data_value,
+ const int64_t * data_spatial_shapes,
+ const int64_t * data_level_start_index,
+ const scalar_t * data_sampling_loc,
+ const scalar_t * data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024)
+ {
+ if ((channels & 1023) == 0)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ else{
+ switch(channels)
+ {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ default:
+ if (channels < 64)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
\ No newline at end of file
diff --git a/models/ops/src/ms_deform_attn.h b/models/ops/src/ms_deform_attn.h
new file mode 100644
index 0000000..ac0ef2e
--- /dev/null
+++ b/models/ops/src/ms_deform_attn.h
@@ -0,0 +1,62 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+
+#include "cpu/ms_deform_attn_cpu.h"
+
+#ifdef WITH_CUDA
+#include "cuda/ms_deform_attn_cuda.h"
+#endif
+
+
+at::Tensor
+ms_deform_attn_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_forward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::vector
+ms_deform_attn_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_backward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
diff --git a/models/ops/src/vision.cpp b/models/ops/src/vision.cpp
new file mode 100644
index 0000000..2201f63
--- /dev/null
+++ b/models/ops/src/vision.cpp
@@ -0,0 +1,16 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include "ms_deform_attn.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
+}
diff --git a/models/ops/test.py b/models/ops/test.py
new file mode 100644
index 0000000..8dbf6d5
--- /dev/null
+++ b/models/ops/test.py
@@ -0,0 +1,89 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import time
+import torch
+import torch.nn as nn
+from torch.autograd import gradcheck
+
+from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
+
+
+N, M, D = 1, 2, 2
+Lq, L, P = 2, 2, 2
+shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
+level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
+S = sum([(H*W).item() for H, W in shapes])
+
+
+torch.manual_seed(3)
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_double():
+ value = torch.rand(N, S, M, D).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
+ fwdok = torch.allclose(output_cuda, output_pytorch)
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_float():
+ value = torch.rand(N, S, M, D).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
+
+
+def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
+
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ func = MSDeformAttnFunction.apply
+
+ value.requires_grad = grad_value
+ sampling_locations.requires_grad = grad_sampling_loc
+ attention_weights.requires_grad = grad_attn_weight
+
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
+
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
+
+
+if __name__ == '__main__':
+ check_forward_equal_with_pytorch_double()
+ check_forward_equal_with_pytorch_float()
+
+ for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
+ check_gradient_numerical(channels, True, True, True)
+
+
+
diff --git a/models/ow_detr.py b/models/ow_detr.py
new file mode 100644
index 0000000..dd4be54
--- /dev/null
+++ b/models/ow_detr.py
@@ -0,0 +1,763 @@
+# ------------------------------------------------------------------------
+# Novel Scenes & Classes: Towards Adaptive Open-set Object Detection
+# Modified by Wuyang Li
+# ------------------------------------------------------------------------
+# OW-DETR: Open-world Detection Transformer
+# Akshita Gupta^, Sanath Narayan^, K J Joseph, Salman Khan, Fahad Shahbaz Khan, Mubarak Shah
+# https://arxiv.org/pdf/2112.01513.pdf
+# ------------------------------------------------------------------------
+# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# ------------------------------------------------------------------------
+
+"""
+Deformable DETR model and criterion classes.
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn
+import math
+import pickle
+from util import box_ops
+from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
+ accuracy, get_world_size, interpolate,
+ is_dist_avail_and_initialized, inverse_sigmoid)
+from .backbone import build_backbone
+from .matcher import build_matcher
+from .segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm,
+ dice_loss, sigmoid_focal_loss) #, sigmoid_focal_loss_CA)
+from .deformable_transformer import build_deforamble_transformer
+import copy
+import heapq
+import operator
+import os
+from copy import deepcopy
+from .utils import GradientReversal
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+class DeformableDETR(nn.Module):
+ """ This is the Deformable DETR module that performs object detection """
+ def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels,
+ aux_loss=True, with_box_refine=False, two_stage=False,
+ unmatched_boxes=False, novelty_cls=False, featdim=1024):
+ """ Initializes the model.
+ Parameters:
+ backbone: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ num_classes: number of object classes
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ with_box_refine: iterative bounding box refinement
+ two_stage: two-stage Deformable DETR
+ """
+ super().__init__()
+
+ self.num_queries = num_queries
+ self.transformer = transformer
+ hidden_dim = transformer.d_model
+
+ self.class_embed = nn.Linear(hidden_dim, num_classes)
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
+ self.num_feature_levels = num_feature_levels
+
+ ###
+ self.featdim = featdim
+ self.unmatched_boxes = unmatched_boxes
+ self.novelty_cls = novelty_cls
+ if self.novelty_cls:
+ self.nc_class_embed = nn.Linear(hidden_dim, 1)
+
+ if not two_stage:
+ self.query_embed = nn.Embedding(num_queries, hidden_dim*2)
+ if num_feature_levels > 1:
+ num_backbone_outs = len(backbone.strides)
+ input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = backbone.num_channels[_]
+ input_proj_list.append(nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ ))
+ for _ in range(num_feature_levels - num_backbone_outs):
+ input_proj_list.append(nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
+ nn.GroupNorm(32, hidden_dim),
+ ))
+ in_channels = hidden_dim
+ self.input_proj = nn.ModuleList(input_proj_list)
+ else:
+ self.input_proj = nn.ModuleList([
+ nn.Sequential(
+ nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ )])
+ self.backbone = backbone
+ self.aux_loss = aux_loss
+ self.with_box_refine = with_box_refine
+ self.two_stage = two_stage
+
+ self.grl = GradientReversal(lambda_=1.0)
+ self.backbone_D = MLP(hidden_dim, hidden_dim, 1, 3)
+ for layer in self.backbone_D.layers:
+ nn.init.xavier_uniform_(layer.weight, gain=1)
+ nn.init.constant_(layer.bias, 0)
+
+ prior_prob = 0.01
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
+ self.class_embed.bias.data = torch.ones(num_classes) * bias_value
+ if self.novelty_cls:
+ self.nc_class_embed.bias.data = torch.ones(1) * bias_value
+ nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
+ for proj in self.input_proj:
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
+ nn.init.constant_(proj[0].bias, 0)
+
+ # if two-stage, the last class_embed and bbox_embed is for region proposal generation
+ num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers
+ if with_box_refine:
+ self.class_embed = _get_clones(self.class_embed, num_pred)
+ self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
+ nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
+ # hack implementation for iterative bounding box refinement
+ self.transformer.decoder.bbox_embed = self.bbox_embed
+ else:
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
+ self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
+
+ if self.novelty_cls:
+ self.nc_class_embed = nn.ModuleList([self.nc_class_embed for _ in range(num_pred)])
+
+ self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
+ self.transformer.decoder.bbox_embed = None
+ if two_stage:
+ # hack implementation for two-stage
+ self.transformer.decoder.class_embed = self.class_embed
+ for box_embed in self.bbox_embed:
+ nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
+
+ def forward(self, samples: NestedTensor):
+ """ The forward expects a NestedTensor, which consists of:
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
+ It returns a dict with the following elements:
+ - "pred_logits": the classification logits (including no-object) for all queries.
+ Shape= [batch_size x num_queries x (num_classes + 1)]
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
+ (center_x, center_y, height, width). These values are normalized in [0, 1],
+ relative to the size of each individual image (disregarding possible padding).
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
+ dictionnaries containing the two above keys for each decoder layer.
+ """
+ if not isinstance(samples, NestedTensor):
+ samples = nested_tensor_from_tensor_list(samples)
+ features, pos = self.backbone(samples)
+
+ srcs = []
+ masks = []
+ if self.featdim == 512:
+ dim_index = 0
+ elif self.featdim == 1024:
+ dim_index = 1
+ else:
+ dim_index = 2
+
+ for l, feat in enumerate(features):
+ src, mask = feat.decompose()
+ ## [Info] extracting the resnet features which are used for selecting unmatched queries
+ if self.unmatched_boxes:
+ if l == dim_index:
+ resnet_1024_feature = src.clone() # 2X1024X61X67
+ else:
+ resnet_1024_feature = None
+ srcs.append(self.input_proj[l](src))
+ masks.append(mask)
+ assert mask is not None
+ if self.num_feature_levels > len(srcs):
+ _len_srcs = len(srcs)
+ for l in range(_len_srcs, self.num_feature_levels):
+ if l == _len_srcs:
+ src = self.input_proj[l](features[-1].tensors)
+ else:
+ src = self.input_proj[l](srcs[-1])
+ m = samples.mask
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
+ pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
+ srcs.append(src)
+ masks.append(mask)
+ pos.append(pos_l)
+
+ query_embeds = None
+ if not self.two_stage:
+ query_embeds = self.query_embed.weight
+
+ # hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = self.transformer(srcs, masks, pos, query_embeds)
+
+ hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, da_output, memory = self.transformer(srcs, masks, pos, query_embeds)
+
+ outputs_classes = []
+ outputs_coords = []
+ outputs_classes_nc = []
+
+ for lvl in range(hs.shape[0]):
+ if lvl == 0:
+ reference = init_reference
+ else:
+ reference = inter_references[lvl - 1]
+
+ reference = inverse_sigmoid(reference)
+ outputs_class = self.class_embed[lvl](hs[lvl])
+
+ ## novelty classification
+ if self.novelty_cls:
+ outputs_class_nc = self.nc_class_embed[lvl](hs[lvl])
+
+ tmp = self.bbox_embed[lvl](hs[lvl])
+ if reference.shape[-1] == 4:
+ tmp += reference
+ else:
+ assert reference.shape[-1] == 2
+ tmp[..., :2] += reference
+ outputs_coord = tmp.sigmoid()
+ outputs_classes.append(outputs_class)
+ if self.novelty_cls:
+ outputs_classes_nc.append(outputs_class_nc)
+ outputs_coords.append(outputs_coord)
+
+ outputs_class = torch.stack(outputs_classes)
+ outputs_coord = torch.stack(outputs_coords)
+ if self.novelty_cls:
+ output_class_nc = torch.stack(outputs_classes_nc)
+ out= {}
+ if self.training:
+ B = outputs_class.shape[1]
+ outputs_class = outputs_class[:, :B//2]
+ outputs_coord = outputs_coord[:, :B//2]
+ output_class_nc = output_class_nc[:, :B//2]
+ if self.two_stage:
+ enc_outputs_class = enc_outputs_class[:B//2]
+ enc_outputs_coord_unact = enc_outputs_coord_unact[:B//2]
+
+ out['da_backbone'] = torch.cat([self.backbone_D(self.grl(src.flatten(2).transpose(1, 2))) for src in srcs], dim=1)
+
+ out.update({'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], 'resnet_1024_feat': resnet_1024_feature})
+
+ if self.novelty_cls:
+ out['pred_nc_logits'] = output_class_nc[-1]
+
+ if self.aux_loss:
+ out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, output_class_nc=None)
+ if self.novelty_cls:
+ out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, output_class_nc=output_class_nc)
+ if self.two_stage:
+ enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
+ out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord}
+
+ return out
+
+ @torch.jit.unused
+ def _set_aux_loss(self, outputs_class, outputs_coord, output_class_nc=None):
+ # this is a workaround to make torchscript happy, as torchscript
+ # doesn't support dictionary with non-homogeneous values, such
+ # as a dict having both a Tensor and a list.
+ # import pdb;pdb.set_trace()
+ if output_class_nc is not None:
+ xx = [{'pred_logits': a, 'pred_nc_logits': c, 'pred_boxes': b}
+ for a, c, b in zip(outputs_class[:-1], output_class_nc[:-1], outputs_coord[:-1])]
+ else:
+ xx = [{'pred_logits': a, 'pred_boxes': b}
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
+ return xx
+
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+ def __init__(self, args, num_classes, matcher, weight_dict, losses, invalid_cls_logits, focal_alpha=0.25):
+ """ Create the criterion.
+ Parameters:
+ num_classes: number of object categories, omitting the special no-object category
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ focal_alpha: alpha in Focal Loss
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.focal_alpha = focal_alpha
+ self.nc_epoch = -1
+ self.invalid_cls_logits = invalid_cls_logits
+ self.unmatched_boxes = True
+ self.top_unk = 5
+ self.num_seen_classes = 4
+
+
+ def loss_NC_labels(self, outputs, targets, indices, num_boxes, current_epoch, owod_targets, owod_indices, log=True):
+ """Novelty classification loss
+ target labels will contain class as 1
+ owod_indices -> indices combining matched indices + psuedo labeled indices
+ owod_targets -> targets combining GT targets + psuedo labeled unknown targets
+ target_classes_o -> contains all 1's
+ """
+ assert 'pred_nc_logits' in outputs
+ src_logits = outputs['pred_nc_logits']
+
+ idx = self._get_src_permutation_idx(owod_indices)
+ target_classes_o = torch.cat([torch.full_like(t["labels"][J], 0) for t, (_, J) in zip(owod_targets, owod_indices)])
+ target_classes = torch.full(src_logits.shape[:2], 1, dtype=torch.int64, device=src_logits.device)
+ target_classes[idx] = target_classes_o
+
+ target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
+ target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
+
+ target_classes_onehot = target_classes_onehot[:,:,:-1]
+ loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]
+
+ losses = {'loss_NC': loss_ce}
+ return losses
+
+ def loss_labels(self, outputs, targets, indices, num_boxes, current_epoch, owod_targets, owod_indices, log=True):
+ """Classification loss (NLL)
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+ """
+ assert 'pred_logits' in outputs
+ ## comment lines from 317-320 when running for oracle settings
+ temp_src_logits = outputs['pred_logits'].clone()
+ temp_src_logits[:,:, self.invalid_cls_logits] = -10e10
+ src_logits = temp_src_logits
+
+ if self.unmatched_boxes:
+ idx = self._get_src_permutation_idx(owod_indices)
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(owod_targets, owod_indices)])
+ else:
+ idx = self._get_src_permutation_idx(indices)
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
+
+ target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device)
+ target_classes[idx] = target_classes_o
+ target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
+ dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
+ target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
+ target_classes_onehot = target_classes_onehot[:,:,:-1]
+ loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]
+
+ losses = {'loss_ce': loss_ce}
+
+ if log:
+ # TODO this should probably be a separate loss, not hacked in this one here
+ losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
+ return losses
+
+ @torch.no_grad()
+ def loss_cardinality(self, outputs, targets, indices, num_boxes, current_epoch, owod_targets, owod_indices):
+ """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
+ This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
+ """
+ temp_pred_logits = outputs['pred_logits'].clone()
+ temp_pred_logits[:,:, self.invalid_cls_logits] = -10e10
+ pred_logits = temp_pred_logits
+
+ device = pred_logits.device
+ if self.unmatched_boxes:
+ tgt_lengths = torch.as_tensor([len(v["labels"]) for v in owod_targets], device=device)
+ else:
+ tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
+ # Count the number of predictions that are NOT "no-object" (which is the last class)
+ card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
+ card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
+ losses = {'cardinality_error': card_err}
+ return losses
+
+ def loss_boxes(self, outputs, targets, indices, num_boxes, current_epoch, owod_targets, owod_indices):
+ # def loss_boxes(self, outputs, targets, indices, num_boxes, current_epoch, owod_targets, owod_indices, ca_owod_targets, ca_owod_indices):
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+ The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
+ """
+
+ assert 'pred_boxes' in outputs
+ idx = self._get_src_permutation_idx(indices)
+ src_boxes = outputs['pred_boxes'][idx]
+ target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
+
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
+
+ losses = {}
+ losses['loss_bbox'] = loss_bbox.sum() / num_boxes
+
+ loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
+ box_ops.box_cxcywh_to_xyxy(src_boxes),
+ box_ops.box_cxcywh_to_xyxy(target_boxes)))
+ losses['loss_giou'] = loss_giou.sum() / num_boxes
+ return losses
+
+ def loss_masks(self, outputs, targets, indices, num_boxes, current_epoch, owod_targets, owod_indices):
+ """Compute the losses related to the masks: the focal loss and the dice loss.
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
+ """
+ assert "pred_masks" in outputs
+
+ src_idx = self._get_src_permutation_idx(indices)
+ tgt_idx = self._get_tgt_permutation_idx(indices)
+
+ src_masks = outputs["pred_masks"]
+
+ # TODO use valid to mask invalid areas due to padding in loss
+ target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose()
+ target_masks = target_masks.to(src_masks)
+
+ src_masks = src_masks[src_idx]
+ # upsample predictions to the target size
+ src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
+ mode="bilinear", align_corners=False)
+ src_masks = src_masks[:, 0].flatten(1)
+
+ target_masks = target_masks[tgt_idx].flatten(1)
+
+ losses = {
+ "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
+ "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
+ }
+ return losses
+
+ def save_dict(self, di_, filename_):
+ with open(filename_, 'wb') as f:
+ pickle.dump(di_, f)
+
+ def load_dict(self, filename_):
+ with open(filename_, 'rb') as f:
+ ret_dict = pickle.load(f)
+ return ret_dict
+
+ def _get_src_permutation_idx(self, indices):
+ # permute predictions following indices
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+ src_idx = torch.cat([src for (src, _) in indices])
+ return batch_idx, src_idx
+
+ def _get_src_single_permutation_idx(self, indices, index):
+ ## Only need the src query index selection from this function for attention feature selection
+ batch_idx = [torch.full_like(src, i) for i, src in enumerate(indices)][0]
+ src_idx = indices[0]
+ return batch_idx, src_idx
+
+ def _get_tgt_permutation_idx(self, indices):
+ # permute targets following indices
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+ return batch_idx, tgt_idx
+
+ def get_loss(self, loss, outputs, targets, indices, num_boxes, epoch, owod_targets, owod_indices, **kwargs):
+ loss_map = {
+ 'labels': self.loss_labels,
+ 'NC_labels': self.loss_NC_labels,
+ 'cardinality': self.loss_cardinality,
+ 'boxes': self.loss_boxes,
+ 'masks': self.loss_masks
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, num_boxes, epoch, owod_targets, owod_indices, **kwargs)
+
+ def loss_da(self, outputs, use_focal=False):
+
+ B = outputs.shape[0]
+ assert B % 2 == 0
+
+ targets = torch.empty_like(outputs)
+ targets[:B//2] = 0
+ targets[B//2:] = 1
+
+ loss = F.binary_cross_entropy_with_logits(outputs, targets, reduction='none')
+
+ if use_focal:
+ prob = outputs.sigmoid()
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = loss * ((1 - p_t) ** self.da_gamma)
+
+ return loss.mean()
+ def forward(self, samples, outputs, targets, epoch):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ # if self.nc_epoch > 0:
+ # loss_epoch = 9
+ # else:
+ # loss_epoch = 0
+ loss_epoch = -1
+
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'}
+ indices = self.matcher(outputs_without_aux, targets)
+
+ owod_targets = deepcopy(targets)
+ owod_indices = deepcopy(indices)
+
+
+ owod_outputs = outputs_without_aux.copy()
+ owod_device = owod_outputs["pred_boxes"].device
+
+ if self.unmatched_boxes and epoch >= loss_epoch and self.training:
+ ## get pseudo unmatched boxes from this section
+ res_feat = torch.mean(outputs['resnet_1024_feat'], 1)
+ queries = torch.arange(outputs['pred_logits'].shape[1])
+ for i in range(len(indices)):
+ combined = torch.cat((queries, self._get_src_single_permutation_idx(indices[i], i)[-1])) ## need to fix the indexing
+ uniques, counts = combined.unique(return_counts=True)
+ unmatched_indices = uniques[counts == 1]
+ boxes = outputs_without_aux['pred_boxes'][i] #[unmatched_indices,:]
+ img = samples.tensors[i].cpu().permute(1,2,0).numpy()
+ h, w = img.shape[:-1]
+ img_w = torch.tensor(w, device=owod_device)
+ img_h = torch.tensor(h, device=owod_device)
+ unmatched_boxes = box_ops.box_cxcywh_to_xyxy(boxes)
+ unmatched_boxes = unmatched_boxes * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(owod_device)
+ means_bb = torch.zeros(queries.shape[0]).to(unmatched_boxes)
+ bb = unmatched_boxes
+ for j, _ in enumerate(means_bb):
+ if j in unmatched_indices:
+ upsaple = nn.Upsample(size=(img_h,img_w), mode='bilinear')
+ img_feat = upsaple(res_feat[i].unsqueeze(0).unsqueeze(0))
+ img_feat = img_feat.squeeze(0).squeeze(0)
+ xmin = bb[j,:][0].long()
+ ymin = bb[j,:][1].long()
+ xmax = bb[j,:][2].long()
+ ymax = bb[j,:][3].long()
+ means_bb[j] = torch.mean(img_feat[ymin:ymax,xmin:xmax])
+ if torch.isnan(means_bb[j]):
+ means_bb[j] = -10e10
+ else:
+ means_bb[j] = -10e10
+
+ _, topk_inds = torch.topk(means_bb, self.top_unk)
+ topk_inds = torch.as_tensor(topk_inds)
+
+ topk_inds = topk_inds.cpu()
+
+ unk_label = torch.as_tensor([self.num_classes-1], device=owod_device)
+ owod_targets[i]['labels'] = torch.cat((owod_targets[i]['labels'], unk_label.repeat_interleave(self.top_unk)))
+ owod_indices[i] = (torch.cat((owod_indices[i][0], topk_inds)), torch.cat((owod_indices[i][1], (owod_targets[i]['labels'] == unk_label).nonzero(as_tuple=True)[0].cpu())))
+
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
+ num_boxes = sum(len(t["labels"]) for t in targets)
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
+ if is_dist_avail_and_initialized():
+ torch.distributed.all_reduce(num_boxes)
+ num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
+
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ kwargs = {}
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, epoch, owod_targets, owod_indices, **kwargs))
+
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if 'aux_outputs' in outputs:
+ for i, aux_outputs in enumerate(outputs['aux_outputs']):
+ indices = self.matcher(aux_outputs, targets)
+
+ owod_targets = deepcopy(targets)
+ owod_indices = deepcopy(indices)
+
+ aux_owod_outputs = aux_outputs.copy()
+ owod_device = aux_owod_outputs["pred_boxes"].device
+
+ if self.unmatched_boxes and epoch >= loss_epoch and self.training:
+ ## get pseudo unmatched boxes from this section
+ res_feat = torch.mean(outputs['resnet_1024_feat'], 1) #2 X 67 X 50
+ queries = torch.arange(aux_owod_outputs['pred_logits'].shape[1])
+ for i in range(len(indices)):
+ combined = torch.cat((queries, self._get_src_single_permutation_idx(indices[i], i)[-1])) ## need to fix the indexing
+ uniques, counts = combined.unique(return_counts=True)
+ unmatched_indices = uniques[counts == 1]
+ boxes = aux_owod_outputs['pred_boxes'][i] #[unmatched_indices,:]
+ img = samples.tensors[i].cpu().permute(1,2,0).numpy()
+ h, w = img.shape[:-1]
+ img_w = torch.tensor(w, device=owod_device)
+ img_h = torch.tensor(h, device=owod_device)
+ unmatched_boxes = box_ops.box_cxcywh_to_xyxy(boxes)
+ unmatched_boxes = unmatched_boxes * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(owod_device)
+ means_bb = torch.zeros(queries.shape[0]).to(unmatched_boxes) #torch.zeros(unmatched_boxes.shape[0])
+ bb = unmatched_boxes
+ ## [INFO]: iterating over the full list of boxes and then selecting the unmatched ones
+ for j, _ in enumerate(means_bb):
+ if j in unmatched_indices:
+ upsaple = nn.Upsample(size=(img_h,img_w), mode='bilinear')
+ img_feat = upsaple(res_feat[i].unsqueeze(0).unsqueeze(0))
+ img_feat = img_feat.squeeze(0).squeeze(0)
+ xmin = bb[j,:][0].long()
+ ymin = bb[j,:][1].long()
+ xmax = bb[j,:][2].long()
+ ymax = bb[j,:][3].long()
+ means_bb[j] = torch.mean(img_feat[ymin:ymax,xmin:xmax])
+ if torch.isnan(means_bb[j]):
+ means_bb[j] = -10e10
+ else:
+ means_bb[j] = -10e10
+
+ _, topk_inds = torch.topk(means_bb, self.top_unk)
+ topk_inds = torch.as_tensor(topk_inds)
+
+ topk_inds = topk_inds.cpu()
+ unk_label = torch.as_tensor([self.num_classes-1], device=owod_device)
+ owod_targets[i]['labels'] = torch.cat((owod_targets[i]['labels'], unk_label.repeat_interleave(self.top_unk)))
+ owod_indices[i] = (torch.cat((owod_indices[i][0], topk_inds)), torch.cat((owod_indices[i][1], (owod_targets[i]['labels'] == unk_label).nonzero(as_tuple=True)[0].cpu())))
+
+ for loss in self.losses:
+ if loss == 'masks':
+ # Intermediate masks losses are too costly to compute, we ignore them.
+ continue
+ kwargs = {}
+ if loss == 'labels':
+ # Logging is enabled only for the last layer
+ kwargs['log'] = False
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, epoch, owod_targets, owod_indices, **kwargs)
+ l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ if 'enc_outputs' in outputs:
+ enc_outputs = outputs['enc_outputs']
+ bin_targets = copy.deepcopy(targets)
+ for bt in bin_targets:
+ bt['labels'] = torch.zeros_like(bt['labels'])
+ indices = self.matcher(enc_outputs, bin_targets)
+ for loss in self.losses:
+ if loss == 'masks':
+ # Intermediate masks losses are too costly to compute, we ignore them.
+ continue
+ kwargs = {}
+ if loss == 'labels':
+ # Logging is enabled only for the last layer
+ kwargs['log'] = False
+ l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
+ l_dict = {k + f'_enc': v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ # import ipdb; ipdb.set_trace()
+ if 'da_backbone' in outputs:
+ # for k, v in outputs['da_output'].items():
+ losses['loss_backbone'] = self.loss_da(outputs['da_backbone'], use_focal=False)
+ # print(losses)
+ return losses
+
+
+class PostProcess(nn.Module):
+ """ This module converts the model's output into the format expected by the coco api"""
+
+ @torch.no_grad()
+ def forward(self, outputs, target_sizes):
+ """ Perform the computation
+ Parameters:
+ outputs: raw outputs of the model
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
+ For evaluation, this must be the original image size (before any data augmentation)
+ For visualization, this should be the image size after data augment, but before padding
+ """
+ out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
+
+ assert len(out_logits) == len(target_sizes)
+ assert target_sizes.shape[1] == 2
+
+ prob = out_logits.sigmoid()
+ topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
+ scores = topk_values
+ topk_boxes = topk_indexes // out_logits.shape[2]
+ labels = topk_indexes % out_logits.shape[2]
+ boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
+
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
+
+ return results
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+# def build_ow_detr(cfg):
+def build(cfg):
+ num_classes = 11
+ # num_classes = 4
+ print(num_classes)
+ # if args.dataset == "coco_panoptic":
+ # num_classes = 250
+ device = torch.device(cfg.DEVICE)
+
+ backbone = build_backbone(cfg)
+
+ transformer = build_deforamble_transformer(cfg)
+
+ prev_intro_cls = 0
+ curr_intro_cls = 10
+ # curr_intro_cls = 3
+ seen_classes = prev_intro_cls + curr_intro_cls
+ invalid_cls_logits = list(range(seen_classes, num_classes-1)) #unknown class indx will not be included in the invalid class range
+ print("Invalid class rangw: " + str(invalid_cls_logits))
+
+ model = DeformableDETR(
+ backbone,
+ transformer,
+ num_classes=cfg.DATASET.NUM_CLASSES,
+ num_queries=cfg.MODEL.NUM_QUERIES,
+ num_feature_levels=cfg.MODEL.NUM_FEATURE_LEVELS,
+ aux_loss=cfg.LOSS.AUX_LOSS,
+ with_box_refine=cfg.MODEL.WITH_BOX_REFINE,
+ two_stage=cfg.MODEL.TWO_STAGE,
+ unmatched_boxes=True,
+ novelty_cls=True,
+ featdim=1024,
+ )
+ if cfg.MODEL.MASKS:
+ model = DETRsegm(model, freeze_detr=(cfg.MODEL.FROZEN_WEIGHTS is not None))
+ matcher = build_matcher(cfg)
+ weight_dict = {'loss_ce': cfg.LOSS.CLS_LOSS_COEF, 'loss_bbox': cfg.LOSS.BBOX_LOSS_COEF}
+ weight_dict['loss_giou'] = cfg.LOSS.GIOU_LOSS_COEF
+ if cfg.MODEL.MASKS:
+ weight_dict["loss_mask"] = cfg.LOSS.MASK_LOSS_COEF
+ weight_dict["loss_dice"] = cfg.LOSS.DICE_LOSS_COEF
+
+ weight_dict['loss_backbone'] = cfg.LOSS.BACKBONE_LOSS_COEF
+ weight_dict['loss_NC'] = 0.1
+
+ # TODO this is a hack
+ if cfg.LOSS.AUX_LOSS:
+ aux_weight_dict = {}
+ for i in range(cfg.MODEL.DEC_LAYERS - 1):
+ aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
+ aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()})
+ weight_dict.update(aux_weight_dict)
+
+ losses = ['labels', 'NC_labels', 'boxes', 'cardinality']
+ # if args.masks:
+ # losses += ["masks"]
+ criterion = SetCriterion(cfg, num_classes, matcher, weight_dict, losses, invalid_cls_logits, focal_alpha=cfg.LOSS.FOCAL_ALPHA)
+ criterion.to(device)
+ postprocessors = {'bbox': PostProcess()}
+ return model, criterion, postprocessors
\ No newline at end of file
diff --git a/models/position_encoding.py b/models/position_encoding.py
new file mode 100644
index 0000000..fa09903
--- /dev/null
+++ b/models/position_encoding.py
@@ -0,0 +1,99 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Various positional encodings for the transformer.
+"""
+import math
+import torch
+from torch import nn
+
+from util.misc import NestedTensor
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ mask = tensor_list.mask
+ assert mask is not None
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+ def __init__(self, num_pos_feats=256):
+ super().__init__()
+ self.row_embed = nn.Embedding(50, num_pos_feats)
+ self.col_embed = nn.Embedding(50, num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.row_embed.weight)
+ nn.init.uniform_(self.col_embed.weight)
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ h, w = x.shape[-2:]
+ i = torch.arange(w, device=x.device)
+ j = torch.arange(h, device=x.device)
+ x_emb = self.col_embed(i)
+ y_emb = self.row_embed(j)
+ pos = torch.cat([
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
+ y_emb.unsqueeze(1).repeat(1, w, 1),
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
+ return pos
+
+
+def build_position_encoding(cfg):
+ N_steps = cfg.MODEL.HIDDEN_DIM // 2
+ if cfg.MODEL.POSITION_EMBEDDING in ('v2', 'sine'):
+ # TODO find a better way of exposing other arguments
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
+ elif cfg.MODEL.POSITION_EMBEDDING in ('v3', 'learned'):
+ position_embedding = PositionEmbeddingLearned(N_steps)
+ else:
+ raise ValueError(f"not supported {cfg.MODEL.POSITION_EMBEDDING}")
+
+ return position_embedding
diff --git a/models/segmentation.py b/models/segmentation.py
new file mode 100644
index 0000000..b1460f0
--- /dev/null
+++ b/models/segmentation.py
@@ -0,0 +1,371 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+This file provides the definition of the convolutional heads used to predict masks, as well as the losses
+"""
+import io
+from collections import defaultdict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+
+import util.box_ops as box_ops
+from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list
+
+try:
+ from panopticapi.utils import id2rgb, rgb2id
+except ImportError:
+ pass
+
+
+class DETRsegm(nn.Module):
+ def __init__(self, detr, freeze_detr=False):
+ super().__init__()
+ self.detr = detr
+
+ if freeze_detr:
+ for p in self.parameters():
+ p.requires_grad_(False)
+
+ hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead
+ self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0)
+ self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim)
+
+ def forward(self, samples: NestedTensor):
+ if not isinstance(samples, NestedTensor):
+ samples = nested_tensor_from_tensor_list(samples)
+ features, pos = self.detr.backbone(samples)
+
+ bs = features[-1].tensors.shape[0]
+
+ src, mask = features[-1].decompose()
+ src_proj = self.detr.input_proj(src)
+ hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1])
+
+ outputs_class = self.detr.class_embed(hs)
+ outputs_coord = self.detr.bbox_embed(hs).sigmoid()
+ out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
+ if self.detr.aux_loss:
+ out["aux_outputs"] = [
+ {"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
+ ]
+
+ # FIXME h_boxes takes the last one computed, keep this in mind
+ bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)
+
+ seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors])
+ outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
+
+ out["pred_masks"] = outputs_seg_masks
+ return out
+
+
+class MaskHeadSmallConv(nn.Module):
+ """
+ Simple convolutional head, using group norm.
+ Upsampling is done using a FPN approach
+ """
+
+ def __init__(self, dim, fpn_dims, context_dim):
+ super().__init__()
+
+ inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
+ self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
+ self.gn1 = torch.nn.GroupNorm(8, dim)
+ self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1)
+ self.gn2 = torch.nn.GroupNorm(8, inter_dims[1])
+ self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
+ self.gn3 = torch.nn.GroupNorm(8, inter_dims[2])
+ self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
+ self.gn4 = torch.nn.GroupNorm(8, inter_dims[3])
+ self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
+ self.gn5 = torch.nn.GroupNorm(8, inter_dims[4])
+ self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1)
+
+ self.dim = dim
+
+ self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
+ self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
+ self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_uniform_(m.weight, a=1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, bbox_mask, fpns):
+ def expand(tensor, length):
+ return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
+
+ x = torch.cat([expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
+
+ x = self.lay1(x)
+ x = self.gn1(x)
+ x = F.relu(x)
+ x = self.lay2(x)
+ x = self.gn2(x)
+ x = F.relu(x)
+
+ cur_fpn = self.adapter1(fpns[0])
+ if cur_fpn.size(0) != x.size(0):
+ cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
+ x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
+ x = self.lay3(x)
+ x = self.gn3(x)
+ x = F.relu(x)
+
+ cur_fpn = self.adapter2(fpns[1])
+ if cur_fpn.size(0) != x.size(0):
+ cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
+ x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
+ x = self.lay4(x)
+ x = self.gn4(x)
+ x = F.relu(x)
+
+ cur_fpn = self.adapter3(fpns[2])
+ if cur_fpn.size(0) != x.size(0):
+ cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
+ x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
+ x = self.lay5(x)
+ x = self.gn5(x)
+ x = F.relu(x)
+
+ x = self.out_lay(x)
+ return x
+
+
+class MHAttentionMap(nn.Module):
+ """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
+
+ def __init__(self, query_dim, hidden_dim, num_heads, dropout=0, bias=True):
+ super().__init__()
+ self.num_heads = num_heads
+ self.hidden_dim = hidden_dim
+ self.dropout = nn.Dropout(dropout)
+
+ self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
+ self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
+
+ nn.init.zeros_(self.k_linear.bias)
+ nn.init.zeros_(self.q_linear.bias)
+ nn.init.xavier_uniform_(self.k_linear.weight)
+ nn.init.xavier_uniform_(self.q_linear.weight)
+ self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
+
+ def forward(self, q, k, mask=None):
+ q = self.q_linear(q)
+ k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
+ qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
+ kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
+ weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)
+
+ if mask is not None:
+ weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
+ weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights)
+ weights = self.dropout(weights)
+ return weights
+
+
+def dice_loss(inputs, targets, num_boxes):
+ """
+ Compute the DICE loss, similar to generalized IOU for masks
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ """
+ inputs = inputs.sigmoid()
+ inputs = inputs.flatten(1)
+ numerator = 2 * (inputs * targets).sum(1)
+ denominator = inputs.sum(-1) + targets.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss.sum() / num_boxes
+
+
+def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
+ """
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ alpha: (optional) Weighting factor in range (0,1) to balance
+ positive vs negative examples. Default = -1 (no weighting).
+ gamma: Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples.
+ Returns:
+ Loss tensor
+ """
+ prob = inputs.sigmoid()
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+ return loss.mean(1).sum() / num_boxes
+
+
+class PostProcessSegm(nn.Module):
+ def __init__(self, threshold=0.5):
+ super().__init__()
+ self.threshold = threshold
+
+ @torch.no_grad()
+ def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
+ assert len(orig_target_sizes) == len(max_target_sizes)
+ max_h, max_w = max_target_sizes.max(0)[0].tolist()
+ outputs_masks = outputs["pred_masks"].squeeze(2)
+ outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False)
+ outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()
+
+ for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
+ img_h, img_w = t[0], t[1]
+ results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
+ results[i]["masks"] = F.interpolate(
+ results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
+ ).byte()
+
+ return results
+
+
+class PostProcessPanoptic(nn.Module):
+ """This class converts the output of the model to the final panoptic result, in the format expected by the
+ coco panoptic API """
+
+ def __init__(self, is_thing_map, threshold=0.85):
+ """
+ Parameters:
+ is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether
+ the class is a thing (True) or a stuff (False) class
+ threshold: confidence threshold: segments with confidence lower than this will be deleted
+ """
+ super().__init__()
+ self.threshold = threshold
+ self.is_thing_map = is_thing_map
+
+ def forward(self, outputs, processed_sizes, target_sizes=None):
+ """ This function computes the panoptic prediction from the model's predictions.
+ Parameters:
+ outputs: This is a dict coming directly from the model. See the model doc for the content.
+ processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the
+ model, ie the size after data augmentation but before batching.
+ target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size
+ of each prediction. If left to None, it will default to the processed_sizes
+ """
+ if target_sizes is None:
+ target_sizes = processed_sizes
+ assert len(processed_sizes) == len(target_sizes)
+ out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"]
+ assert len(out_logits) == len(raw_masks) == len(target_sizes)
+ preds = []
+
+ def to_tuple(tup):
+ if isinstance(tup, tuple):
+ return tup
+ return tuple(tup.cpu().tolist())
+
+ for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
+ out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
+ ):
+ # we filter empty queries and detection below threshold
+ scores, labels = cur_logits.softmax(-1).max(-1)
+ keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold)
+ cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
+ cur_scores = cur_scores[keep]
+ cur_classes = cur_classes[keep]
+ cur_masks = cur_masks[keep]
+ cur_masks = interpolate(cur_masks[None], to_tuple(size), mode="bilinear").squeeze(0)
+ cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep])
+
+ h, w = cur_masks.shape[-2:]
+ assert len(cur_boxes) == len(cur_classes)
+
+ # It may be that we have several predicted masks for the same stuff class.
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
+ cur_masks = cur_masks.flatten(1)
+ stuff_equiv_classes = defaultdict(lambda: [])
+ for k, label in enumerate(cur_classes):
+ if not self.is_thing_map[label.item()]:
+ stuff_equiv_classes[label.item()].append(k)
+
+ def get_ids_area(masks, scores, dedup=False):
+ # This helper function creates the final panoptic segmentation image
+ # It also returns the area of the masks that appears on the image
+
+ m_id = masks.transpose(0, 1).softmax(-1)
+
+ if m_id.shape[-1] == 0:
+ # We didn't detect any mask :(
+ m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
+ else:
+ m_id = m_id.argmax(-1).view(h, w)
+
+ if dedup:
+ # Merge the masks corresponding to the same stuff class
+ for equiv in stuff_equiv_classes.values():
+ if len(equiv) > 1:
+ for eq_id in equiv:
+ m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
+
+ final_h, final_w = to_tuple(target_size)
+
+ seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy()))
+ seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST)
+
+ np_seg_img = (
+ torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy()
+ )
+ m_id = torch.from_numpy(rgb2id(np_seg_img))
+
+ area = []
+ for i in range(len(scores)):
+ area.append(m_id.eq(i).sum().item())
+ return area, seg_img
+
+ area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
+ if cur_classes.numel() > 0:
+ # We know filter empty masks as long as we find some
+ while True:
+ filtered_small = torch.as_tensor(
+ [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device
+ )
+ if filtered_small.any().item():
+ cur_scores = cur_scores[~filtered_small]
+ cur_classes = cur_classes[~filtered_small]
+ cur_masks = cur_masks[~filtered_small]
+ area, seg_img = get_ids_area(cur_masks, cur_scores)
+ else:
+ break
+
+ else:
+ cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)
+
+ segments_info = []
+ for i, a in enumerate(area):
+ cat = cur_classes[i].item()
+ segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a})
+ del cur_classes
+
+ with io.BytesIO() as out:
+ seg_img.save(out, format="PNG")
+ predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
+ preds.append(predictions)
+ return preds
diff --git a/models/utils.py b/models/utils.py
new file mode 100644
index 0000000..c298d89
--- /dev/null
+++ b/models/utils.py
@@ -0,0 +1,124 @@
+# ----------------------------------------------
+# Created by Wei-Jie Huang
+# ----------------------------------------------
+
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+def remove_mask_and_warp(src, pos, padding_mask, level_start_index, spatial_shapes):
+ """ Removes padding mask in sequence and warps each level of tokens into fixed-sized sequences.
+
+ Args:
+ src, pos (batch_size, sequence_length, d_model): patch tokens and position encodings
+ padding_mask (batch_size, sequence_length): key padding mask
+ level_start_index (num_feature_levels): start index of each feature level
+ spatial_shapes (num_feature_levels, 2): spatial shape (H, W) of each feature level
+
+ Returns:
+ src_warped, pos_warped (batch_size, num_feature_levels, C, C): warped patch tokens and
+ position encodings. The last two dimensions indicate sequence length (i.e., H*W) and model
+ dimension, respectively.
+ """
+
+ B, _, C = src.shape
+ sqrt_C = int(C ** 0.5)
+ src_warped = []
+ pos_warped = []
+ for start, shape in zip(level_start_index, spatial_shapes):
+ H, W = shape
+ s = src[:, start:start+H*W].view(B, H, W, C).permute(0, 3, 1, 2)
+ p = pos[:, start:start+H*W].view(B, H, W, C).permute(0, 3, 1, 2)
+ m = padding_mask[:, start:start+H*W].view(B, H, W)
+
+ not_m = ~m
+ real_H = not_m.sum(1).max(1).values
+ real_W = not_m.sum(2).max(1).values
+
+ src_warped.append(torch.stack([F.adaptive_avg_pool2d(s_i[:, :real_H[i], :real_W[i]], sqrt_C) for i, s_i in enumerate(s)]))
+ pos_warped.append(torch.stack([F.adaptive_avg_pool2d(p_i[:, :real_H[i], :real_W[i]], sqrt_C) for i, p_i in enumerate(p)]))
+
+ src_warped = torch.stack(src_warped, dim=1).flatten(-2).transpose(-2, -1)
+ pos_warped = torch.stack(pos_warped, dim=1).flatten(-2).transpose(-2, -1)
+ return src_warped, pos_warped
+
+
+class DomainAttention(nn.Module):
+ """ Wraps domain-adapting cross attention and MLP into a module.
+ The operations are similar to those in Transformer, including normalization
+ layers and dropout layers, while MLP is simplified as a linear layer.
+
+ Args:
+ d_model: total dimension of the model.
+ n_heads: parallel attention heads.
+ dropout: a Dropout layer on attn_output_weights.
+ """
+
+ def __init__(self, d_model, n_heads, dropout):
+ super(DomainAttention, self).__init__()
+ self.grl = GradientReversal()
+ self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+ self.linear = nn.Linear(d_model, d_model)
+ self.dropout2 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward(self, query, src, pos=None, padding_mask=None):
+ """ Args:
+ query (batch_size, num_queries, d_model): discriminator query
+ src, pos (batch_size, sequence_length, d_model): patch tokens and position encodings
+ padding_mask (batch_size, sequence_length): key padding mask
+ """
+ r_query, _ = self.cross_attn(
+ query=query.transpose(0, 1),
+ key=self.grl(self.with_pos_embed(src, pos)).transpose(0, 1),
+ value=self.grl(src).transpose(0, 1),
+ key_padding_mask=padding_mask,
+ )
+ query = query + self.dropout1(r_query.transpose(0, 1))
+ query = self.norm1(query)
+ query = query + self.dropout2(self.linear(query))
+ query = self.norm2(query)
+ return query
+
+
+# ------------------------------------------------------------------------------------------------------------------------------
+# Copy-paste from https://github.com/jvanvugt/pytorch-domain-adaptation/blob/35ac3a5a04b5e1cf5b2145b6c442c2d678362eef/utils.py
+# ------------------------------------------------------------------------------------------------------------------------------
+
+
+class GradientReversalFunction(torch.autograd.Function):
+ """
+ Gradient Reversal Layer from:
+ Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
+ Forward pass is the identity function. In the backward pass,
+ the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
+ """
+
+ @staticmethod
+ def forward(ctx, x, lambda_):
+ ctx.lambda_ = lambda_
+ return x.clone()
+
+ @staticmethod
+ def backward(ctx, grads):
+ lambda_ = ctx.lambda_
+ lambda_ = grads.new_tensor(lambda_)
+ dx = -lambda_ * grads
+ return dx, None
+
+
+class GradientReversal(nn.Module):
+ def __init__(self, lambda_=1):
+ super(GradientReversal, self).__init__()
+ self.lambda_ = lambda_
+
+ def forward(self, x):
+ return GradientReversalFunction.apply(x, self.lambda_)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..96e9e35
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,5 @@
+pycocotools
+tqdm
+cython
+scipy
+yacs
diff --git a/run.sh b/run.sh
new file mode 100644
index 0000000..91f37be
--- /dev/null
+++ b/run.sh
@@ -0,0 +1,29 @@
+# Cityscapes to Foggy Cityscapes
+GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_foggy_r50.yaml \
+--opts DATASET.AOOD_SETTING 1 OUTPUT_DIR experiments/city_to_foggy/setting1
+
+GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_foggy_r50.yaml \
+--opts DATASET.AOOD_SETTING 2 OUTPUT_DIR experiments/city_to_foggy/setting2
+
+GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_foggy_r50.yaml \
+--opts DATASET.AOOD_SETTING 3 OUTPUT_DIR experiments/city_to_foggy/setting3
+
+GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_foggy_r50.yaml \
+--opts DATASET.AOOD_SETTING 4 OUTPUT_DIR experiments/city_to_foggy/setting4
+
+# Pascal to CLipart
+GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_pascal_to_clipart_r50.yaml \
+--opts OUTPUT_DIR experiments/pascal_to_clipart
+
+# Cityscapes to BDD00k_daytime
+GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_bdd100k_r50.yaml \
+--opts DATASET.AOOD_SETTING 1 OUTPUT_DIR experiments/city_to_bdd100k/setting1
+
+GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_bdd100k_r50.yaml \
+--opts DATASET.AOOD_SETTING 2 OUTPUT_DIR experiments/city_to_bdd100k/setting2
+
+GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_bdd100k_r50.yaml \
+--opts DATASET.AOOD_SETTING 3 OUTPUT_DIR experiments/city_to_bdd100k/setting3
+
+GPUS_PER_NODE=2 MASTER_PORT=21001 ./tools/run_dist_launch.sh 2 python main_multi_eval.py --config_file configs/soma_aood_city_to_bdd100k_r50.yaml \
+--opts DATASET.AOOD_SETTING 4 OUTPUT_DIR experiments/city_to_bdd100k/setting4
\ No newline at end of file
diff --git a/tools/launch.py b/tools/launch.py
new file mode 100644
index 0000000..2b3ceaa
--- /dev/null
+++ b/tools/launch.py
@@ -0,0 +1,192 @@
+# --------------------------------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# --------------------------------------------------------------------------------------------------------------------------
+# Modified from https://github.com/pytorch/pytorch/blob/173f224570017b4b1a3a1a13d0bff280a54d9cd9/torch/distributed/launch.py
+# --------------------------------------------------------------------------------------------------------------------------
+
+r"""
+`torch.distributed.launch` is a module that spawns up multiple distributed
+training processes on each of the training nodes.
+The utility can be used for single-node distributed training, in which one or
+more processes per node will be spawned. The utility can be used for either
+CPU training or GPU training. If the utility is used for GPU training,
+each distributed process will be operating on a single GPU. This can achieve
+well-improved single-node training performance. It can also be used in
+multi-node distributed training, by spawning up multiple processes on each node
+for well-improved multi-node distributed training performance as well.
+This will especially be benefitial for systems with multiple Infiniband
+interfaces that have direct-GPU support, since all of them can be utilized for
+aggregated communication bandwidth.
+In both cases of single-node distributed training or multi-node distributed
+training, this utility will launch the given number of processes per node
+(``--nproc_per_node``). If used for GPU training, this number needs to be less
+or euqal to the number of GPUs on the current system (``nproc_per_node``),
+and each process will be operating on a single GPU from *GPU 0 to
+GPU (nproc_per_node - 1)*.
+**How to use this module:**
+1. Single-Node multi-process distributed training
+::
+ >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
+ YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
+ arguments of your training script)
+2. Multi-Node multi-process distributed training: (e.g. two nodes)
+Node 1: *(IP: 192.168.1.1, and has a free port: 1234)*
+::
+ >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
+ --nnodes=2 --node_rank=0 --master_addr="192.168.1.1"
+ --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
+ and all other arguments of your training script)
+Node 2:
+::
+ >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
+ --nnodes=2 --node_rank=1 --master_addr="192.168.1.1"
+ --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
+ and all other arguments of your training script)
+3. To look up what optional arguments this module offers:
+::
+ >>> python -m torch.distributed.launch --help
+**Important Notices:**
+1. This utilty and multi-process distributed (single-node or
+multi-node) GPU training currently only achieves the best performance using
+the NCCL distributed backend. Thus NCCL backend is the recommended backend to
+use for GPU training.
+2. In your training program, you must parse the command-line argument:
+``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by this module.
+If your training program uses GPUs, you should ensure that your code only
+runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by:
+Parsing the local_rank argument
+::
+ >>> import argparse
+ >>> parser = argparse.ArgumentParser()
+ >>> parser.add_argument("--local_rank", type=int)
+ >>> args = parser.parse_args()
+Set your device to local rank using either
+::
+ >>> torch.cuda.set_device(arg.local_rank) # before your code runs
+or
+::
+ >>> with torch.cuda.device(arg.local_rank):
+ >>> # your code to run
+3. In your training program, you are supposed to call the following function
+at the beginning to start the distributed backend. You need to make sure that
+the init_method uses ``env://``, which is the only supported ``init_method``
+by this module.
+::
+ torch.distributed.init_process_group(backend='YOUR BACKEND',
+ init_method='env://')
+4. In your training program, you can either use regular distributed functions
+or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
+training program uses GPUs for training and you would like to use
+:func:`torch.nn.parallel.DistributedDataParallel` module,
+here is how to configure it.
+::
+ model = torch.nn.parallel.DistributedDataParallel(model,
+ device_ids=[arg.local_rank],
+ output_device=arg.local_rank)
+Please ensure that ``device_ids`` argument is set to be the only GPU device id
+that your code will be operating on. This is generally the local rank of the
+process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``,
+and ``output_device`` needs to be ``args.local_rank`` in order to use this
+utility
+5. Another way to pass ``local_rank`` to the subprocesses via environment variable
+``LOCAL_RANK``. This behavior is enabled when you launch the script with
+``--use_env=True``. You must adjust the subprocess example above to replace
+``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher
+will not pass ``--local_rank`` when you specify this flag.
+.. warning::
+ ``local_rank`` is NOT globally unique: it is only unique per process
+ on a machine. Thus, don't use it to decide if you should, e.g.,
+ write to a networked filesystem. See
+ https://github.com/pytorch/pytorch/issues/12042 for an example of
+ how things can go wrong if you don't do this correctly.
+"""
+
+
+import sys
+import subprocess
+import os
+import socket
+from argparse import ArgumentParser, REMAINDER
+
+import torch
+
+
+def parse_args():
+ """
+ Helper function parsing the command line options
+ @retval ArgumentParser
+ """
+ parser = ArgumentParser(description="PyTorch distributed training launch "
+ "helper utilty that will spawn up "
+ "multiple distributed processes")
+
+ # Optional arguments for the launch helper
+ parser.add_argument("--nnodes", type=int, default=1,
+ help="The number of nodes to use for distributed "
+ "training")
+ parser.add_argument("--node_rank", type=int, default=0,
+ help="The rank of the node for multi-node distributed "
+ "training")
+ parser.add_argument("--nproc_per_node", type=int, default=1,
+ help="The number of processes to launch on each node, "
+ "for GPU training, this is recommended to be set "
+ "to the number of GPUs in your system so that "
+ "each process can be bound to a single GPU.")
+ parser.add_argument("--master_addr", default="127.0.0.1", type=str,
+ help="Master node (rank 0)'s address, should be either "
+ "the IP address or the hostname of node 0, for "
+ "single node multi-proc training, the "
+ "--master_addr can simply be 127.0.0.1")
+ parser.add_argument("--master_port", default=29500, type=int,
+ help="Master node (rank 0)'s free port that needs to "
+ "be used for communciation during distributed "
+ "training")
+
+ # positional
+ parser.add_argument("training_script", type=str,
+ help="The full path to the single GPU training "
+ "program/script to be launched in parallel, "
+ "followed by all the arguments for the "
+ "training script")
+
+ # rest from the training program
+ parser.add_argument('training_script_args', nargs=REMAINDER)
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ # world size in terms of number of processes
+ dist_world_size = args.nproc_per_node * args.nnodes
+
+ # set PyTorch distributed related environmental variables
+ current_env = os.environ.copy()
+ current_env["MASTER_ADDR"] = args.master_addr
+ current_env["MASTER_PORT"] = str(args.master_port)
+ current_env["WORLD_SIZE"] = str(dist_world_size)
+
+ processes = []
+
+ for local_rank in range(0, args.nproc_per_node):
+ # each process's rank
+ dist_rank = args.nproc_per_node * args.node_rank + local_rank
+ current_env["RANK"] = str(dist_rank)
+ current_env["LOCAL_RANK"] = str(local_rank)
+
+ cmd = [args.training_script] + args.training_script_args
+
+ process = subprocess.Popen(cmd, env=current_env)
+ processes.append(process)
+
+ for process in processes:
+ process.wait()
+ if process.returncode != 0:
+ raise subprocess.CalledProcessError(returncode=process.returncode,
+ cmd=process.args)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/tools/run_dist_launch.sh b/tools/run_dist_launch.sh
new file mode 100644
index 0000000..f6f6c4f
--- /dev/null
+++ b/tools/run_dist_launch.sh
@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+
+set -x
+
+GPUS=$1
+RUN_COMMAND=${@:2}
+if [ $GPUS -lt 8 ]; then
+ GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS}
+else
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
+fi
+MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
+MASTER_PORT=${MASTER_PORT:-"29500"}
+NODE_RANK=${NODE_RANK:-0}
+
+let "NNODES=GPUS/GPUS_PER_NODE"
+
+python ./tools/launch.py \
+ --nnodes ${NNODES} \
+ --node_rank ${NODE_RANK} \
+ --master_addr ${MASTER_ADDR} \
+ --master_port ${MASTER_PORT} \
+ --nproc_per_node ${GPUS_PER_NODE} \
+ ${RUN_COMMAND}
\ No newline at end of file
diff --git a/tools/run_dist_slurm.sh b/tools/run_dist_slurm.sh
new file mode 100644
index 0000000..bd73d0b
--- /dev/null
+++ b/tools/run_dist_slurm.sh
@@ -0,0 +1,33 @@
+#!/usr/bin/env bash
+# --------------------------------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# --------------------------------------------------------------------------------------------------------------------------
+# Modified from https://github.com/open-mmlab/mmdetection/blob/3b53fe15d87860c6941f3dda63c0f27422da6266/tools/slurm_train.sh
+# --------------------------------------------------------------------------------------------------------------------------
+
+set -x
+
+PARTITION=$1
+JOB_NAME=$2
+GPUS=$3
+RUN_COMMAND=${@:4}
+if [ $GPUS -lt 8 ]; then
+ GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS}
+else
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
+fi
+CPUS_PER_TASK=${CPUS_PER_TASK:-4}
+SRUN_ARGS=${SRUN_ARGS:-""}
+
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ ${SRUN_ARGS} \
+ ${RUN_COMMAND}
+
diff --git a/util/__init__.py b/util/__init__.py
new file mode 100644
index 0000000..01a98e4
--- /dev/null
+++ b/util/__init__.py
@@ -0,0 +1,10 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
diff --git a/util/anno_convert.py b/util/anno_convert.py
new file mode 100644
index 0000000..c3bb09c
--- /dev/null
+++ b/util/anno_convert.py
@@ -0,0 +1,422 @@
+# ----------------------------------------------
+# Created by Wei-Jie Huang
+# A collection of annotation conversion scripts
+# We provide this for reference only
+# ----------------------------------------------
+
+
+import json
+from pathlib import Path
+from tqdm import tqdm
+
+
+r"""
+coco_anno_dict is like {
+ "images": list of image_info's
+ "annotations": list of annotation_info's
+ "categories": list of categorie_info's
+}
+where img_info is like: {
+ "id": ..., # 0-indexed
+ "width": ...,
+ "height": ...,
+ "file_name": ...,
+}, annotation_info is like: {
+ "id": ..., # 0-indexed
+ "image_id": ...,
+ "category_id": ...,
+ "segmentation": ...,
+ "iscrowd": ...,
+ "area": ...,
+ "bbox": ..., # (x, y, w, h)
+}, and category_info is like: {
+ "id": ..., # 1-indexed
+ "name": ...,
+}
+"""
+
+
+def sim10k_to_coco(
+ src_path: str = "VOC2012/Annotations",
+ des_path: str = "annotations/sim10k_caronly.json",
+ categories: tuple = ("car",)
+ ) -> None:
+
+ r""" Convert Sim10k (in VOC format) into COCO format.
+ Args:
+ src_path: path of the directory containing VOC-format annotations
+ des_path: destination of the converted COCO-fomat annotation
+ categories: only category ``car`` is considered by default
+ """
+
+ from xml.etree import ElementTree
+
+ src_path = Path(src_path)
+ des_path = Path(des_path)
+ assert src_path.exists(), "Annotation directory does not exist"
+ if des_path.exists():
+ print(f"{des_path} exists. Override? (y/n)", end=" ")
+ if input() != "y":
+ print("Abort")
+ return
+ else:
+ des_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Initialization
+ coco_anno_dict = {
+ "images": [],
+ "categories": [],
+ "annotations": [],
+ }
+ num_images = 0
+ num_categories = 0
+ num_annotations = 0
+
+ # Categories
+ category_to_id = {}
+ for category in categories:
+ coco_anno_dict["categories"].append({
+ "id": num_categories + 1,
+ "name": category
+ })
+ category_to_id[category] = num_categories + 1
+ num_categories += 1
+
+ # Start Conversion
+ for anno_file in tqdm(list(src_path.glob("*.xml"))):
+ et_root = ElementTree.parse(anno_file).getroot()
+
+ ##### Images #####
+ img_info = {
+ "id": num_images,
+ "file_name": anno_file.stem + ".jpg",
+ }
+ num_images += 1
+
+ # Image Size
+ size = et_root.find("size")
+ img_info["width"] = int(size.find("width").text)
+ img_info["height"] = int(size.find("height").text)
+
+ coco_anno_dict["images"].append(img_info)
+
+ ##### Annotations #####
+ for anno_object in et_root.findall("object"):
+ category = anno_object.find("name").text
+ if category not in categories:
+ continue
+ anno_info = {
+ "id": num_annotations,
+ "image_id": img_info["id"],
+ "category_id": category_to_id[category],
+ "segmentation": [],
+ "iscrowd": 0
+ }
+ num_annotations += 1
+
+ # Bounding box
+ bndbox = anno_object.find("bndbox")
+ xmin = float(bndbox.find("xmin").text)
+ ymin = float(bndbox.find("ymin").text)
+ xmax = float(bndbox.find("xmax").text)
+ ymax = float(bndbox.find("ymax").text)
+ # COCO format expects (x, y, w, h)
+ anno_info["bbox"] = [xmin, ymin, round(xmax - xmin, 2), round(ymax - ymin, 2)]
+ anno_info["area"] = round(anno_info["bbox"][2] * anno_info["bbox"][3], 2)
+
+ coco_anno_dict["annotations"].append(anno_info)
+
+ print("# of images:", num_images)
+ print("# of categories:", num_categories)
+ print("# of annotations:", num_annotations)
+
+ with open(des_path, 'w') as f:
+ f.write(json.dumps(coco_anno_dict, indent=4))
+ print(f"Convert successfully to {des_path}")
+
+
+def bdd100k_daytime_to_coco(
+ src_path: str = "labels/bdd100k_labels_images_train.json",
+ des_path: str = "annotations/bdd_daytime_train.json",
+ categories: tuple = (
+ "person", "rider", "car", "truck", "bus", "train", "motor", "bike")
+ ) -> None:
+
+ r""" Extract ``daytime`` subset from BDD100k dataset and convert into COCO format.
+ Args:
+ src_path: source of the annotation json file
+ des_path: destination of the converted COCO-fomat annotation
+ categories: categories used
+ """
+
+ src_path = Path(src_path)
+ des_path = Path(des_path)
+ assert src_path.exists(), "Source annotation file does not exist"
+ if des_path.exists():
+ print(f"{des_path} exists. Override? (y/n)", end=" ")
+ if input() != "y":
+ print("Abort")
+ return
+ else:
+ des_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Initialization
+ coco_anno_dict = {
+ "images": [],
+ "categories": [],
+ "annotations": [],
+ }
+ num_images = 0
+ num_categories = 0
+ num_annotations = 0
+
+ # Categories
+ category_to_id = {}
+ for category in categories:
+ coco_anno_dict["categories"].append({
+ "id": num_categories + 1,
+ "name": category
+ })
+ category_to_id[category] = num_categories + 1
+ num_categories += 1
+
+ with open(src_path, 'r') as f:
+ raw_img_annos = json.load(f)
+ # Start Conversion
+ for raw_img_anno in tqdm(raw_img_annos):
+ if raw_img_anno["attributes"]["timeofday"] != "daytime":
+ continue
+
+ ##### Images #####
+ img_info = {
+ "id": num_images,
+ "file_name": raw_img_anno["name"],
+ "height": 720,
+ "width": 1280
+ }
+ coco_anno_dict["images"].append(img_info)
+ num_images += 1
+
+ ##### Annotations #####
+ for label in raw_img_anno["labels"]:
+ if label["category"] not in category_to_id or "box2d" not in label:
+ continue
+ anno_info = {
+ "id": num_annotations,
+ "image_id": img_info["id"],
+ "category_id": category_to_id[label["category"]],
+ "segmentation": [],
+ "iscrowd": 0,
+ }
+ num_annotations += 1
+
+ # Bbox
+ x1 = label["box2d"]["x1"]
+ y1 = label["box2d"]["y1"]
+ x2 = label["box2d"]["x2"]
+ y2 = label["box2d"]["y2"]
+ anno_info["bbox"] = [x1, y1, x2 - x1, y2 - y1]
+ anno_info["area"] = float((x2 - x1) * (y2 - y1))
+ coco_anno_dict["annotations"].append(anno_info)
+
+ print("# of images:", num_images)
+ print("# of categories:", num_categories)
+ print("# of annotations:", num_annotations)
+
+ with open(des_path, 'w') as f:
+ f.write(json.dumps(coco_anno_dict, indent=4))
+ print(f"Convert successfully to {des_path}")
+
+
+def cityscapes_to_coco(
+ src_path: str = "gtFine/train",
+ des_path: str = "annotations/cityscapes_train.json",
+ car_only: bool = False,
+ foggy: bool = False,
+ categories: tuple = (
+ "person", "rider", "car", "truck", "bus", "train", "motor", "bike")
+ ) -> None:
+
+ r"""Convert Cityscapes into COCO format.
+ Ref: https://github.com/facebookresearch/Detectron/blob/7aa91aa/tools/convert_cityscapes_to_coco.py
+ Args:
+ src_path: path of the directory containing Cityscapes annotations
+ des_path: destination of the converted COCO-fomat annotation
+ car_only: whether extract category ``car`` only. used in Syn-to-real adaptation
+ foggy: whether extract from foggy cityscapes. used in weather adaptation
+ categories: categories used
+ """
+
+ def get_instances_with_polygons(imageFileName):
+ r""" Ref: https://github.com/facebookresearch/Detectron/issues/111#issuecomment-363425465"""
+ import os
+ import sys
+ import cv2
+ import numpy as np
+ from PIL import Image
+ from cityscapesscripts.evaluation.instance import Instance
+ from cityscapesscripts.helpers.csHelpers import labels, id2label
+
+ # Load image
+ img = Image.open(imageFileName)
+
+ # Image as numpy array
+ imgNp = np.array(img)
+
+ # Initialize label categories
+ instances = {}
+ for label in labels:
+ instances[label.name] = []
+
+ # Loop through all instance ids in instance image
+ for instanceId in np.unique(imgNp):
+ if instanceId < 1000:
+ continue
+
+ instanceObj = Instance(imgNp, instanceId)
+ instanceObj_dict = instanceObj.toDict()
+
+ if id2label[instanceObj.labelID].hasInstances:
+ mask = (imgNp == instanceId).astype(np.uint8)
+ contour, hier = cv2.findContours(
+ mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
+
+ polygons = [c.reshape(-1).tolist() for c in contour]
+ instanceObj_dict["contours"] = polygons
+
+ instances[id2label[instanceObj.labelID].name].append(
+ instanceObj_dict)
+ return instances
+
+ def polygon_to_bbox(polygon: list) -> list:
+ """Convert polygon into COCO-format bounding box."""
+
+ # https://github.com/facebookresearch/maskrcnn-benchmark/issues/288#issuecomment-449098063
+ TO_REMOVE = 1
+
+ x0 = min(min(p[::2]) for p in polygon)
+ x1 = max(max(p[::2]) for p in polygon)
+ y0 = min(min(p[1::2]) for p in polygon)
+ y1 = max(max(p[1::2]) for p in polygon)
+
+ bbox = [x0, y0, x1 -x0 + TO_REMOVE, y1 - y0 + TO_REMOVE]
+ return bbox
+
+ src_path = Path(src_path)
+ des_path = Path(des_path)
+ assert src_path.exists(), "Source annotation file does not exist"
+ if des_path.exists():
+ print(f"{des_path} exists. Override? (y/n)", end=" ")
+ if input() != "y":
+ print("Abort")
+ return
+ else:
+ des_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Initialization
+ coco_anno_dict = {
+ "images": [],
+ "categories": [],
+ "annotations": [],
+ }
+ num_images = 0
+ num_categories = 0
+ num_annotations = 0
+
+ # Categories
+ if car_only:
+ categories = ("car",)
+ category_to_id = {}
+ for category in categories:
+ coco_anno_dict["categories"].append({
+ "id": num_categories + 1,
+ "name": category
+ })
+ category_to_id[category] = num_categories + 1
+ num_categories += 1
+
+ # Start Conversion
+ for file in tqdm(list(src_path.rglob("*instanceIds.png"))):
+ ##### Images #####
+ img_info = {"id": num_images}
+ num_images += 1
+ img_info["file_name"] = \
+ str(file.name).split("_", maxsplit=1)[0] + "/" + \
+ str(file.name).replace("gtFine", "leftImg8bit").replace("_instanceIds", "")
+ if foggy:
+ img_info["file_name"] = \
+ img_info["file_name"].replace("leftImg8bit", "leftImg8bit_foggy_beta_0.02")
+ with open(str(file).replace("instanceIds.png", "polygons.json"), "r") as f:
+ polygon_info = json.load(f)
+ img_info["width"] = polygon_info["imgWidth"]
+ img_info["height"] = polygon_info["imgHeight"]
+ coco_anno_dict["images"].append(img_info)
+
+ ##### Annotations #####
+ instances = get_instances_with_polygons(str(file.absolute()))
+ for category in instances.keys():
+ if category not in categories:
+ continue
+ for instance in instances[category]:
+ anno_info = {
+ "id": num_annotations,
+ "image_id": img_info["id"],
+ "category_id": category_to_id[category],
+ "segmentation": [],
+ "iscrowd": 0,
+ "area": instance["pixelCount"],
+ "bbox": polygon_to_bbox(instance["contours"]),
+ }
+ num_annotations += 1
+ coco_anno_dict["annotations"].append(anno_info)
+
+ print("# of images:", num_images)
+ print("# of categories:", num_categories)
+ print("# of annotations:", num_annotations)
+
+ with open(des_path, 'w') as f:
+ f.write(json.dumps(coco_anno_dict, indent=4))
+ print(f"Convert successfully to {des_path}")
+
+
+if __name__ == "__main__":
+ sim10k_to_coco(
+ src_path="VOC2012/Annotations",
+ des_path="annotations/sim10k_caronly.json"
+ )
+ bdd100k_daytime_to_coco(
+ src_path="labels/bdd100k_labels_images_train.json",
+ des_path="annotations/bdd_daytime_train.json"
+ )
+ bdd100k_daytime_to_coco(
+ src_path="labels/bdd100k_labels_images_val.json",
+ des_path="annotations/bdd_daytime_val.json"
+ )
+ cityscapes_to_coco(
+ src_path="gtFine/train",
+ des_path="annotations/cityscapes_train.json",
+ )
+ cityscapes_to_coco(
+ src_path="gtFine/val",
+ des_path="annotations/cityscapes_val.json",
+ )
+ cityscapes_to_coco(
+ src_path="gtFine/train",
+ des_path="annotations/cityscapes_caronly_train.json",
+ car_only=True,
+ )
+ cityscapes_to_coco(
+ src_path="gtFine/val",
+ des_path="annotations/cityscapes_caronly_val.json",
+ car_only=True,
+ )
+ cityscapes_to_coco(
+ src_path="gtFine/train",
+ des_path="annotations/foggy_cityscapes_train.json",
+ foggy=True,
+ )
+ cityscapes_to_coco(
+ src_path="gtFine/val",
+ des_path="annotations/foggy_cityscapes_val.json",
+ foggy=True,
+ )
diff --git a/util/box_ops.py b/util/box_ops.py
new file mode 100644
index 0000000..3b87b2b
--- /dev/null
+++ b/util/box_ops.py
@@ -0,0 +1,98 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Utilities for bounding box manipulation and GIoU.
+"""
+import torch
+from torchvision.ops.boxes import box_area
+
+
+def box_cxcywh_to_xyxy(x):
+ x_c, y_c, w, h = x.unbind(-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ return torch.stack(b, dim=-1)
+
+
+def box_xyxy_to_cxcywh(x):
+ x0, y0, x1, y1 = x.unbind(-1)
+ b = [(x0 + x1) / 2, (y0 + y1) / 2,
+ (x1 - x0), (y1 - y0)]
+ return torch.stack(b, dim=-1)
+
+
+# modified from torchvision to also return the union
+def box_iou(boxes1, boxes2):
+ area1 = box_area(boxes1)
+ area2 = box_area(boxes2)
+
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
+
+ union = area1[:, None] + area2 - inter
+
+ iou = inter / union
+ return iou, union
+
+
+def generalized_box_iou(boxes1, boxes2):
+ """
+ Generalized IoU from https://giou.stanford.edu/
+
+ The boxes should be in [x0, y0, x1, y1] format
+
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
+ and M = len(boxes2)
+ """
+ # degenerate boxes gives inf / nan results
+ # so do an early check
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+ iou, union = box_iou(boxes1, boxes2)
+
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ area = wh[:, :, 0] * wh[:, :, 1]
+
+ return iou - (area - union) / area
+
+
+def masks_to_boxes(masks):
+ """Compute the bounding boxes around the provided masks
+
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
+
+ Returns a [N, 4] tensors, with the boxes in xyxy format
+ """
+ if masks.numel() == 0:
+ return torch.zeros((0, 4), device=masks.device)
+
+ h, w = masks.shape[-2:]
+
+ y = torch.arange(0, h, dtype=torch.float)
+ x = torch.arange(0, w, dtype=torch.float)
+ y, x = torch.meshgrid(y, x)
+
+ x_mask = (masks * x.unsqueeze(0))
+ x_max = x_mask.flatten(1).max(-1)[0]
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
+
+ y_mask = (masks * y.unsqueeze(0))
+ y_max = y_mask.flatten(1).max(-1)[0]
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
+
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
diff --git a/util/misc.py b/util/misc.py
new file mode 100644
index 0000000..31f0584
--- /dev/null
+++ b/util/misc.py
@@ -0,0 +1,534 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+import os
+import subprocess
+import time
+from collections import defaultdict, deque
+import datetime
+import pickle
+from typing import Optional, List
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from torch import Tensor
+
+# needed due to empty tensor bug in pytorch and torchvision 0.5
+import torchvision
+if float(torchvision.__version__[:3]) < 0.5:
+ import math
+ from torchvision.ops.misc import _NewEmptyTensorOp
+ def _check_size_scale_factor(dim, size, scale_factor):
+ # type: (int, Optional[List[int]], Optional[float]) -> None
+ if size is None and scale_factor is None:
+ raise ValueError("either size or scale_factor should be defined")
+ if size is not None and scale_factor is not None:
+ raise ValueError("only one of size or scale_factor should be defined")
+ if not (scale_factor is not None and len(scale_factor) != dim):
+ raise ValueError(
+ "scale_factor shape must match input shape. "
+ "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
+ )
+ def _output_size(dim, input, size, scale_factor):
+ # type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int]
+ assert dim == 2
+ _check_size_scale_factor(dim, size, scale_factor)
+ if size is not None:
+ return size
+ # if dim is not 2 or scale_factor is iterable use _ntuple instead of concat
+ assert scale_factor is not None and isinstance(scale_factor, (int, float))
+ scale_factors = [scale_factor, scale_factor]
+ # math.floor might return float in py2.7
+ return [
+ int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
+ ]
+elif float(torchvision.__version__[:3]) < 0.7:
+ from torchvision.ops import _new_empty_tensor
+ from torchvision.ops.misc import _output_size
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+def all_gather(data):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ # serialized to a Tensor
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to("cuda")
+
+ # obtain Tensor size of each rank
+ local_size = torch.tensor([tensor.numel()], device="cuda")
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
+ if local_size != max_size:
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
+ tensor = torch.cat((tensor, padding), dim=0)
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Args:
+ input_dict (dict): all the values will be reduced
+ average (bool): whether to do average or sum
+ Reduce the values in the dictionary from all processes so that all processes
+ have the averaged results. Returns a dict with the same fields as
+ input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.all_reduce(values)
+ if average:
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None,logger=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ if torch.cuda.is_available():
+ log_msg = self.delimiter.join([
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}',
+ 'max mem: {memory:.0f}'
+ ])
+ else:
+ log_msg = self.delimiter.join([
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ])
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ if logger is not None and is_main_process():
+ logger.info(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ if logger is not None and is_main_process():
+ logger.info('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+ else:
+
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
+ sha = 'N/A'
+ diff = "clean"
+ branch = 'N/A'
+ try:
+ sha = _run(['git', 'rev-parse', 'HEAD'])
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
+ diff = _run(['git', 'diff-index', 'HEAD'])
+ diff = "has uncommited changes" if diff else "clean"
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+def collate_fn(batch):
+ batch = list(zip(*batch))
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
+ return tuple(batch)
+
+
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+ # TODO make this more general
+ if tensor_list[0].ndim == 3:
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ m[: img.shape[1], :img.shape[2]] = False
+ else:
+ raise ValueError('not supported')
+ return NestedTensor(tensor, mask)
+
+
+class NestedTensor(object):
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+
+ def to(self, device, non_blocking=False):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device, non_blocking=non_blocking)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device, non_blocking=non_blocking)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def record_stream(self, *args, **kwargs):
+ self.tensors.record_stream(*args, **kwargs)
+ if self.mask is not None:
+ self.mask.record_stream(*args, **kwargs)
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_local_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return int(os.environ['LOCAL_SIZE'])
+
+
+def get_local_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return int(os.environ['LOCAL_RANK'])
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(cfg):
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ cfg.DIST.RANK = int(os.environ["RANK"])
+ cfg.DIST.WORLD_SIZE = int(os.environ['WORLD_SIZE'])
+ cfg.DIST.GPU = int(os.environ['LOCAL_RANK'])
+ cfg.DIST.DIST_URL = 'env://'
+ os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
+ elif 'SLURM_PROCID' in os.environ:
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ addr = subprocess.getoutput(
+ 'scontrol show hostname {} | head -n1'.format(node_list))
+ os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['RANK'] = str(proc_id)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['LOCAL_SIZE'] = str(num_gpus)
+ cfg.DIST.DIST_URL = 'env://'
+ cfg.DIST.WORLD_SIZE = ntasks
+ cfg.DIST.RANK = proc_id
+ cfg.DIST.GPU = proc_id % num_gpus
+ else:
+ print('Not using distributed mode')
+ cfg.DIST.DISTRIBUTED = False
+ return
+
+ cfg.DIST.DISTRIBUTED = True
+
+ torch.cuda.set_device(cfg.DIST.GPU)
+ cfg.DIST.DIST_BACKEND = 'nccl'
+ print('| distributed init (rank {}): {}'.format(
+ cfg.DIST.RANK, cfg.DIST.DIST_URL), flush=True)
+ torch.distributed.init_process_group(backend=cfg.DIST.DIST_BACKEND, init_method=cfg.DIST.DIST_URL,
+ world_size=cfg.DIST.WORLD_SIZE, rank=cfg.DIST.RANK)
+ torch.distributed.barrier()
+ setup_for_distributed(cfg.DIST.RANK == 0)
+
+
+@torch.no_grad()
+def accuracy(output, target, topk=(1,)):
+ """Computes the precision@k for the specified values of k"""
+ if target.numel() == 0:
+ return [torch.zeros([], device=output.device)]
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
+ """
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
+ This will eventually be supported natively by PyTorch, and this
+ class can go away.
+ """
+ if float(torchvision.__version__[:3]) < 0.7:
+ if input.numel() > 0:
+ return torch.nn.functional.interpolate(
+ input, size, scale_factor, mode, align_corners
+ )
+
+ output_shape = _output_size(2, input, size, scale_factor)
+ output_shape = list(input.shape[:-2]) + list(output_shape)
+ if float(torchvision.__version__[:3]) < 0.5:
+ return _NewEmptyTensorOp.apply(input, output_shape)
+ return _new_empty_tensor(input, output_shape)
+ else:
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
+
+
+def get_total_grad_norm(parameters, norm_type=2):
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
+ norm_type = float(norm_type)
+ device = parameters[0].grad.device
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
+ norm_type)
+ return total_norm
+
+def inverse_sigmoid(x, eps=1e-5):
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1/x2)
+
diff --git a/util/plot_utils.py b/util/plot_utils.py
new file mode 100644
index 0000000..f6ff051
--- /dev/null
+++ b/util/plot_utils.py
@@ -0,0 +1,113 @@
+# ------------------------------------------------------------------------
+# Modified by Wei-Jie Huang
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Plotting utilities to visualize training logs.
+"""
+import torch
+import pandas as pd
+import seaborn as sns
+import matplotlib.pyplot as plt
+
+from pathlib import Path, PurePath
+
+
+def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
+ '''
+ Function to plot specific fields from training log(s). Plots both training and test results.
+
+ :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
+ - fields = which results to plot from each log file - plots both training and test for each field.
+ - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
+ - log_name = optional, name of log file if different than default 'log.txt'.
+
+ :: Outputs - matplotlib plots of results in fields, color coded for each log file.
+ - solid lines are training results, dashed lines are test results.
+
+ '''
+ func_name = "plot_utils.py::plot_logs"
+
+ # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
+ # convert single Path to list to avoid 'not iterable' error
+
+ if not isinstance(logs, list):
+ if isinstance(logs, PurePath):
+ logs = [logs]
+ print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
+ else:
+ raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
+ Expect list[Path] or single Path obj, received {type(logs)}")
+
+ # verify valid dir(s) and that every item in list is Path object
+ for i, dir in enumerate(logs):
+ if not isinstance(dir, PurePath):
+ raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
+ if dir.exists():
+ continue
+ raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
+
+ # load log file(s) and plot
+ dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
+
+ fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
+
+ for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
+ for j, field in enumerate(fields):
+ if field == 'mAP':
+ coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean()
+ axs[j].plot(coco_eval, c=color)
+ else:
+ df.interpolate().ewm(com=ewm_col).mean().plot(
+ y=[f'train_{field}', f'test_{field}'],
+ ax=axs[j],
+ color=[color] * 2,
+ style=['-', '--']
+ )
+ for ax, field in zip(axs, fields):
+ ax.legend([Path(p).name for p in logs])
+ ax.set_title(field)
+
+
+def plot_precision_recall(files, naming_scheme='iter'):
+ if naming_scheme == 'exp_id':
+ # name becomes exp_id
+ names = [f.parts[-3] for f in files]
+ elif naming_scheme == 'iter':
+ names = [f.stem for f in files]
+ else:
+ raise ValueError(f'not supported {naming_scheme}')
+ fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
+ for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
+ data = torch.load(f)
+ # precision is n_iou, n_points, n_cat, n_area, max_det
+ precision = data['precision']
+ recall = data['params'].recThrs
+ scores = data['scores']
+ # take precision for all classes, all areas and 100 detections
+ precision = precision[0, :, :, 0, -1].mean(1)
+ scores = scores[0, :, :, 0, -1].mean(1)
+ prec = precision.mean()
+ rec = data['recall'][0, :, 0, -1].mean()
+ print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
+ f'score={scores.mean():0.3f}, ' +
+ f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
+ )
+ axs[0].plot(recall, precision, c=color)
+ axs[1].plot(recall, scores, c=color)
+
+ axs[0].set_title('Precision / Recall')
+ axs[0].legend(names)
+ axs[1].set_title('Scores / Recall')
+ axs[1].legend(names)
+ return fig, axs
+
+
+