Merge pull request #21169 from justinjfu/splash_precision_fix

Disable bfloat16 on long seq lengths for splash attention kernel test
This commit is contained in:
Justin Fu 2024-05-13 09:34:39 -07:00 committed by GitHub
commit e4f3b3ff8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -224,7 +224,13 @@ def sequence_length_strategy(draw: Draw) -> tuple[int, int]:
def attention_strategy(draw: Draw) -> tuple[int, int, int, np.dtype]:
q_seq_len, kv_seq_len = draw(sequence_length_strategy())
head_dim = draw(hps.sampled_from([128, 256]))
dtype = draw(hps.sampled_from([np.dtype("float32"), np.dtype(jnp.bfloat16)]))
if q_seq_len >= 4096 and kv_seq_len >= 4096:
# Do not draw bfloat16 on longer sequence lengths, as this increases
# the risk of numerical precision errors causing false positives in
# tests.
dtype = np.dtype("float32")
else:
dtype = draw(hps.sampled_from([np.dtype("float32"), np.dtype(jnp.bfloat16)]))
return q_seq_len, kv_seq_len, head_dim, dtype