Skip to content

Commit

Permalink
Fixing #37 allow_shifted=False in MotionBlur (#38)
Browse files Browse the repository at this point in the history
* Fixing #37

* Black formatting

---------

Co-authored-by: Rémi Pautrat <32239569+rpautrat@users.noreply.github.com>
Co-authored-by: Rémi Pautrat <remi.pautrat@polytechnique.org>
  • Loading branch information
3 people authored Feb 12, 2024
1 parent 365ad7b commit dc01ec8
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 23 deletions.
12 changes: 9 additions & 3 deletions gluefactory/datasets/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,13 @@ def _init(self, conf):
[
A.Blur(**kwi(blur, p=0.1, blur_limit=(3, 9), n="blur")),
A.MotionBlur(
**kwi(blur, p=0.2, blur_limit=(3, 25), n="motion_blur")
**kwi(
blur,
p=0.2,
blur_limit=(3, 25),
allow_shifted=False,
n="motion_blur",
)
),
A.ISONoise(),
A.ImageCompression(),
Expand Down Expand Up @@ -222,14 +228,14 @@ def _init(self, conf):
A.OneOf(
[
A.Blur(blur_limit=(3, 9)),
A.MotionBlur(blur_limit=(3, 25)),
A.MotionBlur(blur_limit=(3, 25), allow_shifted=False),
A.ISONoise(),
A.ImageCompression(),
],
p=0.1,
),
A.Blur(p=0.1, blur_limit=(3, 9)),
A.MotionBlur(p=0.1, blur_limit=(3, 25)),
A.MotionBlur(p=0.1, blur_limit=(3, 25), allow_shifted=False),
A.RandomBrightnessContrast(
p=0.5, brightness_limit=(-0.4, 0.0), contrast_limit=(-0.3, 0.0)
),
Expand Down
1 change: 1 addition & 0 deletions gluefactory/datasets/eth3d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
ETH3D multi-view benchmark, used for line matching evaluation.
"""

import logging
import os
import shutil
Expand Down
2 changes: 1 addition & 1 deletion gluefactory/datasets/homographies.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def visualize(args):
images = []
for _, data in zip(range(args.num_items), loader):
images.append(
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
[data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2)]
)
plot_image_grid(images, dpi=args.dpi)
plt.tight_layout()
Expand Down
3 changes: 2 additions & 1 deletion gluefactory/datasets/hpatches.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Simply load images from a folder or nested folders (does not have any split).
"""

import argparse
import logging
import tarfile
Expand Down Expand Up @@ -127,7 +128,7 @@ def visualize(args):
images = []
for _, data in zip(range(args.num_items), loader):
images.append(
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
[data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2)]
)
plot_image_grid(images, dpi=args.dpi)
plt.tight_layout()
Expand Down
6 changes: 3 additions & 3 deletions gluefactory/geometry/gt_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,9 @@ def gt_line_matches_from_pose_depth(
all_in_batch = (
torch.arange(b_size)[:, None].repeat(1, assignation.shape[-1]).flatten()
)
positive[
all_in_batch, assignation[:, 0].flatten(), assignation[:, 1].flatten()
] = True
positive[all_in_batch, assignation[:, 0].flatten(), assignation[:, 1].flatten()] = (
True
)

m0 = assignation.new_full((b_size, n_lines0), unmatched, dtype=torch.long)
m0.scatter_(-1, assignation[:, 0], assignation[:, 1])
Expand Down
17 changes: 11 additions & 6 deletions gluefactory/models/cache_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ def pad_line_features(pred, seq_l: int = None):

def recursive_load(grp, pkeys):
return {
k: torch.from_numpy(grp[k].__array__())
if isinstance(grp[k], h5py.Dataset)
else recursive_load(grp[k], list(grp.keys()))
k: (
torch.from_numpy(grp[k].__array__())
if isinstance(grp[k], h5py.Dataset)
else recursive_load(grp[k], list(grp.keys()))
)
for k in pkeys
}

Expand Down Expand Up @@ -108,9 +110,12 @@ def _forward(self, data):
pred = recursive_load(grp, pkeys)
if self.numeric_dtype is not None:
pred = {
k: v
if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v)
else v.to(dtype=self.numeric_dtype)
k: (
v
if not isinstance(v, torch.Tensor)
or not torch.is_floating_point(v)
else v.to(dtype=self.numeric_dtype)
)
for k, v in pred.items()
}
pred = batch_to_device(pred, device)
Expand Down
8 changes: 5 additions & 3 deletions gluefactory/models/extractors/aliked.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,9 +717,11 @@ def _init(self, conf):
radius=conf.nms_radius,
top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
scores_th=conf.detection_threshold,
n_limit=conf.max_num_keypoints
if conf.max_num_keypoints > 0
else self.n_limit_max,
n_limit=(
conf.max_num_keypoints
if conf.max_num_keypoints > 0
else self.n_limit_max
),
)

# load pretrained
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/extractors/superpoint_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
The implementation of this model and its trained weights are made
available under the MIT license.
"""

from collections import OrderedDict
from types import SimpleNamespace

Expand Down
6 changes: 3 additions & 3 deletions gluefactory/models/lines/wireframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ def _forward(self, data):
associativity = torch.eye(
len(all_points[-1]), dtype=torch.bool, device=device
)
associativity[
: n_true_junctions[bs], : n_true_junctions[bs]
] = line_association[bs][: n_true_junctions[bs], : n_true_junctions[bs]]
associativity[: n_true_junctions[bs], : n_true_junctions[bs]] = (
line_association[bs][: n_true_junctions[bs], : n_true_junctions[bs]]
)
pl_associativity.append(associativity)

all_points = torch.stack(all_points, dim=0)
Expand Down
8 changes: 5 additions & 3 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def create_input_data(cv_img0, cv_img1, device):
data = {"view0": ip(img0), "view1": ip(img1)}
data = map_tensor(
data,
lambda t: t[None].to(device)
if isinstance(t, Tensor)
else torch.from_numpy(t)[None].to(device),
lambda t: (
t[None].to(device)
if isinstance(t, Tensor)
else torch.from_numpy(t)[None].to(device)
),
)
return data

Expand Down

0 comments on commit dc01ec8

Please sign in to comment.