mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
add is_training && fix seqlen/head_dim checks
This commit is contained in:
parent
9a00721a54
commit
204ee7ff0b
@ -200,52 +200,57 @@ def check_qkv_layout(query, key, value):
|
||||
"query should have layout [batch, q_seq, num_heads, head_dim], " \
|
||||
"key and value should have layout [batch, kv_seq, num_heads, head_dim].")
|
||||
|
||||
def check_is_flash_attention(query, key):
|
||||
def check_is_flash_attention(query, key, cudnn_version, has_bias, is_training):
|
||||
batch, q_seq_len, num_heads, head_dim = query.shape
|
||||
_, kv_sqe_len, _, _ = key.shape
|
||||
is_cross_attention = q_seq_len != kv_sqe_len
|
||||
_, kv_seq_len, _, _ = key.shape
|
||||
|
||||
# check if attention pattern is supported by flash attention or fused attention
|
||||
if q_seq_len > 512 and kv_sqe_len > 512 and head_dim in [64, 128]:
|
||||
# check if flash attention is supported
|
||||
is_flash_attention = True
|
||||
elif q_seq_len <= 512 and kv_sqe_len <= 512 and head_dim == 64:
|
||||
if q_seq_len <= 512 and kv_seq_len <= 512 and head_dim == 64 \
|
||||
and (not is_training or q_seq_len % 64 == 0 and kv_seq_len % 64 == 0):
|
||||
# check if regular fused attention is supported
|
||||
# for training, seqlen should be divisible by 64
|
||||
is_flash_attention = False
|
||||
elif head_dim <= 128 and head_dim % 8 == 0 \
|
||||
and (not is_training or not has_bias or q_seq_len % 2 == 0 and kv_seq_len % 2 == 0):
|
||||
# check if flash attention is supported
|
||||
# for training, for patterns with bias, seqlen should be divisible by 2
|
||||
is_flash_attention = True
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported sequence length Q {q_seq_len}, KV {kv_sqe_len} and head dim {head_dim}.")
|
||||
return is_flash_attention, is_cross_attention
|
||||
|
||||
def check_cudnn_version(is_flash_attention, is_cross_attention):
|
||||
# check if cuDNN is installed and if cuDNN version contraint is satisfied
|
||||
if cuda_versions is None:
|
||||
raise RuntimeError("cuDNN is not detected.")
|
||||
elif is_flash_attention:
|
||||
if not is_cross_attention and cuda_versions.cudnn_get_version() < 8903:
|
||||
raise RuntimeError("JAX requires cuDNN >= 8.9.3 to use flash attention.")
|
||||
if is_cross_attention and cuda_versions.cudnn_get_version() < 8904:
|
||||
raise RuntimeError("JAX requires cuDNN >= 8.9.4 to use flash cross attention.")
|
||||
elif not is_flash_attention and cuda_versions.cudnn_get_version() < 8901:
|
||||
f"Unsupported sequence length Q {q_seq_len}, KV {kv_seq_len} and head dim {head_dim}.")
|
||||
# check if minimum cudnn version requirement is satisfied
|
||||
if is_flash_attention and cudnn_version < 8904:
|
||||
raise RuntimeError("JAX requires cuDNN >= 8.9.4 to use flash cross attention.")
|
||||
elif not is_flash_attention and cudnn_version < 8901:
|
||||
raise RuntimeError("JAX requires cuDNN >= 8.9.1 to use fused attention.")
|
||||
|
||||
return is_flash_attention
|
||||
|
||||
def check_cudnn_version():
|
||||
# check if cuDNN is installed
|
||||
if cuda_versions is None:
|
||||
raise RuntimeError("cuDNN is not detected.")
|
||||
return cuda_versions.cudnn_get_version()
|
||||
|
||||
def _dot_product_attention_fwd(query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
output, _ = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
|
||||
outputs = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
is_causal_mask=is_causal_mask, is_training=is_training)
|
||||
output = outputs[0]
|
||||
return output
|
||||
|
||||
def _dot_product_attention_fwd_rule(query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
output, activation = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
|
||||
outputs = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
res = (query, key, value, bias, mask, activation, output)
|
||||
return output, res
|
||||
is_causal_mask=is_causal_mask, is_training=is_training)
|
||||
res = (query, key, value, bias, mask, outputs[1], outputs[0]) if is_training else None
|
||||
return outputs[0], res
|
||||
|
||||
def _dot_product_attention_bwd_rule(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, res, grad_output):
|
||||
def _dot_product_attention_bwd_rule(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training, res, grad_output):
|
||||
query, key, value, bias, mask, activation, fwd_output = res
|
||||
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
@ -256,13 +261,13 @@ def _dot_product_attention_bwd_rule(scale, seed, dropout_rate, variadic_args, is
|
||||
return grads
|
||||
|
||||
def _dot_product_attention_fwd_impl(query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
|
||||
# args: {Q, K, V, mask*, bias*}
|
||||
output, activation = _dot_product_attention_fwd_p.bind(
|
||||
outputs = _dot_product_attention_fwd_p.bind(
|
||||
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
return output, activation
|
||||
is_causal_mask=is_causal_mask, is_training=is_training)
|
||||
return outputs
|
||||
|
||||
def _dot_product_attention_bwd_impl(query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
@ -275,23 +280,34 @@ def _dot_product_attention_bwd_impl(query, key, value, bias, mask, activation, f
|
||||
return grads
|
||||
|
||||
def _dot_product_attention_fwd_abstract(query, key, value, bias, mask,
|
||||
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
|
||||
query_dtype = dtypes.canonicalize_dtype(query.dtype)
|
||||
batch, q_seq_len, num_heads, head_dim = query.shape
|
||||
_, kv_seq_len, _, _ = key.shape
|
||||
output_shape = (batch, q_seq_len, num_heads, head_dim)
|
||||
activation_shape = (batch, num_heads, q_seq_len, kv_seq_len)
|
||||
softmax_stat_shape = (batch, num_heads, q_seq_len)
|
||||
if q_seq_len > 512:
|
||||
|
||||
if is_flash_attention:
|
||||
# is flash attention
|
||||
if is_training:
|
||||
return (
|
||||
core.ShapedArray(output_shape, query_dtype), # output
|
||||
core.ShapedArray(softmax_stat_shape, jnp.float32), # softmax_stat
|
||||
)
|
||||
else:
|
||||
return (
|
||||
core.ShapedArray(output_shape, query_dtype), # output
|
||||
)
|
||||
if is_training:
|
||||
return (
|
||||
core.ShapedArray(output_shape, query_dtype), # output
|
||||
core.ShapedArray(activation_shape, query_dtype), # activation
|
||||
)
|
||||
else:
|
||||
return (
|
||||
core.ShapedArray(output_shape, query_dtype), # output
|
||||
core.ShapedArray(softmax_stat_shape, jnp.float32), # softmax_stat
|
||||
)
|
||||
return (
|
||||
core.ShapedArray(output_shape, query_dtype), # output
|
||||
core.ShapedArray(activation_shape, query_dtype), # activation
|
||||
)
|
||||
|
||||
def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
@ -312,7 +328,7 @@ def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activatio
|
||||
)
|
||||
|
||||
def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
|
||||
query_type = ir.RankedTensorType(query.type)
|
||||
query_shape = query_type.shape
|
||||
key_type = ir.RankedTensorType(key.type)
|
||||
@ -344,19 +360,33 @@ def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
|
||||
custom_call_name = get_custom_call_name(has_bias, has_mask, has_dropout, False)
|
||||
# create output types and layouts
|
||||
if is_flash_attention:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(output_shape, query_type.element_type),
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type),
|
||||
ir.RankedTensorType.get(softmax_stat_shape, ir.F32Type.get()),
|
||||
]
|
||||
result_layouts = [output_layout] + default_layouts(scratch_shape, softmax_stat_shape)
|
||||
if is_training:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(output_shape, query_type.element_type),
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type),
|
||||
ir.RankedTensorType.get(softmax_stat_shape, ir.F32Type.get()),
|
||||
]
|
||||
result_layouts = [output_layout] + default_layouts(scratch_shape, softmax_stat_shape)
|
||||
else:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(output_shape, query_type.element_type),
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type)
|
||||
]
|
||||
result_layouts = [output_layout] + default_layouts(scratch_shape)
|
||||
else:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(output_shape, query_type.element_type),
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type),
|
||||
ir.RankedTensorType.get(activation_shape, query_type.element_type),
|
||||
]
|
||||
result_layouts = [output_layout] + default_layouts(scratch_shape, activation_shape)
|
||||
if is_training:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(output_shape, query_type.element_type),
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type),
|
||||
ir.RankedTensorType.get(activation_shape, query_type.element_type),
|
||||
]
|
||||
result_layouts = [output_layout] + default_layouts(scratch_shape, activation_shape)
|
||||
else:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(output_shape, query_type.element_type),
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type),
|
||||
]
|
||||
result_layouts = [output_layout] + default_layouts(scratch_shape)
|
||||
# create custom call here
|
||||
out = mlir.custom_call(
|
||||
custom_call_name,
|
||||
@ -368,7 +398,10 @@ def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
|
||||
)
|
||||
# drop scratch memory
|
||||
# output should be (batch, q_seq_len, num_heads, head_dim) instead of (batch, num_heads, q_seq_len, head_dim)
|
||||
return [hlo.transpose(out.results[0], output_transpose_perm), out.results[2]]
|
||||
if is_training:
|
||||
return [hlo.transpose(out.results[0], output_transpose_perm), out.results[2]]
|
||||
else:
|
||||
return [hlo.transpose(out.results[0], output_transpose_perm)]
|
||||
|
||||
def _dot_product_attention_bwd_cuda_lowering(ctx, query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
@ -451,11 +484,14 @@ def _check_valid_batch_dims(bdims):
|
||||
raise NotImplementedError("Currently only support batch_dim in [0, None], " \
|
||||
f"but got {dim=}")
|
||||
|
||||
def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training):
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
query, key, value, bias, mask = batched_args
|
||||
query_bdim = batch_dims[0]
|
||||
out_bdims = query_bdim, query_bdim
|
||||
if is_training:
|
||||
out_bdims = query_bdim, query_bdim
|
||||
else:
|
||||
out_bdims = (query_bdim,)
|
||||
|
||||
*batch_tuple, q_seq_len, num_heads, head_dim = query.shape
|
||||
*_, kv_seq_len, _, _ = key.shape
|
||||
@ -470,19 +506,24 @@ def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed,
|
||||
if has_mask:
|
||||
mask = jnp.reshape(mask, (new_batch, num_heads, q_seq_len, kv_seq_len))
|
||||
|
||||
output, activation = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
outputs = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
is_causal_mask=is_causal_mask, is_training=is_training)
|
||||
|
||||
# reshape to original shape
|
||||
output = outputs[0]
|
||||
output = jnp.reshape(output, (*batch_tuple, q_seq_len, num_heads, head_dim))
|
||||
if is_flash_attention:
|
||||
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len))
|
||||
if is_training:
|
||||
activation = outputs[1]
|
||||
if is_flash_attention:
|
||||
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len))
|
||||
else:
|
||||
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len, kv_seq_len))
|
||||
return (output, activation), out_bdims
|
||||
else:
|
||||
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len, kv_seq_len))
|
||||
return (output, activation), out_bdims
|
||||
return (output,), out_bdims
|
||||
|
||||
def _dot_product_attention_bwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
@ -556,7 +597,7 @@ def _check_qkv_bias_mask_spec(query_spec, key_spec, value_spec, bias_spec, mask_
|
||||
raise ValueError("Sharding on mask sequence dim is not allowed.")
|
||||
|
||||
# fwd custom partition
|
||||
def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args):
|
||||
def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training):
|
||||
# only sharding on batch and num_head dim is allowed
|
||||
# (*batch, q_seq, num_head, head)
|
||||
query_spec = _get_padded_spec(arg_shapes[0])
|
||||
@ -569,21 +610,24 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args):
|
||||
_check_qkv_bias_mask_spec(query_spec, key_spec, value_spec, bias_spec, mask_spec)
|
||||
# keep out sharding same as query sharding since they have same shape
|
||||
out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
|
||||
# activation sharding
|
||||
*batch_spec, q_seq_spec, num_head_spec, head_spec = query_spec
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*batch_spec, num_head_spec, q_seq_spec, None))
|
||||
return (out_sharding, activation_sharding)
|
||||
if is_training:
|
||||
# activation sharding
|
||||
*batch_spec, q_seq_spec, num_head_spec, head_spec = query_spec
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*batch_spec, num_head_spec, q_seq_spec, None))
|
||||
return [out_sharding, activation_sharding]
|
||||
return [out_sharding]
|
||||
|
||||
_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10))
|
||||
def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args)
|
||||
_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10,11))
|
||||
def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training, mesh, arg_shapes, result_shape):
|
||||
return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training)
|
||||
|
||||
def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, is_training, mesh, arg_shapes, result_shape):
|
||||
# args sharding
|
||||
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
|
||||
out_shardings = _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args)
|
||||
out_shardings = _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training)
|
||||
impl = partial(_dot_product_attention_fwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
|
||||
is_training=is_training)
|
||||
return mesh, impl, out_shardings, arg_shardings
|
||||
|
||||
# bwd custom partition
|
||||
@ -673,7 +717,7 @@ dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p_
|
||||
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p)
|
||||
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p_wrapper)
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10))
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11))
|
||||
def _dot_product_attention(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
@ -684,11 +728,13 @@ def _dot_product_attention(query: Array,
|
||||
dropout_rate: float,
|
||||
variadic_args: tuple[bool, ...],
|
||||
is_flash_attention: bool,
|
||||
is_causal_mask: bool):
|
||||
is_causal_mask: bool,
|
||||
is_training: bool):
|
||||
output = _dot_product_attention_fwd(
|
||||
query, key, value, bias, mask,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
|
||||
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask,
|
||||
is_training=is_training)
|
||||
return output
|
||||
|
||||
# _dot_product_attention_fwd must have the same func signature as _dot_product_attention
|
||||
@ -704,7 +750,8 @@ def dot_product_attention(query: Array,
|
||||
scale: float = 1.0,
|
||||
is_causal_mask: bool = False,
|
||||
seed: int = 42,
|
||||
dropout_rate: float = 0.):
|
||||
dropout_rate: float = 0.,
|
||||
is_training = False):
|
||||
"""Computes dot-product attention given query, key, and value.
|
||||
This is the core function for applying attention based on
|
||||
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
|
||||
@ -725,16 +772,19 @@ def dot_product_attention(query: Array,
|
||||
mask: mask used mask out logits with shape of `[batch, num_heads,
|
||||
q_length, kv_length]`.
|
||||
scale: scale for the query.
|
||||
is_causal_mask: choose to apply a causal mask or not.
|
||||
seed: used for dropout mask generation.
|
||||
dropout_rate: dropout rate.
|
||||
is_training: choose to save activation or not.
|
||||
Returns:
|
||||
Output of shape `[batch, q_length, num_heads, v_depth_per_head]`.
|
||||
"""
|
||||
# check if query, key and value layout meets cuDNN layout requirement
|
||||
# check if cuDNN is installed
|
||||
cudnn_version = check_cudnn_version()
|
||||
# check query, key and value shape and data type
|
||||
check_qkv_layout(query, key, value)
|
||||
# check if flash attention is supported for this attention pattern
|
||||
is_flash_attention, is_cross_attention = check_is_flash_attention(query, key)
|
||||
# check if cuDNN is installed and if cuDNN version is sufficient
|
||||
check_cudnn_version(is_flash_attention, is_cross_attention)
|
||||
is_flash_attention = check_is_flash_attention(query, key, cudnn_version, bias is not None, is_training)
|
||||
if mask is not None and is_causal_mask:
|
||||
raise ValueError("can not apply a mask and generate a causal_mask at the same time.")
|
||||
if not is_flash_attention and is_causal_mask:
|
||||
@ -747,5 +797,5 @@ def dot_product_attention(query: Array,
|
||||
output = _dot_product_attention(
|
||||
query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args,
|
||||
is_flash_attention, is_causal_mask)
|
||||
is_flash_attention, is_causal_mask, is_training)
|
||||
return output
|
||||
|
@ -16,7 +16,7 @@ from functools import partial
|
||||
from absl.testing import absltest
|
||||
from typing import Optional
|
||||
import os
|
||||
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
|
||||
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true --xla_dump_hlo_as_text --xla_dump_to=./hlo'
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
@ -25,7 +25,7 @@ from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec, NamedSharding
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention
|
||||
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention, check_is_flash_attention
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
Array = jnp.ndarray
|
||||
@ -43,7 +43,7 @@ def sdpa_train(query: Array,
|
||||
# convert bool mask to dtype mask
|
||||
mask = mask.astype(query.dtype)
|
||||
out, sdpa_vjp = jax.vjp(
|
||||
partial(dot_product_attention, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate),
|
||||
partial(dot_product_attention, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate, is_training=True),
|
||||
query, key, value, bias, mask)
|
||||
query_grad, key_grad, value_grad, _, _ = sdpa_vjp(grad)
|
||||
return out, (query_grad, key_grad, value_grad)
|
||||
@ -201,6 +201,58 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-5, atol=1e-5)
|
||||
self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5)
|
||||
self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_inference(self):
|
||||
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
|
||||
query = jax.random.normal(
|
||||
k1, (4, 1024, 4, 64), dtype=jnp.bfloat16)
|
||||
key = jax.random.normal(
|
||||
k2, (4, 1024, 4, 64), dtype=jnp.bfloat16)
|
||||
value = jax.random.normal(
|
||||
k3, (4, 1024, 4, 64), dtype=jnp.bfloat16)
|
||||
|
||||
devices = np.array(jax.local_devices()[:4])
|
||||
devices = devices.reshape((2, 2))
|
||||
with Mesh(devices, ('dp', 'tp')) as mesh:
|
||||
qkv_spec = PartitionSpec('dp', None, 'tp', None)
|
||||
qkv_sharding = NamedSharding(mesh, qkv_spec)
|
||||
replicated = NamedSharding(mesh, PartitionSpec())
|
||||
in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, replicated, replicated)
|
||||
out_shardings = replicated
|
||||
query = jax.device_put(query, qkv_sharding)
|
||||
key = jax.device_put(key, qkv_sharding)
|
||||
value = jax.device_put(value, qkv_sharding)
|
||||
jitted_sdpa_inference = jax.jit(
|
||||
partial(dot_product_attention, scale=1.0, is_causal_mask=False, dropout_rate=0),
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings
|
||||
)
|
||||
|
||||
jitted_sdpa_inference_ref = jax.jit(
|
||||
partial(sdpa_ref, scale=1.0, is_causal_mask=False, dropout_rate=0),
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings
|
||||
)
|
||||
|
||||
out = jitted_sdpa_inference(query, key, value, None, None)
|
||||
out_ref = jitted_sdpa_inference_ref(query, key, value, None, None)
|
||||
self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_sdpa_utils(self):
|
||||
test_cases = {
|
||||
(256, 512, 64, 8905, False, False): False,
|
||||
(1, 257, 64, 8905, False, True): True,
|
||||
(1, 1024, 64, 8905, False, False): True,
|
||||
(1024, 1024, 64, 8905, False, False): True,
|
||||
(1024, 1024, 128, 8905, False, False): True,
|
||||
}
|
||||
|
||||
for k, v in test_cases.items():
|
||||
sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training = k
|
||||
query = jnp.empty((4, sql_q, 4, head_dim))
|
||||
key = jnp.empty((4, sql_v, 4, head_dim))
|
||||
self.assertEqual(check_is_flash_attention(query, key, cudnn_version, has_bias, is_training), v)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user