mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix test
This commit is contained in:
parent
9b8a100039
commit
2d346149de
@ -106,7 +106,7 @@ def g(query: Array,
|
||||
keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
|
||||
keep = jnp.broadcast_to(keep, attn_weights.shape)
|
||||
multiplier = (
|
||||
keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
|
||||
keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=attn_weights.dtype))
|
||||
attn_weights = attn_weights * multiplier
|
||||
|
||||
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
|
||||
@ -149,7 +149,7 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
self.skipTest("Fused attention does not head dim = 128.")
|
||||
if len(jax.local_devices()) <= 4:
|
||||
self.skipTest("Require at least 4 devices to run sharding tests.")
|
||||
os.environ['XLA_FLAGS'] = '--xla_dump_hlo_as_text --xla_dump_to=./scratch/hlo --xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
|
||||
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
|
||||
|
||||
k1, k2, k3, k4, k5 = jax.random.split(jax.random.PRNGKey(0), 5)
|
||||
query = jax.random.normal(
|
||||
@ -194,7 +194,7 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
out_shardings=out_shardings
|
||||
)
|
||||
|
||||
out, (query_grad, key_grad, value_grad) = pjitted_g_train(query, key, value, grad, bias, None)
|
||||
out, (query_grad, key_grad, value_grad) = pjitted_f_train(query, key, value, grad, bias, None)
|
||||
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = pjitted_g_train(query, key, value, grad, bias, None)
|
||||
assert jnp.allclose(out_ref, out, rtol=1e-5, atol=1e-5)
|
||||
if seq_len > 512:
|
||||
|
Loading…
x
Reference in New Issue
Block a user