Skip to content

Commit

Permalink
updated dataloader and formatted code
Browse files Browse the repository at this point in the history
  • Loading branch information
Fangchang Ma committed Oct 2, 2019
1 parent a3374ba commit 898aa91
Show file tree
Hide file tree
Showing 14 changed files with 795 additions and 524 deletions.
76 changes: 42 additions & 34 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,52 +1,56 @@
# self-supervised-depth-completion

This repo contains the PyTorch implementation of our ICRA'19 paper on ["Self-supervised Sparse-to-Dense: Self-supervised Depth Completion from LiDAR and Monocular Camera"](https://arxiv.org/pdf/1807.00275.pdf) by [Fangchang Ma](http://www.mit.edu/~fcma/), Guilherme Venturelli Cavalheiro, and [Sertac Karaman](http://karaman.mit.edu/) at MIT. A video demonstration is available on [YouTube](https://youtu.be/bGXfvF261pc).
This repo is the PyTorch implementation of our ICRA'19 paper on ["Self-supervised Sparse-to-Dense: Self-supervised Depth Completion from LiDAR and Monocular Camera"](https://arxiv.org/pdf/1807.00275.pdf), developed by [Fangchang Ma](http://www.mit.edu/~fcma/), Guilherme Venturelli Cavalheiro, and [Sertac Karaman](http://karaman.mit.edu/) at MIT. A video demonstration is available on [YouTube](https://youtu.be/bGXfvF261pc).

<p align="center">
<img src="https://j.gifs.com/rRrOW4.gif" alt="photo not available" height="50%">
</p>

Our network is trained with the KITTI dataset alone, without pretraining on Cityscapes or other similar driving dataset (either synthetic or real). The use of additional data is likely to further improve the accuracy.

Please create a new issue for code-related questions.

## Contents
0. [Notes](#requirements)
0. [Requirements](#requirements)
1. [Dependency](#dependency)
0. [Data](#data)
0. [Trained Models](#trained-models)
0. [Training and Testing](#training-and-testing)
0. [Questions](#questions)
0. [Commands](#commands)
0. [Citation](#citation)

## Notes
Our network is trained with the KITTI dataset alone, without pretraining on Cityscapes or other similar driving dataset (either synthetic or real). The use of additional data is very likely to further improve the accuracy.

## Requirements
## Dependency
This code was tested with Python 3 and PyTorch 1.0 on Ubuntu 16.04.
- Install [PyTorch](https://pytorch.org/get-started/locally/) on a machine with CUDA GPU.
- The code for self-supervised training requires [OpenCV](http://pytorch.org/) along with the contrib modules. For instance,
```bash
pip3 uninstall opencv-contrib-python
pip3 install opencv-contrib-python==3.4.2.16
pip install numpy matplotlib Pillow
pip install torch torchvision # pytorch

# for self-supervised training requires opencv, along with the contrib modules
pip install opencv-contrib-python==3.4.2.16
```

## Data
- Download the [KITTI Depth](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion) Dataset from their website. Use the following scripts to extract corresponding RGB images from the raw dataset.
```bash
./download/rgb_train_downloader.sh
./download/rgb_val_downloader.sh
```
- Download the [KITTI Depth](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion) Dataset and the corresponding RGB images. Please refer to scripts under `download`.
- The code, data and result directory structure is shown as follows
The downloaded rgb files will be stored in the `../data/data_rgb` folder. The overall code, data, and results directory is structured as follows (updated on Oct 1, 2019)
```
.
├── self-supervised-depth-completion
├── data
| ├── kitti_depth
| ├── data_depth_annotated
| | ├── train
| | ├── val
| ├── data_depth_velodyne
| | ├── train
| | ├── val
| ├── depth_selection
| | ├── test_depth_completion_anonymous
| | ├── test_depth_prediction_anonymous
| | ├── val_selection_cropped
| | | ├── groundtruth_depth
| | | ├── image
| | | ├── intrinsics
| | | ├── velodyne_raw
| └── kitti_rgb
| └── data_rgb
| | ├── train
| | | ├── 2011_09_26_drive_0001_sync
| | | | ├── image_02
| | | | | ├── data
| | | | | | ├── 0000000000.png
| | | | | | ├── ...
| | | | ├── image_03
| | ├── val
├── results
```
Expand All @@ -56,21 +60,25 @@ Download our trained models at http://datasets.lids.mit.edu/self-supervised-dept
- supervised training (i.e., models trained with semi-dense lidar ground truth): http://datasets.lids.mit.edu/self-supervised-depth-completion/supervised/
- self-supervised (i.e., photometric loss + sparse depth loss + smoothness loss): http://datasets.lids.mit.edu/self-supervised-depth-completion/self-supervised/

## Training and testing
## Commands
A complete list of training options is available with
```bash
python main.py -h
```
For instance,
```bash
python main.py --train-mode dense -b 1 # train with the KITTI semi-dense annotations and batch size 1
python main.py --train-mode sparse+photo # train with the self-supervised framework, not using ground truth
python main.py --resume [checkpoint-path] # resume previous training
python main.py --evaluate [checkpoint-path] # test the trained model
```
# train with the KITTI semi-dense annotations, rgbd input, and batch size 1
python main.py --train-mode dense -b 1 --input rgbd

## Questions
Please create a new issue for code-related questions. Pull requests are welcome.
# train with the self-supervised framework, not using ground truth
python main.py --train-mode sparse+photo

# resume previous training
python main.py --resume [checkpoint-path]

# test the trained model on the val_selection_cropped data
python main.py --evaluate [checkpoint-path] --val select
```

## Citation
If you use our code or method in your work, please cite the following:
Expand Down
40 changes: 28 additions & 12 deletions criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,53 @@

loss_names = ['l1', 'l2']


class MaskedMSELoss(nn.Module):
def __init__(self):
super(MaskedMSELoss, self).__init__()

def forward(self, pred, target):
assert pred.dim() == target.dim(), "inconsistent dimensions"
valid_mask = (target>0).detach()
valid_mask = (target > 0).detach()
diff = target - pred
diff = diff[valid_mask]
self.loss = (diff ** 2).mean()
self.loss = (diff**2).mean()
return self.loss


class MaskedL1Loss(nn.Module):
def __init__(self):
super(MaskedL1Loss, self).__init__()

def forward(self, pred, target, weight=None):
assert pred.dim() == target.dim(), "inconsistent dimensions"
valid_mask = (target>0).detach()
valid_mask = (target > 0).detach()
diff = target - pred
diff = diff[valid_mask]
self.loss = diff.abs().mean()
return self.loss


class PhotometricLoss(nn.Module):
def __init__(self):
super(PhotometricLoss, self).__init__()

def forward(self, target, recon, mask=None):

assert recon.dim()==4, "expected recon dimension to be 4, but instead got {}.".format(recon.dim())
assert target.dim()==4, "expected target dimension to be 4, but instead got {}.".format(target.dim())
assert recon.dim(
) == 4, "expected recon dimension to be 4, but instead got {}.".format(
recon.dim())
assert target.dim(
) == 4, "expected target dimension to be 4, but instead got {}.".format(
target.dim())
assert recon.size()==target.size(), "expected recon and target to have the same size, but got {} and {} instead"\
.format(recon.size(), target.size())
diff = (target - recon).abs()
diff = torch.sum(diff, 1) # sum along the color channel
diff = torch.sum(diff, 1) # sum along the color channel

# compare only pixels that are not black
valid_mask = (torch.sum(recon, 1)>0).float() * (torch.sum(target, 1)>0).float()
valid_mask = (torch.sum(recon, 1) > 0).float() * (torch.sum(target, 1)
> 0).float()
if mask is not None:
valid_mask = valid_mask * torch.squeeze(mask).float()
valid_mask = valid_mask.byte().detach()
Expand All @@ -50,23 +58,31 @@ def forward(self, target, recon, mask=None):
if diff.nelement() > 0:
self.loss = diff.mean()
else:
print("warning: diff.nelement()==0 in PhotometricLoss (this is expected during early stage of training, try larger batch size).")
print(
"warning: diff.nelement()==0 in PhotometricLoss (this is expected during early stage of training, try larger batch size)."
)
self.loss = 0
else:
print("warning: 0 valid pixel in PhotometricLoss")
self.loss = 0
return self.loss


class SmoothnessLoss(nn.Module):
def __init__(self):
super(SmoothnessLoss, self).__init__()

def forward(self, depth):
def second_derivative(x):
assert x.dim() == 4, "expected 4-dimensional data, but instead got {}".format(x.dim())
horizontal = 2 * x[:,:,1:-1,1:-1] - x[:,:,1:-1,:-2] - x[:,:,1:-1,2:]
vertical = 2 * x[:,:,1:-1,1:-1] - x[:,:,:-2,1:-1] - x[:,:,2:,1:-1]
assert x.dim(
) == 4, "expected 4-dimensional data, but instead got {}".format(
x.dim())
horizontal = 2 * x[:, :, 1:-1, 1:-1] - x[:, :, 1:-1, :
-2] - x[:, :, 1:-1, 2:]
vertical = 2 * x[:, :, 1:-1, 1:-1] - x[:, :, :-2, 1:
-1] - x[:, :, 2:, 1:-1]
der_2nd = horizontal.abs() + vertical.abs()
return der_2nd.mean()

self.loss = second_derivative(depth)
return self.loss
return self.loss
Loading

0 comments on commit 898aa91

Please sign in to comment.