This commit is contained in:
Cjkkkk 2024-01-03 11:06:07 -08:00
parent 9b8a100039
commit 2d346149de

View File

@ -106,7 +106,7 @@ def g(query: Array,
keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
keep = jnp.broadcast_to(keep, attn_weights.shape)
multiplier = (
keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=attn_weights.dtype))
attn_weights = attn_weights * multiplier
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
@ -149,7 +149,7 @@ class DotProductAttentionTest(jtu.JaxTestCase):
self.skipTest("Fused attention does not head dim = 128.")
if len(jax.local_devices()) <= 4:
self.skipTest("Require at least 4 devices to run sharding tests.")
os.environ['XLA_FLAGS'] = '--xla_dump_hlo_as_text --xla_dump_to=./scratch/hlo --xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
k1, k2, k3, k4, k5 = jax.random.split(jax.random.PRNGKey(0), 5)
query = jax.random.normal(
@ -194,7 +194,7 @@ class DotProductAttentionTest(jtu.JaxTestCase):
out_shardings=out_shardings
)
out, (query_grad, key_grad, value_grad) = pjitted_g_train(query, key, value, grad, bias, None)
out, (query_grad, key_grad, value_grad) = pjitted_f_train(query, key, value, grad, bias, None)
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = pjitted_g_train(query, key, value, grad, bias, None)
assert jnp.allclose(out_ref, out, rtol=1e-5, atol=1e-5)
if seq_len > 512: