Sebastian Bodenstein c9534b315e Raise NotImplementedError instead of assert for unsupported Q dtype in fused attention.
This currently causes incorrect behaviour for jax.nn.dot_product_attention: it should raise an error rather than failing with an assert.

PiperOrigin-RevId: 650621750
2024-07-09 07:37:13 -07:00
..