mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #25812 from Cjkkkk:segment_ids
PiperOrigin-RevId: 722650439
This commit is contained in:
commit
7164c6ba3e
@ -119,9 +119,11 @@ def element_type_to_backend_config_type_mapping(dtype):
|
||||
def default_layouts(*shapes):
|
||||
return [range(len(shape) - 1, -1, -1) for shape in shapes]
|
||||
|
||||
def get_max_seg_per_batch(q_offsets):
|
||||
return q_offsets.shape[1] - 1 if len(q_offsets.shape) == 2 else 1
|
||||
|
||||
def create_dot_product_attention_backend_config_base(
|
||||
batch, num_heads, seq_q, seq_kv, dtype,fmha_scale, mask_type, layout, is_bwd
|
||||
batch, num_heads, seq_q, seq_kv, dtype, fmha_scale, mask_type, layout, is_bwd
|
||||
):
|
||||
# Q, K, V: query, key, value in shape of BT(S)NH or BNT(S)H
|
||||
# P: BMM1 output in shape of BNTS
|
||||
@ -226,6 +228,7 @@ def create_dot_product_attention_backend_config(
|
||||
mask_type,
|
||||
layout,
|
||||
sliding_window_length,
|
||||
max_seg_per_batch,
|
||||
is_bwd
|
||||
):
|
||||
backend_config = create_dot_product_attention_backend_config_base(
|
||||
@ -237,6 +240,7 @@ def create_dot_product_attention_backend_config(
|
||||
backend_config['cudnn_fmha_backend_config']["dropout_rate"] = dropout_rate
|
||||
backend_config['cudnn_fmha_backend_config']["seed"] = seed
|
||||
backend_config['cudnn_fmha_backend_config']["sliding_window_length"] = sliding_window_length
|
||||
backend_config['cudnn_fmha_backend_config']["max_seg_per_batch"] = max_seg_per_batch
|
||||
return json.dumps(backend_config)
|
||||
|
||||
def create_dot_product_attention_fp8_backend_config(
|
||||
@ -268,7 +272,8 @@ get_fp8_custom_call_name = functools.partial(
|
||||
get_custom_call_name, has_bias=False, has_dropout=False, is_fp8=True
|
||||
)
|
||||
|
||||
def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout):
|
||||
def check_layout(query, key, value, bias, q_seqlen, kv_seqlen,
|
||||
q_offsets, kv_offsets, layout):
|
||||
def check_eq(a, b, c, msg):
|
||||
if not (a == b == c):
|
||||
raise ValueError(f"{msg} must be same, got {a}, {b}, {b}")
|
||||
@ -300,36 +305,36 @@ def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout):
|
||||
if kS != vS:
|
||||
raise ValueError(f"KV must have same seq length, got {kS} vs {vS}")
|
||||
|
||||
# check bias/q_seqlen/kv_seqlen
|
||||
# check bias
|
||||
if bias is not None:
|
||||
_, _, bT, bS = bias.shape
|
||||
if bT != qT or bS != vS:
|
||||
raise ValueError(
|
||||
f"Bias must have same seq length as QKV, got {bT} and {bS}")
|
||||
if q_seqlen is not None:
|
||||
q_seq_dtype = q_seqlen.dtype
|
||||
q_seq_rank = len(q_seqlen.shape)
|
||||
if q_seq_dtype != jnp.int32:
|
||||
raise ValueError(f"q_seqlen must have int32 datatype, got {q_seq_dtype}")
|
||||
if q_seq_rank != 1:
|
||||
raise ValueError(f"q_seqlen must have a rank of 1, got {q_seq_rank}")
|
||||
q_seq_b = q_seqlen.shape[0]
|
||||
if q_seq_b != qB:
|
||||
raise ValueError(f"q_seqlen must have same batch as Q, got {q_seq_b}")
|
||||
if kv_seqlen is not None:
|
||||
kv_seq_dtype = kv_seqlen.dtype
|
||||
kv_seq_rank = len(kv_seqlen.shape)
|
||||
if kv_seq_dtype != jnp.int32:
|
||||
raise ValueError(
|
||||
f"kv_seqlen must have int32 datatype, got {kv_seq_dtype}")
|
||||
if kv_seq_rank != 1:
|
||||
raise ValueError(f"kv_seq_rank must have a rank of 1, got {kv_seq_rank}")
|
||||
kv_seq_b = kv_seqlen.shape[0]
|
||||
if kv_seq_b != qB:
|
||||
raise ValueError(f"kv_seqlen must have same batch as Q, got {kv_seq_b}")
|
||||
|
||||
# check q_seqlen/kv_seqlen/q_offsets/kv_offsets
|
||||
expected_rank = 2 if q_offsets is not None else 1
|
||||
def check_seqlen_offsets(tensor, name):
|
||||
if tensor is not None:
|
||||
dtype = tensor.dtype
|
||||
rank = len(tensor.shape)
|
||||
if dtype != jnp.int32:
|
||||
raise ValueError(f"{name} must have int32 datatype, got {dtype}")
|
||||
if rank != expected_rank:
|
||||
raise ValueError(f"{name} must have a rank of {expected_rank}, got {rank}")
|
||||
b = tensor.shape[0]
|
||||
if b != qB:
|
||||
raise ValueError(f"{name} must have same batch as Q, got {b}")
|
||||
|
||||
check_seqlen_offsets(q_seqlen, "q_seqlen")
|
||||
check_seqlen_offsets(kv_seqlen, "kv_seqlen")
|
||||
check_seqlen_offsets(q_offsets, "q_offsets")
|
||||
check_seqlen_offsets(kv_offsets, "kv_offsets")
|
||||
|
||||
|
||||
def check_is_flash_attention(
|
||||
query, key, layout: int, cudnn_version, has_bias, is_training, is_fp8=False):
|
||||
query, key, layout: int, cudnn_version, has_bias, is_training, is_packed,
|
||||
is_fp8=False):
|
||||
# Extract sequence length (T) and head dim (H) based on layout
|
||||
if layout == AttentionLayout.BNTH.value:
|
||||
_, _, T, H = query.shape
|
||||
@ -363,6 +368,9 @@ def check_is_flash_attention(
|
||||
f"Unsupported sequence length Q {T}, KV {S}."
|
||||
)
|
||||
|
||||
if is_packed and cudnn_version < 90600:
|
||||
raise NotImplementedError("Packed layout requires cudnn version >= 9.6.")
|
||||
|
||||
def check_cudnn_version():
|
||||
# check if cuDNN is installed
|
||||
if cuda_versions is None:
|
||||
@ -378,78 +386,142 @@ def check_compute_capability(capability):
|
||||
return current >= target
|
||||
|
||||
def _dot_product_attention_fwd(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
|
||||
dropout_rate, variadic_args, mask_type, layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale, seed, dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length, cudnn_version):
|
||||
# check if flash attention is supported for this attention pattern
|
||||
check_is_flash_attention(
|
||||
query, key, layout, cudnn_version, bias is not None, False)
|
||||
query, key, layout, cudnn_version, bias is not None, False,
|
||||
get_max_seg_per_batch(q_offsets) > 1)
|
||||
outputs = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
|
||||
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length, is_training=False)
|
||||
output = outputs[0]
|
||||
return output
|
||||
|
||||
def _dot_product_attention_fwd_rule(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
|
||||
dropout_rate, variadic_args, mask_type, layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale, seed, dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length, cudnn_version):
|
||||
# check if flash attention is supported for this attention pattern
|
||||
check_is_flash_attention(
|
||||
query, key, layout, cudnn_version, bias is not None, True)
|
||||
query, key, layout, cudnn_version, bias is not None, True,
|
||||
get_max_seg_per_batch(q_offsets) > 1)
|
||||
outputs = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
|
||||
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length, is_training=True)
|
||||
res = (query, key, value, bias, q_seqlen, kv_seqlen,
|
||||
outputs[1], outputs[0])
|
||||
res = (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets,
|
||||
kv_offsets, outputs[1], outputs[0])
|
||||
return outputs[0], res
|
||||
|
||||
def _dot_product_attention_bwd_rule(
|
||||
scale, seed, dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length, is_training, res, grad_output):
|
||||
(query, key, value, bias, q_seqlen, kv_seqlen, activation,
|
||||
fwd_output) = res
|
||||
(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output) = res
|
||||
grads = _dot_product_attention_bwd_p_wrapper.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, activation,
|
||||
fwd_output, grad_output, scale=scale, seed=seed,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output, grad_output, scale=scale, seed=seed,
|
||||
dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length
|
||||
)
|
||||
grads = (*grads,) + (None,) * (6 - len(grads))
|
||||
grads = (*grads,) + (None,) * (8 - len(grads))
|
||||
return grads
|
||||
|
||||
def _fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key):
|
||||
# fix seqlen and offsets to what cuDNN expects in sequence packing.
|
||||
# cuDNN expects seqlen to have shape [S] where S is the total number of segments
|
||||
# while the SDPA API accetps seqlen with shape [B, M] where B is the batch and M
|
||||
# is the maximum number of segments of one batch. B x M is larger than S and seqlen
|
||||
# is filled with -1 for padded regions. Therefore, we need to shift all non negative
|
||||
# values to left side to form a correct seqlen. Similar layout is required for
|
||||
# offsets tensors.
|
||||
# cuDNN expects offsets to have offset for each segment starting from first segment
|
||||
# while SDPA API accetps offsets to have offset for each segment starting from
|
||||
# current batch, therefore we need to calculate accumulative offset of each segment
|
||||
# starting from first segment.
|
||||
def _shift_to_left(x, fill_value):
|
||||
# shift any non-negative value to left
|
||||
# [[1, 3, -1, -1], [2, 3, 4, -1]]
|
||||
# -> [[1, 3, 2, 3], [4, -1, -1, -1]]
|
||||
x_shape = x.shape
|
||||
x = x.flatten()
|
||||
size = x.size
|
||||
indices = jnp.nonzero(x >= 0, size=size, fill_value=size)[0]
|
||||
y = jnp.take(x, indices, fill_value=fill_value)
|
||||
return jnp.reshape(y, x_shape)
|
||||
|
||||
def _cu_offset(offsets, max_seq):
|
||||
# calculate accumulative offset by batch
|
||||
# [[1, 3, 5, 7], [4, 5, -1, -1]], max_seq = 8
|
||||
# -> [[1, 3, 5, 7], [12, 13, -1, -1]]
|
||||
batch = offsets.shape[0]
|
||||
offsets = jnp.where(
|
||||
offsets >= 0,
|
||||
offsets + (jnp.arange(batch) * max_seq)[..., jnp.newaxis],
|
||||
offsets,
|
||||
)
|
||||
return offsets
|
||||
|
||||
if get_max_seg_per_batch(q_offsets) > 1:
|
||||
B, T, N, H = query.shape
|
||||
_, S, _, _ = key.shape
|
||||
|
||||
q_seqlen = _shift_to_left(q_seqlen, -1)
|
||||
kv_seqlen = _shift_to_left(kv_seqlen, -1)
|
||||
|
||||
q_offsets = _cu_offset(q_offsets, T)
|
||||
kv_offsets = _cu_offset(kv_offsets, S)
|
||||
q_offsets = _shift_to_left(q_offsets, -1)
|
||||
kv_offsets = _shift_to_left(kv_offsets, -1)
|
||||
|
||||
# mark any invalid entries as maximum offset
|
||||
q_offsets = jnp.where(q_offsets < 0, B * T, q_offsets)
|
||||
kv_offsets = jnp.where(kv_offsets < 0, B * S, kv_offsets)
|
||||
|
||||
# multiply by stride_per_token to get correct offsets
|
||||
# do it here because real stride changes after sharding
|
||||
q_offsets = q_offsets * N * H
|
||||
kv_offsets = kv_offsets * N * H
|
||||
|
||||
return q_seqlen, kv_seqlen, q_offsets, kv_offsets
|
||||
|
||||
def _dot_product_attention_fwd_impl(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
|
||||
dropout_rate, variadic_args, mask_type, layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale, seed, dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length, is_training):
|
||||
# args: {Q, K, V, mask*, bias*}
|
||||
q_seqlen, kv_seqlen, q_offsets, kv_offsets = \
|
||||
_fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key)
|
||||
outputs = _dot_product_attention_fwd_p.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
|
||||
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length, is_training=is_training)
|
||||
return outputs
|
||||
|
||||
def _dot_product_attention_bwd_impl(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output,
|
||||
grad_output, scale, seed, dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length):
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output, grad_output, scale, seed, dropout_rate,
|
||||
variadic_args, mask_type, layout, sliding_window_length):
|
||||
q_seqlen, kv_seqlen, q_offsets, kv_offsets = \
|
||||
_fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key)
|
||||
grads = _dot_product_attention_bwd_p.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, activation,
|
||||
fwd_output, grad_output, scale=scale, seed=seed,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output, grad_output, scale=scale, seed=seed,
|
||||
dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length)
|
||||
return grads
|
||||
|
||||
def _dot_product_attention_fwd_abstract(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, *, scale, seed,
|
||||
dropout_rate, variadic_args, mask_type, layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
*, scale, seed, dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length, is_training):
|
||||
query_dtype = dtypes.canonicalize_dtype(query.dtype)
|
||||
if layout == AttentionLayout.BNTH.value:
|
||||
@ -459,7 +531,9 @@ def _dot_product_attention_fwd_abstract(
|
||||
B, T, N, _ = query.shape
|
||||
_, S, _, _ = key.shape
|
||||
output_shape = query.shape
|
||||
softmax_stat_shape = (B, N, T)
|
||||
|
||||
max_seg_per_batch = get_max_seg_per_batch(q_offsets)
|
||||
softmax_stat_shape = (B * max_seg_per_batch, N, T)
|
||||
|
||||
if is_training:
|
||||
return (
|
||||
@ -472,9 +546,9 @@ def _dot_product_attention_fwd_abstract(
|
||||
)
|
||||
|
||||
def _dot_product_attention_bwd_abstract(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output,
|
||||
grad_output, *, scale, seed, dropout_rate, variadic_args, mask_type,
|
||||
layout, sliding_window_length):
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output, grad_output, *, scale, seed, dropout_rate,
|
||||
variadic_args, mask_type, layout, sliding_window_length):
|
||||
query_dtype = dtypes.canonicalize_dtype(query.dtype)
|
||||
key_dtype = dtypes.canonicalize_dtype(key.dtype)
|
||||
value_dtype = dtypes.canonicalize_dtype(value.dtype)
|
||||
@ -511,9 +585,9 @@ def _dot_product_attention_bwd_abstract(
|
||||
)
|
||||
|
||||
def _dot_product_attention_fwd_cuda_lowering(
|
||||
ctx, query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
|
||||
dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length, is_training):
|
||||
ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets,
|
||||
kv_offsets, scale, seed, dropout_rate, variadic_args, mask_type,
|
||||
layout, sliding_window_length, is_training):
|
||||
query_type = ir.RankedTensorType(query.type)
|
||||
query_shape = query_type.shape
|
||||
key_type = ir.RankedTensorType(key.type)
|
||||
@ -530,24 +604,30 @@ def _dot_product_attention_fwd_cuda_lowering(
|
||||
output_layout = (3, 1, 2, 0)
|
||||
output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
|
||||
|
||||
max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type))
|
||||
output_shape = (B, N, T, H)
|
||||
softmax_stat_shape = (B, N, T)
|
||||
softmax_stat_shape = (B * max_seg_per_batch, N, T)
|
||||
workspace_shape = (0,)
|
||||
workspace_type = ir.IntegerType.get_unsigned(8)
|
||||
|
||||
has_bias, _ = variadic_args
|
||||
backend_config = create_dot_product_attention_backend_config(
|
||||
B, N, T, S, query_type.element_type, scale, seed, dropout_rate,
|
||||
mask_type, layout, sliding_window_length, is_bwd=False,
|
||||
)
|
||||
# {Q, K, V, bias*, q_seqlen*, kv_seqlen*}
|
||||
mask_type, layout, sliding_window_length, max_seg_per_batch,
|
||||
is_bwd=False)
|
||||
# {Q, K, V, bias*, q_seqlen*, kv_seqlen*, q_offsets*, kv_offsets*}}
|
||||
# {output, activation*, workspace}
|
||||
has_dropout = dropout_rate > 0
|
||||
has_bias, _ = variadic_args
|
||||
operands = [query, key, value]
|
||||
if has_bias:
|
||||
operands.append(bias)
|
||||
if has_padding(mask_type):
|
||||
if has_padding(mask_type) or max_seg_per_batch > 1:
|
||||
operands.append(q_seqlen)
|
||||
operands.append(kv_seqlen)
|
||||
if max_seg_per_batch > 1:
|
||||
operands.append(q_offsets)
|
||||
operands.append(kv_offsets)
|
||||
|
||||
custom_call_name = get_custom_call_name(has_bias, has_dropout, False)
|
||||
|
||||
if is_training:
|
||||
@ -581,9 +661,9 @@ def _dot_product_attention_fwd_cuda_lowering(
|
||||
return [hlo.transpose(out.results[0], output_transpose_perm)]
|
||||
|
||||
def _dot_product_attention_bwd_cuda_lowering(
|
||||
ctx, query, key, value, bias, q_seqlen, kv_seqlen, activation,
|
||||
fwd_output, grad_output, scale, seed, dropout_rate, variadic_args,
|
||||
mask_type, layout, sliding_window_length):
|
||||
ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output, grad_output, scale, seed, dropout_rate,
|
||||
variadic_args, mask_type, layout, sliding_window_length):
|
||||
query_type = ir.RankedTensorType(query.type)
|
||||
query_shape = query_type.shape
|
||||
key_type = ir.RankedTensorType(key.type)
|
||||
@ -607,23 +687,29 @@ def _dot_product_attention_bwd_cuda_lowering(
|
||||
grad_query_shape = (B, q_N, T, H)
|
||||
grad_key_shape = (B, k_N, S, H)
|
||||
grad_value_shape = (B, k_N, S, H)
|
||||
|
||||
has_bias, has_dbias = variadic_args
|
||||
max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type))
|
||||
backend_config = create_dot_product_attention_backend_config(
|
||||
B, q_N, T, S, query_type.element_type, scale, seed, dropout_rate,
|
||||
mask_type, layout, sliding_window_length, is_bwd=True,
|
||||
)
|
||||
# {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*}
|
||||
mask_type, layout, sliding_window_length, max_seg_per_batch,
|
||||
is_bwd=True)
|
||||
# {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*,
|
||||
# q_offsets*, kv_offsets*}
|
||||
# {dQ, dK, dV, dbias*, workspace}
|
||||
has_dropout = dropout_rate > 0
|
||||
has_bias, has_dbias = variadic_args
|
||||
# create operands
|
||||
operands = [query, key, value, activation, grad_output]
|
||||
if has_bias:
|
||||
# flash attention requires bias in the bwd for remat
|
||||
operands.append(bias)
|
||||
operands.append(fwd_output)
|
||||
if has_padding(mask_type):
|
||||
if has_padding(mask_type) or max_seg_per_batch > 1:
|
||||
operands.append(q_seqlen)
|
||||
operands.append(kv_seqlen)
|
||||
if max_seg_per_batch > 1:
|
||||
operands.append(q_offsets)
|
||||
operands.append(kv_offsets)
|
||||
# get custom call name
|
||||
custom_call_name = get_custom_call_name(has_bias, has_dropout, True)
|
||||
|
||||
@ -674,7 +760,8 @@ def _dot_product_attention_fwd_batcher(
|
||||
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
|
||||
mask_type, layout, sliding_window_length, is_training):
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
query, key, value, bias, q_seqlen, kv_seqlen = batched_args
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, \
|
||||
q_offsets, kv_offsets = batched_args
|
||||
query_bdim = batch_dims[0]
|
||||
if is_training:
|
||||
out_bdims = query_bdim, query_bdim
|
||||
@ -701,9 +788,9 @@ def _dot_product_attention_fwd_batcher(
|
||||
kv_seqlen = jnp.reshape(kv_seqlen, (B, ))
|
||||
|
||||
outputs = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
|
||||
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length, is_training=is_training)
|
||||
|
||||
# reshape to original shape
|
||||
@ -720,8 +807,8 @@ def _dot_product_attention_bwd_batcher(
|
||||
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
|
||||
mask_type, layout, sliding_window_length):
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
query, key, value, bias, q_seqlen, \
|
||||
kv_seqlen, activation, fwd_output, grad_output = batched_args
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, \
|
||||
activation, fwd_output, grad_output = batched_args
|
||||
query_bdim = batch_dims[0]
|
||||
out_bdims = query_bdim, query_bdim, query_bdim
|
||||
|
||||
@ -757,8 +844,8 @@ def _dot_product_attention_bwd_batcher(
|
||||
grad_output = jnp.reshape(grad_output, (B,) + query.shape[-3:])
|
||||
|
||||
grads = _dot_product_attention_bwd_p_wrapper.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, activation,
|
||||
fwd_output, grad_output, scale=scale, seed=seed,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output, grad_output, scale=scale, seed=seed,
|
||||
dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length,
|
||||
@ -834,7 +921,7 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args,is_training, layo
|
||||
return [out_sharding]
|
||||
|
||||
_dot_product_attention_fwd_lower = custom_partitioning(
|
||||
_dot_product_attention_fwd_impl, static_argnums=(6, 7, 8, 9, 10, 11, 12, 13))
|
||||
_dot_product_attention_fwd_impl, static_argnums=(8, 9, 10, 11, 12, 13, 14, 15))
|
||||
|
||||
def _dot_product_attention_fwd_infer_sharding_from_operands(
|
||||
scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length,
|
||||
@ -883,7 +970,7 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, layout, variadic_args):
|
||||
return out_shardings
|
||||
|
||||
_dot_product_attention_bwd_lower = custom_partitioning(
|
||||
_dot_product_attention_bwd_impl, static_argnums=(9, 10, 11, 12, 13, 14, 15)
|
||||
_dot_product_attention_bwd_impl, static_argnums=(11, 12, 13, 14, 15, 16, 17)
|
||||
)
|
||||
|
||||
def _dot_product_attention_bwd_infer_sharding_from_operands(
|
||||
@ -1003,13 +1090,15 @@ dispatch.prim_requires_devices_during_lowering.add(
|
||||
_dot_product_attention_bwd_p_wrapper
|
||||
)
|
||||
|
||||
@functools.partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12, 13))
|
||||
@functools.partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15))
|
||||
def _dot_product_attention(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Array,
|
||||
q_seqlen: Array,
|
||||
kv_seqlen: Array,
|
||||
q_offsets: Array,
|
||||
kv_offsets: Array,
|
||||
scale: float,
|
||||
seed: int,
|
||||
dropout_rate: float,
|
||||
@ -1019,9 +1108,10 @@ def _dot_product_attention(query: Array,
|
||||
sliding_window_length: int | None,
|
||||
cudnn_version: int):
|
||||
output = _dot_product_attention_fwd(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
|
||||
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length,
|
||||
cudnn_version=cudnn_version)
|
||||
return output
|
||||
|
||||
@ -1612,6 +1702,8 @@ def dot_product_attention(
|
||||
mask: Array | None = None,
|
||||
q_seqlen: Array | None = None,
|
||||
kv_seqlen: Array | None = None,
|
||||
q_offsets: Array | None = None,
|
||||
kv_offsets: Array | None = None,
|
||||
fp8_params: FP8Params | None = None,
|
||||
*,
|
||||
scale: float = 1.0,
|
||||
@ -1647,8 +1739,26 @@ def dot_product_attention(
|
||||
value: Values to be used in attention with a shape of BSNH or BNSH.
|
||||
bias: Bias to be added to logits with a shape of BNTS.
|
||||
mask: Mask used to filter out logits with a shape of BNTS.
|
||||
q_seqlen: Non padded sequence length of Queries with a shape of B.
|
||||
kv_seqlen: Non padded sequence length of Keys and Values with a shape of B.
|
||||
q_seqlen: Non padded sequence length of query with a shape of B.
|
||||
If q_offsets is set, q_seqlen should have shape [B,M] where M is the
|
||||
maximum number of segments per batch. For batch that has less segments
|
||||
than maximum segments, fill the padded entries with -1.
|
||||
kv_seqlen: Non padded sequence length of key and value with a shape of B.
|
||||
If kv_offsets is set, kv_seqlen should have shape [B,M] where M is the
|
||||
maximum number of segments per batch. For batch that has less segments
|
||||
than maximum segments, fill the padded entries with -1.
|
||||
q_offsets: offset of each segment packed in query with a shape of [B,M+1]
|
||||
where M is the maximum number of segments per batch. For batch that has
|
||||
less segments than maximum segments, fill the padded entries with -1.
|
||||
E.g, if 2 batches has 3 and 2 segments respectively, each segment has
|
||||
size 1, q_offsets = [[0,1,2,-1], [0,1,-1,-1]]. q_seqlen should be set
|
||||
to indicate the size of each segment.
|
||||
kv_offsets: offset of each segment packed in key with a shape of [B,M+1]
|
||||
where M is the maximum number of segments per batch. For batch that has
|
||||
less segments than maximum segments, fill the padded entries with -1.
|
||||
E.g, if 2 batches has 3 and 2 segments respectively, each segment has
|
||||
size 1, kv_offsets = [[0,1,2,-1], [0,1,-1,-1]]. kv_seqlen should be set
|
||||
to indicate the size of each segment.
|
||||
scale: Scale for the query.
|
||||
dropout_rate: Dropout rate.
|
||||
qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH,
|
||||
@ -1679,7 +1789,7 @@ def dot_product_attention(
|
||||
f"but got: bias={bias}, mask={mask}, q_seqlen={q_seqlen}, kv_seqlen={kv_seqlen}"
|
||||
)
|
||||
check_fp8_params(fp8_params)
|
||||
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout)
|
||||
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout)
|
||||
output, amax_s, amax_o = _dot_product_attention_fp8(
|
||||
query, key, value, fp8_params,
|
||||
scale, mask_type == MaskType.CAUSAL, layout.value, cudnn_version
|
||||
@ -1691,6 +1801,8 @@ def dot_product_attention(
|
||||
if sliding_window_length is not None and sliding_window_length <= 0:
|
||||
raise ValueError(
|
||||
f"Require sliding_window_length > 0, got {sliding_window_length}")
|
||||
if q_offsets is not None and (q_seqlen is None or kv_seqlen is None):
|
||||
raise ValueError("Require q_seqlen and kv_seqlen to use packed layout")
|
||||
|
||||
if bias is not None:
|
||||
# reshape bias to have 4D shape
|
||||
@ -1712,7 +1824,7 @@ def dot_product_attention(
|
||||
bias = bias + mask
|
||||
|
||||
# check if input shape and data type is compatiable
|
||||
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout)
|
||||
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout)
|
||||
has_bias = bias is not None
|
||||
has_dbias = has_bias and \
|
||||
should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr]
|
||||
@ -1724,8 +1836,12 @@ def dot_product_attention(
|
||||
q_seqlen = jnp.zeros(0, dtype=query.dtype)
|
||||
if kv_seqlen is None:
|
||||
kv_seqlen = jnp.zeros(0, dtype=query.dtype)
|
||||
if q_offsets is None:
|
||||
q_offsets = jnp.zeros(0, dtype=query.dtype)
|
||||
if kv_offsets is None:
|
||||
kv_offsets = jnp.zeros(0, dtype=query.dtype)
|
||||
output = _dot_product_attention(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
|
||||
dropout_rate, variadic_args, mask_type, layout.value, sliding_window_length,
|
||||
cudnn_version)
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale, seed, dropout_rate, variadic_args, mask_type, layout.value,
|
||||
sliding_window_length, cudnn_version)
|
||||
return output
|
||||
|
@ -96,6 +96,10 @@ def sdpa_train(query: Array,
|
||||
grad: Array,
|
||||
bias: Array | None = None,
|
||||
mask: Array | None = None,
|
||||
q_seqlen: Array | None = None,
|
||||
kv_seqlen: Array | None = None,
|
||||
q_offsets: Array | None = None,
|
||||
kv_offsets: Array | None = None,
|
||||
scale: float = 0.5,
|
||||
mask_type: MaskType = MaskType.NO_MASK,
|
||||
is_bnth: bool = False,
|
||||
@ -107,15 +111,13 @@ def sdpa_train(query: Array,
|
||||
else:
|
||||
B, S, _, _ = query.shape
|
||||
q_seqlen = kv_seqlen = jnp.full((B,), S // 2, jnp.int32)
|
||||
else:
|
||||
q_seqlen = kv_seqlen = None
|
||||
out, sdpa_vjp = jax.vjp(
|
||||
partial(dot_product_attention, scale=scale, mask_type=mask_type,
|
||||
dropout_rate=dropout_rate,
|
||||
qkv_layout="BNTH" if is_bnth else "BTNH",
|
||||
sliding_window_length=sliding_window_length),
|
||||
query, key, value, bias, mask, q_seqlen, kv_seqlen)
|
||||
query_grad, key_grad, value_grad, bias_grad, _, _, _ = sdpa_vjp(grad)
|
||||
query, key, value, bias, mask, q_seqlen, kv_seqlen, q_offsets, kv_offsets)
|
||||
query_grad, key_grad, value_grad, bias_grad = sdpa_vjp(grad)[:4]
|
||||
if bias is not None and len(bias.shape) == 3:
|
||||
# has dbias
|
||||
return out, (query_grad, key_grad, value_grad, bias_grad)
|
||||
@ -612,6 +614,136 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
self.assertArraysAllClose(grads_ref[1], grads_ans[1])
|
||||
self.assertArraysAllClose(grads_ref[2], grads_ans[2])
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_packed_layout(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
try:
|
||||
cudnn_version = check_cudnn_version()
|
||||
except RuntimeError as e:
|
||||
self.skipTest(str(e))
|
||||
return
|
||||
if cudnn_version < 90600:
|
||||
self.skipTest("Requires >= cuDNN 9.6.0")
|
||||
k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
|
||||
query = jax.random.normal(
|
||||
k1, (4, 512, 4, 64), dtype=jnp.bfloat16)
|
||||
key = jax.random.normal(
|
||||
k2, (4, 512, 4, 64), dtype=jnp.bfloat16)
|
||||
value = jax.random.normal(
|
||||
k3, (4, 512, 4, 64), dtype=jnp.bfloat16)
|
||||
grad = jax.random.normal(
|
||||
k4, (4, 512, 4, 64), dtype=jnp.bfloat16)
|
||||
|
||||
def generate_padding_mask(segment_ids, padding_id, shape, dtype):
|
||||
# segment_ids [B, T]
|
||||
encoded_padding = jnp.where(segment_ids >= padding_id, 0, 1).astype(dtype)
|
||||
return jax.lax.broadcast_in_dim(
|
||||
encoded_padding, shape, broadcast_dimensions=[0, 1])
|
||||
|
||||
def generate_segment_mask(segment_ids, dtype):
|
||||
segment_ids_1 = jnp.expand_dims(segment_ids, axis=-1)
|
||||
# segment_ids_1 = jnp.where(segment_ids_1 == 3, 4, segment_ids_1)
|
||||
segment_ids_2 = jnp.expand_dims(segment_ids, axis=1)
|
||||
mask = jnp.not_equal(segment_ids_1, segment_ids_2).astype(dtype)
|
||||
# broadcast to [B, N, T, T]
|
||||
mask = jnp.expand_dims(mask, 1)
|
||||
mask *= get_large_negative_number(dtype)
|
||||
return mask
|
||||
|
||||
# starting pos of each segment
|
||||
q_offsets = jnp.asarray([
|
||||
[0, 170, 340, -1], # 3 segments
|
||||
[0, 150, 340, -1], # 3 segments
|
||||
[0, 190, -1, -1], # 2 segments
|
||||
[0, -1, -1, -1] # 1 segment
|
||||
], dtype=np.int32)
|
||||
|
||||
# actual seqlen of each segment without padding
|
||||
q_seqlen = jnp.asarray([
|
||||
[170, 170, 172], # No padding inside each segment
|
||||
[150, 187, 172], # 3 padding tokens inside second segment
|
||||
[190, 190, -1], # 132 padding tokens inside last segment
|
||||
[400, -1, -1], # 112 padding tokens inside last segment
|
||||
], dtype=np.int32)
|
||||
|
||||
# maximum number of segments is id for padding token
|
||||
segment_ids = jnp.asarray([
|
||||
[0]*170 + [1]*170 + [2]*172,
|
||||
[0]*150 + [1]*187 + [3]*3 + [2]*172,
|
||||
[0]*190 + [1]*190 + [3]*132,
|
||||
[0]*400 + [3]*112,
|
||||
], dtype=np.int32)
|
||||
|
||||
kv_offsets = q_offsets.copy()
|
||||
kv_seqlen = q_seqlen.copy()
|
||||
|
||||
mask = generate_padding_mask(segment_ids, q_seqlen.shape[1], query.shape, query.dtype)
|
||||
bias = generate_segment_mask(segment_ids, query.dtype)
|
||||
|
||||
devices = np.array(jax.local_devices()[:4])
|
||||
devices = devices.reshape((2, 2))
|
||||
with Mesh(devices, ("dp", "tp")) as mesh:
|
||||
qkv_spec = PartitionSpec("dp", None, "tp", None)
|
||||
qkv_sharding = NamedSharding(mesh, qkv_spec)
|
||||
bias_spec = PartitionSpec("dp", None, None, None)
|
||||
bias_sharding = NamedSharding(mesh, bias_spec)
|
||||
offsets_specs = PartitionSpec("dp", None)
|
||||
offsets_sharding = NamedSharding(mesh, offsets_specs)
|
||||
|
||||
query = jax.device_put(query, qkv_sharding)
|
||||
key = jax.device_put(key, qkv_sharding)
|
||||
value = jax.device_put(value, qkv_sharding)
|
||||
grad = jax.device_put(grad, qkv_sharding)
|
||||
bias = jax.device_put(bias, bias_sharding)
|
||||
q_offsets = jax.device_put(q_offsets, offsets_sharding)
|
||||
kv_offsets = jax.device_put(kv_offsets, offsets_sharding)
|
||||
q_seqlen = jax.device_put(q_seqlen, offsets_sharding)
|
||||
kv_seqlen = jax.device_put(kv_seqlen, offsets_sharding)
|
||||
|
||||
jitted_sdpa_train = jax.jit(
|
||||
partial(
|
||||
sdpa_train, scale=0.1, mask_type=MaskType.NO_MASK, dropout_rate=0),
|
||||
in_shardings=(qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding,
|
||||
None, None, offsets_sharding, offsets_sharding, offsets_sharding, offsets_sharding),
|
||||
out_shardings=(qkv_sharding, (qkv_sharding, qkv_sharding, qkv_sharding))
|
||||
)
|
||||
|
||||
jitted_sdpa_train_ref = jax.jit(
|
||||
partial(
|
||||
sdpa_train_ref, scale=0.1, mask_type=MaskType.NO_MASK, dropout_rate=0),
|
||||
in_shardings=(qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding,
|
||||
bias_sharding),
|
||||
out_shardings=(qkv_sharding, (qkv_sharding, qkv_sharding, qkv_sharding))
|
||||
)
|
||||
|
||||
query = query * mask
|
||||
key = key * mask
|
||||
value = value * mask
|
||||
grad = grad * mask
|
||||
|
||||
out, (query_grad, key_grad, value_grad) = \
|
||||
jitted_sdpa_train(query, key, value, grad, None, None, q_seqlen, kv_seqlen, q_offsets, kv_offsets)
|
||||
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \
|
||||
jitted_sdpa_train_ref(query, key, value, grad, bias)
|
||||
|
||||
out = out * mask
|
||||
out_ref = out_ref * mask
|
||||
|
||||
query_grad = query_grad * mask
|
||||
query_grad_ref = query_grad_ref * mask
|
||||
|
||||
key_grad = key_grad * mask
|
||||
key_grad_ref = key_grad_ref * mask
|
||||
|
||||
value_grad = value_grad * mask
|
||||
value_grad_ref = value_grad_ref * mask
|
||||
|
||||
self.assertArraysAllClose(out_ref, out, rtol=1e-2, atol=1e-2)
|
||||
self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2)
|
||||
self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-2, atol=1e-2)
|
||||
self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_layouts(self):
|
||||
if jax.device_count() < 4:
|
||||
|
Loading…
x
Reference in New Issue
Block a user