Skip to content

Commit

Permalink
reformat and resolve error
Browse files Browse the repository at this point in the history
  • Loading branch information
divyashreepathihalli committed Nov 14, 2023
1 parent 31ebec7 commit 2830b10
Show file tree
Hide file tree
Showing 12 changed files with 38 additions and 17 deletions.
5 changes: 4 additions & 1 deletion benchmarks/vectorized_mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.layers import Mosaic
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501
IMAGES,
)
from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import (
LABELS,
)
from keras_cv.utils import preprocessing as preprocessing_utils
from tensorflow import keras


class OldMosaic(BaseImageAugmentationLayer):
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/vectorized_random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501
BOUNDING_BOXES,
)
from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501
from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import (
IMAGES,
)

Expand Down Expand Up @@ -103,11 +103,11 @@ def get_random_transformation(self, **kwargs):
flip_vertical = False
if self.horizontal:
flip_horizontal = (
random.uniform(shape=[]) > 0.5, seed=self._seed_generator
random.uniform(shape=[], seed=self._seed_generator) > 0.5
)
if self.vertical:
flip_vertical = (
random.uniform(shape=[]) > 0.5, seed=self._seed_generator
random.uniform(shape=[], seed=self._seed_generator) > 0.5
)
return {
"flip_horizontal": tf.cast(flip_horizontal, dtype=tf.bool),
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/vectorized_random_hue.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras

from keras_cv.layers import RandomHue
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
from keras_cv.utils import preprocessing as preprocessing_utils
from tensorflow import keras


class OldRandomHue(BaseImageAugmentationLayer):
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/vectorized_random_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.backend import random
from keras_cv.layers import RandomRotation
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
from keras_cv.utils import preprocessing as preprocessing_utils
from tensorflow import keras

H_AXIS = -3
W_AXIS = -2
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/vectorized_random_shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
import warnings
from unittest.mock import MagicMock

import keras_cv
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

import keras_cv
from keras_cv import bounding_box
from keras_cv.layers import RandomShear
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
from keras_cv.utils import preprocessing
from matplotlib import pyplot as plt


# Copied from:
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/vectorized_random_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
import numpy as np
import tensorflow as tf
from keras import backend
from tensorflow import keras

from keras_cv.backend import random
from keras_cv.layers import RandomTranslation
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
from keras_cv.utils import preprocessing as preprocessing_utils
from tensorflow import keras

H_AXIS = -3
W_AXIS = -2
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/vectorized_random_zoom.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
import numpy as np
import tensorflow as tf
from keras import backend
from tensorflow import keras

from keras_cv.backend import random
from keras_cv.layers import RandomZoom
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
from keras_cv.utils import preprocessing as preprocessing_utils
from tensorflow import keras

# In order to support both unbatched and batched inputs, the horizontal
# and vertical axis is reverse indexed
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/vectorized_randomly_zoomed_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras

from keras_cv import core
from keras_cv.layers import RandomlyZoomedCrop
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
from keras_cv.utils import preprocessing as preprocessing_utils
from tensorflow import keras


class OldRandomlyZoomedCrop(BaseImageAugmentationLayer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

import keras
import tensorflow as tf

from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras, scope
from keras_cv.backend import keras
from keras_cv.backend import scope
from keras_cv.backend.config import multi_backend
from keras_cv.backend.random import SeedGenerator
from keras_cv.utils import preprocessing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
import tensorflow as tf

from keras_cv import bounding_box
from keras_cv.backend import random
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

import keras
import tensorflow as tf

from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras, scope
from keras_cv.backend import keras
from keras_cv.backend import scope
from keras_cv.backend.config import multi_backend
from keras_cv.backend.random import SeedGenerator
from keras_cv.utils import preprocessing
Expand Down
15 changes: 11 additions & 4 deletions keras_cv/utils/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

import tensorflow as tf
from keras_cv import core
from keras_cv.backend import ops, random
from tensorflow import keras
from tensorflow.keras import backend

from keras_cv import core
from keras_cv.backend import ops
from keras_cv.backend import random

_TF_INTERPOLATION_METHODS = {
"bilinear": tf.image.ResizeMethod.BILINEAR,
"nearest": tf.image.ResizeMethod.NEAREST_NEIGHBOR,
Expand Down Expand Up @@ -183,14 +185,19 @@ def random_inversion(seed_generator):
Returns:
either -1, or -1.
"""
negate = keras.backend.uniform((), 0, 1, dtype=tf.float32) > 0.5
negate = (
keras.backend.uniform((), 0, 1, dtype=tf.float32, seed=seed_generator)
> 0.5
)
negate = tf.cond(negate, lambda: -1.0, lambda: 1.0)
return negate


def batch_random_inversion(seed_generator, batch_size):
"""Same as `random_inversion` but for batched inputs."""
negate = random.uniform((batch_size, 1), 0, 1, dtype=tf.float32)
negate = random.uniform(
(batch_size, 1), 0, 1, dtype=tf.float32, seed=seed_generator
)
negate = tf.where(negate > 0.5, -1.0, 1.0)
return negate

Expand Down

0 comments on commit 2830b10

Please sign in to comment.