mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 04:56:07 +00:00

This currently causes incorrect behaviour for jax.nn.dot_product_attention: it should raise an error rather than failing with an assert. PiperOrigin-RevId: 650621750