You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Not sure if this is happening elsewhere but usually to keep the weights in sync you apply a jax.lax.pmean over the gradients before passing them to apply_gradients, e.g.
grads = jax.lax.pmean(grads, axis_name='devices')
The text was updated successfully, but these errors were encountered:
cgarciae
changed the title
Gradient synchronization in data-parallel traininers
Gradient synchronization in data-parallel trainers
Feb 28, 2024
Thanks for noticing this! It's often challenging to test these portions due to the unavailability of a personal multi-GPU setup for development. However, I will be accessing 2 GPUs around 10th March. Will immediately examine this but you are more than welcome to make corrections from your end if convenient, I would in fact very much appreciate that.
Hey, great job with nanodl!
I was just looking through the code and noticed that when in Lambda's Trainer the gradients are not being averaged across devices here:
nanodl/nanodl/__src/models/lamda.py
Lines 564 to 565 in 18c7f8e
Not sure if this is happening elsewhere but usually to keep the weights in sync you apply a
jax.lax.pmean
over the gradients before passing them toapply_gradients
, e.g.The text was updated successfully, but these errors were encountered: