mirror of
https://github.com/ROCm/jax.git
synced 2025-04-23 22:06:06 +00:00

Imported from GitHub PR https://github.com/google/jax/pull/22404 * only check bias batch/num_head sharding spec if present. Both dims could be broadcast. * dbias calculation is incorrect in spmd and all_reduce is required. Copybara import of the project: -- cb81b80626bcf17db875bad5526cd2f24c989049 by cjkkkk <ske@nvidia.com>: fix sharding Merging this change closes #22404 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22404 from Cjkkkk:fix_bias_sharding_dbias_all_reduce cb81b80626bcf17db875bad5526cd2f24c989049 PiperOrigin-RevId: 653335832