Skip to content

Commit

Permalink
fix: add option to use RANSAC to reject false matches
Browse files Browse the repository at this point in the history
  • Loading branch information
SuTanTank committed May 27, 2021
1 parent c7d09cb commit fe0a686
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 74 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,5 @@ __pycache__/
*.dll
*.exe
~case-cuhk_lib/
case*
case*
build
2 changes: 1 addition & 1 deletion Stitching-1.1.0/RANSAC/SURF2.m
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
threshold = threshold - 10;
pNew = detectSURFFeatures(I, 'ROI', roi, 'MetricThreshold', threshold);
end
while nMore * 2 < size(pNew, 1) && threshold < 200;
while nMore * 2 < size(pNew, 1) && threshold < 200
threshold = threshold + 20;
pNew = detectSURFFeatures(I, 'ROI', roi, 'MetricThreshold', threshold);
end
Expand Down
13 changes: 7 additions & 6 deletions Stitching-1.1.0/RunStitching.m
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
TracksPerFrame = 600; % # of trajectories in a frame
TrackWindowSize = 40; % the window size for motion segmentation
% ---------------
MeshSize = 8; % The mesh size of bundled camera path, 5 - 10 is OK
MaxIte = 15; % Number of iterations of the optimization scheme, 10 - 15 is OK
Smoothness = 1; % adjust how stable the output is, 0.5 - 3 is OK
Cropping = 1; % adjust how similar the result to the original video, usually set to 1;
Stitchness = 20; % adjust the weight of stitching term, 10 - 30 is OK
MeshSize = 8; % The mesh size of bundled camera path, 5 - 10 is OK
MaxIte = 15; % Number of iterations of the optimization scheme, 10 - 15 is OK
Smoothness = 3; % adjust how stable the output is, 1 - 4 is OK
Cropping = 1; % adjust how similar the result to the original video, usually set to 1;
Stitchness = 20; % adjust the weight of stitching term, 10 - 30 is OK
SKIP_BACKGROUND_SEGMENTATION = true; % skip the background segmentation - treat all tracks as background.
RANSAC = true; % use RANSAC to remove wrong sparse correspondences, set to true if overlap is small and stitch fails.
% ---------------
OutputPadding = 500; % the padding around the video
OutputPath = 'res_demo'; % the directory to store the output frames, auto create it if not exist
Expand Down Expand Up @@ -90,7 +91,7 @@
%% Matching SIFT in every frame pair
tic;
if ~exist([data 'ControlPoints' int2str(PointsPerFrame) '.mat'], 'file')
[CP, ppf] = getControlPoints([data input_A], [data input_B], 500);
[CP, ppf] = getControlPoints([data input_A], [data input_B], PointsPerFrame, RANSAC);
save([data 'ControlPoints' int2str(PointsPerFrame) '.mat'], 'CP', 'ppf');
else
load([data 'ControlPoints' int2str(PointsPerFrame) '.mat']);
Expand Down
103 changes: 39 additions & 64 deletions Stitching-1.1.0/stitch/getControlPoints.m
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
function [CP, ppf] = getControlPoints( input_A, input_B, maxppf )
function [CP, ppf] = getControlPoints( input_A, input_B, maxppf, ransac)
disp('Detecting and Matching SIFT features...');
fileListA = dir(input_A);
fileListA = fileListA(3:length(fileListA));
fileListB = dir(input_B);
fileListB = fileListB(3:length(fileListB));
nFrames = min(length(fileListA), length(fileListB))
nFrames = min(length(fileListA), length(fileListB));
CP = zeros(nFrames, maxppf, 4);
ppf = zeros(nFrames, 1);
trackerA = vision.PointTracker('MaxBidirectionalError', 1);
Expand All @@ -14,113 +14,88 @@
if mod(frameIndex, 20) == 0
fprintf('\n') ;
end

fileNameA = fileListA(frameIndex).name;
fileNameB = fileListB(frameIndex).name;
IA = imread([input_A fileNameA]);
IB = imread([input_B fileNameB]);

[H, W, ~] = size(IA);

if frameIndex > 1
setPoints(trackerA, trackA);
[trackAcont, validityA] = step(trackerA, IA);
[trackAcont, validityA] = step(trackerA, IA);
setPoints(trackerB, trackB);
[trackBcont, validityB] = step(trackerB, IB);
[trackBcont, validityB] = step(trackerB, IB);
trackAcont = trackAcont(validityA & validityB, :);
trackBcont = trackBcont(validityA & validityB, :);
end

% [~, ~, HH] = SURF(IA, IB);
HH = eye(3);
% IApre = imwarp(IA, projective2d(HH'), 'OutputView', imref2d(size(IA)));
IApre = IA;
% [trackApresift, trackBsift] = SIFT(IApre, IB);

[trackApresurf, trackBsurf] = SURF2(IApre, IB);
% trackerAB = vision.PointTracker('MaxBidirectionalError', 0.1);
% trackApre = getMorePoints(IApre, 20, 2000);
% initialize(trackerAB, trackApre, IApre);
% [trackB, validity] = step(trackerAB, IB);

% if length(trackB) > maxppf - length(trackBsurf)
% ordering = randperm(length(trackB));
% trackApre = trackApre(ordering(1:maxppf - length(trackBsurf)), :);
% trackB = trackB(ordering(1:maxppf - length(trackBsurf)), :);
% validity = validity(ordering(1:maxppf - length(trackBsurf)));
% end

% trackB = [trackB(validity, :); trackBsurf];
% trackApre = [trackApre(validity, :); trackApresurf];
if ransac
[trackAsurf, trackBsurf] = SURF(IA, IB);
else
[trackAsurf, trackBsurf] = SURF2(IA, IB);
end

% trackApre = [trackApresurf;trackApresift];
% trackB = [trackBsurf;trackBsift];
trackApre = trackApresurf;
trackA = trackAsurf;
trackB = trackBsurf;
trackA = HH \ [trackApre' ; ones(1, length(trackApre))];
trackA(1, :) = trackA(1, :) ./ trackA(3, :);
trackA(2, :) = trackA(2, :) ./ trackA(3, :);
trackA = trackA(1:2, :);
trackA = trackA';

valid = trackA(:, 1) > 0 & trackA(:, 1) < W & trackA(:, 2) > 0 & trackA(:, 2) < H ...
& trackB(:, 1) > 0 & trackB(:, 1) < W & trackB(:, 2) > 0 & trackB(:, 2) < H;

trackA = trackA(valid, :);
trackB = trackB(valid, :);
% valid = filtermask(IA, trackA);
% trackA = trackA(valid, :);
% trackB = trackB(valid, :);
% valid = filtermask(IB, trackB);
% trackA = trackA(valid, :);
% trackB = trackB(valid, :);

valid = filtermask(IA, trackA);
trackA = trackA(valid, :);
trackB = trackB(valid, :);
valid = filtermask(IB, trackB);
trackA = trackA(valid, :);
trackB = trackB(valid, :);

if frameIndex == 1
initialize(trackerA, trackA, IA);
initialize(trackerA, trackA, IA);
initialize(trackerB, trackB, IB);
else
if length(trackA) + length(trackAcont) > maxppf
ordering = randperm(length(trackAcont));
trackAcont = trackAcont(ordering(1:maxppf - length(trackA)), :);
trackBcont = trackBcont(ordering(1:maxppf - length(trackA)), :);
end
trackA = [trackA ; trackAcont];
trackB = [trackB ; trackBcont];
end
if length(trackA) > maxppf
ordering = randperm(length(trackA));
trackA = trackA(ordering(1:maxppf), :);
trackB = trackB(ordering(1:maxppf), :);
end

% IA = insertMarker(IA, trackA, 'o', 'color', 'red');
% IB = insertMarker(IB, trackB, 's', 'color', 'yellow');
% figure(1);
% imshow(IA);
% figure(2);
% imshow(IB);
% figure(3);
% imshow(IApre);

IA = insertMarker(IA, trackA, 'o', 'color', 'red');
IB = insertMarker(IB, trackB, 's', 'color', 'yellow');
figure(1);
imshow(IA);
figure(2);
imshow(IB);

if length(trackA) > maxppf
ppf(frameIndex) = maxppf;
CP(frameIndex, :, 1:2) = trackA(1:maxppf, :);
CP(frameIndex, :, 3:4) = trackB(1:maxppf, :);
else
ppf(frameIndex) = length(trackA);
CP(frameIndex, 1:ppf(frameIndex), 1:2) = trackA;
CP(frameIndex, 1:ppf(frameIndex), 3:4) = trackB;
CP(frameIndex, 1:ppf(frameIndex), 3:4) = trackB;
end
% CP(frameIndex, :, :) = [featuresA featuresB; zeros(maxppf - ppf(frameIndex), 4)];

end
end

function valid = filtermask(frame, points_)
[H, W, ~] = size(frame);
valid = points_(:, 1) > 0 & points_(:, 1) < W & points_(:, 2) > 0 & points_(:, 2) < H;
points_ = points_(valid, :);
points_ = points_(valid, :);
mask = frame(:, :, 1) < 20 & frame(:, :, 2) < 20 & frame(:, :, 3) < 20;
mask = imgaussfilt(double(mask), 50);
videoH = size(frame, 1);
mask(mask > 0.2) = 1;
mask(mask ~= 1) = 0;
valid = mask(round(points_(:, 1) - 1) * videoH + round(points_(:, 2))) == 0;
valid = mask(round(points_(:, 1) - 1) * videoH + round(points_(:, 2))) == 0;
end


4 changes: 2 additions & 2 deletions Stitching-1.1.0/stitch/getPath.m
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

fileList = dir(input);
fileList = fileList(3:length(fileList));
nFrames = length(fileList);
nFrames = tracks.nFrame;
if nFrames < 2
error('Wrong inputs') ;
end
Expand All @@ -23,7 +23,7 @@
end
end
fprintf('%5d', 1);
for frameIndex = 2:length(fileList)
for frameIndex = 2:nFrames
fprintf('%5d', frameIndex);
if mod(frameIndex, 20) == 0
fprintf('\n') ;
Expand Down

1 comment on commit fe0a686

@SuTanTank
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix #5

Please sign in to comment.