mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix lint failure in fused attention test.
PiperOrigin-RevId: 606404377
This commit is contained in:
parent
8e96c49f85
commit
7aadabdc03
@ -14,7 +14,7 @@
|
||||
|
||||
from functools import partial
|
||||
from absl.testing import absltest
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
import os
|
||||
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
|
||||
|
||||
@ -77,6 +77,8 @@ def sdpa_ref(query: Array,
|
||||
if bias is not None:
|
||||
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
|
||||
if mask is not None:
|
||||
large_negative_number = jnp.asarray(
|
||||
-0.7 * jnp.finfo(attn_weights.dtype).max, dtype=attn_weights.dtype)
|
||||
attn_weights = jax.lax.select(mask, attn_weights, large_negative_number)
|
||||
attn_weights = jax.nn.softmax(attn_weights)
|
||||
if dropout_rate > 0.:
|
||||
|
Loading…
x
Reference in New Issue
Block a user