mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #21169 from justinjfu/splash_precision_fix
Disable bfloat16 on long seq lengths for splash attention kernel test
This commit is contained in:
commit
e4f3b3ff8f
@ -224,7 +224,13 @@ def sequence_length_strategy(draw: Draw) -> tuple[int, int]:
|
||||
def attention_strategy(draw: Draw) -> tuple[int, int, int, np.dtype]:
|
||||
q_seq_len, kv_seq_len = draw(sequence_length_strategy())
|
||||
head_dim = draw(hps.sampled_from([128, 256]))
|
||||
dtype = draw(hps.sampled_from([np.dtype("float32"), np.dtype(jnp.bfloat16)]))
|
||||
if q_seq_len >= 4096 and kv_seq_len >= 4096:
|
||||
# Do not draw bfloat16 on longer sequence lengths, as this increases
|
||||
# the risk of numerical precision errors causing false positives in
|
||||
# tests.
|
||||
dtype = np.dtype("float32")
|
||||
else:
|
||||
dtype = draw(hps.sampled_from([np.dtype("float32"), np.dtype(jnp.bfloat16)]))
|
||||
return q_seq_len, kv_seq_len, head_dim, dtype
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user