Shanbin Ke ab9fc2d839 PR #22404: [cuDNN SDPA] fix bias sharding check and dbias calculation
Imported from GitHub PR https://github.com/google/jax/pull/22404

* only check bias batch/num_head sharding spec if present. Both dims could be broadcast.
* dbias calculation is incorrect in spmd and all_reduce is required.
Copybara import of the project:

--
cb81b80626bcf17db875bad5526cd2f24c989049 by cjkkkk <ske@nvidia.com>:

fix sharding

Merging this change closes #22404

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22404 from Cjkkkk:fix_bias_sharding_dbias_all_reduce cb81b80626bcf17db875bad5526cd2f24c989049
PiperOrigin-RevId: 653335832
2024-07-17 13:07:15 -07:00
..