Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensorflow] MirroredStrategy, LossScaledOptimizer - merge_call failed #18666

Closed
crohkohl opened this issue Oct 22, 2023 · 5 comments
Closed
Assignees
Labels

Comments

@crohkohl
Copy link

Hi,

Using the LossScaledOptimizer fails for MirroredStrategy with the following exception:

Exception has occurred: RuntimeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
in user code:

    File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 105, in one_step_on_data  **
        return self.train_step(data)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 72, in train_step
        self.optimizer.apply_gradients(zip(gradients, trainable_weights))
    File "/usr/local/lib/python3.10/dist-packages/keras/src/optimizers/base_optimizer.py", line 206, in apply_gradients
        self.apply(grads, trainable_variables)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/optimizers/loss_scale_optimizer.py", line 183, in apply
        ops.cond(finite, handle_finite_grads, handle_non_finite_grads)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/ops/core.py", line 594, in cond
        return Cond()(pred, true_fn, false_fn)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 123, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/optimizer.py", line 82, in _internal_apply_gradients
        tf.__internal__.distribute.interim.maybe_merge_call(

    RuntimeError: Exception encountered when calling Cond.call().
    
    �[1m`merge_call` called while defining a new graph or a tf.function. This can often happen if the function `fn` passed to `strategy.run()` contains a nested `@tf.function`, and the nested `@tf.function` contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function `fn` uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested `tf.function`s or control flow statements that may potentially cross a synchronization boundary, for example, wrap the `fn` passed to `strategy.run` or the entire `strategy.run` inside a `tf.function` or move the control flow out of `fn`. If you are subclassing a `tf.keras.Model`, please avoid decorating overridden methods `test_step` and `train_step` in `tf.function`.�[0m
    hods `test_step` and `train_step` in `tf.function`.

The reason for the exception is the following tf.cond() call:

ops.cond(finite, handle_finite_grads, handle_non_finite_grads)

To reproduce change the following line:

optimizer=optimizers.SGD(learning_rate=0.001, momentum=0.01),

to

optimizer=optimizers.LossScaleOptimizer(optimizers.SGD(learning_rate=0.001, momentum=0.01)),   

Alternatively, you can turn on the GPU and used mixed precision which then automatically uses the optimizer.

@qlzh727
Copy link
Member

qlzh727 commented Oct 23, 2023

Thanks for the report. let me take a look.

@qlzh727 qlzh727 self-assigned this Oct 23, 2023
@qlzh727
Copy link
Member

qlzh727 commented Oct 24, 2023

I see. It seems that we miss bunch of logic for the tf specific backend when the loss scale optimizer runs with tf.distribute. Will fix that.

@qlzh727
Copy link
Member

qlzh727 commented Oct 30, 2023

Should be addressed by #18691

@qlzh727 qlzh727 closed this as completed Oct 30, 2023
@iamsoroush
Copy link

the same issue exists with Adam and AdamW optimizers, when setting use_ema=True and using MirroredStrategy

@IvanUkhov
Copy link

IvanUkhov commented Dec 3, 2024

For me, gradient accumulation does not work with Adam, with or without use_ema, and results in the same error. It feels like it applies to the base optimizer in general. Opened #20582.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants