Fix lint failure in fused attention test.

PiperOrigin-RevId: 606404377
This commit is contained in:
Peter Hawkins 2024-02-12 16:34:09 -08:00 committed by jax authors
parent 8e96c49f85
commit 7aadabdc03

View File

@ -14,7 +14,7 @@
from functools import partial
from absl.testing import absltest
from typing import Any, Optional
from typing import Optional
import os
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
@ -77,6 +77,8 @@ def sdpa_ref(query: Array,
if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
if mask is not None:
large_negative_number = jnp.asarray(
-0.7 * jnp.finfo(attn_weights.dtype).max, dtype=attn_weights.dtype)
attn_weights = jax.lax.select(mask, attn_weights, large_negative_number)
attn_weights = jax.nn.softmax(attn_weights)
if dropout_rate > 0.: