Skip to content

Commit

Permalink
add Random ops shim (#2145)
Browse files Browse the repository at this point in the history
* add Random ops shim

* update random opss

* handle None case

* change seed handeling

* undo changes

* undo changes

* undo

* update int_seed

* change int to init

* update init_seed everywhere

* add kwargs

* undo bial

* code reformat

* remove dtype
  • Loading branch information
divyashreepathihalli authored Nov 14, 2023
1 parent cbd34f0 commit 0e71149
Showing 1 changed file with 119 additions and 1 deletion.
120 changes: 119 additions & 1 deletion keras_cv/backend/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,128 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.backend import keras
from keras_cv.backend.config import keras_3

if keras_3():
from keras.random import * # noqa: F403, F401
else:
from keras_core.random import * # noqa: F403, F401


class SeedGenerator:
def __init__(self, seed=None, **kwargs):
if keras_3():
self._seed_generator = keras.random.SeedGenerator(
seed=seed, **kwargs
)
else:
self._current_seed = [0, seed]

def next(self, ordered=True):
if keras_3():
return self._seed_generator.next(ordered=ordered)
else:
self._current_seed[0] += 1
return self._current_seed[:]


def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed

kwargs = {}
if dtype:
kwargs["dtype"] = dtype
if keras_3():
return keras.random.normal(
shape,
mean=mean,
stddev=stddev,
seed=init_seed,
**kwargs,
)
else:
import tensorflow as tf

return tf.random.normal(
shape,
mean=mean,
stddev=stddev,
seed=init_seed,
**kwargs,
)


def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
if keras_3():
return keras.random.uniform(
shape,
minval=minval,
maxval=maxval,
seed=init_seed,
**kwargs,
)
else:
import tensorflow as tf

return tf.random.uniform(
shape,
minval=minval,
maxval=maxval,
seed=init_seed,
**kwargs,
)


def shuffle(x, axis=0, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed

if keras_3():
return keras.random.shuffle(x=x, axis=axis, seed=init_seed)
else:
import tensorflow as tf

return tf.random.shuffle(x=x, axis=axis, seed=init_seed)


def categorical(logits, num_samples, dtype=None, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
if keras_3():
return keras.random.categorical(
logits=logits,
num_samples=num_samples,
seed=init_seed,
**kwargs,
)
else:
import tensorflow as tf

return tf.random.categorical(
logits=logits,
num_samples=num_samples,
seed=init_seed,
**kwargs,
)

0 comments on commit 0e71149

Please sign in to comment.