diff --git a/inference.py b/inference.py index d8e665b..8765b28 100644 --- a/inference.py +++ b/inference.py @@ -32,7 +32,18 @@ def main(checkpoint_path, image_path, save_path): pad_r = find_padding(image.shape[0]) pad_c = find_padding(image.shape[1]) image = np.pad(image, ((pad_r[0], pad_r[1]), (pad_c[0], pad_c[1]), (0, 0)), 'reflect') + + # solve no-pad index issue after inference + if pad_r[1] == 0: + pad_r = (pad_r[0], 1) + if pad_c[1] == 0: + pad_c = (pad_c[0], 1) + image = image.astype(np.float32) + + # remove nans (and infinity) - replace with 0s + image = np.nan_to_num(image, copy=False, nan=0.0, posinf=0.0, neginf=0.0) + image = image - np.min(image) image = image / np.maximum(np.max(image), 1)