1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

Merge pull request from Cjkkkk:segment_ids

PiperOrigin-RevId: 722650439
This commit is contained in:
jax authors 2025-02-03 07:28:25 -08:00
commit 7164c6ba3e
2 changed files with 350 additions and 102 deletions

@ -119,9 +119,11 @@ def element_type_to_backend_config_type_mapping(dtype):
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]
def get_max_seg_per_batch(q_offsets):
return q_offsets.shape[1] - 1 if len(q_offsets.shape) == 2 else 1
def create_dot_product_attention_backend_config_base(
batch, num_heads, seq_q, seq_kv, dtype,fmha_scale, mask_type, layout, is_bwd
batch, num_heads, seq_q, seq_kv, dtype, fmha_scale, mask_type, layout, is_bwd
):
# Q, K, V: query, key, value in shape of BT(S)NH or BNT(S)H
# P: BMM1 output in shape of BNTS
@ -226,6 +228,7 @@ def create_dot_product_attention_backend_config(
mask_type,
layout,
sliding_window_length,
max_seg_per_batch,
is_bwd
):
backend_config = create_dot_product_attention_backend_config_base(
@ -237,6 +240,7 @@ def create_dot_product_attention_backend_config(
backend_config['cudnn_fmha_backend_config']["dropout_rate"] = dropout_rate
backend_config['cudnn_fmha_backend_config']["seed"] = seed
backend_config['cudnn_fmha_backend_config']["sliding_window_length"] = sliding_window_length
backend_config['cudnn_fmha_backend_config']["max_seg_per_batch"] = max_seg_per_batch
return json.dumps(backend_config)
def create_dot_product_attention_fp8_backend_config(
@ -268,7 +272,8 @@ get_fp8_custom_call_name = functools.partial(
get_custom_call_name, has_bias=False, has_dropout=False, is_fp8=True
)
def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout):
def check_layout(query, key, value, bias, q_seqlen, kv_seqlen,
q_offsets, kv_offsets, layout):
def check_eq(a, b, c, msg):
if not (a == b == c):
raise ValueError(f"{msg} must be same, got {a}, {b}, {b}")
@ -300,36 +305,36 @@ def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout):
if kS != vS:
raise ValueError(f"KV must have same seq length, got {kS} vs {vS}")
# check bias/q_seqlen/kv_seqlen
# check bias
if bias is not None:
_, _, bT, bS = bias.shape
if bT != qT or bS != vS:
raise ValueError(
f"Bias must have same seq length as QKV, got {bT} and {bS}")
if q_seqlen is not None:
q_seq_dtype = q_seqlen.dtype
q_seq_rank = len(q_seqlen.shape)
if q_seq_dtype != jnp.int32:
raise ValueError(f"q_seqlen must have int32 datatype, got {q_seq_dtype}")
if q_seq_rank != 1:
raise ValueError(f"q_seqlen must have a rank of 1, got {q_seq_rank}")
q_seq_b = q_seqlen.shape[0]
if q_seq_b != qB:
raise ValueError(f"q_seqlen must have same batch as Q, got {q_seq_b}")
if kv_seqlen is not None:
kv_seq_dtype = kv_seqlen.dtype
kv_seq_rank = len(kv_seqlen.shape)
if kv_seq_dtype != jnp.int32:
raise ValueError(
f"kv_seqlen must have int32 datatype, got {kv_seq_dtype}")
if kv_seq_rank != 1:
raise ValueError(f"kv_seq_rank must have a rank of 1, got {kv_seq_rank}")
kv_seq_b = kv_seqlen.shape[0]
if kv_seq_b != qB:
raise ValueError(f"kv_seqlen must have same batch as Q, got {kv_seq_b}")
# check q_seqlen/kv_seqlen/q_offsets/kv_offsets
expected_rank = 2 if q_offsets is not None else 1
def check_seqlen_offsets(tensor, name):
if tensor is not None:
dtype = tensor.dtype
rank = len(tensor.shape)
if dtype != jnp.int32:
raise ValueError(f"{name} must have int32 datatype, got {dtype}")
if rank != expected_rank:
raise ValueError(f"{name} must have a rank of {expected_rank}, got {rank}")
b = tensor.shape[0]
if b != qB:
raise ValueError(f"{name} must have same batch as Q, got {b}")
check_seqlen_offsets(q_seqlen, "q_seqlen")
check_seqlen_offsets(kv_seqlen, "kv_seqlen")
check_seqlen_offsets(q_offsets, "q_offsets")
check_seqlen_offsets(kv_offsets, "kv_offsets")
def check_is_flash_attention(
query, key, layout: int, cudnn_version, has_bias, is_training, is_fp8=False):
query, key, layout: int, cudnn_version, has_bias, is_training, is_packed,
is_fp8=False):
# Extract sequence length (T) and head dim (H) based on layout
if layout == AttentionLayout.BNTH.value:
_, _, T, H = query.shape
@ -363,6 +368,9 @@ def check_is_flash_attention(
f"Unsupported sequence length Q {T}, KV {S}."
)
if is_packed and cudnn_version < 90600:
raise NotImplementedError("Packed layout requires cudnn version >= 9.6.")
def check_cudnn_version():
# check if cuDNN is installed
if cuda_versions is None:
@ -378,78 +386,142 @@ def check_compute_capability(capability):
return current >= target
def _dot_product_attention_fwd(
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, cudnn_version):
# check if flash attention is supported for this attention pattern
check_is_flash_attention(
query, key, layout, cudnn_version, bias is not None, False)
query, key, layout, cudnn_version, bias is not None, False,
get_max_seg_per_batch(q_offsets) > 1)
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=False)
output = outputs[0]
return output
def _dot_product_attention_fwd_rule(
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, cudnn_version):
# check if flash attention is supported for this attention pattern
check_is_flash_attention(
query, key, layout, cudnn_version, bias is not None, True)
query, key, layout, cudnn_version, bias is not None, True,
get_max_seg_per_batch(q_offsets) > 1)
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=True)
res = (query, key, value, bias, q_seqlen, kv_seqlen,
outputs[1], outputs[0])
res = (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets,
kv_offsets, outputs[1], outputs[0])
return outputs[0], res
def _dot_product_attention_bwd_rule(
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, is_training, res, grad_output):
(query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output) = res
(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output) = res
grads = _dot_product_attention_bwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output, grad_output, scale=scale, seed=seed,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output, grad_output, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length
)
grads = (*grads,) + (None,) * (6 - len(grads))
grads = (*grads,) + (None,) * (8 - len(grads))
return grads
def _fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key):
# fix seqlen and offsets to what cuDNN expects in sequence packing.
# cuDNN expects seqlen to have shape [S] where S is the total number of segments
# while the SDPA API accetps seqlen with shape [B, M] where B is the batch and M
# is the maximum number of segments of one batch. B x M is larger than S and seqlen
# is filled with -1 for padded regions. Therefore, we need to shift all non negative
# values to left side to form a correct seqlen. Similar layout is required for
# offsets tensors.
# cuDNN expects offsets to have offset for each segment starting from first segment
# while SDPA API accetps offsets to have offset for each segment starting from
# current batch, therefore we need to calculate accumulative offset of each segment
# starting from first segment.
def _shift_to_left(x, fill_value):
# shift any non-negative value to left
# [[1, 3, -1, -1], [2, 3, 4, -1]]
# -> [[1, 3, 2, 3], [4, -1, -1, -1]]
x_shape = x.shape
x = x.flatten()
size = x.size
indices = jnp.nonzero(x >= 0, size=size, fill_value=size)[0]
y = jnp.take(x, indices, fill_value=fill_value)
return jnp.reshape(y, x_shape)
def _cu_offset(offsets, max_seq):
# calculate accumulative offset by batch
# [[1, 3, 5, 7], [4, 5, -1, -1]], max_seq = 8
# -> [[1, 3, 5, 7], [12, 13, -1, -1]]
batch = offsets.shape[0]
offsets = jnp.where(
offsets >= 0,
offsets + (jnp.arange(batch) * max_seq)[..., jnp.newaxis],
offsets,
)
return offsets
if get_max_seg_per_batch(q_offsets) > 1:
B, T, N, H = query.shape
_, S, _, _ = key.shape
q_seqlen = _shift_to_left(q_seqlen, -1)
kv_seqlen = _shift_to_left(kv_seqlen, -1)
q_offsets = _cu_offset(q_offsets, T)
kv_offsets = _cu_offset(kv_offsets, S)
q_offsets = _shift_to_left(q_offsets, -1)
kv_offsets = _shift_to_left(kv_offsets, -1)
# mark any invalid entries as maximum offset
q_offsets = jnp.where(q_offsets < 0, B * T, q_offsets)
kv_offsets = jnp.where(kv_offsets < 0, B * S, kv_offsets)
# multiply by stride_per_token to get correct offsets
# do it here because real stride changes after sharding
q_offsets = q_offsets * N * H
kv_offsets = kv_offsets * N * H
return q_seqlen, kv_seqlen, q_offsets, kv_offsets
def _dot_product_attention_fwd_impl(
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, is_training):
# args: {Q, K, V, mask*, bias*}
q_seqlen, kv_seqlen, q_offsets, kv_offsets = \
_fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key)
outputs = _dot_product_attention_fwd_p.bind(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=is_training)
return outputs
def _dot_product_attention_bwd_impl(
query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output,
grad_output, scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length):
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output, grad_output, scale, seed, dropout_rate,
variadic_args, mask_type, layout, sliding_window_length):
q_seqlen, kv_seqlen, q_offsets, kv_offsets = \
_fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key)
grads = _dot_product_attention_bwd_p.bind(
query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output, grad_output, scale=scale, seed=seed,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output, grad_output, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length)
return grads
def _dot_product_attention_fwd_abstract(
query, key, value, bias, q_seqlen, kv_seqlen, *, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
*, scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, is_training):
query_dtype = dtypes.canonicalize_dtype(query.dtype)
if layout == AttentionLayout.BNTH.value:
@ -459,7 +531,9 @@ def _dot_product_attention_fwd_abstract(
B, T, N, _ = query.shape
_, S, _, _ = key.shape
output_shape = query.shape
softmax_stat_shape = (B, N, T)
max_seg_per_batch = get_max_seg_per_batch(q_offsets)
softmax_stat_shape = (B * max_seg_per_batch, N, T)
if is_training:
return (
@ -472,9 +546,9 @@ def _dot_product_attention_fwd_abstract(
)
def _dot_product_attention_bwd_abstract(
query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output,
grad_output, *, scale, seed, dropout_rate, variadic_args, mask_type,
layout, sliding_window_length):
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output, grad_output, *, scale, seed, dropout_rate,
variadic_args, mask_type, layout, sliding_window_length):
query_dtype = dtypes.canonicalize_dtype(query.dtype)
key_dtype = dtypes.canonicalize_dtype(key.dtype)
value_dtype = dtypes.canonicalize_dtype(value.dtype)
@ -511,9 +585,9 @@ def _dot_product_attention_bwd_abstract(
)
def _dot_product_attention_fwd_cuda_lowering(
ctx, query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, is_training):
ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets,
kv_offsets, scale, seed, dropout_rate, variadic_args, mask_type,
layout, sliding_window_length, is_training):
query_type = ir.RankedTensorType(query.type)
query_shape = query_type.shape
key_type = ir.RankedTensorType(key.type)
@ -530,24 +604,30 @@ def _dot_product_attention_fwd_cuda_lowering(
output_layout = (3, 1, 2, 0)
output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type))
output_shape = (B, N, T, H)
softmax_stat_shape = (B, N, T)
softmax_stat_shape = (B * max_seg_per_batch, N, T)
workspace_shape = (0,)
workspace_type = ir.IntegerType.get_unsigned(8)
has_bias, _ = variadic_args
backend_config = create_dot_product_attention_backend_config(
B, N, T, S, query_type.element_type, scale, seed, dropout_rate,
mask_type, layout, sliding_window_length, is_bwd=False,
)
# {Q, K, V, bias*, q_seqlen*, kv_seqlen*}
mask_type, layout, sliding_window_length, max_seg_per_batch,
is_bwd=False)
# {Q, K, V, bias*, q_seqlen*, kv_seqlen*, q_offsets*, kv_offsets*}}
# {output, activation*, workspace}
has_dropout = dropout_rate > 0
has_bias, _ = variadic_args
operands = [query, key, value]
if has_bias:
operands.append(bias)
if has_padding(mask_type):
if has_padding(mask_type) or max_seg_per_batch > 1:
operands.append(q_seqlen)
operands.append(kv_seqlen)
if max_seg_per_batch > 1:
operands.append(q_offsets)
operands.append(kv_offsets)
custom_call_name = get_custom_call_name(has_bias, has_dropout, False)
if is_training:
@ -581,9 +661,9 @@ def _dot_product_attention_fwd_cuda_lowering(
return [hlo.transpose(out.results[0], output_transpose_perm)]
def _dot_product_attention_bwd_cuda_lowering(
ctx, query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output, grad_output, scale, seed, dropout_rate, variadic_args,
mask_type, layout, sliding_window_length):
ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output, grad_output, scale, seed, dropout_rate,
variadic_args, mask_type, layout, sliding_window_length):
query_type = ir.RankedTensorType(query.type)
query_shape = query_type.shape
key_type = ir.RankedTensorType(key.type)
@ -607,23 +687,29 @@ def _dot_product_attention_bwd_cuda_lowering(
grad_query_shape = (B, q_N, T, H)
grad_key_shape = (B, k_N, S, H)
grad_value_shape = (B, k_N, S, H)
has_bias, has_dbias = variadic_args
max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type))
backend_config = create_dot_product_attention_backend_config(
B, q_N, T, S, query_type.element_type, scale, seed, dropout_rate,
mask_type, layout, sliding_window_length, is_bwd=True,
)
# {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*}
mask_type, layout, sliding_window_length, max_seg_per_batch,
is_bwd=True)
# {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*,
# q_offsets*, kv_offsets*}
# {dQ, dK, dV, dbias*, workspace}
has_dropout = dropout_rate > 0
has_bias, has_dbias = variadic_args
# create operands
operands = [query, key, value, activation, grad_output]
if has_bias:
# flash attention requires bias in the bwd for remat
operands.append(bias)
operands.append(fwd_output)
if has_padding(mask_type):
if has_padding(mask_type) or max_seg_per_batch > 1:
operands.append(q_seqlen)
operands.append(kv_seqlen)
if max_seg_per_batch > 1:
operands.append(q_offsets)
operands.append(kv_offsets)
# get custom call name
custom_call_name = get_custom_call_name(has_bias, has_dropout, True)
@ -674,7 +760,8 @@ def _dot_product_attention_fwd_batcher(
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
mask_type, layout, sliding_window_length, is_training):
_check_valid_batch_dims(batch_dims)
query, key, value, bias, q_seqlen, kv_seqlen = batched_args
query, key, value, bias, q_seqlen, kv_seqlen, \
q_offsets, kv_offsets = batched_args
query_bdim = batch_dims[0]
if is_training:
out_bdims = query_bdim, query_bdim
@ -701,9 +788,9 @@ def _dot_product_attention_fwd_batcher(
kv_seqlen = jnp.reshape(kv_seqlen, (B, ))
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=is_training)
# reshape to original shape
@ -720,8 +807,8 @@ def _dot_product_attention_bwd_batcher(
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
mask_type, layout, sliding_window_length):
_check_valid_batch_dims(batch_dims)
query, key, value, bias, q_seqlen, \
kv_seqlen, activation, fwd_output, grad_output = batched_args
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, \
activation, fwd_output, grad_output = batched_args
query_bdim = batch_dims[0]
out_bdims = query_bdim, query_bdim, query_bdim
@ -757,8 +844,8 @@ def _dot_product_attention_bwd_batcher(
grad_output = jnp.reshape(grad_output, (B,) + query.shape[-3:])
grads = _dot_product_attention_bwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output, grad_output, scale=scale, seed=seed,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output, grad_output, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length,
@ -834,7 +921,7 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args,is_training, layo
return [out_sharding]
_dot_product_attention_fwd_lower = custom_partitioning(
_dot_product_attention_fwd_impl, static_argnums=(6, 7, 8, 9, 10, 11, 12, 13))
_dot_product_attention_fwd_impl, static_argnums=(8, 9, 10, 11, 12, 13, 14, 15))
def _dot_product_attention_fwd_infer_sharding_from_operands(
scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length,
@ -883,7 +970,7 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, layout, variadic_args):
return out_shardings
_dot_product_attention_bwd_lower = custom_partitioning(
_dot_product_attention_bwd_impl, static_argnums=(9, 10, 11, 12, 13, 14, 15)
_dot_product_attention_bwd_impl, static_argnums=(11, 12, 13, 14, 15, 16, 17)
)
def _dot_product_attention_bwd_infer_sharding_from_operands(
@ -1003,13 +1090,15 @@ dispatch.prim_requires_devices_during_lowering.add(
_dot_product_attention_bwd_p_wrapper
)
@functools.partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12, 13))
@functools.partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15))
def _dot_product_attention(query: Array,
key: Array,
value: Array,
bias: Array,
q_seqlen: Array,
kv_seqlen: Array,
q_offsets: Array,
kv_offsets: Array,
scale: float,
seed: int,
dropout_rate: float,
@ -1019,9 +1108,10 @@ def _dot_product_attention(query: Array,
sliding_window_length: int | None,
cudnn_version: int):
output = _dot_product_attention_fwd(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length,
cudnn_version=cudnn_version)
return output
@ -1612,6 +1702,8 @@ def dot_product_attention(
mask: Array | None = None,
q_seqlen: Array | None = None,
kv_seqlen: Array | None = None,
q_offsets: Array | None = None,
kv_offsets: Array | None = None,
fp8_params: FP8Params | None = None,
*,
scale: float = 1.0,
@ -1647,8 +1739,26 @@ def dot_product_attention(
value: Values to be used in attention with a shape of BSNH or BNSH.
bias: Bias to be added to logits with a shape of BNTS.
mask: Mask used to filter out logits with a shape of BNTS.
q_seqlen: Non padded sequence length of Queries with a shape of B.
kv_seqlen: Non padded sequence length of Keys and Values with a shape of B.
q_seqlen: Non padded sequence length of query with a shape of B.
If q_offsets is set, q_seqlen should have shape [B,M] where M is the
maximum number of segments per batch. For batch that has less segments
than maximum segments, fill the padded entries with -1.
kv_seqlen: Non padded sequence length of key and value with a shape of B.
If kv_offsets is set, kv_seqlen should have shape [B,M] where M is the
maximum number of segments per batch. For batch that has less segments
than maximum segments, fill the padded entries with -1.
q_offsets: offset of each segment packed in query with a shape of [B,M+1]
where M is the maximum number of segments per batch. For batch that has
less segments than maximum segments, fill the padded entries with -1.
E.g, if 2 batches has 3 and 2 segments respectively, each segment has
size 1, q_offsets = [[0,1,2,-1], [0,1,-1,-1]]. q_seqlen should be set
to indicate the size of each segment.
kv_offsets: offset of each segment packed in key with a shape of [B,M+1]
where M is the maximum number of segments per batch. For batch that has
less segments than maximum segments, fill the padded entries with -1.
E.g, if 2 batches has 3 and 2 segments respectively, each segment has
size 1, kv_offsets = [[0,1,2,-1], [0,1,-1,-1]]. kv_seqlen should be set
to indicate the size of each segment.
scale: Scale for the query.
dropout_rate: Dropout rate.
qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH,
@ -1679,7 +1789,7 @@ def dot_product_attention(
f"but got: bias={bias}, mask={mask}, q_seqlen={q_seqlen}, kv_seqlen={kv_seqlen}"
)
check_fp8_params(fp8_params)
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout)
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout)
output, amax_s, amax_o = _dot_product_attention_fp8(
query, key, value, fp8_params,
scale, mask_type == MaskType.CAUSAL, layout.value, cudnn_version
@ -1691,6 +1801,8 @@ def dot_product_attention(
if sliding_window_length is not None and sliding_window_length <= 0:
raise ValueError(
f"Require sliding_window_length > 0, got {sliding_window_length}")
if q_offsets is not None and (q_seqlen is None or kv_seqlen is None):
raise ValueError("Require q_seqlen and kv_seqlen to use packed layout")
if bias is not None:
# reshape bias to have 4D shape
@ -1712,7 +1824,7 @@ def dot_product_attention(
bias = bias + mask
# check if input shape and data type is compatiable
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout)
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout)
has_bias = bias is not None
has_dbias = has_bias and \
should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr]
@ -1724,8 +1836,12 @@ def dot_product_attention(
q_seqlen = jnp.zeros(0, dtype=query.dtype)
if kv_seqlen is None:
kv_seqlen = jnp.zeros(0, dtype=query.dtype)
if q_offsets is None:
q_offsets = jnp.zeros(0, dtype=query.dtype)
if kv_offsets is None:
kv_offsets = jnp.zeros(0, dtype=query.dtype)
output = _dot_product_attention(
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout.value, sliding_window_length,
cudnn_version)
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale, seed, dropout_rate, variadic_args, mask_type, layout.value,
sliding_window_length, cudnn_version)
return output

@ -96,6 +96,10 @@ def sdpa_train(query: Array,
grad: Array,
bias: Array | None = None,
mask: Array | None = None,
q_seqlen: Array | None = None,
kv_seqlen: Array | None = None,
q_offsets: Array | None = None,
kv_offsets: Array | None = None,
scale: float = 0.5,
mask_type: MaskType = MaskType.NO_MASK,
is_bnth: bool = False,
@ -107,15 +111,13 @@ def sdpa_train(query: Array,
else:
B, S, _, _ = query.shape
q_seqlen = kv_seqlen = jnp.full((B,), S // 2, jnp.int32)
else:
q_seqlen = kv_seqlen = None
out, sdpa_vjp = jax.vjp(
partial(dot_product_attention, scale=scale, mask_type=mask_type,
dropout_rate=dropout_rate,
qkv_layout="BNTH" if is_bnth else "BTNH",
sliding_window_length=sliding_window_length),
query, key, value, bias, mask, q_seqlen, kv_seqlen)
query_grad, key_grad, value_grad, bias_grad, _, _, _ = sdpa_vjp(grad)
query, key, value, bias, mask, q_seqlen, kv_seqlen, q_offsets, kv_offsets)
query_grad, key_grad, value_grad, bias_grad = sdpa_vjp(grad)[:4]
if bias is not None and len(bias.shape) == 3:
# has dbias
return out, (query_grad, key_grad, value_grad, bias_grad)
@ -612,6 +614,136 @@ class DotProductAttentionTest(jtu.JaxTestCase):
self.assertArraysAllClose(grads_ref[1], grads_ans[1])
self.assertArraysAllClose(grads_ref[2], grads_ans[2])
@jtu.run_on_devices("cuda")
def test_sdpa_packed_layout(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
try:
cudnn_version = check_cudnn_version()
except RuntimeError as e:
self.skipTest(str(e))
return
if cudnn_version < 90600:
self.skipTest("Requires >= cuDNN 9.6.0")
k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
query = jax.random.normal(
k1, (4, 512, 4, 64), dtype=jnp.bfloat16)
key = jax.random.normal(
k2, (4, 512, 4, 64), dtype=jnp.bfloat16)
value = jax.random.normal(
k3, (4, 512, 4, 64), dtype=jnp.bfloat16)
grad = jax.random.normal(
k4, (4, 512, 4, 64), dtype=jnp.bfloat16)
def generate_padding_mask(segment_ids, padding_id, shape, dtype):
# segment_ids [B, T]
encoded_padding = jnp.where(segment_ids >= padding_id, 0, 1).astype(dtype)
return jax.lax.broadcast_in_dim(
encoded_padding, shape, broadcast_dimensions=[0, 1])
def generate_segment_mask(segment_ids, dtype):
segment_ids_1 = jnp.expand_dims(segment_ids, axis=-1)
# segment_ids_1 = jnp.where(segment_ids_1 == 3, 4, segment_ids_1)
segment_ids_2 = jnp.expand_dims(segment_ids, axis=1)
mask = jnp.not_equal(segment_ids_1, segment_ids_2).astype(dtype)
# broadcast to [B, N, T, T]
mask = jnp.expand_dims(mask, 1)
mask *= get_large_negative_number(dtype)
return mask
# starting pos of each segment
q_offsets = jnp.asarray([
[0, 170, 340, -1], # 3 segments
[0, 150, 340, -1], # 3 segments
[0, 190, -1, -1], # 2 segments
[0, -1, -1, -1] # 1 segment
], dtype=np.int32)
# actual seqlen of each segment without padding
q_seqlen = jnp.asarray([
[170, 170, 172], # No padding inside each segment
[150, 187, 172], # 3 padding tokens inside second segment
[190, 190, -1], # 132 padding tokens inside last segment
[400, -1, -1], # 112 padding tokens inside last segment
], dtype=np.int32)
# maximum number of segments is id for padding token
segment_ids = jnp.asarray([
[0]*170 + [1]*170 + [2]*172,
[0]*150 + [1]*187 + [3]*3 + [2]*172,
[0]*190 + [1]*190 + [3]*132,
[0]*400 + [3]*112,
], dtype=np.int32)
kv_offsets = q_offsets.copy()
kv_seqlen = q_seqlen.copy()
mask = generate_padding_mask(segment_ids, q_seqlen.shape[1], query.shape, query.dtype)
bias = generate_segment_mask(segment_ids, query.dtype)
devices = np.array(jax.local_devices()[:4])
devices = devices.reshape((2, 2))
with Mesh(devices, ("dp", "tp")) as mesh:
qkv_spec = PartitionSpec("dp", None, "tp", None)
qkv_sharding = NamedSharding(mesh, qkv_spec)
bias_spec = PartitionSpec("dp", None, None, None)
bias_sharding = NamedSharding(mesh, bias_spec)
offsets_specs = PartitionSpec("dp", None)
offsets_sharding = NamedSharding(mesh, offsets_specs)
query = jax.device_put(query, qkv_sharding)
key = jax.device_put(key, qkv_sharding)
value = jax.device_put(value, qkv_sharding)
grad = jax.device_put(grad, qkv_sharding)
bias = jax.device_put(bias, bias_sharding)
q_offsets = jax.device_put(q_offsets, offsets_sharding)
kv_offsets = jax.device_put(kv_offsets, offsets_sharding)
q_seqlen = jax.device_put(q_seqlen, offsets_sharding)
kv_seqlen = jax.device_put(kv_seqlen, offsets_sharding)
jitted_sdpa_train = jax.jit(
partial(
sdpa_train, scale=0.1, mask_type=MaskType.NO_MASK, dropout_rate=0),
in_shardings=(qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding,
None, None, offsets_sharding, offsets_sharding, offsets_sharding, offsets_sharding),
out_shardings=(qkv_sharding, (qkv_sharding, qkv_sharding, qkv_sharding))
)
jitted_sdpa_train_ref = jax.jit(
partial(
sdpa_train_ref, scale=0.1, mask_type=MaskType.NO_MASK, dropout_rate=0),
in_shardings=(qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding,
bias_sharding),
out_shardings=(qkv_sharding, (qkv_sharding, qkv_sharding, qkv_sharding))
)
query = query * mask
key = key * mask
value = value * mask
grad = grad * mask
out, (query_grad, key_grad, value_grad) = \
jitted_sdpa_train(query, key, value, grad, None, None, q_seqlen, kv_seqlen, q_offsets, kv_offsets)
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \
jitted_sdpa_train_ref(query, key, value, grad, bias)
out = out * mask
out_ref = out_ref * mask
query_grad = query_grad * mask
query_grad_ref = query_grad_ref * mask
key_grad = key_grad * mask
key_grad_ref = key_grad_ref * mask
value_grad = value_grad * mask
value_grad_ref = value_grad_ref * mask
self.assertArraysAllClose(out_ref, out, rtol=1e-2, atol=1e-2)
self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2)
self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-2, atol=1e-2)
self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-2, atol=1e-2)
@jtu.run_on_devices("cuda")
def test_layouts(self):
if jax.device_count() < 4: