Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rethink the conditional in the gradient accumulation #20582

Open
IvanUkhov opened this issue Dec 3, 2024 · 2 comments
Open

Rethink the conditional in the gradient accumulation #20582

IvanUkhov opened this issue Dec 3, 2024 · 2 comments
Assignees
Labels

Comments

@IvanUkhov
Copy link

IvanUkhov commented Dec 3, 2024

The following conditional precludes the usage of gradient accumulation under a distributed strategy in TensorFlow:

ops.cond(
is_update_step,
lambda: _update_step_fn(grads, trainable_variables),
lambda: self._backend_increment_gradient_accumulators(
grads, acc_grads
),
)

The exception is as follows:

RuntimeError: Exception encountered when calling Cond.call().

merge_call called while defining a new graph or a tf.function. This can often happen if the function fn passed to strategy.run() contains a nested @tf.function, and the nested @tf.function contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function fn uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested tf.functions or control flow statements that may potentially cross a synchronization boundary, for example, wrap the fn passed to strategy.run or the entire strategy.run inside a tf.function or move the control flow out of fn. If you are subclassing a tf.keras.Model, please avoid decorating overridden methods test_step and train_step in tf.function.

This probably has something to do with this one:

tf.__internal__.distribute.interim.maybe_merge_call(
_distributed_tf_increment_grad_acc,
self._distribution_strategy,
grads,
accumulators,
)

One could perhaps rewrite it as an implicit conditional via math manipulations: the code will be executed unconditionally but will be leading to different outcomes depending on whether it is the end of an accumulation round or not.

@IvanUkhov IvanUkhov changed the title Remove the explicit conditional in the gradient accumulation Rethink the conditional in the gradient accumulation Dec 3, 2024
@fchollet
Copy link
Collaborator

fchollet commented Dec 4, 2024

One could perhaps rewrite it as an implicit conditional via math manipulations: the code will be executed unconditionally but will be leading to different outcomes depending on whether it is the end of an accumulation round or not.

Indeed -- I vaguely recall I implemented it in this way at some point. I don't remember why I changed it though.

Are you able to open a PR along those lines?

@IvanUkhov
Copy link
Author

IvanUkhov commented Dec 5, 2024

@fchollet, is it not a little strange that it does not work in a distributed setting given tf.distribute is profusely used in the code related to gradient accumulation? Are there tests for this? Perhaps it does not work only in some specific cases, like when using GPUs? Just trying to make sure are are not jumping on a nonexisting problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants